In [None]:
!pip install transformers datasets accelerate torch torchvision torchaudio
!pip install sentencepiece protobuf
!pip install wandb
!pip install bitsandbytes

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    GPT2LMHeadModel, GPT2Tokenizer,
    TrainingArguments, Trainer,
    DataCollatorForLanguageModeling
)
from datasets import load_dataset, Dataset as HFDataset
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
import os
import gc
import warnings
warnings.filterwarnings('ignore')

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

Using device: cuda
GPU: Tesla T4
GPU Memory: 14.7 GB


In [4]:
config = {
    'teacher_model_name': 'gpt2-large',
    'student_model_name': 'gpt2-medium',

    'max_length': 256,
    'batch_size': 1,
    'gradient_accumulation_steps': 16,
    'learning_rate': 5e-5,
    'num_epochs': 3,
    'warmup_steps': 100,
    'max_steps': 1000,

    'temperature': 4.0,
    'alpha': 0.5,
    'beta': 0.5,

    # Data configuration
    'dataset_name': 'yahma/alpaca-cleaned',
    'max_samples': 2000,

    # Output configuration
    'output_dir': './distilled_model',
    'save_steps': 250,
    'logging_steps': 50,
}

In [5]:
dataset = load_dataset(config['dataset_name'])

README.md: 0.00B [00:00, ?B/s]

alpaca_data_cleaned.json:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/51760 [00:00<?, ? examples/s]

In [6]:
train_dataset = dataset['train'].select(range(min(config['max_samples'], len(dataset['train']))))

In [7]:
tokenizer = AutoTokenizer.from_pretrained(config['teacher_model_name'])

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [8]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [9]:
def format_alpaca_prompt(instruction, input_text="", output=""):
    """Format the Alpaca-style prompt"""
    if input_text:
        prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
    else:
        prompt = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
    return prompt

In [10]:
def preprocess_function(examples):
    """Preprocess the dataset"""
    prompts = []
    for i in range(len(examples['instruction'])):
        instruction = examples['instruction'][i]
        input_text = examples['input'][i] if examples['input'][i] else ""
        output = examples['output'][i]

        formatted_prompt = format_alpaca_prompt(instruction, input_text, output)
        prompts.append(formatted_prompt)

    # Tokenize
    tokenized = tokenizer(
        prompts,
        truncation=True,
        padding='max_length',
        max_length=config['max_length'],
        return_tensors='pt'
    )

    tokenized['labels'] = tokenized['input_ids'].clone()

    return tokenized

In [11]:
tokenized_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names
)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [12]:
teacher_model = AutoModelForCausalLM.from_pretrained(
    config['teacher_model_name'],
    dtype=torch.float32,
    device_map='auto',
    low_cpu_mem_usage=True
)

model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [13]:
teacher_model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3840, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=1280)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=5120, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=5120)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [14]:
for param in teacher_model.parameters():
    param.requires_grad = False

In [15]:
student_model = AutoModelForCausalLM.from_pretrained(
    config['student_model_name'],
    dtype=torch.float32,
    device_map='auto',
    low_cpu_mem_usage=True
)

config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [16]:
student_model.resize_token_embeddings(len(tokenizer))

Embedding(50257, 1024)

