In [1]:
%pwd

'd:\\software_3\\Generative_models\\Text_models\\chat_gpt2\\distilled_gpt2'

In [2]:
import os

os.chdir("../")

In [3]:
%pwd

'd:\\software_3\\Generative_models\\Text_models\\chat_gpt2'

# Model Distillation

This Notebook implements the Model Distillation Process for the LLM.

In [None]:
import torch
import numpy as np
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import tensorflow as tf
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from gpt import GPTModelWithHiddenState
from model_args import BASE_CONFIG
from utils.generate import generate
from utils.download_dataset import download_and_load_dataset
from utils.token_converter import get_tokenizer, text_to_token_ids, token_ids_to_text

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
def format_input(entry):
    """Format instruction dataset entry into a training prompt"""
    instruction = entry.get("instruction", "")
    input_text = entry.get("input", "")
    
    if input_text:
        # Case where there's both instruction and input
        formatted_text = (
            f"Below is an instruction that describes a task, paired with an input that provides further context. "
            f"Write a response that appropriately completes the request.\n\n"
            f"### Instruction:\n{instruction}\n\n"
            f"### Input:\n{input_text}\n\n"
            f"### Response:\n"
        )
    else:
        # Case where there's only instruction
        formatted_text = (
            f"Below is an instruction that describes a task. "
            f"Write a response that appropriately completes the request.\n\n"
            f"### Instruction:\n{instruction}\n\n"
            f"### Response:\n"
        )
    
    return formatted_text

In [7]:
class DistillationDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=1024):
        self.data = []
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        print(f"Processing {len(data)} examples...")
        
        for idx, item in enumerate(data):
            try:
                # Format the instruction + input as prompt
                prompt_text = format_input(item)
                
                # Get the expected output/response
                output_text = item.get("output", "")
                if not output_text:
                    continue  # Skip entries without output
                
                # Tokenize prompt and full sequence
                prompt_tokens = self._safe_tokenize(prompt_text, tokenizer)
                full_text = prompt_text + output_text
                full_tokens = self._safe_tokenize(full_text, tokenizer)
                
                # Truncate if too long
                if len(full_tokens) > max_length:
                    full_tokens = full_tokens[:max_length]
                    # Recalculate prompt length after truncation
                    prompt_tokens = self._safe_tokenize(prompt_text, tokenizer)
                    if len(prompt_tokens) > max_length // 2:
                        prompt_tokens = prompt_tokens[:max_length // 2]
                
                # Create labels (for computing loss only on output tokens)
                labels = [-100] * len(prompt_tokens)
                
                # Add response tokens to labels
                if len(full_tokens) > len(prompt_tokens):
                    response_tokens = full_tokens[len(prompt_tokens):]
                    labels.extend(response_tokens)
                
                # Ensure labels and input_ids have same length
                min_len = min(len(labels), len(full_tokens))
                labels = labels[:min_len]
                full_tokens = full_tokens[:min_len]
                
                if len(full_tokens) > 0 and len(labels) > 0:
                    self.data.append({
                        'input_ids': full_tokens,
                        'labels': labels,
                        'prompt_length': len(prompt_tokens)
                    })
                    
            except Exception as e:
                print(f"Error processing item {idx}: {e}")
                continue
        
        print(f"Successfully processed {len(self.data)} examples")
    
    def _safe_tokenize(self, text, tokenizer):
        """Safely tokenize text and return as list"""
        try:
            # Assuming text_to_token_ids is your tokenization function
            tokens = text_to_token_ids(text, tokenizer)
            
            # Convert to list if it's a tensor
            if hasattr(tokens, 'tolist'):
                return tokens.tolist()
            elif isinstance(tokens, torch.Tensor):
                return tokens.cpu().numpy().tolist()
            elif isinstance(tokens, (list, tuple)):
                return list(tokens)
            else:
                return [tokens] if isinstance(tokens, int) else []
        except Exception as e:
            print(f"Tokenization error: {e}")
            return []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [8]:
def split_dataset(data):
    train_portion = int(len(data) * 0.8)
    test_portion = int(len(data) * 0.1)
    
    train_data = data[:train_portion]
    test_data = data[train_portion:train_portion+test_portion]
    val_data = data[train_portion+test_portion:]

    print("train dataset length:", len(train_data))
    print("length of test data:", len(test_data))
    print("length of val data:", len(val_data))

    return train_data, test_data, val_data

In [27]:
def distillation_collate_fn(batch, pad_token_id=50256, max_length=1024, device='cpu'):
    if not batch:
        return None

    input_ids_list = []
    labels_list = []
    prompt_lengths = []

    for item in batch:
        input_ids = torch.tensor(item['input_ids'], dtype=torch.long).flatten()
        labels = torch.tensor(item['labels'], dtype=torch.long).flatten()
        input_ids_list.append(input_ids)
        labels_list.append(labels)
        prompt_lengths.append(item['prompt_length'])

    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_list, batch_first=True, padding_value=pad_token_id
    )
    labels = torch.nn.utils.rnn.pad_sequence(
        labels_list, batch_first=True, padding_value=-100
    )

    if max_length is not None:
        input_ids = input_ids[:, :max_length]
        labels = labels[:, :max_length]

    return {
        'input_ids': input_ids.to(device),
        'labels': labels.to(device),
        'prompt_lengths': prompt_lengths
    }

