<a href="https://www.kaggle.com/code/evanupham/gpt-tiny-story?scriptVersionId=187276834" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader, Dataset, Subset
from datasets import load_dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer, get_linear_schedule_with_warmup
import numpy as np
# Define the custom dataset class
class TinyStoriesDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length):
        self.texts = [text for text in texts if text.strip() != '']  # Filter out empty sequences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, truncation=True, padding='max_length')
        input_ids = tokens.input_ids.squeeze(0)  # Ensure the correct dimension
        attention_mask = tokens.attention_mask.squeeze(0)  # Ensure the correct dimension
        return input_ids, attention_mask

# Load the dataset
dataset = load_dataset('roneneldan/TinyStories')

# Initialize the GPT tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Prepare the dataset
max_length = 1000
train_texts = dataset['train']['text']
train_dataset = TinyStoriesDataset(train_texts, tokenizer, max_length)

# Create data loader
batch_size = 2
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Define the GPT-2 model with dropout
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_groups=2, dropout=0.1):
        super(GroupedQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.d_model = d_model

        assert d_model % (num_heads * num_groups) == 0

        self.depth = d_model // (num_heads * num_groups)

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.dense = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_groups, self.num_heads, self.depth)
        return x.permute(0, 2, 3, 1, 4)  # (batch_size, num_groups, num_heads, seq_len, depth)

    def forward(self, q, k, v, mask):
        batch_size = q.size(0)

        q = self.split_heads(self.wq(q), batch_size)
        k = self.split_heads(self.wk(k), batch_size)
        v = self.split_heads(self.wv(v), batch_size)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.depth)

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch_size, 1, 1, seq_len, seq_len)
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = torch.nn.functional.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        output = torch.matmul(attention_weights, v)

        output = output.permute(0, 3, 1, 2, 4).contiguous().view(batch_size, -1, self.d_model)
        output = self.dense(output)

        return output, attention_weights

    
class MoE(nn.Module):
    def __init__(self, d_model, d_ff, n_experts=4, dropout=0.3, temperature=1.0):
        super(MoE, self).__init__()
        self.n_experts = n_experts
        self.temperature = temperature
        self.experts = nn.ModuleList([nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        ) for _ in range(n_experts)])
        
        self.gating_network = nn.Linear(d_model, n_experts)
        
    def forward(self, x):
        # Compute the gating weights
        gate_logits = self.gating_network(x)
        gate_outputs = F.gumbel_softmax(gate_logits, tau=self.temperature, hard=False)
        
        # Compute the expert outputs
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=-1)
        
        # Combine expert outputs weighted by the gating network
        output = torch.einsum('bld,blnd->bln', gate_outputs, expert_outputs)
        
        return output

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, n_experts=5, dropout=0.3, temperature=0.8):
        super(FeedForward, self).__init__()
        self.moe_layer = MoE(d_model, d_ff, n_experts, dropout, temperature)
        
    def forward(self, x):
        return self.moe_layer(x)


class GPTBlock(nn.Module):
    def __init__(self, d_model, num_heads, num_groups, d_ff, dropout=0.3):
        super(GPTBlock, self).__init__()
        self.attention = GroupedQueryAttention(d_model, num_heads, num_groups, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, n_experts=4, dropout=dropout)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        attn_output, _ = self.attention(x, x, x, mask)
        out1 = self.norm1(x + attn_output)
        ffn_output = self.ffn(out1)
        out2 = self.norm2(out1 + ffn_output)
        return out2

