<a href="https://www.kaggle.com/code/evanupham/gpt-tiny-story?scriptVersionId=186953109" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader, Dataset, Subset
from datasets import load_dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer, get_linear_schedule_with_warmup
import numpy as np
# Define the custom dataset class
class TinyStoriesDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length):
        self.texts = [text for text in texts if text.strip() != '']  # Filter out empty sequences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, truncation=True, padding='max_length')
        input_ids = tokens.input_ids.squeeze(0)  # Ensure the correct dimension
        attention_mask = tokens.attention_mask.squeeze(0)  # Ensure the correct dimension
        return input_ids, attention_mask

# Load the dataset
dataset = load_dataset('roneneldan/TinyStories')

# Initialize the GPT tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Prepare the dataset
max_length = 1000
train_texts = dataset['train']['text']
train_dataset = TinyStoriesDataset(train_texts, tokenizer, max_length)

# Create data loader
batch_size = 5
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Define the GPT-2 model with dropout
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_groups=2, dropout=0.1):
        super(GroupedQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.d_model = d_model

        assert d_model % (num_heads * num_groups) == 0

        self.depth = d_model // (num_heads * num_groups)

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.dense = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_groups, self.num_heads, self.depth)
        return x.permute(0, 2, 3, 1, 4)  # (batch_size, num_groups, num_heads, seq_len, depth)

    def forward(self, q, k, v, mask):
        batch_size = q.size(0)

        q = self.split_heads(self.wq(q), batch_size)
        k = self.split_heads(self.wk(k), batch_size)
        v = self.split_heads(self.wv(v), batch_size)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.depth)

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch_size, 1, 1, seq_len, seq_len)
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = torch.nn.functional.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        output = torch.matmul(attention_weights, v)

        output = output.permute(0, 3, 1, 2, 4).contiguous().view(batch_size, -1, self.d_model)
        output = self.dense(output)

        return output, attention_weights

    
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.3):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout1(torch.nn.functional.relu(self.linear1(x)))
        return self.dropout2(self.linear2(x))