In [17]:
class DistillationLoss(nn.Module):
    """Memory-optimized custom loss function for knowledge distillation"""

    def __init__(self, temperature=4.0, alpha=0.5, beta=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha  # Weight for distillation loss
        self.beta = beta    # Weight for task loss
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, student_logits, teacher_logits, labels):
        """
        Memory-optimized distillation loss calculation

        Args:
            student_logits: Logits from student model
            teacher_logits: Logits from teacher model
            labels: Ground truth labels
        """
        # Task loss (standard cross-entropy)
        task_loss = self.ce_loss(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))

        # Memory-optimized distillation loss calculation
        # Process in smaller chunks to avoid OOM
        batch_size, seq_len, vocab_size = student_logits.shape

        # Create mask for valid tokens only once
        valid_mask = (labels != -100)

        # Initialize distillation loss
        total_kl_loss = 0.0
        valid_tokens = 0

        # Process sequence in chunks to save memory
        chunk_size = min(128, seq_len)  # Process 128 tokens at a time

        for start_idx in range(0, seq_len, chunk_size):
            end_idx = min(start_idx + chunk_size, seq_len)

            # Get chunks
            student_chunk = student_logits[:, start_idx:end_idx, :]
            teacher_chunk = teacher_logits[:, start_idx:end_idx, :]
            mask_chunk = valid_mask[:, start_idx:end_idx]

            # Only process if there are valid tokens in this chunk
            if mask_chunk.any():
                # Apply temperature scaling
                student_probs_chunk = F.log_softmax(student_chunk / self.temperature, dim=-1)
                teacher_probs_chunk = F.softmax(teacher_chunk / self.temperature, dim=-1)

                # Apply mask - only compute loss for valid tokens
                mask_expanded = mask_chunk.unsqueeze(-1)  # [batch, chunk_len, 1]

                # Flatten and select only valid positions
                student_flat = student_probs_chunk[mask_expanded.expand_as(student_probs_chunk)]
                teacher_flat = teacher_probs_chunk[mask_expanded.expand_as(teacher_probs_chunk)]

                if student_flat.numel() > 0:  # Only compute if we have valid tokens
                    # Reshape back for KL divergence
                    student_valid = student_flat.view(-1, vocab_size)
                    teacher_valid = teacher_flat.view(-1, vocab_size)

                    # Compute KL divergence for this chunk
                    chunk_kl = F.kl_div(student_valid, teacher_valid, reduction='sum')
                    total_kl_loss += chunk_kl
                    valid_tokens += student_valid.size(0)

            # Clear intermediate tensors to free memory
            del student_chunk, teacher_chunk, mask_chunk
            if 'student_probs_chunk' in locals():
                del student_probs_chunk, teacher_probs_chunk
            torch.cuda.empty_cache()

        # Average the KL loss across valid tokens
        if valid_tokens > 0:
            distillation_loss = (total_kl_loss / valid_tokens) * (self.temperature ** 2)
        else:
            distillation_loss = torch.tensor(0.0, device=student_logits.device)

        # Combined loss
        total_loss = self.alpha * distillation_loss + self.beta * task_loss

        return {
            'loss': total_loss,
            'task_loss': task_loss,
            'distillation_loss': distillation_loss
        }

In [18]:
# Initialize distillation loss with memory optimization
distillation_criterion = DistillationLoss(
    temperature=config['temperature'],
    alpha=config['alpha'],
    beta=config['beta']
)

print("Memory-optimized distillation loss function initialized")
print(f"Temperature: {config['temperature']}")
print(f"Alpha (distillation weight): {config['alpha']}")
print(f"Beta (task weight): {config['beta']}")

Memory-optimized distillation loss function initialized
Temperature: 4.0
Alpha (distillation weight): 0.5
Beta (task weight): 0.5
🔧 Using chunked processing to reduce memory usage