class GPT2(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_groups, d_ff, num_layers, max_len=5000, dropout=0.3):
        super(GPT2, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([GPTBlock(d_model, num_heads, num_groups, d_ff, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        x = self.norm(x)
        x = self.dropout(x)
        return self.fc(x)

def create_future_mask(size):
    mask = torch.tril(torch.ones(size, size)).unsqueeze(0)
    return mask  # (1, size, size)

vocab_size = len(tokenizer)
d_model = 1536  # GPT-2 small model size
num_heads = 6
d_ff = 3072
num_layers = 12
max_len = 1024
num_groups = 4
model = GPT2(vocab_size, d_model, num_heads, num_groups, d_ff, num_layers, max_len)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
import os
model_path = "/kaggle/working/model_weights_1536.pth"
# # Load the model weights if they exist
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print(f"Model weights loaded from {model_path}")

# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training loop with progress bar
model.train()

# Freeze all layers initially
for param in model.parameters():
    param.requires_grad = False

def contains_repeated_ngram(seq, n):
    ngrams = set()
    for i in range(len(seq) - n + 1):
        ngram = tuple(seq[i:i+n].tolist())
        if ngram in ngrams:
            return True
        ngrams.add(ngram)
    return False

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, min_p=0.0):
    """Filter a distribution of logits using top-k, top-p (nucleus), and min-p filtering"""
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = -float('Inf')

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        if min_p > 0.0:
            sorted_indices_to_remove &= (sorted_logits < min_p).cumsum(dim=-1).bool()

        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = -float('Inf')
        
    if min_p > 0.0:
        logits[logits < min_p] = -float('Inf')

    return logits

def apply_repetition_penalty(logits, seq, repetition_penalty):
    """Apply a penalty to the logits to discourage repetition"""
    for token_id in seq:
        logits[0, token_id] /= repetition_penalty
    return logits

import torch
import torch.nn.functional as F
from collections import defaultdict
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def calculate_bleu(reference, hypothesis):
    reference = [reference]  # BLEU expects a list of references
    smoothie = SmoothingFunction().method4
    return sentence_bleu(reference, hypothesis, smoothing_function=smoothie)

def contains_repeated_ngram(seq, n):
    ngrams = set()
    for i in range(len(seq) - n + 1):
        ngram = tuple(seq[i:i+n].tolist())
        if ngram in ngrams:
            return True
        ngrams.add(ngram)
    return False

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, min_p=0.0):
    """Filter a distribution of logits using top-k, top-p (nucleus), and min-p filtering"""
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = -float('Inf')

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        if min_p > 0.0:
            sorted_indices_to_remove &= (sorted_logits < min_p).cumsum(dim=-1).bool()

        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = -float('Inf')
        
    if min_p > 0.0:
        logits[logits < min_p] = -float('Inf')

    return logits

def apply_repetition_penalty(logits, seq, repetition_penalty):
    """Apply a penalty to the logits to discourage repetition"""
    for token_id in seq:
        logits[0, token_id] /= repetition_penalty
    return logits

def beam_search(model, tokenizer, input_text, beam_width=5, max_len=100, length_penalty=1.2, no_repeat_ngram_size=3, top_k=70, top_p=0.7, min_p=0.1, temperature=0.8, repetition_penalty=1.2, diversity_rate=0.3):
    input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to(device)
    input_ids = input_ids[:, :-1]  # Remove the last token for autoregressive generation

    beam = [(input_ids, 0, [])]  # (input_ids, score, generated tokens)
    completed_sequences = []
    diversity_penalty = defaultdict(lambda: 0)

    for step in range(max_len):
        new_beam = []
        for seq, score, generated_tokens in beam:
            with torch.no_grad():
                outputs = model(seq, create_future_mask(seq.size(1)).to(device))
            logits = outputs[:, -1, :]  # Get the logits for the last token
            logits = logits / temperature
            logits = apply_repetition_penalty(logits, seq[0], repetition_penalty)
            logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_p=min_p)
            probs = F.log_softmax(logits, dim=-1)
            topk_probs, topk_ids = probs.topk(beam_width)

            for i in range(beam_width):
                next_seq = torch.cat([seq, topk_ids[:, i:i+1]], dim=-1)
                new_score = score + topk_probs[0, i].item()
                new_generated_tokens = generated_tokens + [topk_ids[0, i].item()]

                if no_repeat_ngram_size > 0 and contains_repeated_ngram(next_seq[0], no_repeat_ngram_size):
                    continue  # Skip sequences with repeated n-grams

                # Diversity penalty
                diversity_penalty[tuple(map(tuple, next_seq.tolist()))] += diversity_rate * step
                new_score -= diversity_penalty[tuple(map(tuple, next_seq.tolist()))]

                new_beam.append((next_seq, new_score, new_generated_tokens))

        if not new_beam:
            break  # Break the loop if no new sequences are generated

        beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]

        # Check for completed sequences (sequences that have the end token)
        for seq, score, generated_tokens in beam:
            if seq[0, -1] == tokenizer.eos_token_id:
                length_normalized_score = score / (seq.size(1) ** length_penalty)
                completed_sequences.append((seq, length_normalized_score, generated_tokens))

        # Keep only the sequences that are not completed
        beam = [b for b in beam if b[0][0, -1] != tokenizer.eos_token_id]

        # Early stopping if all sequences are completed
        if not beam:
            break

    if completed_sequences:
        best_seq = sorted(completed_sequences, key=lambda x: x[1], reverse=True)[0]
    else:
        if beam:
            best_seq = beam[0]  # Fallback to the best beam
        else:
            return ""  # Return an empty string if no valid sequence is found

    best_seq_tokens = best_seq[2]
    reference = tokenizer.encode(input_text)  # Use the input text as the reference
    bleu_score = calculate_bleu(reference, best_seq_tokens)

    output_text = tokenizer.decode(best_seq[0].squeeze(), skip_special_tokens=True)
    return output_text, bleu_score