class GPTBlock(nn.Module):
    def __init__(self, d_model, num_heads, num_groups, d_ff, dropout=0.3):
        super(GPTBlock, self).__init__()
        self.attention = GroupedQueryAttention(d_model, num_heads, num_groups, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        attn_output, _ = self.attention(x, x, x, mask)
        out1 = self.norm1(x + attn_output)
        ffn_output = self.ffn(out1)
        out2 = self.norm2(out1 + ffn_output)
        return out2

class GPT2(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_groups, d_ff, num_layers, max_len=5000, dropout=0.3):
        super(GPT2, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([GPTBlock(d_model, num_heads, num_groups, d_ff, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        x = self.norm(x)
        x = self.dropout(x)
        return self.fc(x)

def create_future_mask(size):
    mask = torch.tril(torch.ones(size, size)).unsqueeze(0)
    return mask  # (1, size, size)

vocab_size = len(tokenizer)
d_model = 768  # GPT-2 small model size
num_heads = 6
d_ff = 3072
num_layers = 12
max_len = 1024
num_groups = 2
model = GPT2(vocab_size, d_model, num_heads, num_groups, d_ff, num_layers, max_len)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
import os
model_path = "/kaggle/working/model_weights.pth"
# # Load the model weights if they exist
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print(f"Model weights loaded from {model_path}")

# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training loop with progress bar
model.train()

# Freeze all layers initially
for param in model.parameters():
    param.requires_grad = False

def contains_repeated_ngram(seq, n):
    ngrams = set()
    for i in range(len(seq) - n + 1):
        ngram = tuple(seq[i:i+n].tolist())
        if ngram in ngrams:
            return True
        ngrams.add(ngram)
    return False

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, min_p=0.0):
    """Filter a distribution of logits using top-k, top-p (nucleus), and min-p filtering"""
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = -float('Inf')

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        if min_p > 0.0:
            sorted_indices_to_remove &= (sorted_logits < min_p).cumsum(dim=-1).bool()

        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = -float('Inf')
        
    if min_p > 0.0:
        logits[logits < min_p] = -float('Inf')

    return logits

def apply_repetition_penalty(logits, seq, repetition_penalty):
    """Apply a penalty to the logits to discourage repetition"""
    for token_id in seq:
        logits[0, token_id] /= repetition_penalty
    return logits

import torch
import torch.nn.functional as F
from collections import defaultdict
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def calculate_bleu(reference, hypothesis):
    reference = [reference]  # BLEU expects a list of references
    smoothie = SmoothingFunction().method4
    return sentence_bleu(reference, hypothesis, smoothing_function=smoothie)

def contains_repeated_ngram(seq, n):
    ngrams = set()
    for i in range(len(seq) - n + 1):
        ngram = tuple(seq[i:i+n].tolist())
        if ngram in ngrams:
            return True
        ngrams.add(ngram)
    return False

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, min_p=0.0):
    """Filter a distribution of logits using top-k, top-p (nucleus), and min-p filtering"""
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = -float('Inf')

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        if min_p > 0.0:
            sorted_indices_to_remove &= (sorted_logits < min_p).cumsum(dim=-1).bool()

        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = -float('Inf')
        
    if min_p > 0.0:
        logits[logits < min_p] = -float('Inf')

    return logits

def apply_repetition_penalty(logits, seq, repetition_penalty):
    """Apply a penalty to the logits to discourage repetition"""
    for token_id in seq:
        logits[0, token_id] /= repetition_penalty
    return logits

def beam_search(model, tokenizer, input_text, beam_width=5, max_len=100, length_penalty=1.2, no_repeat_ngram_size=3, top_k=70, top_p=0.7, min_p=0.1, temperature=0.8, repetition_penalty=1.2, diversity_rate=0.3):
    input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to(device)
    input_ids = input_ids[:, :-1]  # Remove the last token for autoregressive generation

    beam = [(input_ids, 0, [])]  # (input_ids, score, generated tokens)
    completed_sequences = []
    diversity_penalty = defaultdict(lambda: 0)

    for step in range(max_len):
        new_beam = []
        for seq, score, generated_tokens in beam:
            with torch.no_grad():
                outputs = model(seq, create_future_mask(seq.size(1)).to(device))
            logits = outputs[:, -1, :]  # Get the logits for the last token
            logits = logits / temperature
            logits = apply_repetition_penalty(logits, seq[0], repetition_penalty)
            logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_p=min_p)
            probs = F.log_softmax(logits, dim=-1)
            topk_probs, topk_ids = probs.topk(beam_width)

            for i in range(beam_width):
                next_seq = torch.cat([seq, topk_ids[:, i:i+1]], dim=-1)
                new_score = score + topk_probs[0, i].item()
                new_generated_tokens = generated_tokens + [topk_ids[0, i].item()]

                if no_repeat_ngram_size > 0 and contains_repeated_ngram(next_seq[0], no_repeat_ngram_size):
                    continue  # Skip sequences with repeated n-grams

                # Diversity penalty
                diversity_penalty[tuple(map(tuple, next_seq.tolist()))] += diversity_rate * step
                new_score -= diversity_penalty[tuple(map(tuple, next_seq.tolist()))]

                new_beam.append((next_seq, new_score, new_generated_tokens))

        if not new_beam:
            break  # Break the loop if no new sequences are generated

        beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]

        # Check for completed sequences (sequences that have the end token)
        for seq, score, generated_tokens in beam:
            if seq[0, -1] == tokenizer.eos_token_id:
                length_normalized_score = score / (seq.size(1) ** length_penalty)
                completed_sequences.append((seq, length_normalized_score, generated_tokens))

        # Keep only the sequences that are not completed
        beam = [b for b in beam if b[0][0, -1] != tokenizer.eos_token_id]

        # Early stopping if all sequences are completed
        if not beam:
            break

    if completed_sequences:
        best_seq = sorted(completed_sequences, key=lambda x: x[1], reverse=True)[0]
    else:
        if beam:
            best_seq = beam[0]  # Fallback to the best beam
        else:
            return ""  # Return an empty string if no valid sequence is found

    best_seq_tokens = best_seq[2]
    reference = tokenizer.encode(input_text)  # Use the input text as the reference
    bleu_score = calculate_bleu(reference, best_seq_tokens)

    output_text = tokenizer.decode(best_seq[0].squeeze(), skip_special_tokens=True)
    return output_text, bleu_score

def set_requires_grad(model, layer_idx, requires_grad):
    for i, layer in enumerate(model.layers):
        for param in layer.parameters():
            param.requires_grad = (i == layer_idx) and requires_grad

def get_custom_training_sequence(num_layers):
    sequence = []
    i = 1
    increment = 1
    while len(sequence) < 2 * num_layers:
        sequence.append(i)
        i += increment
        if increment == 1:
            increment = -1
        else:
            increment = 2
        if i > num_layers:
            i = num_layers - 1
            increment = 1
    return sequence

num_layers = len(model.layers)
training_sequence = get_custom_training_sequence(num_layers)

num_epochs = len(training_sequence)
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=total_steps // 10, num_training_steps=total_steps)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
dataset_size = 3000  # Number of samples per epoch

def parabolic_scaling(epoch, num_epochs):
    mid_epoch = num_epochs // 2
    return -2 * ((epoch - mid_epoch) ** 2) / (num_epochs ** 2) + 1


# Define the BLEU score calculation function
def calculate_bleu(reference, hypothesis):
    reference = [reference]  # BLEU expects a list of references
    smoothie = SmoothingFunction().method4
    return sentence_bleu(reference, hypothesis, smoothing_function=smoothie)

def z_loss(logits, beta=1e-4):
    """Z-Loss regularizes logits to prevent extreme values."""
    log_z = torch.logsumexp(logits, dim=-1)
    return beta * log_z.pow(2).mean()

for epoch in range(num_epochs):
    # Determine which layer to unfreeze according to the training sequence
    layer_to_unfreeze = training_sequence[epoch] - 1
    set_requires_grad(model, layer_to_unfreeze, True)
    
    # Create a new subset of the dataset
    indices = np.random.choice(len(train_dataset), dataset_size, replace=False)
    subset = Subset(train_dataset, indices)
    train_loader = DataLoader(subset, batch_size=batch_size, shuffle=True)
        
    total_loss = 0
    total_bleu_score = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", postfix={"Loss": 0.0000, "Perplexity": 0.0000, "BLEU": 0.0000})
    
    for batch in progress_bar:
        input_ids, attention_mask = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        # Shift the input for the next token prediction
        labels = input_ids[:, 1:].contiguous()
        input_ids = input_ids[:, :-1].contiguous()
        
        # Create future mask
        seq_length = input_ids.size(1)
        mask = create_future_mask(seq_length).to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, mask)
        
        # Compute loss
        loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))
        
        # Generate sequences for reward calculation
        generated_ids = outputs.argmax(dim=-1).cpu().numpy()
        references = labels.cpu().numpy()
        batch_bleu_score = 0
        rewards = []
        for ref, gen in zip(references, generated_ids):
            ref_tokens = ref.tolist()
            gen_tokens = gen.tolist()
            bleu_score = calculate_bleu(ref_tokens, gen_tokens)
            rewards.append(bleu_score)
            batch_bleu_score += bleu_score
        
        avg_bleu_score = batch_bleu_score / len(references)
        
        # REINFORCE algorithm
        rewards = torch.tensor(rewards, dtype=torch.float).to(device)
        log_probs = F.log_softmax(outputs, dim=-1)
        log_probs = log_probs.gather(2, labels.unsqueeze(-1)).squeeze(-1)
        policy_loss = -log_probs * rewards.unsqueeze(-1)
        policy_loss = policy_loss.mean() * parabolic_scaling(epoch, num_epochs)
        zloss_value = z_loss(outputs)
        total_loss_with_reward = loss + zloss_value + policy_loss
        total_loss_with_reward.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += total_loss_with_reward.item()
        total_bleu_score += avg_bleu_score
        avg_loss = total_loss / len(progress_bar)
        perplexity = torch.exp(torch.tensor(loss)).item()
        progress_bar.set_postfix(Loss=f"{loss.item():.4f}", Perplexity=f"{perplexity:.4f}", BLEU=f"{avg_bleu_score:.4f}")
    
    avg_loss = total_loss / len(train_loader)
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    avg_bleu_score = total_bleu_score / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss_with_reward.item():.4f}, Perplexity: {perplexity:.4f}, Avg BLEU: {avg_bleu_score:.4f}")
    generated_text, bleu_score = beam_search(model, tokenizer, "Once upon a time", beam_width=5, max_len=50)
    print(f"Generated text: {generated_text}")
    print(f"BLEU score: {bleu_score:.4f}")  
    # Save model weights
    torch.save(model.state_dict(), model_path)
    print(f"Model weights saved to {model_path}")
    
    # Freeze the previously unfrozen layer
    set_requires_grad(model, layer_to_unfreeze, False)

