In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.

In [None]:
!pip install gradio

Collecting gradio
  Downloading gradio-5.23.3-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.8.0 (from gradio)
  Downloading gradio_client-1.8.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import GPT2Tokenizer, GPT2LMHeadModel, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import torch.nn as nn

In [None]:

# Optimized Configuration
class Config:
    BATCH_SIZE = 32  # Increased batch size for better stability
    MAX_LEN = 96     # Slightly longer sequences for more context
    NUM_EPOCHS = 10   # Reduced epochs with better params
    LEARNING_RATE = 5e-5  # More aggressive learning
    WARMUP_STEPS = 100
    MODEL_NAME = "distilgpt2"
    SUBSET_SIZE = 6000  # Slightly larger subset
    DROPOUT = 0.1       # Regularization
    WEIGHT_DECAY = 0.01 # L2 regularization

config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Enhanced Dataset with better preprocessing
dataset = load_dataset("ag_news", split=f"train[:{config.SUBSET_SIZE}]")
tokenizer = GPT2Tokenizer.from_pretrained(config.MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

class NewsCompletionDataset(Dataset):
    def __init__(self, tokenizer, max_length):
        self.examples = []
        for example in dataset:
            text = " ".join(example["text"].split()[:100])  # Cleaner truncation
            tokens = tokenizer.encode(text, max_length=max_length, truncation=True)
            if 10 < len(tokens) < max_length - 10:  # Better length filtering
                split_point = np.random.randint(10, len(tokens)-5)
                self.examples.append({
                    'input': tokens[:split_point],
                    'target': tokens[split_point:split_point+15]  # Predict 15 tokens
                })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.examples[idx]['input']),
            torch.tensor(self.examples[idx]['target'])
        )

train_loader = DataLoader(
    NewsCompletionDataset(tokenizer, config.MAX_LEN),
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda batch: (
        torch.nn.utils.rnn.pad_sequence([x[0] for x in batch], batch_first=True, padding_value=tokenizer.pad_token_id),
        torch.nn.utils.rnn.pad_sequence([x[1] for x in batch], batch_first=True, padding_value=tokenizer.pad_token_id)
    )
)

# 2. Enhanced Model with improvements
class ImprovedGPT2Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = GPT2LMHeadModel.from_pretrained(config.MODEL_NAME)
        self.model.config.dropout = config.DROPOUT
        # Freeze first 3 layers
        for param in self.model.transformer.h[:3].parameters():
            param.requires_grad = False

    def forward(self, input_ids, target_ids=None):
        if target_ids is not None:
            inputs = torch.cat([input_ids, target_ids], dim=1)
            outputs = self.model(inputs, labels=inputs)
            return outputs
        return self.model(input_ids)

model = ImprovedGPT2Model().to(device)

# 3. Optimized Training Setup
optimizer = AdamW(
    model.parameters(),
    lr=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY
)
total_steps = len(train_loader) * config.NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.WARMUP_STEPS,
    num_training_steps=total_steps
)

# 4. Enhanced Training Loop with Validation
def train():
    model.train()
    best_loss = float('inf')

    for epoch in range(config.NUM_EPOCHS):
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for input_ids, target_ids in progress_bar:
            optimizer.zero_grad()
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            outputs = model(input_ids, target_ids)
            loss = outputs.loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_loss = total_loss/len(train_loader)
        print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}")

        # Simple early stopping
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), "best_model.pt")

# 5. Improved Inference with Temperature
def complete_text(prompt, max_length=20, temperature=0.7):
    model.load_state_dict(torch.load("best_model.pt"))
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_token], dim=1)

            if next_token == tokenizer.eos_token_id:
                break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# 5. Improved Inference with Temperature
def complete_text(prompt, max_length=20, temperature=0.7):
    model.load_state_dict(torch.load("/content/best_model.pt"))
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_token], dim=1)

            if next_token == tokenizer.eos_token_id:
                break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Train
train()

Epoch 1: 100%|██████████| 177/177 [01:09<00:00,  2.53it/s, loss=3.0594]


Epoch 1 Avg Loss: 3.9312


Epoch 2: 100%|██████████| 177/177 [01:09<00:00,  2.54it/s, loss=2.4536]


Epoch 2 Avg Loss: 2.4667


Epoch 3: 100%|██████████| 177/177 [01:08<00:00,  2.57it/s, loss=2.0782]


Epoch 3 Avg Loss: 2.3432


Epoch 4: 100%|██████████| 177/177 [01:09<00:00,  2.56it/s, loss=2.1722]


Epoch 4 Avg Loss: 2.2539


Epoch 5: 100%|██████████| 177/177 [01:08<00:00,  2.57it/s, loss=2.3890]


Epoch 5 Avg Loss: 2.2038


Epoch 6: 100%|██████████| 177/177 [01:08<00:00,  2.57it/s, loss=2.2480]


Epoch 6 Avg Loss: 2.1546


Epoch 7: 100%|██████████| 177/177 [01:09<00:00,  2.56it/s, loss=2.4156]


Epoch 7 Avg Loss: 2.1176


Epoch 8: 100%|██████████| 177/177 [01:09<00:00,  2.55it/s, loss=2.0891]


Epoch 8 Avg Loss: 2.0845


Epoch 9: 100%|██████████| 177/177 [01:08<00:00,  2.57it/s, loss=2.2299]


Epoch 9 Avg Loss: 2.0694


Epoch 10: 100%|██████████| 177/177 [01:09<00:00,  2.56it/s, loss=1.7803]


Epoch 10 Avg Loss: 2.0540


In [None]:
import gradio as gr

# Gradio Interface
iface = gr.Interface(
    fn=complete_text,  # <- This is your text completion function
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
        gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Max Length"),
        gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature"),
    ],
    outputs=gr.Textbox(label="Generated Completion"),
    title="🧠 Sentence Autocompletion",
    description="Give a sentence prompt and receive a generated continuation from your model.",
)

iface.launch()


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://6d12fad88aaa358ce0.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [None]:
# 5. Improved Inference with Temperature
def complete_text(prompt, max_length=20, temperature=0.7):
    model.load_state_dict(torch.load("/content/best_model.pt"))
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_token], dim=1)

            if next_token == tokenizer.eos_token_id:
                break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


In [None]:
# Test completions
print("Completion 1:", complete_text("The stock market"))
print("Completion 2:", complete_text("Scientists discovered"))
print("Completion 3:", complete_text("The president announced"))

Completion 1: The stock market closes at \$82 a share on Wednesday after a steady rally in U.S. stocks
Completion 2: Scientists discovered new species of dinosaur Pestolosaurus, a dinosaur whose legs were pulled off by human hands,
Completion 3: The president announced in the Oval Office on Friday that he would not be running for president, but that he would not