In [10]:
def kl_div_loss(T, student_logits, teacher_logits):
    """Compute KL divergence loss"""
    student_soft = F.log_softmax(student_logits / T, dim=-1)
    teacher_soft = F.softmax(teacher_logits / T, dim=-1)
    
    kl_loss = F.kl_div(
        student_soft,
        teacher_soft,
        reduction='batchmean'
    ) * (T ** 2)
    
    return kl_loss

In [11]:
def compute_feat_loss(student_hidden_states, teacher_hidden_states):
    """Compute feature distillation loss between hidden states"""
    if not student_hidden_states or not teacher_hidden_states:
        return 0
        
    if len(student_hidden_states) != len(teacher_hidden_states):
        # If different number of layers, use only the minimum
        min_layers = min(len(student_hidden_states), len(teacher_hidden_states))
        student_hidden_states = student_hidden_states[:min_layers]
        teacher_hidden_states = teacher_hidden_states[:min_layers]
    
    total_loss = 0
    valid_layers = 0
    
    for s_hidden, t_hidden in zip(student_hidden_states, teacher_hidden_states):
        # Ensure same shape
        if s_hidden.shape != t_hidden.shape:
            continue
            
        cos_sim = F.cosine_similarity(s_hidden, t_hidden, dim=-1)
        layer_loss = torch.mean(1 - cos_sim)
        total_loss += layer_loss
        valid_layers += 1
    
    return total_loss / valid_layers if valid_layers > 0 else 0

In [12]:
def train_step(batch, student_model, teacher_model, student_optimizer, config):
    """Perform one training step for instruction-following dataset"""
    
    if batch is None:
        return None
        
    input_ids = batch['input_ids']  # Shape: [batch_size, seq_len]
    labels = batch['labels']        # Shape: [batch_size, seq_len]
    
    # Forward pass through both models
    try:
        with torch.no_grad():
            teacher_logits, teacher_hidden_states = teacher_model(input_ids)
        
        student_logits, student_hidden_states = student_model(input_ids)
        
        # Compute losses
        # 1. KL Divergence Loss (knowledge distillation)
        kl_loss = kl_div_loss(config.T, student_logits, teacher_logits)
        
        # 2. Feature Distillation Loss
        feat_loss = compute_feat_loss(student_hidden_states, teacher_hidden_states)
        
        # 3. Cross-entropy loss with labels (only on response tokens)
        ce_loss = 0
        if labels.numel() > 0:
            # Flatten for cross entropy
            student_logits_flat = student_logits.view(-1, student_logits.size(-1))
            labels_flat = labels.view(-1)
            
            # Only compute loss on non-ignored tokens (labels != -100)
            mask = labels_flat != -100
            if mask.any():
                ce_loss = F.cross_entropy(
                    student_logits_flat[mask], 
                    labels_flat[mask]
                )
        
        # Combined loss
        feat_loss_val = feat_loss if isinstance(feat_loss, torch.Tensor) else torch.tensor(0.0, device=input_ids.device)
        ce_loss_val = ce_loss if isinstance(ce_loss, torch.Tensor) else torch.tensor(0.0, device=input_ids.device)
        
        total_loss = (
            config.alpha * ce_loss_val + 
            (1 - config.alpha) * kl_loss + 
            config.beta * feat_loss_val
        )
        
        # Backpropagation
        student_optimizer.zero_grad()
        total_loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
        
        student_optimizer.step()
        
        return {
            'total_loss': total_loss.item(),
            'kl_loss': kl_loss.item(),
            'feat_loss': feat_loss_val.item() if isinstance(feat_loss_val, torch.Tensor) else 0,
            'ce_loss': ce_loss_val.item() if isinstance(ce_loss_val, torch.Tensor) else 0
        }
        
    except Exception as e:
        print(f"Error in train_step: {e}")
        return None