print("Training complete.")



Repo card metadata block was not found. Setting CardData to empty.


Model weights loaded from /kaggle/working/model_weights.pth


  perplexity = torch.exp(torch.tensor(loss)).item()
Epoch 1/24: 100%|██████████| 600/600 [07:41<00:00,  1.30it/s, BLEU=0.7961, Loss=8.5326, Perplexity=5077.5889]


Epoch 1/24, Loss: 9.2613, Perplexity: 10673.1035, Avg BLEU: 0.7908
Generated text: Once upon a time, a, there was a little.


". She to the, ". She and the to the. He was a and said, " it. They and to theSocial, "

The to the and. He
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 2/24: 100%|██████████| 600/600 [07:25<00:00,  1.35it/s, BLEU=0.7515, Loss=8.5365, Perplexity=5097.5640]


Epoch 2/24, Loss: 9.4832, Perplexity: 11884.0752, Avg BLEU: 0.7902
Generated text: Once upon a time, there was a little. She to the, "


". She was a to the and. He. He to the. They, " and it.

The and said, ", "alos and the to the
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 3/24: 100%|██████████| 600/600 [07:40<00:00,  1.30it/s, BLEU=0.8125, Loss=8.4576, Perplexity=4710.6294]


Epoch 3/24, Loss: 9.3591, Perplexity: 13052.0537, Avg BLEU: 0.7932
Generated text: Once upon a time, there was a a little.