In [19]:
class DistillationTrainer(Trainer):
    """Custom Trainer for Knowledge Distillation"""

    def __init__(self, teacher_model, distillation_criterion, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.distillation_criterion = distillation_criterion

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        Custom loss computation with knowledge distillation
        """
        # Handle the case where labels might not be in inputs
        if "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = inputs["input_ids"].clone()

        # Student forward pass
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits

        # Teacher forward pass (no gradient computation)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits

        # Compute distillation loss
        loss_dict = self.distillation_criterion(student_logits, teacher_logits, labels)

        # Log individual losses (only if training)
        if self.model.training:
            try:
                self.log({
                    "task_loss": loss_dict['task_loss'].item(),
                    "distillation_loss": loss_dict['distillation_loss'].item(),
                    "total_loss": loss_dict['loss'].item()
                })
            except:
                pass  # Skip logging if there are issues

        return (loss_dict['loss'], student_outputs) if return_outputs else loss_dict['loss']

In [20]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # We're doing causal language modeling, not masked language modeling
)

In [22]:
training_args = TrainingArguments(
    output_dir=config['output_dir'],
    per_device_train_batch_size=config['batch_size'],
    gradient_accumulation_steps=config['gradient_accumulation_steps'],
    learning_rate=config['learning_rate'],
    num_train_epochs=config['num_epochs'],
    max_steps=config['max_steps'],
    warmup_steps=config['warmup_steps'],
    logging_steps=config['logging_steps'],
    save_steps=config['save_steps'],
    save_total_limit=2,
    prediction_loss_only=False,  # Changed to False for custom loss
    remove_unused_columns=False,
    dataloader_pin_memory=False,

    # Mixed precision settings - Fixed for FP16 gradient scaling
    fp16=False,  # Disable FP16 to avoid gradient scaling issues
    bf16=False,  # Disable BF16 as well for compatibility

    # Memory optimization settings
    gradient_checkpointing=True,  # Keep gradient checkpointing
    dataloader_num_workers=0,  # Reduce to 0 for stability

    # Optimizer settings
    optim="adamw_torch",
    weight_decay=0.01,
    lr_scheduler_type="cosine",
    max_grad_norm=1.0,  # Add gradient clipping

    # Disable all external logging
    report_to=[],  # Explicitly disable all reporting
    logging_dir=None,  # Disable logging directory
    disable_tqdm=False,  # Keep progress bars

    # Additional stability settings
    seed=42,  # Set seed for reproducibility
    data_seed=42,
)

In [23]:
trainer = DistillationTrainer(
    teacher_model=teacher_model,
    distillation_criterion=distillation_criterion,
    model=student_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [24]:
# Aggressive memory cleanup and optimization
def optimize_gpu_memory():
    """Perform aggressive GPU memory optimization"""

    print("🧹 Performing aggressive GPU memory cleanup...")

    # Clear Python cache
    import gc
    gc.collect()

    # Clear PyTorch cache
    torch.cuda.empty_cache()

    # Set memory fraction (use 90% of available memory)
    if torch.cuda.is_available():
        torch.cuda.set_per_process_memory_fraction(0.9)

    # Enable memory-efficient attention if available
    try:
        # Enable Flash Attention if available
        torch.backends.cuda.enable_flash_sdp(True)
    except:
        pass

    # Set environment variables for memory optimization
    import os
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    print("✅ GPU memory optimization completed")

    if torch.cuda.is_available():
        print(f"📊 GPU Memory Status:")
        print(f"  Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        print(f"  Currently allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"  Currently cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
        print(f"  Available: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1024**3:.2f} GB")

# Run memory optimization
optimize_gpu_memory()

🧹 Performing aggressive GPU memory cleanup...
✅ GPU memory optimization completed
📊 GPU Memory Status:
  Total GPU memory: 14.7 GB
  Currently allocated: 4.26 GB
  Currently cached: 4.46 GB
  Available: 10.28 GB


In [25]:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

In [26]:
torch.cuda.empty_cache()
gc.collect()

381

In [27]:
print(f"GPU Memory before training: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

GPU Memory before training: 4.26 GB


In [28]:
try:
    print("Initializing training...")
    training_result = trainer.train()
    print("\nTraining completed successfully!")
    print(f"Final training loss: {training_result.training_loss:.4f}")

except Exception as e:
    print(f"Training failed with error: {e}")
    print(f"Error type: {type(e).__name__}")
    import traceback
    print("Full traceback:")
    traceback.print_exc()

    # Try to save whatever progress was made
    try:
        print("Attempting to save current model state...")
        trainer.save_model(config['output_dir'] + "_partial")
        print("Partial model saved successfully")
    except:
        print("Could not save partial model")

    # Clear memory and try to continue
    torch.cuda.empty_cache()
    gc.collect()

    # Re-raise the error for debugging
    raise e

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 50256}.


Initializing training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss
50,50.2412
100,14.3747


Step,Training Loss
50,50.2412
100,14.3747
150,10.526
200,9.4494
250,8.8586
300,8.2478
350,7.958
400,7.4257
450,7.5167
500,7.1817



Training completed successfully!
Final training loss: 9.9391


In [29]:
# Save the distilled model
print("Saving the distilled student model...")
student_model.save_pretrained(config['output_dir'])
tokenizer.save_pretrained(config['output_dir'])

print(f"Model saved to: {config['output_dir']}")

# Clear memory
torch.cuda.empty_cache()
gc.collect()
print(f"GPU Memory after training: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

Saving the distilled student model...
Model saved to: ./distilled_model
GPU Memory after training: 6.92 GB


In [37]:
# Advanced generation test with proper tokenizer configuration
def advanced_generation_test():
    """Test generation with proper tokenizer setup and anti-repetition measures"""

    print("🧪 Advanced generation test with proper configuration...")

    # Ensure model is on GPU and in eval mode
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    student_model.eval()

    # Fix tokenizer configuration
    if tokenizer.pad_token_id == tokenizer.eos_token_id:
        print("⚠️ Fixing tokenizer: pad_token same as eos_token")
        # Use a different token for padding
        tokenizer.pad_token = tokenizer.unk_token if tokenizer.unk_token else "<pad>"
        if tokenizer.pad_token == tokenizer.eos_token:
            # Add a new pad token
            tokenizer.add_special_tokens({'pad_token': '<pad>'})
            student_model.resize_token_embeddings(len(tokenizer))
        print(f"✅ Fixed - PAD: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), EOS: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")

    # Test with meaningful prompts
    test_prompts = [
        "Explain what machine learning is",
        "Write a Python function to add two numbers",
        "What is the capital of France?",
        "How do you make a sandwich?"
    ]

    for i, prompt in enumerate(test_prompts):
        print(f"\n--- Advanced Test {i+1}: '{prompt}' ---")

        # Tokenize with proper attention mask
        inputs = tokenizer(prompt, return_tensors='pt', padding=False, truncation=True, max_length=100)

        # Create explicit attention mask
        attention_mask = torch.ones_like(inputs['input_ids'])
        inputs['attention_mask'] = attention_mask
        inputs = {k: v.to(device) for k, v in inputs.items()}

        try:
            with torch.no_grad():
                output_tokens = student_model.generate(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],  # Explicit attention mask
                    max_new_tokens=50,
                    temperature=0.7,
                    top_p=0.9,
                    top_k=50,
                    do_sample=True,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    repetition_penalty=1.3,      # Strong anti-repetition
                    no_repeat_ngram_size=3,      # Prevent trigram repetition
                    length_penalty=1.0,
                    early_stopping=True,
                )

            # Decode only new tokens
            input_length = inputs['input_ids'].shape[1]
            generated_tokens = output_tokens[0][input_length:]
            generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

            print(f"✅ Generated: '{generated_text}'")

            # Check quality
            if len(generated_text) > 5 and len(set(generated_text.split())) > 2:
                print("🎉 Quality response generated!")
            else:
                print("⚠️ Response may be repetitive or too short")

        except Exception as e:
            print(f"❌ Generation failed: {e}")

# Run advanced test
advanced_generation_test()

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


🧪 Advanced generation test with proper configuration...
⚠️ Fixing tokenizer: pad_token same as eos_token


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


✅ Fixed - PAD: <pad> (ID: 50257), EOS: <|endoftext|> (ID: 50256)

--- Advanced Test 1: 'Explain what machine learning is' ---
✅ Generated: ', and why it's so crucial important

It It helps help you to understand understanding Understanding Understand EmbEmbembememem Emotions emotions feelings towards a a a person... AAAA A A A student students Students can can'
🎉 Quality response generated!

--- Advanced Test 2: 'Write a Python function to add two numbers' ---
✅ Generated: 'in in in InIninininineines

function(x,,, aaaaAAA A A A))) return + + + / / ////---// / / - -'
🎉 Quality response generated!

--- Advanced Test 3: 'What is the capital of France?' ---
✅ Generated: '??...:.. It's's's in in in a big little small village town city called La Paineainieieieiesiaees ( ( ((((### # # #########'
🎉 Quality response generated!

--- Advanced Test 4: 'How do you make a sandwich?' ---
✅ Generated: '??...

I I I'm's's's'm's is am am am AmAmamat at at at in in in InInininin in in East West Western 