In [1]:
# !pip install -q transformers
# !pip install -q bitsandbytes
# !pip install -q datasets
# !pip install -q lm-eval
# !pip install -q wandb
# !pip install -q ipywidgets
# !pip install -q gdown

In [2]:
# import os

# os.system('gdown --id 1qcKw1-vwtR4qMOB--B93b61mfD8JviKh')
# os.system('gdown --id 1b9FWW2RSPT2Mdnz4YifNe7C-argn4KJJ')

# os.system("unzip 'train_dataset_clean.zip'")
# os.system("unzip 'valid_dataset_clean.zip'")

# # os.makedirs("/root/.cache/huggingface/hub/dataset", exist_ok=True)
# # os.system("unzip 'train_dataset_clean.zip' -d '/root/.cache/huggingface/hub/dataset'")
# # os.system("unzip 'valid_dataset_clean.zip' -d '/root/.cache/huggingface/hub/dataset'")

# os.system("rm -rf 'train_dataset_clean.zip'")
# os.system("rm -rf 'valid_dataset_clean.zip'")

In [3]:
import torch
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, Gemma3ForConditionalGeneration, AutoProcessor,
                          AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments)
from tqdm import tqdm
from datasets import load_dataset
from torch.nn import functional as F
from torch.utils.data import DataLoader

In [4]:
import os
import wandb
# from google.colab import userdata

HF_API_KEY = "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" # userdata.get('HUGGINGFACE_TOKEN')
os.environ["HF_TOKEN"] = HF_API_KEY
WANDB_API_KEY =  "2be7c86a28a2bcbeccdfa66844abfdd19b9bdabf" # userdata.get('WANDB_API_KEY')

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

WANDB_PROJECT_NAME = "Distil_MedGemma_LLama"
wandb.login(key=WANDB_API_KEY)
if len(WANDB_PROJECT_NAME) > 0:
    os.environ["WANDB_PROJECT"] = WANDB_PROJECT_NAME

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/maboelenen/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmohamed-ahmed[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch_dtype = torch.bfloat16
attn_implementation = "eager"

In [6]:
teacher_model_name = "google/medgemma-4b-it"
# student_model_name = "meta-llama/Llama-3.2-1B-Instruct"
student_model_name = "MohamedAhmedAE/distil_MedGemma_4B_Llama-3.2-1B"

In [7]:

quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                         bnb_4bit_use_double_quant=True,
                                         bnb_4bit_quant_type="nf4",
                                         bnb_4bit_compute_dtype=torch_dtype,
                                         llm_int8_enable_fp32_cpu_offload=False)

In [8]:
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,
                                                               quantization_config=quantization_config,
                                                               device_map="auto",
                                                               trust_remote_code=True,
                                                               torch_dtype=torch_dtype,
                                                               attn_implementation=attn_implementation)

student_model = AutoModelForCausalLM.from_pretrained(student_model_name,
                                                     torch_dtype=torch_dtype,
                                                     device_map="auto")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

In [9]:
train_dataset_directory="content/train_dataset_clean"
valid_dataset_directory="content/valid_dataset_clean"

In [10]:
from datasets import Dataset
from datasets import concatenate_datasets


train_dataset = Dataset.load_from_disk(train_dataset_directory)
valid_dataset = Dataset.load_from_disk(valid_dataset_directory)

formatted_dataset = concatenate_datasets([train_dataset, valid_dataset])
formatted_dataset = formatted_dataset.shuffle(seed=2)

In [11]:
# formatted_dataset = formatted_dataset.select(range(1000))

In [12]:
del train_dataset, valid_dataset

In [13]:
formatted_dataset

Dataset({
    features: ['output', 'instruction', 'input', 'context', 'choices', 'type', 'len_text', 'prompt'],
    num_rows: 1437851
})

In [14]:
from transformers import AutoTokenizer

teacher_processor = AutoProcessor.from_pretrained(teacher_model_name)
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

if student_tokenizer.pad_token is None:
    student_tokenizer.pad_token = student_tokenizer.eos_token
    
max_length = 1024

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [15]:
print("Loading tokenizers...")
def format_prompt(example):
    system = "You are a Medical Assistant follow the following instruction"

    if len(example["context"]) >= 5:
        instruction = f"""{example["instruction"]}\n\n{example["context"]}\n\n"""

    else:
        instruction = f"""{example["instruction"]}\n\n"""

    if len(example["choices"]) > 0:
        prompt_template ="""{instruction}

{input}

{choices}
"""
        prompt = prompt_template.format(instruction=instruction.strip(),
                                        input=example["input"].strip(),
                                        choices=example["choices"].strip())
    else:
        prompt_template = """{instruction}

{input}
"""

        prompt = prompt_template.format(instruction=instruction.strip(),
                                        input=example["input"].strip())

    student_message = [{"role": "system", "content": system},
                       {"role": "user", "content": prompt},
                       {"role": "assistant", "content": example['output'].strip()}]


    teacher_message = [{
        "role": "system",
        "content": [{"type": "text", "text": system}]
    },
    {
        "role": "user",
        "content": [{"type": "text", "text": prompt}]
    },
    {
        "role": "assistant",
        "content": [{"type": "text", "text": example['output'].strip()}]
    }]
    
    teacher_prompt = teacher_processor.apply_chat_template(teacher_message,
                                                           tokenize=False,
                                                           max_length=max_length,
                                                           # add_generation_prompt=True,
                                                           truncation=True)

    student_prompt = student_tokenizer.apply_chat_template(student_message,
                                                           tokenize=False,
                                                           max_length=max_length,
                                                           add_generation_prompt=True,
                                                           truncation=True)

    return {"teacher_prompt": teacher_prompt, "student_prompt": student_prompt}

formatted_dataset = formatted_dataset.map(format_prompt)
# formatted_dataset = formatted_dataset.select(range(100)).map(format_prompt)



Loading tokenizers...


Map:   0%|          | 0/1437851 [00:00<?, ? examples/s]

In [16]:
print(formatted_dataset['teacher_prompt'][10])

<bos><start_of_turn>user
You are a Medical Assistant follow the following instruction

Choose the correct answer for the following question

Entonox cylinder has blue body and white shoulder.

Colour of Entonox cylinder is

1) Black body, white shoulder.
2) Grey body, black and white shoulder.
3) Black body, brown and white shoulder.
4) Blue body, white shoulder.<end_of_turn>
<start_of_turn>model
4<end_of_turn>



In [17]:
print(formatted_dataset['student_prompt'][10])

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 27 Jul 2025

You are a Medical Assistant follow the following instruction<|eot_id|><|start_header_id|>user<|end_header_id|>

Choose the correct answer for the following question

Entonox cylinder has blue body and white shoulder.