The, ". She and to the, " to the to the. She. He was a and. The to the and said, "

", but it was so. They
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 4/24: 100%|██████████| 600/600 [07:11<00:00,  1.39it/s, BLEU=0.7502, Loss=8.5904, Perplexity=5379.6919]


Epoch 4/24, Loss: 9.7603, Perplexity: 14201.2021, Avg BLEU: 0.7894
Generated text: Once upon a time, there was a big. She to the, "


The. She. He was a to the and the. He, " and it. They and said, " identifying and to the.

" was a unc to
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 5/24: 100%|██████████| 600/600 [06:41<00:00,  1.49it/s, BLEU=0.8253, Loss=8.5622, Perplexity=5230.0132]


Epoch 5/24, Loss: 9.6049, Perplexity: 15580.8828, Avg BLEU: 0.7908
Generated text: Once upon a time, there was a little.


". She and the to the, ", ". He was a She to the and. They. He to the Cumber and1900, "

The. The it was a big and
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 6/24: 100%|██████████| 600/600 [06:12<00:00,  1.61it/s, BLEU=0.8158, Loss=8.6489, Perplexity=5703.7168]


Epoch 6/24, Loss: 9.7896, Perplexity: 16459.5020, Avg BLEU: 0.7911
Generated text: Once upon a time, there was a little.


". She to the, " and the. He was a, "

The to the and it. He and said, ". They to the that. She was a big, but
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 7/24: 100%|██████████| 600/600 [05:43<00:00,  1.74it/s, BLEU=0.8180, Loss=8.3632, Perplexity=4286.2349]


Epoch 7/24, Loss: 9.5320, Perplexity: 17490.5645, Avg BLEU: 0.7930
Generated text: Once upon a time, there was a


The, ". She. They and the to the, " and the. She was a little.