In [13]:
class DistillTrainer:
    """Knowledge Distillation Trainer"""
    
    def __init__(self, config, student_model, teacher_model, tokenizer):
        self.config = config
        self.student_model = student_model
        self.teacher_model = teacher_model
        self.tokenizer = tokenizer
        
        # Freeze teacher model
        for param in self.teacher_model.parameters():
            param.requires_grad = False
        
        # Initialize optimizer
        self.student_optimizer = torch.optim.AdamW(
            self.student_model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=0.01
        )
        
        # Learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.student_optimizer, T_max=100
        )
    
    def train(self, dataloader, num_epochs=1):
        """Train the student model using knowledge distillation"""
        
        self.student_model.train()
        self.teacher_model.eval()
        
        for epoch in range(num_epochs):
            epoch_metrics = []
            
            progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
            
            for batch_idx, batch in enumerate(progress_bar):
                try:
                    metrics = train_step(
                        batch,
                        self.student_model,
                        self.teacher_model,
                        self.student_optimizer,
                        self.config
                    )
                    
                    if metrics is not None:
                        epoch_metrics.append(metrics)
                        
                        # Update progress bar
                        if len(epoch_metrics) > 0:
                            avg_loss = np.mean([m['total_loss'] for m in epoch_metrics[-10:]])
                            progress_bar.set_postfix({'avg_loss': f'{avg_loss:.4f}'})
                        
                        # Log detailed metrics every 50 batches
                        if batch_idx % 50 == 0 and len(epoch_metrics) >= 10:
                            recent_metrics = epoch_metrics[-10:]
                            avg_metrics = {
                                k: np.mean([m[k] for m in recent_metrics])
                                for k in recent_metrics[0].keys()
                            }
                            print(f"\nBatch {batch_idx}: {avg_metrics}")
                    
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {str(e)}")
                    continue
            
            # Step scheduler
            self.scheduler.step()
            
            # Epoch summary
            if epoch_metrics:
                epoch_avg = {
                    k: np.mean([m[k] for m in epoch_metrics])
                    for k in epoch_metrics[0].keys()
                }
                print(f"\nEpoch {epoch+1} Summary: {epoch_avg}")
            else:
                print(f"Epoch {epoch+1}: No successful batches processed")

In [14]:
def create_dataloaders(train_data, test_data, val_data, batch_size, tokenizer, device):
    """Create data loaders for instruction-following dataset"""
    
    # Create custom collate function with device
    collate_fn = partial(
        distillation_collate_fn,
        device=device,
        max_length=1024
    )
    
    # Create datasets with max_length parameter
    train_dataset = DistillationDataset(train_data, tokenizer, max_length=1024)
    test_dataset = DistillationDataset(test_data, tokenizer, max_length=1024)
    val_dataset = DistillationDataset(val_data, tokenizer, max_length=1024)
    
    # Create data loaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True,
        num_workers=0
    )
    
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=False,
        drop_last=False,
        num_workers=0
    )
    
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=False,
        drop_last=False,
        num_workers=0
    )
    
    return train_dataloader, test_dataloader, val_dataloader