def set_requires_grad(model, layer_idx, requires_grad):
    for i, layer in enumerate(model.layers):
        for param in layer.parameters():
            param.requires_grad = (i == layer_idx) and requires_grad

def get_custom_training_sequence(num_layers):
    sequence = []
    i = 1
    increment = 1
    while len(sequence) < 2 * num_layers:
        sequence.append(i)
        i += increment
        if increment == 1:
            increment = -1
        else:
            increment = 2
        if i > num_layers:
            i = num_layers - 1
            increment = 1
    return sequence

num_layers = len(model.layers)
training_sequence = get_custom_training_sequence(num_layers)

num_epochs = len(training_sequence)
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=total_steps // 10, num_training_steps=total_steps)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
dataset_size = 1000  # Number of samples per epoch

def parabolic_scaling(epoch, num_epochs):
    mid_epoch = num_epochs // 2
    return -2 * ((epoch - mid_epoch) ** 2) / (num_epochs ** 2) + 1
import textstat

def parabolic_scale_readability_score(text, target_grade, grade_range):
    # Calculate the Flesch-Kincaid Grade Level
    fk_grade = textstat.flesch_kincaid_grade(text)
    
    # Parabolic scaling
    scaled_score = 1 - ((fk_grade - target_grade) / grade_range) ** 2
    
    # Clip the scaled score to be within the range [-1, 1]
    scaled_score = max(min(scaled_score, 1), -1)
    
    return scaled_score
from rouge_score import rouge_scorer

# Define the ROUGE score calculation function
def calculate_rouge(reference, hypothesis, tokenizer):
    reference_text = tokenizer.decode(reference, skip_special_tokens=True)
    hypothesis_text = tokenizer.decode(hypothesis, skip_special_tokens=True)
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(reference_text, hypothesis_text)
    return scores

# Define the BLEU score calculation function
def calculate_bleu(reference, hypothesis):
    reference = [reference]  # BLEU expects a list of references
    smoothie = SmoothingFunction().method4
    return sentence_bleu(reference, hypothesis, smoothing_function=smoothie)

def z_loss(logits, beta=1e-4):
    """Z-Loss regularizes logits to prevent extreme values."""
    log_z = torch.logsumexp(logits, dim=-1)
    return beta * log_z.pow(2).mean()