" to the. He. He was a big and said, " lot and the, but
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 8/24: 100%|██████████| 600/600 [05:14<00:00,  1.91it/s, BLEU=0.8352, Loss=8.3906, Perplexity=4405.5327]


Epoch 8/24, Loss: 9.5607, Perplexity: 18269.6738, Avg BLEU: 0.7921
Generated text: Once upon a time, there was a, there.


". She was a little to the, " to the and the. They. She. He and said, " and and the, but to it was a

The. He was
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 9/24: 100%|██████████| 600/600 [05:14<00:00,  1.91it/s, BLEU=0.8219, Loss=8.4899, Perplexity=4865.2246]


Epoch 9/24, Loss: 9.7815, Perplexity: 18930.7188, Avg BLEU: 0.7941
Generated text: Once upon a little
BLEU score: 0.0000
Model weights saved to /kaggle/working/model_weights.pth


Epoch 10/24: 100%|██████████| 600/600 [05:00<00:00,  2.00it/s, BLEU=0.6762, Loss=8.3846, Perplexity=4379.2388]


Epoch 10/24, Loss: 9.6660, Perplexity: 19378.0781, Avg BLEU: 0.7937
Generated text: Once upon a time, there was a little, ".


". She and the. He was a to the, " and the, but. She to the Bass. He it was a
The. They to the and Aff and said,
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 11/24: 100%|██████████| 600/600 [05:14<00:00,  1.91it/s, BLEU=0.8155, Loss=8.4125, Perplexity=4503.2202]


Epoch 11/24, Loss: 9.7564, Perplexity: 19851.4121, Avg BLEU: 0.7923
Generated text: Once upon a time, there was a little, ". She.


". He to the, " and the, but and to the it. He was a big. She to the and said, "

The. She was a representing
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 12/24: 100%|██████████| 600/600 [05:14<00:00,  1.91it/s, BLEU=0.7856, Loss=8.4351, Perplexity=4606.1353]


Epoch 12/24, Loss: 9.9045, Perplexity: 19697.5840, Avg BLEU: 0.7949
Generated text: Once upon a time, there was a little. She, ".


The and the to the, " to the. He. He was a big. She was a
" and the and it. They, " and said, but to the
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 13/24: 100%|██████████| 600/600 [05:00<00:00,  2.00it/s, BLEU=0.8127, Loss=8.4114, Perplexity=4498.0786]


Epoch 13/24, Loss: 9.7854, Perplexity: 19584.6543, Avg BLEU: 0.7924
Generated text: Once upon a time, there was a little, ".


The. She was a He to the, " to the and the. She. He was a

". They and it was so to the 10000 and said, but. They
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 14/24: 100%|██████████| 600/600 [05:15<00:00,  1.90it/s, BLEU=0.7447, Loss=8.5353, Perplexity=5091.2432]


Epoch 14/24, Loss: 10.1703, Perplexity: 19122.6035, Avg BLEU: 0.7923
Generated text: Once upon a time, there was a little. She to the, ".


The and the. He, "

". He was a big. She and the to the park and said, " it was a to the Yok. They
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 15/24: 100%|██████████| 600/600 [05:14<00:00,  1.91it/s, BLEU=0.7250, Loss=8.2829, Perplexity=3955.6033]


Epoch 15/24, Loss: 9.9075, Perplexity: 18588.4824, Avg BLEU: 0.7950
Generated text: Once upon a time, there was a little, ".

". She to the, "


The and the to the. She. He was a big and it and the, but. They to the and said, " Axel. He
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 16/24: 100%|██████████| 600/600 [05:00<00:00,  2.00it/s, BLEU=0.8564, Loss=8.3116, Perplexity=4070.6257]


Epoch 16/24, Loss: 9.4119, Perplexity: 18200.5645, Avg BLEU: 0.7917
Generated text: Once upon a time, ", there was a little. She was a


The. She to the, ". He and the to the and the
" was a big. They. He to the-, " and said, but. They
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 17/24: 100%|██████████| 600/600 [05:14<00:00,  1.90it/s, BLEU=0.7349, Loss=8.5682, Perplexity=5261.7314]