Colour of Entonox cylinder is

1) Black body, white shoulder.
2) Grey body, black and white shoulder.
3) Black body, brown and white shoulder.
4) Blue body, white shoulder.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

4<|eot_id|><|start_header_id|>assistant<|end_header_id|>




In [18]:
def create_tokenize_function(teacher_processor, student_tokenizer, max_length=1024, 
                             student_text_column="student_prompt", teacher_text_column="teacher_prompt"):

    def tokenize_function(examples):
        
        # Process teacher inputs
        teacher_inputs = teacher_processor(
            text=examples[teacher_text_column],
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
            padding_side="right"
        )

        # Process student inputs
        student_inputs = student_tokenizer(
            examples[student_text_column],
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
            padding_side="right"
        )

        return {
            "teacher_input_ids": teacher_inputs["input_ids"].tolist(),
            "teacher_attention_mask": teacher_inputs["attention_mask"].tolist(),
            "teacher_labels": teacher_inputs["input_ids"].tolist(),
            "student_input_ids": student_inputs["input_ids"].tolist(),
            "student_attention_mask": student_inputs["attention_mask"].tolist(),
            "student_labels": student_inputs["input_ids"].tolist(),
        }

    return tokenize_function

In [19]:
tokenize_function = create_tokenize_function(
    teacher_processor=teacher_processor,
    student_tokenizer=student_tokenizer,
    max_length=max_length)

print("Tokenizing dataset .... ")

tokenized_dataset = formatted_dataset.map(
    tokenize_function,
    batched=True,
    batch_size=1000,
    remove_columns=formatted_dataset.column_names,
    desc="Processing samples for distillation",
    load_from_cache_file=False
)

Tokenizing dataset .... 


Processing samples for distillation:   0%|          | 0/1437851 [00:00<?, ? examples/s]

In [20]:
tokenized_dataset.set_format("torch")

In [21]:
train_dataloader = DataLoader(
    tokenized_dataset,
    batch_size=1, ########################################
    shuffle=True
)

In [22]:
val_dataloader = DataLoader(
    tokenized_dataset.select(range(100)),
    batch_size=1, ########################################
)

### Optimizer

In [23]:
lr = 1e-6
optimizer = AdamW(student_model.parameters(), lr=lr)

### Training Loop

In [24]:
num_epochs = 10
temperature = 4.0
alpha = 0.7
accumulation_steps = 1

In [25]:
import torch 

torch.cuda.empty_cache()

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import wandb
import numpy as np
import os
from pathlib import Path
import json
import gc  # For garbage collection
from huggingface_hub import HfApi, Repository, login
from transformers import AutoTokenizer

# HuggingFace configuration
HF_TOKEN = os.getenv("HF_TOKEN")
HF_REPO_ID = "MohamedAhmedAE/distil_MedGemma_4B_Llama-3.2-1B"
HF_PRIVATE = False

# Login to HuggingFace Hub
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("Logged in to HuggingFace Hub")
else:
    print("Warning: HF_TOKEN not found. Set it as environment variable for auto-upload.")

# Initialize wandb
wandb.init(
    project="distil_MedGemma_4B_Llama-3.2-1B",
    name=f"distillation-{wandb.util.generate_id()}",
    config={
        "alpha": alpha,
        "temperature": temperature,
        "accumulation_steps": accumulation_steps,
        "hidden_dim": 1024,
        "learning_rate": lr,
        "num_epochs": num_epochs,
        "teacher_vocab_size": teacher_model.get_input_embeddings().weight.size(0),
        "student_vocab_size": student_model.get_input_embeddings().weight.size(0),
        "save_every_steps": 5000,
        "checkpoint_dir": "./checkpoints",
        "hf_repo_id": HF_REPO_ID,
        "hf_private": HF_PRIVATE,
    }
)

