In [None]:
#!pip install transformers
#!pip install tqdm
#!pip install evaluate
#!pip install torch
#!pip install accelerate
#!pip install numpy
#!pip install matplotlib
#!pip install tensorboardx
#!pip install scikit-learn

from transformers import AutoModelForQuestionAnswering, AutoTokenizer, Trainer, TrainingArguments, TrainerCallback
from datasets import load_dataset
import torch
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}") 

# Load the SQuAD dataset
squad = load_dataset("squad")

In [None]:
# Replace original teacher model initialization
teacher_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
teacher_model = AutoModelForQuestionAnswering.from_pretrained(
    teacher_model_name,
    trust_remote_code=True  # Required for DeepSeek models
)
teacher_tokenizer = AutoTokenizer.from_pretrained(
    teacher_model_name,
    trust_remote_code=True,
    use_fast=False  # Recommended for DeepSeek models
)

# Modify preprocessing for DeepSeek's tokenization
def preprocess_teacher_train(example):
    inputs = teacher_tokenizer(
        example["question"],
        example["context"],
        truncation=True,
        max_length=512,  # Matches DeepSeek's context window
        stride=96,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
        add_special_tokens=True  # Explicitly enable special tokens
    )
    

In [None]:
# Should do a training pass on SQuAD1.1 to initialise the weights in the Teacher model.
# This will also stop the warning from displaying

# New function to preprocess training data for teacher
def preprocess_teacher_train(example):
    # roberta-base-squad-v1
    #inputs = teacher_tokenizer(
    #    example["question"],
    #    example["context"],
    #    truncation=True,
    #    max_length=384,
    #    stride=128,
    #    return_overflowing_tokens=True,
    #    return_offsets_mapping=True,
    #    padding="max_length"
    #)
    
    # DeepSeek-R1-Distill-Qwen-1.5B
    inputs = teacher_tokenizer(
        example["question"],
        example["context"],
        truncation=True,
        max_length=512,
        stride=96,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
        add_special_tokens=True  # Explicitly enable special tokens
    )
    
    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = example["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        
        sequence_ids = inputs.sequence_ids(i)

        # Handle empty sequence_ids case
        if not sequence_ids:
            start_positions.append(0)
            end_positions.append(0)
            continue

        # Find context start with boundary checks
        idx = 0
        while idx < len(sequence_ids) and sequence_ids[idx] != 1:
            idx += 1
        context_start = idx if idx < len(sequence_ids) else 0

        # Find context end with boundary checks
        while idx < len(sequence_ids) and sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1 if idx > 0 else 0

        # Handle answer position calculation
        if (context_start >= len(offset) or 
            context_end >= len(offset) or
            offset[context_start][0] > end_char or 
            offset[context_end][1] < start_char):
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Find start position with bounds checking
            idx = context_start
            while idx <= context_end and idx < len(offset) and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(min(idx - 1, len(offset)-1))
            
            # Find end position with bounds checking
            idx = context_end
            while idx >= context_start and idx < len(offset) and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(min(idx + 1, len(offset)-1))
    
    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

# Teacher Training Arguments
teacher_training_args = TrainingArguments(
    output_dir="./teacher_train",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,  # Increased logging frequency
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    learning_rate=1e-5,
    report_to="tensorboard",
    fp16=True,  # Enable mixed precision training
    dataloader_num_workers=4,
)

# Custom progress callback
class TeacherTrainingProgress(TrainerCallback):
    def on_train_begin(self, args, state, control, **kwargs):
        print(f"üöÄ Starting training with {args.num_train_epochs} epochs")
        print(f"üìä Batch size: {args.per_device_train_batch_size}")
        print(f"üîç Evaluation every {args.eval_steps} steps")

    def on_epoch_begin(self, args, state, control, **kwargs):
        print(f"\n‚è≥ Starting epoch {state.epoch}/{args.num_train_epochs}")
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and 'loss' in logs:
            print(f"Step {state.global_step}: Loss {logs['loss']:.4f}")
        if logs and 'eval_loss' in logs:
            print(f"Validation Loss: {logs['eval_loss']:.4f}")
            print(f"Exact Match: {logs['eval_exact_match']:.2f}%")
            print(f"F1 Score: {logs['eval_f1']:.2f}%")

# Add metrics computation to Trainer
def compute_metrics(p):

    # Convert logits to predictions
    start_pred = np.argmax(p.predictions[0], axis=1)
    end_pred = np.argmax(p.predictions[1], axis=1)
    
    # Get true positions
    start_true = p.label_ids[0]
    end_true = p.label_ids[1]
    
    # Calculate exact match
    exact_matches = np.logical_and(
        start_pred == start_true,
        end_pred == end_true
    )

    # Calculate span F1
    def overlap_f1(p_start, p_end, t_start, t_end):
        pred_span = set(range(p_start, p_end+1))
        true_span = set(range(t_start, t_end+1))
        overlap = len(pred_span & true_span)
        precision = overlap / len(pred_span) if pred_span else 0
        recall = overlap / len(true_span) if true_span else 0
        return 2*(precision*recall)/(precision+recall) if (precision+recall) else 0
    
    f1_scores = [
        overlap_f1(sp, ep, st, et)
        for sp, ep, st, et in zip(start_pred, end_pred, start_true, end_true)
    ]
    
    return {
        "exact_match": np.mean(exact_matches) * 100,
        "f1": np.mean(f1_scores) * 100
    }


# Create Trainer for teacher
teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_training_args,
    train_dataset=squad["train"].map(preprocess_teacher_train, batched=True, remove_columns=squad["train"].column_names),
    eval_dataset=squad["validation"].map(preprocess_teacher_train, batched=True, remove_columns=squad["validation"].column_names),
    tokenizer=teacher_tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[TeacherTrainingProgress()]
)


In [None]:
# Train teacher model
print("\nTraining Teacher Model on SQuAD1.1...")
teacher_trainer.train()
teacher_model.save_pretrained("./DeepSeek-R1-Distill-Qwen-1.5B-trained")
teacher_tokenizer.save_pretrained("./DeepSeek-R1-Distill-Qwen-1.5B-trained")

In [None]:

print("\nRe-loading optimized teacher model")
teacher_model = AutoModelForQuestionAnswering.from_pretrained("./trained_teacher")
teacher_tokenizer = AutoTokenizer.from_pretrained("./trained_teacher")