Epoch 17/24, Loss: 10.1470, Perplexity: 17006.2598, Avg BLEU: 0.7977
Generated text: Once upon a time,
BLEU score: 0.1337
Model weights saved to /kaggle/working/model_weights.pth


Epoch 18/24: 100%|██████████| 600/600 [05:15<00:00,  1.90it/s, BLEU=0.7252, Loss=8.4190, Perplexity=4532.4409]


Epoch 18/24, Loss: 9.7477, Perplexity: 16031.6562, Avg BLEU: 0.7949
Generated text: Once upon a time, there was a little, " to the, ".


". He was a big. She and the She to the
The. He. They and said, but and it to the that. She was very to the
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 19/24: 100%|██████████| 600/600 [05:00<00:00,  1.99it/s, BLEU=0.8272, Loss=8.3917, Perplexity=4410.5059]


Epoch 19/24, Loss: 9.5519, Perplexity: 15072.5605, Avg BLEU: 0.7922
Generated text: Once upon a time, there was a little.


" to the, ". He was a big. She to the and said, " and the park. They. She and the

The, "opathy, but it was a window to
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 20/24: 100%|██████████| 600/600 [05:15<00:00,  1.90it/s, BLEU=0.7770, Loss=8.2794, Perplexity=3941.6328]


Epoch 20/24, Loss: 9.4418, Perplexity: 13881.7734, Avg BLEU: 0.7946
Generated text: Once upon a time, there was a little girl, ".


The to the. She was a big. He to the, but and the, " and it. He was so. She to the

". They and said, "
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 21/24: 100%|██████████| 600/600 [05:15<00:00,  1.90it/s, BLEU=0.8325, Loss=8.4358, Perplexity=4609.0928]


Epoch 21/24, Loss: 9.4582, Perplexity: 12644.5137, Avg BLEU: 0.7954
Generated text: Once upon a time, there was a little, ".

". He was a big and the, " to the, but. She and the


The. She was so it to the park and said, " day, but to the
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth


Epoch 22/24: 100%|██████████| 600/600 [05:00<00:00,  1.99it/s, BLEU=0.8365, Loss=8.4632, Perplexity=4737.2271]


Epoch 22/24, Loss: 9.3672, Perplexity: 11464.5439, Avg BLEU: 0.7962
Generated text: Once upon a time, there was a little, ". She was a big to the, "

The. He and the, but the


". They. She to the Vital and said, but it was very. He was a time
BLEU score: 0.0873
Model weights saved to /kaggle/working/model_weights.pth


Epoch 23/24: 100%|██████████| 600/600 [05:15<00:00,  1.90it/s, BLEU=0.6940, Loss=8.4390, Perplexity=4623.7793]


Epoch 23/24, Loss: 9.6457, Perplexity: 10315.8906, Avg BLEU: 0.7940
Generated text: Once upon a time, there was a little
BLEU score: 0.2368
Model weights saved to /kaggle/working/model_weights.pth


Epoch 24/24: 100%|██████████| 600/600 [05:15<00:00,  1.90it/s, BLEU=0.8083, Loss=8.5683, Perplexity=5262.2480]


Epoch 24/24, Loss: 9.4212, Perplexity: 9237.1982, Avg BLEU: 0.7938
Generated text: Once upon a time, there was a little girl.


". She was a big, " to the, "
The and the park. He was very to the park and said, but. She. They and the

The to the
BLEU score: 0.1881
Model weights saved to /kaggle/working/model_weights.pth
Training complete.


In [5]:
model_path = "model_weights.pth"
torch.save(model.state_dict(), model_path)
print(f"Model weights saved to {model_path}")

Model weights saved to model_weights.pth


In [6]:
import torch
import torch.nn.functional as F
from collections import defaultdict

def create_future_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return mask.to(device)

import torch
import torch.nn.functional as F
from collections import defaultdict
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def calculate_bleu(reference, hypothesis):
    reference = [reference]  # BLEU expects a list of references
    smoothie = SmoothingFunction().method4
    return sentence_bleu(reference, hypothesis, smoothing_function=smoothie)