start_epoch = 0
for epoch in range(start_epoch, num_epochs):
    # Determine which layer to unfreeze according to the training sequence
    layer_to_unfreeze = training_sequence[epoch] - 1
    set_requires_grad(model, layer_to_unfreeze, True)
    
    # Create a new subset of the dataset
    indices = np.random.choice(len(train_dataset), dataset_size, replace=False)
    subset = Subset(train_dataset, indices)
    train_loader = DataLoader(subset, batch_size=batch_size, shuffle=True)
        
    total_loss = 0
    total_bleu_score = 0
    total_rouge_score = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", postfix={"Loss": 0.0000, "Perplexity": 0.0000, "BLEU": 0.0000})
    
    for batch in progress_bar:
        input_ids, attention_mask = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        # Shift the input for the next token prediction
        labels = input_ids[:, 1:].contiguous()
        input_ids = input_ids[:, :-1].contiguous()
        
        # Create future mask
        seq_length = input_ids.size(1)
        mask = create_future_mask(seq_length).to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, mask)
        
        # Compute loss
        loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))
        
        # Generate sequences for reward calculation
        generated_ids = outputs.argmax(dim=-1).cpu().numpy()
        references = labels.cpu().numpy()
        batch_bleu_score = 0
        batch_rouge_score = 0
        rewards = []
        for ref, gen in zip(references, generated_ids):
            ref_tokens = ref.tolist()
            gen_tokens = gen.tolist()
            bleu_score = calculate_bleu(ref_tokens, gen_tokens)
            rouge_score = calculate_rouge(ref_tokens, gen_tokens, tokenizer)
            rouge_l_score = rouge_score['rougeL'].fmeasure  # Using ROUGE-L F1-score as the reward
            rewards.append(rouge_l_score)
            batch_bleu_score += bleu_score
            batch_rouge_score += rouge_l_score
        
        avg_bleu_score = batch_bleu_score / len(references)
        avg_rouge_score = batch_rouge_score / len(references)
        
        # REINFORCE algorithm
        rewards = torch.tensor(rewards, dtype=torch.float).to(device)
        log_probs = F.log_softmax(outputs, dim=-1)
        log_probs = log_probs.gather(2, labels.unsqueeze(-1)).squeeze(-1)
        policy_loss = -log_probs * rewards.unsqueeze(-1)
        policy_loss = policy_loss.mean() * parabolic_scaling(epoch, num_epochs)
        zloss_value = z_loss(outputs)
        total_loss_with_reward = loss + zloss_value + policy_loss
        total_loss_with_reward.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += total_loss_with_reward.item()
        total_bleu_score += avg_bleu_score
        total_rouge_score += avg_rouge_score
        avg_loss = total_loss / len(progress_bar)
        perplexity = torch.exp(torch.tensor(loss)).item()
        progress_bar.set_postfix(Loss=f"{loss.item():.4f}", Perplexity=f"{perplexity:.4f}", BLEU=f"{avg_bleu_score:.4f}", ROUGE=f"{avg_rouge_score:.4f}")
    
    avg_loss = total_loss / len(train_loader)
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    avg_bleu_score = total_bleu_score / len(train_loader)
    avg_rouge_score = total_rouge_score / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss_with_reward.item():.4f}, Perplexity: {perplexity:.4f}, Avg BLEU: {avg_bleu_score:.4f}, Avg ROUGE: {avg_rouge_score:.4f}")
    generated_text, bleu_score = beam_search(model, tokenizer, "Once upon a time", beam_width=5, max_len=50)
    print(f"Generated text: {generated_text}")
    print(f"BLEU score: {bleu_score:.4f}")  
    # Save model weights
    torch.save(model.state_dict(), model_path)
    print(f"Model weights saved to {model_path}")
    
    # Freeze the previously unfrozen layer
    set_requires_grad(model, layer_to_unfreeze, False)

print("Training complete.")

Repo card metadata block was not found. Setting CardData to empty.


Model weights loaded from /kaggle/working/model_weights_1536.pth


Epoch 1/24:   0%|          | 0/500 [00:00<?, ?it/s, BLEU=0, Loss=0, Perplexity=0]2024-07-07 20:00:26.276031: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-07 20:00:26.276138: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-07 20:00:26.416140: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  perplexity = torch.exp(torch.tensor(loss)).item()
Epoch 1/24:  49%|████▊     | 243/500 [05:04<05:08,  1.20s/it, BLEU=0.8126, Loss=7.9649, Perplexity=2878.0193, ROUGE=0.1741]

In [3]:
!pip install rouge-score


Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: rouge-score
  Building wheel for rouge-score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=f9c752516bf27bae9aee8640d13601ee8df3f26c59d6a2d297cc1e982ae3d2c9
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge-score
Installing collected packages: rouge-score
Successfully installed rouge-score-0.1.2


In [2]:
!pip install textstat

Collecting textstat
  Downloading textstat-0.7.3-py3-none-any.whl.metadata (14 kB)
Collecting pyphen (from textstat)
  Downloading pyphen-0.15.0-py3-none-any.whl.metadata (3.3 kB)