In [15]:
class DistillConfig:
    """Configuration for knowledge distillation"""
    
    def __init__(
        self, 
        T=3.0,           # Temperature for knowledge distillation
        alpha=0.3,       # Weight for cross-entropy loss
        beta=0.1,        # Weight for feature distillation loss
        learning_rate=1e-5,
        vocab_size=50257,
        batch_size=4,
        max_length=1024
    ):
        self.T = T
        self.alpha = alpha
        self.beta = beta
        self.learning_rate = learning_rate
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.max_length = max_length

In [16]:
config = DistillConfig(
    learning_rate=1e-5,
    batch_size=4,  # Start small to avoid memory issues
    max_length=1024,
    T=3.0,
    alpha=0.3,
    beta=0.1
)

In [17]:
tokenizer = get_tokenizer()

In [18]:
model_path = "gpt_models\\SFT_model.pth"
teacher_model = GPTModelWithHiddenState(BASE_CONFIG)
teacher_model.load_state_dict(
    torch.load(
    model_path,
    map_location=torch.device("cpu"),
    weights_only=True
    )['model_state_dict']
)
teacher_model.to(device)
teacher_model.eval()

GPTModelWithHiddenState(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linea

In [19]:
model_path = "gpt_models\\Foundational_model.pth"
student_model = GPTModelWithHiddenState(BASE_CONFIG)
student_model.load_state_dict(
    torch.load(
    model_path,
    map_location=torch.device('cpu'),
    weights_only=True
    )
)
student_model.to(device)
student_model.eval()

GPTModelWithHiddenState(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linea

In [20]:
tokenizer = get_tokenizer()
device = "cuda" if torch.cuda.is_available() else "cpu"

In [28]:

data = download_and_load_dataset("instruction-data.json")
train_data, test_data, val_data = split_dataset(data)

train_dataloader, test_dataloader, val_dataloader = create_dataloaders(
    train_data, test_data, val_data, 
    config.batch_size, tokenizer, device
)



train dataset length: 880
length of test data: 110
length of val data: 110
Processing 880 examples...
Successfully processed 880 examples
Processing 110 examples...
Successfully processed 110 examples
Processing 110 examples...
Successfully processed 110 examples


In [22]:
trainer = DistillTrainer(config, student_model, teacher_model, tokenizer)

In [29]:
NUM_EPOCHS = 1

trainer.train(train_dataloader, num_epochs=3)
torch.save({
    'model_state_dict': student_model.state_dict(),
    'config': config.__dict__,
    'epoch': 3
}, 'distilled_gpt_model.pth')
print("-------------------------------------------------------------------")
print()
print("Training has been completed and model has beed saved")

Epoch 1/3:  23%|██▎       | 51/220 [02:22<10:47,  3.83s/it, avg_loss=139.6129]


Batch 50: {'total_loss': np.float64(139.61294326782226), 'kl_loss': np.float64(199.41669616699218), 'feat_loss': np.float64(0.21258984357118607), 'ce_loss': np.float64(0.0)}


Epoch 1/3:  46%|████▌     | 101/220 [04:39<04:57,  2.50s/it, avg_loss=93.0867]


Batch 100: {'total_loss': np.float64(93.08665161132812), 'kl_loss': np.float64(132.95119705200196), 'feat_loss': np.float64(0.20815364718437196), 'ce_loss': np.float64(0.0)}


Epoch 1/3:  69%|██████▊   | 151/220 [06:52<03:14,  2.81s/it, avg_loss=85.7081] 


Batch 150: {'total_loss': np.float64(85.70813751220703), 'kl_loss': np.float64(122.4114891052246), 'feat_loss': np.float64(0.20095255374908447), 'ce_loss': np.float64(0.0)}


Epoch 1/3:  91%|█████████▏| 201/220 [09:07<00:56,  2.97s/it, avg_loss=76.9813]


Batch 200: {'total_loss': np.float64(76.98127250671386), 'kl_loss': np.float64(109.94473037719726), 'feat_loss': np.float64(0.19961598962545396), 'ce_loss': np.float64(0.0)}


Epoch 1/3: 100%|██████████| 220/220 [10:05<00:00,  2.75s/it, avg_loss=79.8474]



Epoch 1 Summary: {'total_loss': np.float64(121.42905696522106), 'kl_loss': np.float64(173.4404061057351), 'feat_loss': np.float64(0.2077528475360437), 'ce_loss': np.float64(0.0)}


Epoch 2/3:  23%|██▎       | 51/220 [02:16<07:09,  2.54s/it, avg_loss=69.1071]


Batch 50: {'total_loss': np.float64(69.10713157653808), 'kl_loss': np.float64(98.69687881469727), 'feat_loss': np.float64(0.19318017661571502), 'ce_loss': np.float64(0.0)}


Epoch 2/3:  46%|████▌     | 101/220 [04:32<05:15,  2.65s/it, avg_loss=67.3134]


Batch 100: {'total_loss': np.float64(67.31340751647949), 'kl_loss': np.float64(96.1349910736084), 'feat_loss': np.float64(0.18915003538131714), 'ce_loss': np.float64(0.0)}


Epoch 2/3:  69%|██████▊   | 151/220 [06:41<03:04,  2.67s/it, avg_loss=49.5978]


Batch 150: {'total_loss': np.float64(49.59783706665039), 'kl_loss': np.float64(70.82784690856934), 'feat_loss': np.float64(0.18344430178403853), 'ce_loss': np.float64(0.0)}


Epoch 2/3:  91%|█████████▏| 201/220 [08:53<00:54,  2.88s/it, avg_loss=53.6883]


Batch 200: {'total_loss': np.float64(53.688341522216795), 'kl_loss': np.float64(76.67213249206543), 'feat_loss': np.float64(0.17850395441055297), 'ce_loss': np.float64(0.0)}


Epoch 2/3: 100%|██████████| 220/220 [09:46<00:00,  2.67s/it, avg_loss=56.8292]



Epoch 2 Summary: {'total_loss': np.float64(60.624206126819956), 'kl_loss': np.float64(86.57930526733398), 'feat_loss': np.float64(0.18693537122823975), 'ce_loss': np.float64(0.0)}


Epoch 3/3:  23%|██▎       | 51/220 [02:16<06:15,  2.22s/it, avg_loss=48.6724]


Batch 50: {'total_loss': np.float64(48.67239990234375), 'kl_loss': np.float64(69.5071014404297), 'feat_loss': np.float64(0.1742960423231125), 'ce_loss': np.float64(0.0)}


Epoch 3/3:  46%|████▌     | 101/220 [04:52<06:25,  3.24s/it, avg_loss=46.4936]


Batch 100: {'total_loss': np.float64(46.493560791015625), 'kl_loss': np.float64(66.3950912475586), 'feat_loss': np.float64(0.16996641159057618), 'ce_loss': np.float64(0.0)}


Epoch 3/3:  69%|██████▊   | 151/220 [07:02<02:49,  2.45s/it, avg_loss=54.7707]


Batch 150: {'total_loss': np.float64(54.77074508666992), 'kl_loss': np.float64(78.21990585327148), 'feat_loss': np.float64(0.16812762022018432), 'ce_loss': np.float64(0.0)}


Epoch 3/3:  91%|█████████▏| 201/220 [08:58<00:38,  2.05s/it, avg_loss=45.5694]


Batch 200: {'total_loss': np.float64(45.569393157958984), 'kl_loss': np.float64(65.07554397583007), 'feat_loss': np.float64(0.16513274163007735), 'ce_loss': np.float64(0.0)}


Epoch 3/3: 100%|██████████| 220/220 [09:40<00:00,  2.64s/it, avg_loss=46.4272]



Epoch 3 Summary: {'total_loss': np.float64(47.3233569318598), 'kl_loss': np.float64(67.5805571642789), 'feat_loss': np.float64(0.16967466670003803), 'ce_loss': np.float64(0.0)}
-------------------------------------------------------------------

Training has been completed and model has beed saved