def contains_repeated_ngram(seq, n):
    ngrams = set()
    for i in range(len(seq) - n + 1):
        ngram = tuple(seq[i:i+n].tolist())
        if ngram in ngrams:
            return True
        ngrams.add(ngram)
    return False

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, min_p=0.0):
    """Filter a distribution of logits using top-k, top-p (nucleus), and min-p filtering"""
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = -float('Inf')

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        if min_p > 0.0:
            sorted_indices_to_remove &= (sorted_logits < min_p).cumsum(dim=-1).bool()

        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = -float('Inf')
        
    if min_p > 0.0:
        logits[logits < min_p] = -float('Inf')

    return logits

def apply_repetition_penalty(logits, seq, repetition_penalty):
    """Apply a penalty to the logits to discourage repetition"""
    for token_id in seq:
        logits[0, token_id] /= repetition_penalty
    return logits

def beam_search(model, tokenizer, input_text, beam_width=5, max_len=50, length_penalty=1.0, no_repeat_ngram_size=3, top_k=50, top_p=0.8, min_p=0.2, temperature=0.6, repetition_penalty=1.1, diversity_rate=0.1):
    input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to(device)
    input_ids = input_ids[:, :-1]  # Remove the last token for autoregressive generation

    beam = [(input_ids, 0, [])]  # (input_ids, score, generated tokens)
    completed_sequences = []
    diversity_penalty = defaultdict(lambda: 0)

    for step in range(max_len):
        new_beam = []
        for seq, score, generated_tokens in beam:
            with torch.no_grad():
                outputs = model(seq, create_future_mask(seq.size(1)).to(device))
            logits = outputs[:, -1, :]  # Get the logits for the last token
            logits = logits / temperature
            logits = apply_repetition_penalty(logits, seq[0], repetition_penalty)
            logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_p=min_p)
            probs = F.log_softmax(logits, dim=-1)
            topk_probs, topk_ids = probs.topk(beam_width)

            for i in range(beam_width):
                next_seq = torch.cat([seq, topk_ids[:, i:i+1]], dim=-1)
                new_score = score + topk_probs[0, i].item()
                new_generated_tokens = generated_tokens + [topk_ids[0, i].item()]

                if no_repeat_ngram_size > 0 and contains_repeated_ngram(next_seq[0], no_repeat_ngram_size):
                    continue  # Skip sequences with repeated n-grams

                # Diversity penalty
                diversity_penalty[tuple(map(tuple, next_seq.tolist()))] += diversity_rate * step
                new_score -= diversity_penalty[tuple(map(tuple, next_seq.tolist()))]

                new_beam.append((next_seq, new_score, new_generated_tokens))

        if not new_beam:
            break  # Break the loop if no new sequences are generated

        beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]

        # Check for completed sequences (sequences that have the end token)
        for seq, score, generated_tokens in beam:
            if seq[0, -1] == tokenizer.eos_token_id:
                length_normalized_score = score / (seq.size(1) ** length_penalty)
                completed_sequences.append((seq, length_normalized_score, generated_tokens))

        # Keep only the sequences that are not completed
        beam = [b for b in beam if b[0][0, -1] != tokenizer.eos_token_id]

        # Early stopping if all sequences are completed
        if not beam:
            break

    if completed_sequences:
        best_seq = sorted(completed_sequences, key=lambda x: x[1], reverse=True)[0]
    else:
        if beam:
            best_seq = beam[0]  # Fallback to the best beam
        else:
            return ""  # Return an empty string if no valid sequence is found

    best_seq_tokens = best_seq[2]
    reference = tokenizer.encode(input_text)  # Use the input text as the reference
    bleu_score = calculate_bleu(reference, best_seq_tokens)

    output_text = tokenizer.decode(best_seq[0].squeeze(), skip_special_tokens=True)
    return output_text, bleu_score



generated_text, bleu_score = beam_search(model, tokenizer, "Once upon a time", beam_width=5, max_len=50)
print(f"Generated text: {generated_text}")
print(f"BLEU score: {bleu_score:.4f}")


RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

In [None]:
!pip install einops