Downloading textstat-0.7.3-py3-none-any.whl (105 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.1/105.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading pyphen-0.15.0-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m29.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: pyphen, textstat
Successfully installed pyphen-0.15.0 textstat-0.7.3


In [None]:
for epoch in range(num_epochs):
    # Determine which layer to unfreeze according to the training sequence
    layer_to_unfreeze = training_sequence[epoch] - 1
    set_requires_grad(model, layer_to_unfreeze, True)
    
    # Create a new subset of the dataset
    indices = np.random.choice(len(train_dataset), dataset_size, replace=False)
    subset = Subset(train_dataset, indices)
    train_loader = DataLoader(subset, batch_size=batch_size, shuffle=True)
        
    total_loss = 0
    total_bleu_score = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", postfix={"Loss": 0.0000, "Perplexity": 0.0000, "BLEU": 0.0000})
    
    for batch in progress_bar:
        input_ids, attention_mask = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        # Shift the input for the next token prediction
        labels = input_ids[:, 1:].contiguous()
        input_ids = input_ids[:, :-1].contiguous()
        
        # Create future mask
        seq_length = input_ids.size(1)
        mask = create_future_mask(seq_length).to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, mask)
        
        # Ensure matching dimensions for loss computation
        batch_size, seq_len, vocab_size = outputs.size()
        outputs = outputs.view(-1, vocab_size)
        labels = labels.view(-1)
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Generate sequences for reward calculation
        generated_ids = outputs.argmax(dim=-1).view(batch_size, seq_len).cpu().numpy()
        references = labels.view(batch_size, seq_len).cpu().numpy()
        batch_bleu_score = 0
        rewards = []
        for ref, gen in zip(references, generated_ids):
            ref_tokens = ref.tolist()
            gen_tokens = gen.tolist()
            bleu_score = calculate_bleu(ref_tokens, gen_tokens)
            rewards.append(bleu_score)
            batch_bleu_score += bleu_score
        
        avg_bleu_score = batch_bleu_score / len(references)
        
        # REINFORCE algorithm
        rewards = torch.tensor(rewards, dtype=torch.float).to(device)
        log_probs = F.log_softmax(outputs, dim=-1)
        log_probs = log_probs.gather(1, labels.unsqueeze(-1)).squeeze(-1)
        policy_loss = -log_probs * rewards.unsqueeze(-1)
        policy_loss = policy_loss.mean() * parabolic_scaling(epoch, num_epochs)
        zloss_value = z_loss(outputs)
        total_loss_with_reward = loss + zloss_value + policy_loss
        total_loss_with_reward.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += total_loss_with_reward.item()
        total_bleu_score += avg_bleu_score
        avg_loss = total_loss / len(progress_bar)
        perplexity = torch.exp(torch.tensor(loss)).item()
        progress_bar.set_postfix(Loss=f"{loss.item():.4f}", Perplexity=f"{perplexity:.4f}", BLEU=f"{avg_bleu_score:.4f}")
    
    avg_loss = total_loss / len(train_loader)
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    avg_bleu_score = total_bleu_score / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss_with_reward.item():.4f}, Perplexity: {perplexity:.4f}, Avg BLEU: {avg_bleu_score:.4f}")
    generated_text, bleu_score = beam_search(model, tokenizer, "Once upon a time", beam_width=5, max_len=50)
    print(f"Generated text: {generated_text}")
    print(f"BLEU score: {bleu_score:.4f}")  
    # Save model weights
    torch.save(model.module.state_dict(), model_path)  # Use model.module to access the underlying model
    print(f"Model weights saved to {model_path}")
    
    # Freeze the previously unfrozen layer
    set_requires_grad(model, layer_to_unfreeze, False)

print("Training complete.")


In [None]:
model_path = "model_weights.pth"
torch.save(model.state_dict(), model_path)
print(f"Model weights saved to {model_path}")

In [None]:
import torch
import torch.nn.functional as F
from collections import defaultdict

def create_future_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return mask.to(device)

import torch
import torch.nn.functional as F
from collections import defaultdict
import nltk