# Setup checkpoint directory
checkpoint_dir = Path(wandb.config.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Memory optimization function
def clear_memory():
    """Aggressively clear GPU memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()

def get_memory_usage():
    """Get current GPU memory usage"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**3  # GB
    return 0




# Get tokenizer for HuggingFace upload (assumes it's available)
tokenizer = None
try:
    if hasattr(student_model, 'config') and hasattr(student_model.config, 'name_or_path'):
        tokenizer = AutoTokenizer.from_pretrained(student_model.config.name_or_path)
except:
    print("Warning: Could not load tokenizer. Model will be uploaded without tokenizer.")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Logged in to HuggingFace Hub


In [27]:
import torch
from torch.serialization import safe_globals
import wandb.sdk.wandb_config

def save_checkpoint(state, filename):
    torch.save(state, filename)

def load_checkpoint(filename, model, teacher_proj, optimizer):
    with safe_globals([wandb.sdk.wandb_config.Config]):
        checkpoint = torch.load(filename, weights_only=True)
    model.load_state_dict(checkpoint['student_state_dict'])
    if teacher_proj is not None and checkpoint['teacher_proj_state_dict'] is not None:
        teacher_proj.load_state_dict(checkpoint['teacher_proj_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return {
        'start_step': checkpoint['start_step'],
        'start_epoch': checkpoint['start_epoch']
    }

In [28]:
# Enhanced Knowledge Distillation Training - Integrated Version
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import wandb
import os
from tqdm import tqdm

# Enhanced hyperparameters - Add these to your wandb.config
def get_enhanced_config_defaults():
    """Enhanced configuration defaults to add to your wandb.config"""
    return {
        'attention_distill': True,
        'feature_matching': True,
        'progressive_unfreezing': True,
        'adaptive_temperature': True,
        'layer_wise_distill': True,
        'attention_loss_weight': 0.05,
        'feature_loss_weight': 0.1,
        'min_temperature': 1.0,
        'max_temperature': 8.0,
        'unfreeze_schedule': 'linear'  # 'linear', 'cosine', 'step'
    }

class EnhancedDistillationLoss(nn.Module):
    """Enhanced loss function with multiple distillation strategies"""
    
    def __init__(self, config, device):
        super().__init__()
        self.config = config
        self.device = device
        self.mse_loss = nn.MSELoss()
        self.cosine_loss = nn.CosineSimilarity(dim=-1)
        
    def compute_attention_loss(self, student_attentions, teacher_attentions):
        """Distill attention patterns"""
        if not student_attentions or not teacher_attentions:
            return torch.tensor(0.0, device=self.device)
        
        try:
            attention_loss = 0.0
            min_layers = min(len(student_attentions), len(teacher_attentions))
            
            for i in range(min_layers):
                # Average across heads for simplicity
                s_att = student_attentions[i].mean(dim=1)  # [batch, seq, seq]
                t_att = teacher_attentions[i].mean(dim=1)
                
                # Align sequence lengths
                min_seq = min(s_att.size(-1), t_att.size(-1))
                s_att = s_att[:, :min_seq, :min_seq]
                t_att = t_att[:, :min_seq, :min_seq]
                
                # Check for NaN/Inf
                if torch.isnan(s_att).any() or torch.isnan(t_att).any():
                    continue
                    
                attention_loss += self.mse_loss(s_att, t_att)
            
            result = attention_loss / min_layers if min_layers > 0 else torch.tensor(0.0, device=self.device)
            return result if not torch.isnan(result) else torch.tensor(0.0, device=self.device)
            
        except Exception as e:
            return torch.tensor(0.0, device=self.device)
    
    def compute_feature_matching_loss(self, student_hidden, teacher_hidden, intermediate_proj=None):
        """Feature matching across multiple layers"""
        if not student_hidden or not teacher_hidden:
            return torch.tensor(0.0, device=self.device)
        
        try:
            feature_loss = 0.0
            num_layers = min(len(student_hidden), len(teacher_hidden))
            
            # Select layers to distill (every 2nd layer for efficiency)
            layers_to_distill = list(range(0, num_layers, max(1, num_layers // 4)))
            if not layers_to_distill:
                layers_to_distill = [num_layers // 2]  # At least middle layer
            
            valid_layers = 0
            
            for layer_idx in layers_to_distill:
                try:
                    s_hidden = student_hidden[layer_idx]
                    t_hidden = teacher_hidden[layer_idx]
                    
                    # Align sequence lengths
                    min_seq = min(s_hidden.size(1), t_hidden.size(1))
                    s_hidden = s_hidden[:, :min_seq, :]
                    t_hidden = t_hidden[:, :min_seq, :]
                    
                    # Apply projection if needed for this layer
                    if intermediate_proj is not None and s_hidden.size(-1) != t_hidden.size(-1):
                        t_hidden = intermediate_proj(t_hidden)
                    
                    # Check for NaN/Inf
                    if torch.isnan(s_hidden).any() or torch.isnan(t_hidden).any():
                        continue
                    
                    # L2 distance loss
                    layer_mse = self.mse_loss(s_hidden, t_hidden)
                    if not torch.isnan(layer_mse):
                        feature_loss += layer_mse
                        valid_layers += 1
                    
                    # Cosine similarity loss (normalize hidden states)
                    s_norm = F.normalize(s_hidden, p=2, dim=-1)
                    t_norm = F.normalize(t_hidden, p=2, dim=-1)
                    cosine_sim = self.cosine_loss(s_norm, t_norm).mean()
                    
                    if not torch.isnan(cosine_sim):
                        feature_loss += (1 - cosine_sim)
                        
                except Exception:
                    continue
            
            result = feature_loss / max(valid_layers, 1)
            return result if not torch.isnan(result) else torch.tensor(0.0, device=self.device)
            
        except Exception:
            return torch.tensor(0.0, device=self.device)

class AdaptiveTemperatureScheduler:
    """Adaptive temperature scheduling based on training progress"""
    
    def __init__(self, initial_temp=4.0, min_temp=1.0, max_temp=8.0):
        self.initial_temp = initial_temp
        self.min_temp = min_temp
        self.max_temp = max_temp
        self.loss_history = []
        self.current_temp = initial_temp
        
    def update(self, current_loss, global_step, total_steps):
        try:
            self.loss_history.append(current_loss)
            
            # Keep only recent history
            if len(self.loss_history) > 100:
                self.loss_history = self.loss_history[-100:]
            
            # Calculate loss trend
            temp_adjustment = 1.0
            if len(self.loss_history) >= 10:
                recent_avg = np.mean(self.loss_history[-10:])
                older_avg = np.mean(self.loss_history[-20:-10]) if len(self.loss_history) >= 20 else recent_avg
                
                # If loss is decreasing well, reduce temperature for sharper predictions
                if recent_avg < older_avg * 0.95:
                    temp_adjustment = 0.98
                # If loss is stagnating, increase temperature for softer targets
                elif recent_avg > older_avg * 1.02:
                    temp_adjustment = 1.02
            
            # Base temperature decay
            progress_ratio = global_step / max(total_steps, 1)
            base_temp = self.initial_temp * (1 - 0.3 * progress_ratio)  # 30% reduction over training
            
            # Apply adjustment
            self.current_temp = base_temp * temp_adjustment
            self.current_temp = max(self.min_temp, min(self.max_temp, self.current_temp))
            
            return self.current_temp
        except Exception:
            return self.initial_temp

class ProgressiveUnfreezer:
    """Progressive unfreezing strategy for student model"""
    
    def __init__(self, model, total_steps, unfreeze_schedule='linear'):
        self.model = model
        self.total_steps = total_steps
        self.unfreeze_schedule = unfreeze_schedule
        self.frozen_params = []
        self.total_params = []
        
        # Initially freeze early layers
        self.freeze_early_layers()
    
    def freeze_early_layers(self):
        """Freeze embedding and early transformer layers"""
        try:
            params_to_freeze = []
            
            # Collect all parameters first
            for name, param in self.model.named_parameters():
                self.total_params.append((name, param))
            
            # Freeze embeddings and early layers
            for name, param in self.model.named_parameters():
                should_freeze = False
                
                # Freeze embeddings
                if any(embed_name in name.lower() for embed_name in ['embed', 'wte', 'wpe']):
                    should_freeze = True
                
                # Freeze first few transformer layers
                for layer_pattern, max_layers in [('layers.', 4), ('transformer.h.', 4), ('h.', 4)]:
                    if layer_pattern in name:
                        try:
                            layer_num = int(name.split(layer_pattern)[1].split('.')[0])
                            if layer_num < max_layers:
                                should_freeze = True
                        except (ValueError, IndexError):
                            pass
                
                if should_freeze:
                    param.requires_grad = False
                    params_to_freeze.append((name, param))
            
            self.frozen_params = params_to_freeze
            print(f"Froze {len(self.frozen_params)} parameter groups for progressive unfreezing")
            
        except Exception as e:
            print(f"Warning: Could not setup progressive unfreezing: {e}")
            self.frozen_params = []
    
    def update_freezing(self, global_step):
        """Progressively unfreeze layers based on training progress"""
        try:
            if not self.frozen_params:
                return
            
            progress = global_step / max(self.total_steps, 1)
            
            # Unfreeze schedule
            if self.unfreeze_schedule == 'linear':
                unfreeze_threshold = progress
            elif self.unfreeze_schedule == 'cosine':
                unfreeze_threshold = 0.5 * (1 - np.cos(np.pi * progress))
            else:  # 'step'
                unfreeze_threshold = 1.0 if progress > 0.5 else 0.0
            
            # Determine how many parameter groups to unfreeze
            total_frozen = len(self.frozen_params)
            params_to_unfreeze = int(total_frozen * unfreeze_threshold)
            
            # Unfreeze parameters
            for i in range(min(params_to_unfreeze, total_frozen)):
                name, param = self.frozen_params[i]
                if not param.requires_grad:
                    param.requires_grad = True
                    
        except Exception as e:
            print(f"Warning: Error in progressive unfreezing: {e}")

In [29]:
# Your existing code with enhancements integrated
# Check if we should resume from a checkpoint
teacher_proj = None
student_vocab_size = student_model.get_input_embeddings().weight.size(0)
teacher_vocab_size = teacher_model.get_input_embeddings().weight.size(0)
hidden_dim = wandb.config.hidden_dim
train_dataloader_len = len(train_dataloader)

# Enhanced hyperparameters - merge with your existing config
enhanced_defaults = get_enhanced_config_defaults()
for key, default_value in enhanced_defaults.items():
    if not hasattr(wandb.config, key):
        setattr(wandb.config, key, default_value)

temperature = wandb.config.get('temperature', 4.0)
base_alpha = wandb.config.get('alpha', 0.7)
distill_intermediate_layers = wandb.config.get('distill_intermediate', False)
dynamic_alpha = wandb.config.get('dynamic_alpha', True)

# New enhanced features
attention_distill = wandb.config.get('attention_distill', True)
feature_matching = wandb.config.get('feature_matching', True)
progressive_unfreezing = wandb.config.get('progressive_unfreezing', True)
adaptive_temperature = wandb.config.get('adaptive_temperature', True)
attention_loss_weight = wandb.config.get('attention_loss_weight', 0.05)
feature_loss_weight = wandb.config.get('feature_loss_weight', 0.1)

# Calculate total training steps for dynamic scheduling
total_steps = num_epochs * train_dataloader_len

# Initialize enhanced components
enhanced_loss = EnhancedDistillationLoss(wandb.config, device)
temp_scheduler = AdaptiveTemperatureScheduler(
    temperature, 
    wandb.config.get('min_temperature', 1.0),
    wandb.config.get('max_temperature', 8.0)
) if adaptive_temperature else None

# Progressive unfreezing
unfreezer = ProgressiveUnfreezer(
    student_model, 
    total_steps, 
    wandb.config.get('unfreeze_schedule', 'linear')
) if progressive_unfreezing else None

# Enhanced learning rate scheduler
scheduler = CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=max(1, total_steps//4), 
    T_mult=2,
    eta_min=1e-6
)

# Your existing projection layer code
if teacher_vocab_size != student_vocab_size:
    teacher_proj = nn.Sequential(
        nn.Linear(teacher_vocab_size, hidden_dim, bias=False, dtype=torch_dtype),
        nn.ReLU(),
        nn.Linear(hidden_dim, student_vocab_size, bias=False, dtype=torch_dtype)
    ).to(device)

    with torch.no_grad():
        nn.init.xavier_uniform_(teacher_proj[0].weight)
        nn.init.xavier_uniform_(teacher_proj[2].weight)
        teacher_proj[0].weight *= 0.1
        teacher_proj[2].weight *= 0.1

    print(f"Initialized projection layer: {teacher_vocab_size} -> {student_vocab_size}")
    wandb.log({"setup/projection_layer_created": True})

# Enhanced intermediate layer projection
intermediate_proj = None
if distill_intermediate_layers or feature_matching:
    student_hidden_size = student_model.config.hidden_size
    # teacher_hidden_size = teacher_model.config.hidden_size
    teacher_hidden_size = teacher_model.config.text_config.hidden_size
    
    if student_hidden_size != teacher_hidden_size:
        intermediate_proj = nn.Linear(
            teacher_hidden_size, 
            student_hidden_size, 
            bias=False, 
            dtype=torch_dtype
        ).to(device)
        
        with torch.no_grad():
            nn.init.xavier_uniform_(intermediate_proj.weight)
            intermediate_proj.weight *= 0.1
        
        print(f"Initialized intermediate projection: {teacher_hidden_size} -> {student_hidden_size}")

# Your existing checkpoint configuration
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
final_checkpoint_path = os.path.join(checkpoint_dir, "final_checkpoint.pt")
save_steps = wandb.config.save_every_steps

# Your existing resume training logic
start_step = 0
start_epoch = 0
batches_to_skip = 0

if os.path.exists(final_checkpoint_path):
    resume_info = load_checkpoint(final_checkpoint_path, student_model, teacher_proj, optimizer)
    start_step = resume_info['start_step']
    start_epoch = resume_info['start_epoch']
    
    batches_to_skip = start_step % train_dataloader_len
    
    print(f"Resumed training from step {start_step}, epoch {start_epoch}")
    print(f"Will skip {batches_to_skip} batches in current epoch")
else:
    batches_to_skip = 0

# Global step counter
global_step = start_step

# Enhanced training loop
for epoch in range(start_epoch, num_epochs):
    print(f"\n=== Starting Epoch {epoch+1}/{num_epochs} ===")
    print(f"Resuming from global step {global_step}")
    
    # Model preparation
    student_model.train()
    teacher_model.eval()

    # Initialize metrics
    total_loss = 0.0
    total_hard_loss = 0.0
    total_soft_loss = 0.0
    total_intermediate_loss = 0.0
    total_attention_loss = 0.0  # New
    total_feature_loss = 0.0    # New
    num_valid_batches = 0
    total_gradient_norm = 0.0
    
    # Enhanced error tracking
    error_counts = {
        "data_processing": 0,
        "teacher_inference": 0,
        "student_inference": 0,
        "projection": 0,
        "loss_computation": 0,
        "backward_pass": 0,
        "nan_gradients": 0,
        "nan_logits": 0,
        "intermediate_distill": 0,
        "attention_distill": 0,  # New
        "feature_matching": 0    # New
    }

    # Your existing dataloader setup
    train_dataloader_iter = iter(train_dataloader)
    
    if epoch == start_epoch and batches_to_skip > 0:
        print(f"Skipping {batches_to_skip} batches...")
        for _ in range(batches_to_skip):
            try:
                next(train_dataloader_iter)
            except StopIteration:
                break
        
        remaining_batches = train_dataloader_len - batches_to_skip
        progress_bar = tqdm(
            range(remaining_batches), 
            desc=f"Epoch {epoch+1}/{num_epochs}", 
            initial=batches_to_skip, 
            total=train_dataloader_len
        )
        actual_batch_idx_offset = batches_to_skip
    else:
        remaining_batches = train_dataloader_len
        progress_bar = tqdm(
            range(train_dataloader_len), 
            desc=f"Epoch {epoch+1}/{num_epochs}", 
            total=train_dataloader_len
        )
        actual_batch_idx_offset = 0

    # Process remaining batches
    for progress_idx in progress_bar:
        actual_batch_idx = progress_idx + actual_batch_idx_offset if epoch == start_epoch else progress_idx
        
        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            break
            
        global_step += 1
        
        # Update progressive unfreezing
        if unfreezer:
            unfreezer.update_freezing(global_step)
        
        # Dynamic alpha scheduling - start with more soft loss emphasis
        if dynamic_alpha:
            progress_ratio = global_step / total_steps
            current_alpha = base_alpha * (0.3 + 0.7 * progress_ratio)
        else:
            current_alpha = base_alpha
        
        # Update adaptive temperature
        current_temperature = temperature
        if temp_scheduler:
            # Use a simple proxy loss for temperature adaptation
            try:
                proxy_loss = total_loss / max(num_valid_batches, 1) if num_valid_batches > 0 else 1.0
                current_temperature = temp_scheduler.update(proxy_loss, global_step, total_steps)
            except:
                current_temperature = temperature
        
        try:
            # === DATA PROCESSING === (Your existing code)
            teacher_input_ids = batch["teacher_input_ids"].to(device, non_blocking=True)
            teacher_attention_mask = batch["teacher_attention_mask"].to(device, non_blocking=True)
            teacher_labels = batch["teacher_labels"].to(device, non_blocking=True)

            student_input_ids = batch["student_input_ids"].to(device, non_blocking=True)
            student_attention_mask = batch["student_attention_mask"].to(device, non_blocking=True)
            student_labels = batch["student_labels"].to(device, non_blocking=True)

            # Validate data
            if (torch.isnan(student_labels).any() or torch.isinf(student_labels).any() or 
                (student_labels != -100).sum() == 0):
                error_counts["data_processing"] += 1
                continue

            # === ENHANCED TEACHER MODEL INFERENCE ===
            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=teacher_input_ids,
                    attention_mask=teacher_attention_mask,
                    labels=teacher_labels,
                    output_hidden_states=distill_intermediate_layers or feature_matching,
                    output_attentions=attention_distill
                )
                teacher_logits = teacher_outputs.logits
                teacher_hidden = teacher_outputs.hidden_states if (distill_intermediate_layers or feature_matching) else None
                teacher_attentions = teacher_outputs.attentions if attention_distill else None

                if torch.isnan(teacher_logits).any():
                    error_counts["teacher_inference"] += 1
                    continue

            # Clear teacher inputs from memory
            del teacher_input_ids, teacher_attention_mask, teacher_labels

            # === ENHANCED STUDENT MODEL INFERENCE ===
            student_outputs = student_model(
                input_ids=student_input_ids,
                attention_mask=student_attention_mask,
                labels=student_labels,
                output_hidden_states=distill_intermediate_layers or feature_matching,
                output_attentions=attention_distill
            )
            student_logits = student_outputs.logits
            student_hidden = student_outputs.hidden_states if (distill_intermediate_layers or feature_matching) else None
            student_attentions = student_outputs.attentions if attention_distill else None

            # Check for NaN/Inf in student outputs
            if torch.isnan(student_logits).any() or torch.isinf(student_logits).any():
                error_counts["nan_logits"] += 1
                print(f"NaN/Inf in student logits at batch {actual_batch_idx}")
                
                # Emergency parameter reinitialization
                for name, module in student_model.named_modules():
                    if hasattr(module, 'weight') and module.weight is not None:
                        if torch.isnan(module.weight).any():
                            nn.init.normal_(module.weight, mean=0.0, std=0.02)
                    if hasattr(module, 'bias') and module.bias is not None:
                        if torch.isnan(module.bias).any():
                            nn.init.zeros_(module.bias)
                continue

            # Clear student inputs from memory
            del student_input_ids, student_attention_mask

            # === SEQUENCE LENGTH ALIGNMENT ===
            min_seq_len = min(student_logits.size(1), teacher_logits.size(1))
            student_logits = student_logits[:, :min_seq_len, :]
            teacher_logits = teacher_logits[:, :min_seq_len, :]
            student_labels_aligned = student_labels[:, :min_seq_len]

            # === APPLY PROJECTION IF NEEDED ===
            if teacher_proj is not None:
                teacher_logits = teacher_proj(teacher_logits.to(torch_dtype))
                if torch.isnan(teacher_logits).any():
                    error_counts["projection"] += 1
                    continue

            # === ENHANCED LOSS COMPUTATION ===
            # Hard loss (student vs ground truth)
            mask = (student_labels_aligned != -100)
            if mask.sum() == 0:
                error_counts["loss_computation"] += 1
                continue

            flat_logits = student_logits.view(-1, student_logits.size(-1))
            flat_labels = student_labels_aligned.view(-1)
            hard_loss = F.cross_entropy(flat_logits, flat_labels, ignore_index=-100, reduction='mean')
            
            del flat_logits, flat_labels

            if torch.isnan(hard_loss) or torch.isinf(hard_loss):
                error_counts["loss_computation"] += 1
                continue

            # Soft loss (Knowledge Distillation) with current temperature
            mask_float = mask.float()
            if mask_float.sum() == 0:
                soft_loss = torch.tensor(0.0, device=device, dtype=torch_dtype)
            else:
                # Temperature scaling with stability checks
                student_logits_temp = student_logits / current_temperature
                teacher_logits_temp = teacher_logits / current_temperature

                # Check for extreme values
                if (torch.abs(student_logits_temp).max() > 50 or 
                    torch.abs(teacher_logits_temp).max() > 50):
                    error_counts["loss_computation"] += 1
                    continue

                student_log_probs = F.log_softmax(student_logits_temp, dim=-1)
                teacher_probs = F.softmax(teacher_logits_temp, dim=-1)

                if (torch.isnan(student_log_probs).any() or 
                    torch.isnan(teacher_probs).any()):
                    error_counts["loss_computation"] += 1
                    continue

                # KL divergence
                kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction='none', log_target=False)
                kl_per_token = kl_loss.sum(-1)
                masked_kl = kl_per_token * mask_float
                soft_loss = masked_kl.sum() / mask_float.sum() * (current_temperature ** 2)

                del (student_logits_temp, teacher_logits_temp, student_log_probs, 
                     teacher_probs, kl_loss, kl_per_token, masked_kl)

            # Enhanced intermediate layer distillation
            intermediate_loss = torch.tensor(0.0, device=device)
            if distill_intermediate_layers and student_hidden is not None and teacher_hidden is not None:
                try:
                    # Use middle layer for distillation
                    student_mid_layer = student_hidden[len(student_hidden) // 2][:, :min_seq_len, :]
                    teacher_mid_layer = teacher_hidden[len(teacher_hidden) // 2][:, :min_seq_len, :]
                    
                    # Apply projection if needed
                    if intermediate_proj is not None:
                        teacher_mid_layer = intermediate_proj(teacher_mid_layer)
                    
                    # MSE loss for hidden states
                    intermediate_loss = F.mse_loss(student_mid_layer, teacher_mid_layer, reduction='mean')
                    
                    if torch.isnan(intermediate_loss):
                        intermediate_loss = torch.tensor(0.0, device=device)
                        error_counts["intermediate_distill"] += 1
                        
                except Exception as e:
                    intermediate_loss = torch.tensor(0.0, device=device)
                    error_counts["intermediate_distill"] += 1

            # NEW: Enhanced feature matching loss
            feature_loss = torch.tensor(0.0, device=device)
            if feature_matching and student_hidden is not None and teacher_hidden is not None:
                try:
                    feature_loss = enhanced_loss.compute_feature_matching_loss(
                        student_hidden, teacher_hidden, intermediate_proj
                    )
                    if torch.isnan(feature_loss):
                        feature_loss = torch.tensor(0.0, device=device)
                        error_counts["feature_matching"] += 1
                except Exception:
                    feature_loss = torch.tensor(0.0, device=device)
                    error_counts["feature_matching"] += 1

            # NEW: Attention distillation loss
            attention_loss = torch.tensor(0.0, device=device)
            if attention_distill and student_attentions is not None and teacher_attentions is not None:
                try:
                    attention_loss = enhanced_loss.compute_attention_loss(
                        student_attentions, teacher_attentions
                    )
                    if torch.isnan(attention_loss):
                        attention_loss = torch.tensor(0.0, device=device)
                        error_counts["attention_distill"] += 1
                except Exception:
                    attention_loss = torch.tensor(0.0, device=device)
                    error_counts["attention_distill"] += 1

            # Clear logits and other tensors
            del student_logits, teacher_logits, student_labels, student_labels_aligned, mask, mask_float
            if student_hidden is not None:
                del student_hidden
            if teacher_hidden is not None:
                del teacher_hidden
            if student_attentions is not None:
                del student_attentions
            if teacher_attentions is not None:
                del teacher_attentions

            if torch.isnan(soft_loss) or torch.isinf(soft_loss):
                error_counts["loss_computation"] += 1
                continue

            # Enhanced loss combination
            total_batch_loss = (
                current_alpha * soft_loss + 
                (1 - current_alpha) * hard_loss + 
                0.1 * intermediate_loss +  # Original intermediate loss
                feature_loss_weight * feature_loss +  # New feature matching
                attention_loss_weight * attention_loss  # New attention loss
            )

            # Apply gradient accumulation scaling BEFORE backward pass
            if ((actual_batch_idx + 1) % accumulation_steps == 0) or (progress_idx == remaining_batches - 1):
                pass
            else:
                total_batch_loss = total_batch_loss / accumulation_steps

            if torch.isnan(total_batch_loss) or torch.isinf(total_batch_loss):
                continue

            # === BACKWARD PASS ===
            total_batch_loss.backward()

            # Check for NaN gradients
            has_nan_grad = False
            for param in student_model.parameters():
                if param.grad is not None and torch.isnan(param.grad).any():
                    has_nan_grad = True
                    break

            if has_nan_grad:
                optimizer.zero_grad()
                error_counts["nan_gradients"] += 1
                continue

            # Gradient clipping
            grad_norm = torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            if teacher_proj is not None:
                torch.nn.utils.clip_grad_norm_(teacher_proj.parameters(), max_norm=1.0)
            if intermediate_proj is not None:
                torch.nn.utils.clip_grad_norm_(intermediate_proj.parameters(), max_norm=1.0)

            total_gradient_norm += grad_norm.item()

            # === OPTIMIZATION ===
            if ((actual_batch_idx + 1) % accumulation_steps == 0) or (progress_idx == remaining_batches - 1):
                optimizer.step()
                scheduler.step()  # Enhanced: Add scheduler step
                optimizer.zero_grad()

            # === LOSS TRACKING ===
            batch_total_loss_val = total_batch_loss.item() * (accumulation_steps if (actual_batch_idx + 1) % accumulation_steps != 0 else 1)
            batch_hard_loss_val = hard_loss.item()
            batch_soft_loss_val = soft_loss.item()
            batch_intermediate_loss_val = intermediate_loss.item()
            batch_attention_loss_val = attention_loss.item()  # New
            batch_feature_loss_val = feature_loss.item()      # New

            # Clear loss tensors
            del total_batch_loss, hard_loss, soft_loss, intermediate_loss, attention_loss, feature_loss

            # Accumulate losses
            total_loss += batch_total_loss_val
            total_hard_loss += batch_hard_loss_val
            total_soft_loss += batch_soft_loss_val
            total_intermediate_loss += batch_intermediate_loss_val
            total_attention_loss += batch_attention_loss_val    # New
            total_feature_loss += batch_feature_loss_val        # New
            num_valid_batches += 1

            # Update progress bar with enhanced metrics
            progress_bar.set_postfix({
                'Loss': f'{batch_total_loss_val:.4f}',
                'Hard': f'{batch_hard_loss_val:.4f}',
                'Soft': f'{batch_soft_loss_val:.4f}',
                'Attn': f'{batch_attention_loss_val:.4f}',
                'Feat': f'{batch_feature_loss_val:.4f}',
                'Alpha': f'{current_alpha:.3f}',
                'Temp': f'{current_temperature:.2f}',
                'Step': global_step
            })

            # === ENHANCED CHECKPOINT SAVING ===
            if global_step % save_steps == 0 or (epoch == num_epochs - 1 and progress_idx == remaining_batches - 1):
                # Enhanced checkpoint with new components
                checkpoint_data = {
                    'student_state_dict': student_model.state_dict(),
                    'teacher_proj_state_dict': teacher_proj.state_dict() if teacher_proj is not None else None,
                    'intermediate_proj_state_dict': intermediate_proj.state_dict() if intermediate_proj is not None else None,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),  # New: Save scheduler state
                    'start_step': global_step,
                    'start_epoch': epoch,
                    'config': wandb.config.as_dict(),
                    'temperature_scheduler_state': {  # New: Temperature scheduler state
                        'current_temp': current_temperature,
                        'loss_history': temp_scheduler.loss_history if temp_scheduler else []
                    } if temp_scheduler else None,
                    'unfreezer_state': {  # New: Progressive unfreezer state
                        'frozen_params': len(unfreezer.frozen_params) if unfreezer else 0
                    } if unfreezer else None
                }
                
                save_checkpoint(checkpoint_data, final_checkpoint_path)
                print(f"Saved enhanced checkpoint at step {global_step}")

            # === ENHANCED LOGGING ===
            if (actual_batch_idx + 1) % save_steps == 0:
                avg_grad_norm = total_gradient_norm / max(num_valid_batches, 1)
                
                # Enhanced batch metrics
                batch_metrics = {
                    "batch/total_loss": batch_total_loss_val,
                    "batch/hard_loss": batch_hard_loss_val,
                    "batch/soft_loss": batch_soft_loss_val,
                    "batch/intermediate_loss": batch_intermediate_loss_val,
                    "batch/attention_loss": batch_attention_loss_val,     # New
                    "batch/feature_loss": batch_feature_loss_val,         # New
                    "batch/gradient_norm": grad_norm.item(),
                    "batch/current_alpha": current_alpha,
                    "batch/current_temperature": current_temperature,     # New
                    "training/step": global_step,
                    "training/epoch": epoch,
                    "training/learning_rate": scheduler.get_last_lr()[0], # New
                    "memory/gpu_usage_gb": get_memory_usage(),
                    "hyperparams/temperature": current_temperature,
                    "hyperparams/alpha": current_alpha
                }
                
                # Add progressive unfreezing metrics
                if unfreezer:
                    frozen_count = sum(1 for _, param in unfreezer.frozen_params if not param.requires_grad)
                    total_frozen = len(unfreezer.frozen_params)
                    batch_metrics["training/frozen_params_ratio"] = frozen_count / max(total_frozen, 1)
                    batch_metrics["training/unfrozen_params"] = total_frozen - frozen_count
                
                wandb.log(batch_metrics)

                print(f"Epoch {epoch+1}/{num_epochs}, Batch {actual_batch_idx+1}, Step {global_step}")
                print(f"  Total Loss: {batch_total_loss_val:.4f}")
                print(f"  Hard Loss: {batch_hard_loss_val:.4f}")
                print(f"  Soft Loss: {batch_soft_loss_val:.4f}")
                print(f"  Attention Loss: {batch_attention_loss_val:.4f}") # New
                print(f"  Feature Loss: {batch_feature_loss_val:.4f}")     # New
                if distill_intermediate_layers:
                    print(f"  Intermediate Loss: {batch_intermediate_loss_val:.4f}")
                print(f"  Current Alpha: {current_alpha:.3f}")
                print(f"  Current Temperature: {current_temperature:.2f}")  # New
                print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.2e}") # New
                if unfreezer:
                    frozen_count = sum(1 for _, param in unfreezer.frozen_params if not param.requires_grad)
                    print(f"  Frozen Parameters: {frozen_count}/{len(unfreezer.frozen_params)}")
                print(f"  GPU Memory: {get_memory_usage():.2f} GB")

            # === MEMORY CLEANUP ===
            if (actual_batch_idx + 1) % 100 == 0:
                clear_memory()

        except Exception as e:
            print(f"Error in batch {actual_batch_idx}: {e}")
            optimizer.zero_grad()
            clear_memory()
            continue

    # Close progress bar
    progress_bar.close()

    # === ENHANCED EPOCH SUMMARY ===
    if num_valid_batches > 0:
        avg_total_loss = total_loss / num_valid_batches
        avg_hard_loss = total_hard_loss / num_valid_batches
        avg_soft_loss = total_soft_loss / num_valid_batches
        avg_intermediate_loss = total_intermediate_loss / num_valid_batches
        avg_attention_loss = total_attention_loss / num_valid_batches      # New
        avg_feature_loss = total_feature_loss / num_valid_batches          # New
        avg_gradient_norm = total_gradient_norm / num_valid_batches
    else:
        avg_total_loss = avg_hard_loss = avg_soft_loss = avg_intermediate_loss = float('nan')
        avg_attention_loss = avg_feature_loss = avg_gradient_norm = float('nan')

    print(f"\nEpoch {epoch+1}/{num_epochs} Summary:")
    print(f"  Average Total Loss: {avg_total_loss:.4f}")
    print(f"  Average Hard Loss: {avg_hard_loss:.4f}")
    print(f"  Average Soft Loss: {avg_soft_loss:.4f}")
    print(f"  Average Attention Loss: {avg_attention_loss:.4f}")    # New
    print(f"  Average Feature Loss: {avg_feature_loss:.4f}")        # New
    if distill_intermediate_layers:
        print(f"  Average Intermediate Loss: {avg_intermediate_loss:.4f}")
    print(f"  Final Alpha: {current_alpha:.3f}")
    print(f"  Final Temperature: {current_temperature:.2f}")         # New
    print(f"  Final Learning Rate: {scheduler.get_last_lr()[0]:.2e}") # New
    print(f"  Valid Batches: {num_valid_batches}/{remaining_batches}")
    print(f"  GPU Memory: {get_memory_usage():.2f} GB")
    
    # Progressive unfreezing summary
    if unfreezer:
        frozen_count = sum(1 for _, param in unfreezer.frozen_params if not param.requires_grad)
        unfrozen_count = len(unfreezer.frozen_params) - frozen_count
        print(f"  Progressive Unfreezing: {unfrozen_count}/{len(unfreezer.frozen_params)} parameters unfrozen")
    
    print("-" * 50)

    # Enhanced epoch metrics logging
    epoch_metrics = {
        "epoch/total_loss": avg_total_loss,
        "epoch/hard_loss": avg_hard_loss,
        "epoch/soft_loss": avg_soft_loss,
        "epoch/intermediate_loss": avg_intermediate_loss,
        "epoch/attention_loss": avg_attention_loss,        # New
        "epoch/feature_loss": avg_feature_loss,            # New
        "epoch/gradient_norm": avg_gradient_norm,
        "epoch/final_alpha": current_alpha,
        "epoch/final_temperature": current_temperature,    # New
        "epoch/final_learning_rate": scheduler.get_last_lr()[0], # New
        "epoch/valid_batches": num_valid_batches,
        "epoch/total_batches": remaining_batches,
        "epoch/success_rate": num_valid_batches / remaining_batches if remaining_batches > 0 else 0,
        "epoch/epoch_number": epoch + 1,
        "training/global_step": global_step,
        "memory/gpu_usage_gb": get_memory_usage()
    }

    # Add progressive unfreezing metrics to epoch summary
    if unfreezer:
        frozen_count = sum(1 for _, param in unfreezer.frozen_params if not param.requires_grad)
        total_frozen = len(unfreezer.frozen_params)
        epoch_metrics["epoch/frozen_params_ratio"] = frozen_count / max(total_frozen, 1)
        epoch_metrics["epoch/unfrozen_params"] = total_frozen - frozen_count
        epoch_metrics["epoch/unfreezing_progress"] = (total_frozen - frozen_count) / max(total_frozen, 1)

    # Add error counts
    for error_type, count in error_counts.items():
        epoch_metrics[f"errors/{error_type}"] = count

    # Add enhanced metrics totals
    epoch_metrics["totals/attention_loss"] = total_attention_loss
    epoch_metrics["totals/feature_loss"] = total_feature_loss

    wandb.log(epoch_metrics)

    # Clear memory at end of epoch
    clear_memory()

    # Reset batches_to_skip after first resumed epoch
    batches_to_skip = 0

print("Enhanced training completed successfully!")

# === FINAL MODEL SAVING ===
print("\nSaving final enhanced model...")

# Save the final student model
final_model_path = os.path.join(checkpoint_dir, "final_student_model")
student_model.save_pretrained(final_model_path)
print(f"Final student model saved to: {final_model_path}")

# Save additional components
if teacher_proj is not None:
    torch.save(teacher_proj.state_dict(), os.path.join(checkpoint_dir, "teacher_projection.pt"))
    print("Teacher projection layer saved")

if intermediate_proj is not None:
    torch.save(intermediate_proj.state_dict(), os.path.join(checkpoint_dir, "intermediate_projection.pt"))
    print("Intermediate projection layer saved")

# Save training metadata
training_metadata = {
    "total_steps": global_step,
    "num_epochs": num_epochs,
    "final_temperature": current_temperature if temp_scheduler else temperature,
    "final_alpha": current_alpha,
    "enhanced_features_used": {
        "attention_distillation": attention_distill,
        "feature_matching": feature_matching,
        "progressive_unfreezing": progressive_unfreezing,
        "adaptive_temperature": adaptive_temperature,
        "layer_wise_distillation": distill_intermediate_layers
    },
    "final_learning_rate": scheduler.get_last_lr()[0],
    "config": wandb.config.as_dict()
}

with open(os.path.join(checkpoint_dir, "training_metadata.json"), 'w') as f:
    import json
    json.dump(training_metadata, f, indent=2)

print("Training metadata saved")
print(f"All files saved in: {checkpoint_dir}")

# Final wandb summary
wandb.log({
    "training/completed": True,
    "training/total_steps": global_step,
    "training/final_temperature": current_temperature if temp_scheduler else temperature,
    "training/final_alpha": current_alpha,
    "training/final_lr": scheduler.get_last_lr()[0]
})

print("\nEnhanced Knowledge Distillation Training Complete!")

Froze 37 parameter groups for progressive unfreezing
Initialized projection layer: 262208 -> 128256
Initialized intermediate projection: 2560 -> 2048
Resumed training from step 305000, epoch 0
Will skip 305000 batches in current epoch

=== Starting Epoch 1/10 ===
Resuming from global step 305000
Skipping 305000 batches...


Epoch 1/10:  21%|██████▎                       | 305000/1437851 [00:00<?, ?it/s]`sdpa` attention does not support `output_attentions=True` or `head_mask`. Please set your attention to `eager` if you want any of these features.
Epoch 1/10:  22%|▏| 310000/1437851 [24:28<816:30:27,  2.61s/it, Loss=89.7247, Ha

Saved enhanced checkpoint at step 310000
Epoch 1/10, Batch 310000, Step 310000
  Total Loss: 89.7247
  Hard Loss: 3.0000
  Soft Loss: 1.7269
  Attention Loss: 0.0000
  Feature Loss: 872.0000
  Current Alpha: 0.221
  Current Temperature: 3.97
  Learning Rate: 1.00e-06
  Frozen Parameters: 37/37
  GPU Memory: 12.11 GB


Epoch 1/10:  22%|▏| 315000/1437851 [48:57<806:38:25,  2.59s/it, Loss=78.0299, Ha

Saved enhanced checkpoint at step 315000
Epoch 1/10, Batch 315000, Step 315000
  Total Loss: 78.0299
  Hard Loss: 2.1094
  Soft Loss: 1.7635
  Attention Loss: 0.0000
  Feature Loss: 760.0000
  Current Alpha: 0.221
  Current Temperature: 3.97
  Learning Rate: 1.00e-06
  Frozen Parameters: 37/37
  GPU Memory: 12.11 GB


Epoch 1/10:  22%|▏| 319153/1437851 [1:09:11<91:08:54,  3.41it/s, Loss=70.2340, H


KeyboardInterrupt: 

In [30]:
try:
    # Save model separately for HF
    student_model.save_pretrained(os.path.join(checkpoint_dir, "hf_model"))
    
    # Create or update HF model card
    with open(os.path.join(checkpoint_dir, "hf_model", "README.md"), "w") as f:
        f.write(f"Model checkpoint at step {global_step}\n")
        f.write(f"Training config: {wandb.config}\n")
    
    # Push to Hub
    api = HfApi()
    api.upload_folder(
        folder_path=os.path.join(checkpoint_dir, "hf_model"),
        repo_id=HF_REPO_ID,
        repo_type="model",
        commit_message=f"Checkpoint at step {global_step}"
    )
    print(f"Pushed checkpoint to HF Hub at step {global_step}")
except Exception as e:
    print(f"Failed to push to HF Hub: {e}")

- empty or missing yaml metadata in repo card


model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

Pushed checkpoint to HF Hub at step 319154


In [31]:
# student_model.push_to_hub("MohamedAhmedAE/distil_MedGemma_4B_Llama-3.2-1B")
# tokenizer.push_to_hub("MohamedAhmedAE/distil_MedGemma_4B_Llama-3.2-1B")

[1;34mwandb[0m: 
[1;34mwandb[0m: 🚀 View run [33mdistillation-3egin9y2[0m at: [34mhttps://wandb.ai/mohamed-ahmed/distil_MedGemma_4B_Llama-3.2-1B/runs/uf9zf7s6[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20250727_152754-uf9zf7s6/logs[0m


In [None]:
def aggressive_memory_clear():
    """
    Aggressively clear CUDA memory and Python objects
    Use this if you want maximum memory clearing before starting training
    """
    import gc
    
    print("Performing aggressive memory clearing...")
    
    # Clear Python objects
    gc.collect()
    
    if torch.cuda.is_available():
        # Show initial memory
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            cached = torch.cuda.memory_reserved(i) / 1024**3
            print(f"GPU {i} - Before clearing: {allocated:.2f}GB allocated, {cached:.2f}GB cached")
        
        # Multiple rounds of clearing
        for round_num in range(3):
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            gc.collect()
        
        # Reset all memory statistics
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        
        # Final clearing
        torch.cuda.empty_cache()
        
        # Show final memory
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            cached = torch.cuda.memory_reserved(i) / 1024**3
            print(f"GPU {i} - After clearing: {allocated:.2f}GB allocated, {cached:.2f}GB cached")
    
    print("Aggressive memory clearing completed!")
    print("-" * 60)

aggressive_memory_clear()