def calculate_bleu(reference, hypothesis):
    reference = [reference]  # BLEU expects a list of references
    smoothie = SmoothingFunction().method4
    return sentence_bleu(reference, hypothesis, smoothing_function=smoothie)

def contains_repeated_ngram(seq, n):
    ngrams = set()
    for i in range(len(seq) - n + 1):
        ngram = tuple(seq[i:i+n].tolist())
        if ngram in ngrams:
            return True
        ngrams.add(ngram)
    return False

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, min_p=0.0):
    """Filter a distribution of logits using top-k, top-p (nucleus), and min-p filtering"""
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = -float('Inf')

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        if min_p > 0.0:
            sorted_indices_to_remove &= (sorted_logits < min_p).cumsum(dim=-1).bool()

        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = -float('Inf')
        
    if min_p > 0.0:
        logits[logits < min_p] = -float('Inf')

    return logits

def apply_repetition_penalty(logits, seq, repetition_penalty):
    """Apply a penalty to the logits to discourage repetition"""
    for token_id in seq:
        logits[0, token_id] /= repetition_penalty
    return logits

def beam_search(model, tokenizer, input_text, beam_width=5, max_len=50, length_penalty=1.0, no_repeat_ngram_size=3, top_k=50, top_p=0.8, min_p=0.2, temperature=0.6, repetition_penalty=1.1, diversity_rate=0.1):
    input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to(device)
    input_ids = input_ids[:, :-1]  # Remove the last token for autoregressive generation

    beam = [(input_ids, 0, [])]  # (input_ids, score, generated tokens)
    completed_sequences = []
    diversity_penalty = defaultdict(lambda: 0)

    for step in range(max_len):
        new_beam = []
        for seq, score, generated_tokens in beam:
            with torch.no_grad():
                outputs = model(seq, create_future_mask(seq.size(1)).to(device))
            logits = outputs[:, -1, :]  # Get the logits for the last token
            logits = logits / temperature
            logits = apply_repetition_penalty(logits, seq[0], repetition_penalty)
            logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_p=min_p)
            probs = F.log_softmax(logits, dim=-1)
            topk_probs, topk_ids = probs.topk(beam_width)

            for i in range(beam_width):
                next_seq = torch.cat([seq, topk_ids[:, i:i+1]], dim=-1)
                new_score = score + topk_probs[0, i].item()
                new_generated_tokens = generated_tokens + [topk_ids[0, i].item()]

                if no_repeat_ngram_size > 0 and contains_repeated_ngram(next_seq[0], no_repeat_ngram_size):
                    continue  # Skip sequences with repeated n-grams

                # Diversity penalty
                diversity_penalty[tuple(map(tuple, next_seq.tolist()))] += diversity_rate * step
                new_score -= diversity_penalty[tuple(map(tuple, next_seq.tolist()))]

                new_beam.append((next_seq, new_score, new_generated_tokens))

        if not new_beam:
            break  # Break the loop if no new sequences are generated

        beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]

        # Check for completed sequences (sequences that have the end token)
        for seq, score, generated_tokens in beam:
            if seq[0, -1] == tokenizer.eos_token_id:
                length_normalized_score = score / (seq.size(1) ** length_penalty)
                completed_sequences.append((seq, length_normalized_score, generated_tokens))

        # Keep only the sequences that are not completed
        beam = [b for b in beam if b[0][0, -1] != tokenizer.eos_token_id]

        # Early stopping if all sequences are completed
        if not beam:
            break

    if completed_sequences:
        best_seq = sorted(completed_sequences, key=lambda x: x[1], reverse=True)[0]
    else:
        if beam:
            best_seq = beam[0]  # Fallback to the best beam
        else:
            return ""  # Return an empty string if no valid sequence is found

    best_seq_tokens = best_seq[2]
    reference = tokenizer.encode(input_text)  # Use the input text as the reference
    bleu_score = calculate_bleu(reference, best_seq_tokens)

    output_text = tokenizer.decode(best_seq[0].squeeze(), skip_special_tokens=True)
    return output_text, bleu_score



generated_text, bleu_score = beam_search(model, tokenizer, "Once upon a time", beam_width=5, max_len=50)
print(f"Generated text: {generated_text}")
print(f"BLEU score: {bleu_score:.4f}")


In [None]:
!pip install einops