In [1]:
!pip install transformers datasets wandb



In [2]:
!pip install tiktoken



In [3]:
!pip install ipywidgets






In [4]:

!pip install sentencepiece transformers --upgrade






In [5]:
!pip install bitsandbytes



In [6]:
# 1. Initialize environment and imports correctly
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Enable synchronous CUDA errors

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM, DistilBertForMaskedLM, AlbertTokenizer, AutoConfig
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
import logging
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from contextlib import nullcontext
import pandas as pd
from torch.cuda.amp import GradScaler
import bitsandbytes as bnb


In [7]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("training.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)


In [8]:
# Cell 219: Config (Add distillation hyperparameters)
import os
import torch

class Config:
    def __init__(self):
        # Model parameters
        self.teacher_model_name = "ai4bharat/indic-bert"
        self.student_model_name = "distilbert-base-multilingual-cased"
        self.max_length = 64
        self.batch_size = 32 # Reduced from 32 in logs, check GPU memory
        self.learning_rate = 5e-5 # Teacher LR
        self.weight_decay = 0.01
        self.teacher_epochs = 1
        self.distillation_epochs = 1 # Student epochs
        self.warmup_steps = 500 # Reduced warmup
        self.gradient_accumulation_steps = 8
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Language parameters
        self.teacher_lang = "hi" # Hindi
        self.student_lang = "hne" # Chhattisgarhi

        # --- Distillation / RL parameters ---
        self.rl_lr = 3e-5 # Student LR (Adjusted)
        # Knowledge Distillation (KD) specific parameters
        self.distillation_temp = 2.0 # Temperature for softening probabilities
        self.distill_loss_weight = 0.5 # Weight for KLDiv loss vs CE loss (0.0 means only CE, 1.0 means only KD)
        self.ce_loss_weight = 1.0 - self.distill_loss_weight # Weight for CrossEntropy loss with hard labels

        # (Optional RL params - might not be needed if focusing on KD)
        self.gamma = 0.99 # Discount factor (only if using RL rewards)
        self.entropy_coef = 0.01 # Entropy coefficient (only if using RL action sampling loss)

        self.few_shot_examples = 5 # Kept from original

        # Dataset
        self.train_test_split = 0.1

        # Paths
        self.output_dir = "output/"
        self.log_dir = "logs/"
        self.model_dir = "models/"

        # Evaluation & Checkpointing
        self.eval_every = 200 # Evaluate every N **update steps** (not batches)
        self.save_every = 500 # Save model every N **update steps**
        self.early_stopping_patience = 3 # Based on validation accuracy

        # --- Numerical Stability ---
        self.logit_clamp_value = 30.0 # Clamp logits to prevent extreme values before softmax/log_softmax
        self.label_smoothing = 0.1 # Label smoothing for CrossEntropyLoss

        # Create directories
        os.makedirs(self.output_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.model_dir, exist_ok=True)



In [9]:
from transformers import AutoConfig, AutoModelForMaskedLM

class TeacherModel(nn.Module):
    def __init__(self, model_name, tokenizer_vocab_size=None):
        super().__init__()
        # Load config first
        self.config = AutoConfig.from_pretrained(model_name)
        original_vocab_size = self.config.vocab_size
        logger.info(f"Original model vocab size: {original_vocab_size}")
        
        # Initialize model on CPU with float32
        self.bert_model = AutoModelForMaskedLM.from_pretrained(
            model_name,
            config=self.config,
            torch_dtype=torch.float32  # Force FP32 initialization
        )
        
        # Resize embeddings if needed BEFORE moving to GPU
        if tokenizer_vocab_size and tokenizer_vocab_size != self.config.vocab_size:
            logger.info(f"Resizing embeddings from {self.config.vocab_size} to {tokenizer_vocab_size}")
            self._safe_resize_embeddings(tokenizer_vocab_size)
            
            # Verify resize was successful
            if self.bert_model.config.vocab_size != tokenizer_vocab_size:
                logger.error(f"Resize failed! Current size: {self.bert_model.config.vocab_size}")
                raise ValueError("Embedding resize operation failed")
        
        # Initialize adapter layer with verified vocab size
        self.hidden_size = self.config.hidden_size
        self.vocab_size = self.bert_model.config.vocab_size
        self.next_token_adapter = nn.Linear(self.hidden_size, self.vocab_size)
        logger.info(f"TeacherModel initialized with vocab_size={self.vocab_size}")
    
    def _safe_resize_embeddings(self, new_vocab_size):
        """Improved embedding resize with better validation"""
        old_embeddings = self.bert_model.get_input_embeddings()
        old_size = old_embeddings.num_embeddings
        
        # Skip if new size is smaller (would truncate vocabulary)
        if new_vocab_size <= old_size:
            logger.warning(f"New vocab size {new_vocab_size} <= current size {old_size}. Skipping resize.")
            return
        
        # Perform resize
        new_embeddings = self.bert_model.resize_token_embeddings(new_vocab_size)
        
        # Initialize new embedding weights with normal distribution
        if new_vocab_size > old_size:
            with torch.no_grad():
                # Calculate statistics of existing embeddings for initialization
                mean = old_embeddings.weight.data.mean().item()
                std = old_embeddings.weight.data.std().item()
                
                # Initialize new tokens with appropriate distribution
                new_embeddings.weight.data[old_size:] = torch.normal(
                    mean=mean, std=std,
                    size=(new_vocab_size - old_size, self.config.hidden_size)
                )
        
        # Update all necessary config values
        self.bert_model.config.vocab_size = new_vocab_size
        self.config.vocab_size = new_vocab_size
        logger.info(f"Embeddings successfully resized to {new_vocab_size}")


    def forward(self, input_ids, attention_mask=None):
        # Validate input tensor indices
        if torch.min(input_ids) < 0 or torch.max(input_ids) >= self.vocab_size:
            min_id = torch.min(input_ids).item()
            max_id = torch.max(input_ids).item()
            logger.error(f"Invalid input_ids detected! Range: [{min_id}, {max_id}], Vocab size: {self.vocab_size}")
            raise ValueError(f"Input IDs must be within range [0, {self.vocab_size-1}]")
        
        # Create token_type_ids tensor (required for Albert models)
        batch_size, seq_length = input_ids.shape
        token_type_ids = torch.zeros(batch_size, seq_length, 
                                   dtype=torch.long, device=input_ids.device)
        
        # Use try/except to catch any forward pass errors
        try:
            with torch.amp.autocast(device_type=input_ids.device.type):
                outputs = self.bert_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    output_hidden_states=True
                )
                
                last_hidden_state = outputs.hidden_states[-1]
                return self.next_token_adapter(last_hidden_state)
        except Exception as e:
            logger.error(f"Error in forward pass: {e}")
            raise


In [10]:
class StudentModel(nn.Module):
    def __init__(self, model_name):
        super(StudentModel, self).__init__()
        # Use DistilBERT as the base model
        self.model = DistilBertForMaskedLM.from_pretrained(model_name)
        
        # Add a next token prediction head
        self.vocab_size = self.model.config.vocab_size
        self.hidden_size = self.model.config.dim
        self.next_token_head = nn.Linear(self.hidden_size, self.vocab_size)
        
    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]
        logits = self.next_token_head(hidden_states)
        return logits
    
    # Method for RL action and value calculation
    def get_action_and_value(self, input_ids, attention_mask=None):
        # Get logits for the entire sequence
        logits = self.forward(input_ids, attention_mask)
        
        # Extract only the last token position logits for next token prediction
        last_token_logits = logits[:, -1, :]
        
        # Get probabilities
        probs = torch.softmax(last_token_logits, dim=-1)
        
        # Sample from the distribution
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()  # Shape: [batch_size]
        
        # Get log probability of the action
        log_prob = dist.log_prob(action)  # Shape: [batch_size]
        
        # Calculate entropy for exploration encouragement
        entropy = dist.entropy().mean()  # Scalar
        
        return action, log_prob, entropy, probs


In [11]:
# Cell 8: DataProcessor (Refined Pad Token and Vocab Size Logging)
class DataProcessor:
    def __init__(self, config):
        self.config = config
        
        # Load tokenizer first
        self._load_tokenizers()
        
        # Verify embedding compatibility
        self._verify_embedding_sizes()

    def _load_tokenizers(self):
        # Teacher tokenizer with AlbertTokenizer
        self.teacher_tokenizer = AlbertTokenizer.from_pretrained(
            self.config.teacher_model_name,
            keep_accents=True
        )
        
        # Add [PAD] token if missing
        if self.teacher_tokenizer.pad_token is None:
            self.teacher_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            
        self.teacher_pad_id = self.teacher_tokenizer.pad_token_id
        self.teacher_vocab_size = len(self.teacher_tokenizer)

        # Student tokenizer (DistilBERT)
        self.student_tokenizer = AutoTokenizer.from_pretrained(
            self.config.student_model_name
        )

    def _verify_embedding_sizes(self):
        """Ensure tokenizer and model vocab sizes match before model creation"""
        teacher_config = AutoConfig.from_pretrained(self.config.teacher_model_name)
        if teacher_config.vocab_size != self.teacher_vocab_size:
            logger.warning(f"Tokenizer vocab ({self.teacher_vocab_size}) ≠ model vocab ({teacher_config.vocab_size})")
    def load_dataset(self):
        # ... (Keep trust_remote_code=True) ...
        try:
            logger.info("Loading NLLB dataset for Hindi-Chhattisgarhi pair")
            nllb_dataset = load_dataset("allenai/nllb", "hin_Deva-hne_Deva", trust_remote_code=True)
            hindi_samples = [{"lang": self.config.teacher_lang, "text": item['translation']['hin_Deva']} for item in nllb_dataset["train"]]
            chhattisgarhi_samples = [{"lang": self.config.student_lang, "text": item['translation']['hne_Deva']} for item in nllb_dataset["train"]]
            logger.info(f"Loaded {len(hindi_samples)} Hindi, {len(chhattisgarhi_samples)} Chhattisgarhi samples")
            return hindi_samples, chhattisgarhi_samples
        except Exception as e:
            logger.error(f"Error loading dataset: {e}", exc_info=True) # Log traceback
            raise


In [12]:
# Assume this function is defined in a previous cell (e.g., Cell 225) or at the top of Cell 233
import numpy as np
from tqdm import tqdm # Make sure tqdm is imported

def validate_dataset(dataset, tokenizer, name="dataset", sample_size=None):
    """Ensure all token IDs in the dataset are valid against the tokenizer's vocab size."""
    vocab_size = len(tokenizer)
    logger.info(f"Validating {name} (Tokenizer Vocab Size: {vocab_size})...")

    num_samples_to_check = len(dataset) if sample_size is None else min(sample_size, len(dataset))
    indices_to_check = range(num_samples_to_check) if sample_size is None else np.random.choice(len(dataset), num_samples_to_check, replace=False)
    logger.info(f"Checking {num_samples_to_check} samples.")

    invalid_count = 0
    first_invalid_idx = -1
    first_invalid_ids = None

    for i in tqdm(indices_to_check, desc=f"Validating {name}"):
        try:
            sample = dataset[i]
            # Ensure 'input_ids' exists and is a tensor
            if 'input_ids' not in sample or not isinstance(sample['input_ids'], torch.Tensor):
                 logger.warning(f"Sample {i} in {name} is missing 'input_ids' tensor. Skipping.")
                 invalid_count +=1 # Count as invalid if structure is wrong
                 if first_invalid_idx == -1: first_invalid_idx = i
                 continue

            input_ids = sample['input_ids']

            # Check for invalid IDs (negative or >= vocab_size)
            min_id_val = torch.min(input_ids).item()
            max_id_val = torch.max(input_ids).item()

            if min_id_val < 0 or max_id_val >= vocab_size:
                invalid_count += 1
                if first_invalid_idx == -1: # Log first error in detail
                    first_invalid_idx = i
                    first_invalid_ids = input_ids.tolist() # Get the problematic IDs
                    logger.error(f"!!! Invalid token ID found in {name} at index {i} !!!")
                    logger.error(f"    Range Found: [{min_id_val}, {max_id_val}], Required Range: [0, {vocab_size-1}]")
                    logger.error(f"    Problematic IDs (sample): {first_invalid_ids[:20]}...") # Show beginning of IDs
                    # Try decoding for context
                    try:
                        # Filter out only potentially valid IDs for decoding attempt
                        valid_range_ids = [id_val for id_val in first_invalid_ids if 0 <= id_val < vocab_size]
                        if valid_range_ids:
                             decoded = tokenizer.decode(valid_range_ids, skip_special_tokens=False)
                             logger.error(f"    Partial Decode Attempt of valid-range IDs: '{decoded}'")
                        else:
                             logger.error(f"    Cannot decode, no IDs were within the valid range [0, {vocab_size-1}]")
                    except Exception as decode_e:
                        logger.error(f"    Could not decode IDs: {decode_e}")

        except IndexError:
            logger.error(f"IndexError accessing sample {i} in {name}. Dataset length reported as {len(dataset)}.")
            invalid_count += 1
            if first_invalid_idx == -1: first_invalid_idx = i
        except Exception as e:
            logger.error(f"Unexpected error processing item at index {i} in {name}: {e}", exc_info=True)
            invalid_count += 1 # Count other errors as invalid
            if first_invalid_idx == -1: first_invalid_idx = i

    if invalid_count > 0:
        logger.error(f"Validation FAILED for {name}. Found {invalid_count} invalid samples out of {num_samples_to_check} checked.")
        logger.error(f"First invalid sample detected at index: {first_invalid_idx}")
        return False
    else:
        logger.info(f"Dataset {name} validation passed ({num_samples_to_check} samples checked).")
        return True


In [13]:
# Cell 9: NextWordPredictionDataset (Add check for pad token ID)
class NextWordPredictionDataset(Dataset):
    def __init__(self, samples, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.pad_token_id = tokenizer.pad_token_id
        self.vocab_size = len(tokenizer) # Store vocab size

        if self.pad_token_id is None:
             logger.error(f"CRITICAL: Tokenizer {tokenizer.name_or_path} provided to dataset without a valid pad_token_id!")
             # Attempt recovery or raise error earlier? For now, use default 0.
             self.pad_token_id = 0

        self.examples = [s["text"] for s in samples if isinstance(s.get("text"), str) and len(s["text"]) > 10]
        if len(self.examples) < len(samples):
             logger.warning(f"Filtered out {len(samples) - len(self.examples)} samples due to length or type.")
        logger.info(f"Created dataset with {len(self.examples)} examples")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        if idx >= len(self.examples):
            raise IndexError(f"Index {idx} out of bounds")

        text = self.examples[idx]
        try:
            encoding = self.tokenizer(
                text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
            input_ids = encoding["input_ids"].squeeze(0) # Remove batch dim
            attention_mask = encoding["attention_mask"].squeeze(0)

            # Check for invalid IDs immediately after tokenization
            if torch.any(input_ids >= self.vocab_size) or torch.any(input_ids < 0):
                 logger.error(f"Invalid IDs found *during tokenization* for index {idx}, text: '{text[:100]}...'")
                 logger.error(f"Min: {torch.min(input_ids)}, Max: {torch.max(input_ids)}, Vocab Size: {self.vocab_size}")
                 # Replace invalid IDs with UNK token ID? Or skip sample? Skipping is safer.
                 # For now, let's just log and proceed, the check in train_teacher will catch it.
                 # Or raise an error here:
                 # raise ValueError("Invalid token IDs generated by tokenizer.")

            # Handle empty/short sequences after tokenization/padding
            if input_ids.numel() <= 1:
                bos = self.tokenizer.cls_token_id if self.tokenizer.cls_token_id is not None else self.tokenizer.bos_token_id
                eos = self.tokenizer.sep_token_id if self.tokenizer.sep_token_id is not None else self.tokenizer.eos_token_id
                bos = bos if bos is not None else self.pad_token_id
                eos = eos if eos is not None else self.pad_token_id
                input_ids = torch.tensor([bos] + [self.pad_token_id] * (self.max_length - 1), dtype=torch.long)
                attention_mask = torch.tensor([1] + [0] * (self.max_length - 1), dtype=torch.long)
                logger.warning(f"Handling short/empty sequence for index {idx}")

            # Create labels (shifted input_ids)
            labels = input_ids.clone()
            # Shift - use pad_token_id for positions where we don't predict
            labels[:-1] = input_ids[1:]
            labels[-1] = self.pad_token_id # No label for the last token prediction

            # Input should be sequence up to second-to-last token
            model_input_ids = input_ids[:-1]
            model_attention_mask = attention_mask[:-1]
            model_labels = labels[:-1] # Labels correspond to model_input_ids positions

            return {
                "input_ids": model_input_ids,
                "attention_mask": model_attention_mask,
                "labels": model_labels
            }
        except Exception as e:
            logger.error(f"Error processing item at index {idx}, text: '{text[:100]}...': {e}", exc_info=True)
            # Return a dummy item or skip? Returning dummy might hide errors.
            # Let's create a fully padded dummy item.
            pad_id = self.pad_token_id
            dummy_input_ids = torch.full((self.max_length - 1,), pad_id, dtype=torch.long)
            dummy_attn_mask = torch.zeros((self.max_length - 1,), dtype=torch.long)
            dummy_labels = torch.full((self.max_length - 1,), pad_id, dtype=torch.long)
            return {"input_ids": dummy_input_ids, "attention_mask": dummy_attn_mask, "labels": dummy_labels}




In [14]:
class FewShotDataset(Dataset):
    def __init__(self, samples, tokenizer, few_shot_examples=5, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.few_shot_examples = few_shot_examples
        
        # Ensure pad token is set
        if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'):
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Process samples
        self.examples = []
        self.few_shot_pool = []
        
        # Select few-shot examples
        valid_samples = [s for s in samples if len(s["text"]) > 10]
        if len(valid_samples) > few_shot_examples:
            self.few_shot_pool = valid_samples[:few_shot_examples]
            self.examples = valid_samples[few_shot_examples:]
        else:
            self.examples = valid_samples
            
        logger.info(f"Created few-shot dataset with {len(self.examples)} examples and {len(self.few_shot_pool)} few-shot examples")
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        if idx >= len(self.examples):
            raise IndexError(f"Index {idx} out of bounds for dataset with {len(self.examples)} examples")
        
        # Get the target example
        target_text = self.examples[idx]["text"]
        
        # Build few-shot context with examples and their completions
        few_shot_context = ""
        for example in self.few_shot_pool:
            # Split text to simulate next word prediction
            words = example["text"].split()
            if len(words) > 1:
                prefix = " ".join(words[:-1])
                target = words[-1]
                few_shot_context += f"Text: {prefix}\nNext word: {target}\n\n"
        
        # Add the target without completion
        target_words = target_text.split()
        if len(target_words) > 1:
            prefix = " ".join(target_words[:-1])
            target = target_words[-1]
            few_shot_context += f"Text: {prefix}\nNext word:"
            
            # Tokenize few-shot context
            context_encoding = self.tokenizer(few_shot_context, max_length=self.max_length, 
                                            padding="max_length", truncation=True, return_tensors="pt")
            
            # Tokenize target (the expected next word)
            target_encoding = self.tokenizer(target, return_tensors="pt")
            
            # Ensure we have valid tensors
            input_ids = context_encoding["input_ids"].squeeze()
            attention_mask = context_encoding["attention_mask"].squeeze()
            target_id = target_encoding["input_ids"].squeeze()[0]  # Just take the first token
            
            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": target_id
            }
        else:
            # Handle edge case of single word
            return self.__getitem__((idx + 1) % len(self))


In [15]:
def compute_reward(teacher_probs, student_action, target_tokens, alpha=0.7):
    """Improved reward function with better learning signals"""
    batch_size = teacher_probs.shape[0]
    reward = torch.zeros(batch_size, device=teacher_probs.device)
    
    for i in range(batch_size):
        # Higher reward for matching the target
        if student_action[i].item() == target_tokens[i].item():
            reward[i] += 5.0
        else:
            # Small penalty for wrong answers to speed up learning
            reward[i] -= 0.1
        
        # Add partial reward for being close (teacher had high probability for the correct token)
        if target_tokens[i].item() < teacher_probs.shape[1]:
            teacher_confidence_for_correct = teacher_probs[i, target_tokens[i].item()]
            reward[i] += alpha * teacher_confidence_for_correct
        
        # Add smaller reward based on teacher probability of student's action
        token_idx = student_action[i].item()
        if token_idx < teacher_probs.shape[1]:
            teacher_confidence = teacher_probs[i, token_idx]
            reward[i] += (alpha * 0.5) * teacher_confidence
    
    return reward


In [16]:
# Cell 228: train_teacher (Fix gradient checkpointing call, add 8-bit AdamW, fix GradScaler)

# Ensure necessary imports are available
from torch.cuda.amp import GradScaler # Keep old import if using older PyTorch, prefer torch.amp.cuda otherwise
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW # Keep standard AdamW import for fallback if needed
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import os
import logging

# Attempt to import bitsandbytes for 8-bit AdamW
try:
    import bitsandbytes as bnb
    bnb_available = True
    logger.info("bitsandbytes found, will use AdamW8bit.")
except ImportError:
    logger.warning("bitsandbytes not found. Falling back to standard AdamW. "
                   "Install with 'pip install bitsandbytes' for memory savings.")
    bnb_available = False

logger = logging.getLogger(__name__) # Ensure logger is defined


def train_teacher(model, train_loader, val_loader, config, data_processor): # Added data_processor
    logger.info("Starting teacher model fine-tuning for Hindi...")
    model.to(config.device)

    # --- FIX: Call gradient checkpointing on the underlying bert_model ---
    if hasattr(model, 'bert_model') and config.device.type == 'cuda':
         try:
             model.bert_model.gradient_checkpointing_enable()
             logger.info("Gradient Checkpointing Enabled on model.bert_model.")
         except AttributeError:
              logger.warning("model.bert_model does not have gradient_checkpointing_enable. Skipping.")
         except Exception as e:
              logger.error(f"Error enabling gradient checkpointing: {e}. Skipping.")
    elif config.device.type == 'cuda':
         logger.warning("Gradient checkpointing requested but model object doesn't have 'bert_model' attribute.")
    # ---------------------------------------------------------------------

    # --- Use 8-bit AdamW if available, otherwise standard AdamW ---
    if bnb_available:
        optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    else:
        optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    # ----------------------------------------------------------------

    # Adjust total_steps for gradient accumulation
    num_update_steps_per_epoch = len(train_loader) // config.gradient_accumulation_steps if len(train_loader) > 0 else 0
    total_steps = num_update_steps_per_epoch * config.teacher_epochs
    num_warmup_steps = min(config.warmup_steps, total_steps // 10) if total_steps > 0 else 0
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=max(1, total_steps))

    # --- Use modern torch.amp.GradScaler initialization ---
    # scaler = GradScaler(enabled=(config.device.type == 'cuda')) # Old/Deprecated style
    scaler = torch.amp.GradScaler(device_type='cuda', enabled=(config.device.type == 'cuda'))
    logger.info(f"GradScaler enabled: {scaler.is_enabled()}")
    # -------------------------------------------------------

    # Get pad_token_id and vocab_size
    pad_token_id = data_processor.teacher_pad_id
    teacher_vocab_size = model.config.vocab_size # Assumes model.config reflects the potentially resized vocab

    if pad_token_id is None:
        logger.error("CRITICAL: Teacher pad_token_id is None. Using ignore_index=-100.")
        pad_token_id = -100
    else:
        logger.info(f"Using ignore_index={pad_token_id} for teacher loss.")

    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token_id)

    best_loss = float('inf')
    early_stopping_counter = 0
    global_step = 0

    for epoch in range(config.teacher_epochs):
        model.train()
        epoch_loss = 0.0
        optimizer.zero_grad() # Zero gradients at the start

        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.teacher_epochs}")):
            try:
                input_ids = batch["input_ids"].to(config.device)
                attention_mask = batch["attention_mask"].to(config.device)
                labels = batch["labels"].to(config.device)

                # --- Input Validation (Keep as before) ---
                min_id, max_id = torch.min(input_ids), torch.max(input_ids)
                if max_id >= teacher_vocab_size or min_id < 0:
                    # ... (error logging and skipping logic) ...
                    logger.error(f"!!!!! Invalid teacher input_ids in batch {batch_idx+1}: Range [{min_id.item()},{max_id.item()}], Vocab Size {teacher_vocab_size}. Skipping.")
                    # Optional: Decode for debugging
                    if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                    continue

                min_lbl, max_lbl = torch.min(labels), torch.max(labels)
                if max_lbl >= teacher_vocab_size or (min_lbl < 0 and min_lbl != pad_token_id):
                    # ... (error logging and skipping logic) ...
                    logger.error(f"!!!!! Invalid teacher labels in batch {batch_idx+1}: Range [{min_lbl.item()},{max_lbl.item()}], Vocab Size {teacher_vocab_size}, Pad ID {pad_token_id}. Skipping.")
                    if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                    continue
                # --- End Input Validation ---

                # Use autocast context manager
                context = torch.amp.autocast(device_type=config.device.type, dtype=torch.float16, enabled=(config.device.type == 'cuda'))
                with context:
                    logits = model(input_ids, attention_mask)
                    loss = loss_fn(logits.view(-1, teacher_vocab_size), labels.view(-1))

                    # Check loss value before scaling
                    if torch.isnan(loss) or torch.isinf(loss):
                         logger.error(f"NaN or Inf loss detected *before* scaling at Epoch {epoch+1}, Batch {batch_idx+1}! Skipping step.")
                         # Optionally reduce scaler scale aggressively or reset optimizer state if persistent
                         # scaler.update(max(scaler.get_scale() / 4.0, 1e-4)) # Example reduction
                         optimizer.zero_grad() # Zero grad for this cycle
                         continue # Skip backward and step

                    loss = loss / config.gradient_accumulation_steps # Scale loss for accumulation

                # Scaler operations outside the autocast context
                scaler.scale(loss).backward()

                if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                    # Unscale, clip, step, update
                    scaler.unscale_(optimizer)
                    # Consider clipping *after* checking grad norm if issues persist
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scale_before_update = scaler.get_scale()
                    scaler.update()
                    scale_after_update = scaler.get_scale()
                    optimizer.zero_grad() # Zero grad after successful step
                    scheduler.step()
                    global_step += 1

                    # Log scale changes if scale was skipped/reduced
                    if not scaler._found_inf_per_device(optimizer).all() and scale_after_update < scale_before_update :
                         logger.warning(f"Scaler reduced scale to {scale_after_update:.1f} at step {global_step} (Grad Norm: {grad_norm:.4f})")
                    elif scaler._found_inf_per_device(optimizer).any():
                         logger.warning(f"Scaler skipped step due to Inf/NaN grads at step {global_step} (Grad Norm: {grad_norm:.4f}). Scale still {scale_after_update:.1f}")


                # Accumulate unscaled loss for logging
                # Use loss_to_log = loss.item() * config.gradient_accumulation_steps
                # Check item() availability before calling
                try:
                     loss_item = loss.item()
                     epoch_loss += loss_item * config.gradient_accumulation_steps
                except Exception as item_err:
                     logger.error(f"Could not get loss.item() at batch {batch_idx+1}: {item_err}")
                     loss_item = float('nan')


                if (batch_idx + 1) % 50 == 0: # Log every 50 micro-batches
                    step_loss_log = loss_item * config.gradient_accumulation_steps if not np.isnan(loss_item) else float('nan')
                    logger.info(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, Step Loss: {step_loss_log:.4f}, Scale: {scaler.get_scale():.1f}")

            except RuntimeError as e:
                 if "CUDA out of memory" in str(e):
                     logger.error(f"CUDA OOM error during training batch {batch_idx+1}. Consider reducing batch size further or enabling gradient checkpointing.")
                     if config.device.type == 'cuda': torch.cuda.empty_cache()
                     # VERY IMPORTANT: Reset accumulation if OOM happens mid-cycle
                     if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                     continue # Skip batch
                 else:
                     logger.error(f"RuntimeError during training batch {batch_idx+1}: {e}", exc_info=True)
                     if config.device.type == 'cuda': torch.cuda.empty_cache()
                     if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                     continue # Skip batch on other runtime errors

            except Exception as e:
                logger.error(f"Error during training batch {batch_idx+1}: {e}", exc_info=True)
                if config.device.type == 'cuda': torch.cuda.empty_cache()
                if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                continue

        # --- End of Epoch ---
        # Perform final optimizer step if accumulation ended mid-epoch
        if (len(train_loader) % config.gradient_accumulation_steps) != 0:
             logger.info("Performing final optimizer step for remaining accumulated gradients.")
             try:
                 scaler.unscale_(optimizer)
                 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                 scaler.step(optimizer)
                 scaler.update()
                 scheduler.step() # Step scheduler here too
                 optimizer.zero_grad()
                 global_step += 1
             except Exception as final_step_e:
                 logger.error(f"Error during final optimizer step: {final_step_e}")


        # Calculate average epoch loss based on successful steps
        avg_epoch_loss = epoch_loss / global_step if global_step > 0 else float('nan')

        # Validation - Pass the correct pad_token_id
        # Ensure evaluate_teacher is defined and handles potential NaNs correctly
        try:
             val_loss, val_perplexity, val_accuracy = evaluate_teacher(model, val_loader, config, pad_token_id)
        except Exception as eval_e:
             logger.error(f"Error during teacher evaluation: {eval_e}", exc_info=True)
             val_loss, val_perplexity, val_accuracy = float('nan'), float('nan'), float('nan')


        logger.info(f"Epoch {epoch+1} completed. Avg Train Loss: {avg_epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}, Val Accuracy: {val_accuracy:.4f}")

        # --- Checkpointing and Early Stopping (Keep as before) ---
        # ... (logic based on val_loss) ...
        if not np.isnan(val_loss) and val_loss < best_loss:
            best_loss = val_loss
            early_stopping_counter = 0
            torch.save(model.state_dict(), os.path.join(config.model_dir, "best_teacher_model.pt"))
            logger.info(f"New best teacher model saved with val loss: {best_loss:.4f}")
        elif not np.isnan(val_loss):
            early_stopping_counter += 1
            logger.info(f"Validation loss did not improve for {early_stopping_counter} epoch(s). Best loss: {best_loss:.4f}")
            if early_stopping_counter >= config.early_stopping_patience:
                logger.info(f"Early stopping triggered after {epoch+1} epochs")
                break
        else: # Handle NaN validation loss
            logger.warning(f"Validation loss is NaN at epoch {epoch+1}.")
            early_stopping_counter += 1
            logger.info(f"NaN validation loss encountered. Early stopping counter: {early_stopping_counter}/{config.early_stopping_patience}")
            if early_stopping_counter >= config.early_stopping_patience:
                logger.info(f"Early stopping triggered due to persistent NaN validation loss.")
                break


    # Load best model weights before returning
    best_model_path = os.path.join(config.model_dir, "best_teacher_model.pt")
    if os.path.exists(best_model_path):
        try:
            map_location = config.device
            model.load_state_dict(torch.load(best_model_path, map_location=map_location))
            logger.info(f"Loaded best teacher model weights from {best_model_path}.")
        except Exception as e:
             logger.error(f"Error loading best teacher model weights: {e}")
             logger.warning("Returning model from last trained epoch state due to load error.")
    else:
        logger.warning("No best teacher model checkpoint found. Returning model from last epoch.")

    # Disable gradient checkpointing before returning if it was enabled
    if hasattr(model, 'bert_model') and config.device.type == 'cuda':
         try:
             # Check if gradient checkpointing is actually enabled before trying to disable
             if getattr(model.bert_model.config, "use_cache", True) is False: # Common way transformers disables cache for GC
                 model.bert_model.gradient_checkpointing_disable()
                 logger.info("Gradient Checkpointing Disabled on model.bert_model.")
         except Exception as e:
             logger.warning(f"Could not disable gradient checkpointing: {e}")

    return model



2025-04-14 17:27:34,102 - __main__ - INFO - bitsandbytes found, will use AdamW8bit.


In [17]:
# Cell 14: evaluate_teacher (Added ignore_index and pad_token_id parameter)
def evaluate_teacher(model, val_loader, config, pad_token_id): # Added pad_token_id parameter
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []

    # *** FIX: Define loss function with ignore_index ***
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token_id if pad_token_id is not None else -100)

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating Teacher"):
            input_ids = batch["input_ids"].to(config.device)
            attention_mask = batch["attention_mask"].to(config.device)
            labels = batch["labels"].to(config.device)

            with torch.amp.autocast(device_type=config.device.type, enabled=(config.device.type == 'cuda')):
                logits = model(input_ids, attention_mask)
                # Calculate loss using pre-defined function
                loss = loss_fn(logits.view(-1, model.vocab_size), labels.view(-1))

            if not torch.isnan(loss): # Only accumulate valid losses
                total_loss += loss.item()
            else:
                logger.warning("NaN loss encountered during teacher evaluation.")

            preds = torch.argmax(logits, dim=-1).cpu().numpy()
            targets = labels.cpu().numpy()

            # *** FIX: Exclude padding tokens from accuracy calculation ***
            # Create a mask for non-padding tokens in the labels
            valid_indices = (targets != pad_token_id) if pad_token_id is not None else np.ones_like(targets, dtype=bool)

            # Flatten and apply mask
            all_preds.extend(preds.flatten()[valid_indices.flatten()])
            all_targets.extend(targets.flatten()[valid_indices.flatten()])

    # Calculate metrics, checking for division by zero or empty lists
    num_batches = len(val_loader)
    num_targets = len(all_targets)

    avg_loss = total_loss / num_batches if num_batches > 0 and not np.isnan(total_loss) else float('nan')
    perplexity = np.exp(avg_loss) if not np.isnan(avg_loss) else float('nan')
    accuracy = accuracy_score(all_targets, all_preds) if num_targets > 0 else 0.0

    return avg_loss, perplexity, accuracy


In [18]:
# Cell 228: train_teacher (Remove Grad Checkpointing, Fix GradScaler Init, Keep 8-bit AdamW)

# Ensure necessary imports are available
from torch.cuda.amp import GradScaler # Using the import from Cell 5
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW # Keep standard AdamW import for fallback if needed
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import os
import logging

# Attempt to import bitsandbytes for 8-bit AdamW
try:
    import bitsandbytes as bnb
    bnb_available = True
    logger.info("bitsandbytes found, will use AdamW8bit.")
except ImportError:
    logger.warning("bitsandbytes not found. Falling back to standard AdamW. "
                   "Install with 'pip install bitsandbytes' for memory savings.")
    bnb_available = False

logger = logging.getLogger(__name__) # Ensure logger is defined


def train_teacher(model, train_loader, val_loader, config, data_processor): # Added data_processor
    logger.info("Starting teacher model fine-tuning for Hindi...")
    model.to(config.device)

    # --- REMOVED Gradient Checkpointing attempt as ALBERT model doesn't support it ---
    # if hasattr(model, 'bert_model') and config.device.type == 'cuda':
    #      try:
    #          model.bert_model.gradient_checkpointing_enable()
    #          logger.info("Gradient Checkpointing Enabled on model.bert_model.")
    #      except AttributeError:
    #           logger.warning("model.bert_model does not have gradient_checkpointing_enable. Skipping.")
    #      except Exception as e:
    #           logger.error(f"Error enabling gradient checkpointing: {e}. Skipping.")
    # elif config.device.type == 'cuda':
    #      logger.warning("Gradient checkpointing requested but model object doesn't have 'bert_model' attribute.")
    logger.info("Gradient Checkpointing skipped (model does not support it or not requested).")
    # ---------------------------------------------------------------------------------

    # --- Use 8-bit AdamW if available, otherwise standard AdamW ---
    if bnb_available:
        optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    else:
        optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    # ----------------------------------------------------------------

    # Adjust total_steps for gradient accumulation
    num_update_steps_per_epoch = len(train_loader) // config.gradient_accumulation_steps if len(train_loader) > 0 else 0
    total_steps = num_update_steps_per_epoch * config.teacher_epochs
    num_warmup_steps = min(config.warmup_steps, total_steps // 10) if total_steps > 0 else 0
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=max(1, total_steps))

    # --- FIX: Revert GradScaler initialization to match the import ---
    # scaler = torch.amp.GradScaler(device_type='cuda', enabled=(config.device.type == 'cuda')) # Caused TypeError
    scaler = GradScaler(enabled=(config.device.type == 'cuda')) # Matches `from torch.cuda.amp import GradScaler`
    logger.info(f"GradScaler enabled: {scaler.is_enabled()}")
    # -------------------------------------------------------------

    # Get pad_token_id and vocab_size
    pad_token_id = data_processor.teacher_pad_id
    teacher_vocab_size = model.config.vocab_size # Assumes model.config reflects the potentially resized vocab

    if pad_token_id is None:
        logger.error("CRITICAL: Teacher pad_token_id is None. Using ignore_index=-100.")
        pad_token_id = -100
    else:
        logger.info(f"Using ignore_index={pad_token_id} for teacher loss.")

    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token_id)

    best_loss = float('inf')
    early_stopping_counter = 0
    global_step = 0

    for epoch in range(config.teacher_epochs):
        model.train()
        epoch_loss = 0.0
        optimizer.zero_grad() # Zero gradients at the start

        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.teacher_epochs}")):
            try:
                # --- Data Loading and Validation (Keep as before) ---
                input_ids = batch["input_ids"].to(config.device)
                attention_mask = batch["attention_mask"].to(config.device)
                labels = batch["labels"].to(config.device)

                # --- Input Validation (Keep as before) ---
                min_id, max_id = torch.min(input_ids), torch.max(input_ids)
                if max_id >= teacher_vocab_size or min_id < 0:
                     logger.error(f"!!!!! Invalid teacher input_ids in batch {batch_idx+1}: Range [{min_id.item()},{max_id.item()}], Vocab Size {teacher_vocab_size}. Skipping.")
                     if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                     continue

                min_lbl, max_lbl = torch.min(labels), torch.max(labels)
                if max_lbl >= teacher_vocab_size or (min_lbl < 0 and min_lbl != pad_token_id):
                     logger.error(f"!!!!! Invalid teacher labels in batch {batch_idx+1}: Range [{min_lbl.item()},{max_lbl.item()}], Vocab Size {teacher_vocab_size}, Pad ID {pad_token_id}. Skipping.")
                     if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                     continue
                # --- End Input Validation ---

                # Use autocast context manager
                context = torch.amp.autocast(device_type=config.device.type, dtype=torch.float16, enabled=(config.device.type == 'cuda'))
                with context:
                    logits = model(input_ids, attention_mask)
                    loss = loss_fn(logits.view(-1, teacher_vocab_size), labels.view(-1))

                    if torch.isnan(loss) or torch.isinf(loss):
                         logger.error(f"NaN or Inf loss detected *before* scaling at Epoch {epoch+1}, Batch {batch_idx+1}! Skipping step.")
                         # Resetting scaler state when NaN/Inf loss occurs before scaling
                         # We need to re-initialize GradScaler here if we want to reset its state,
                         # or just skip the update for this step. Skipping is simpler.
                         # scaler = GradScaler(init_scale=scaler.get_scale() / 2.0, enabled=(config.device.type == 'cuda')) # Example Re-init
                         optimizer.zero_grad()
                         continue

                    loss = loss / config.gradient_accumulation_steps

                # Scaler operations
                scaler.scale(loss).backward()

                if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                    scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer) # This step might be skipped internally if grads are inf/nan
                    scale_before_update = scaler.get_scale()
                    scaler.update() # Updates scale; decreases if inf/nan found, increases if stable
                    scale_after_update = scaler.get_scale()
                    optimizer.zero_grad()
                    scheduler.step()
                    global_step += 1

                    # Log scale changes (Adjusted condition for clarity)
                    # step_skipped = scaler._found_inf_per_device(optimizer).any()
                    # if step_skipped:
                    #     logger.warning(f"Scaler skipped step due to Inf/NaN grads at step {global_step} (Grad Norm: {grad_norm:.4f}). Scale remains {scale_after_update:.1f}")
                    # elif scale_after_update < scale_before_update:
                    #     logger.warning(f"Scaler reduced scale to {scale_after_update:.1f} at step {global_step} (Grad Norm: {grad_norm:.4f})")

                    if scale_after_update < scale_before_update:
                         # This implies the step might have been skipped or scale was reduced.
                         logger.warning(f"Scaler reduced scale from {scale_before_update:.1f} to {scale_after_update:.1f} at step {global_step} (Grad Norm: {grad_norm:.4f}). Potential instability or skipped step.")
                    # else: # Optional: Log normal updates
                        # logger.info(f"Scale updated to {scale_after_update:.1f}, Grad Norm: {grad_norm:.4f}")

                # Accumulate loss for logging
                try:
                     loss_item = loss.item()
                     epoch_loss += loss_item * config.gradient_accumulation_steps
                except Exception as item_err:
                     logger.error(f"Could not get loss.item() at batch {batch_idx+1}: {item_err}")
                     loss_item = float('nan')

                if (batch_idx + 1) % 50 == 0: # Log every 50 micro-batches
                    step_loss_log = loss_item * config.gradient_accumulation_steps if not np.isnan(loss_item) else float('nan')
                    logger.info(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, Step Loss: {step_loss_log:.4f}, Scale: {scaler.get_scale():.1f}")

            # --- Exception Handling (Keep as before) ---
            except RuntimeError as e:
                 if "CUDA out of memory" in str(e):
                     logger.error(f"CUDA OOM error during training batch {batch_idx+1}. Memory saving options in use: 8-bit Adam={'Yes' if bnb_available else 'No'}, Grad Checkpointing=No. Consider reducing batch_size further.")
                     if config.device.type == 'cuda': torch.cuda.empty_cache()
                     if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                     continue
                 else:
                     logger.error(f"RuntimeError during training batch {batch_idx+1}: {e}", exc_info=True)
                     if config.device.type == 'cuda': torch.cuda.empty_cache()
                     if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                     continue

            except Exception as e:
                logger.error(f"Error during training batch {batch_idx+1}: {e}", exc_info=True)
                if config.device.type == 'cuda': torch.cuda.empty_cache()
                if (batch_idx + 1) % config.gradient_accumulation_steps != 0: optimizer.zero_grad()
                continue
        # --- End Batch Loop ---

        # --- End of Epoch ---
        # Final optimizer step if needed
        if (len(train_loader) % config.gradient_accumulation_steps) != 0:
             logger.info("Performing final optimizer step for remaining accumulated gradients.")
             try:
                 scaler.unscale_(optimizer)
                 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                 scaler.step(optimizer)
                 scaler.update()
                 scheduler.step()
                 optimizer.zero_grad()
                 global_step += 1 # Ensure global step increments if final step happens
             except Exception as final_step_e:
                 logger.error(f"Error during final optimizer step: {final_step_e}")

        # Calculate average loss based on actual optimizer steps (global_step)
        avg_epoch_loss = epoch_loss / global_step if global_step > 0 else float('nan')

        # Validation
        try:
             # Ensure evaluate_teacher uses pad_token_id
             val_loss, val_perplexity, val_accuracy = evaluate_teacher(model, val_loader, config, pad_token_id)
        except Exception as eval_e:
             logger.error(f"Error during teacher evaluation: {eval_e}", exc_info=True)
             val_loss, val_perplexity, val_accuracy = float('nan'), float('nan'), float('nan')

        logger.info(f"Epoch {epoch+1} completed. Avg Train Loss: {avg_epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}, Val Accuracy: {val_accuracy:.4f}")

        # --- Checkpointing and Early Stopping (Keep as before) ---
        if not np.isnan(val_loss) and val_loss < best_loss:
            best_loss = val_loss
            early_stopping_counter = 0
            torch.save(model.state_dict(), os.path.join(config.model_dir, "best_teacher_model.pt"))
            logger.info(f"New best teacher model saved with val loss: {best_loss:.4f}")
        elif not np.isnan(val_loss):
            early_stopping_counter += 1
            logger.info(f"Validation loss did not improve for {early_stopping_counter} epoch(s). Best loss: {best_loss:.4f}")
            if early_stopping_counter >= config.early_stopping_patience:
                logger.info(f"Early stopping triggered after {epoch+1} epochs")
                break
        else: # Handle NaN validation loss
            logger.warning(f"Validation loss is NaN at epoch {epoch+1}.")
            early_stopping_counter += 1
            logger.info(f"NaN validation loss encountered. Early stopping counter: {early_stopping_counter}/{config.early_stopping_patience}")
            if early_stopping_counter >= config.early_stopping_patience:
                logger.info(f"Early stopping triggered due to persistent NaN validation loss.")
                break
    # --- End Epoch Loop ---

    # Load best model weights before returning
    best_model_path = os.path.join(config.model_dir, "best_teacher_model.pt")
    if os.path.exists(best_model_path):
        try:
            map_location = config.device
            model.load_state_dict(torch.load(best_model_path, map_location=map_location))
            logger.info(f"Loaded best teacher model weights from {best_model_path}.")
        except Exception as e:
             logger.error(f"Error loading best teacher model weights: {e}")
             logger.warning("Returning model from last trained epoch state due to load error.")
    else:
        logger.warning("No best teacher model checkpoint found. Returning model from last epoch.")

    return model



2025-04-14 17:27:34,126 - __main__ - INFO - bitsandbytes found, will use AdamW8bit.


In [19]:
# Cell 231: evaluate_student (Fix ignore_index, metrics calculation)
# Ensure imports: torch, nn, tqdm, logger, np, accuracy_score, precision_recall_fscore_support

# Cell 231: evaluate_student (Add shape debugging and checks)

# Ensure necessary imports are available
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import logging
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

logger = logging.getLogger(__name__) # Ensure logger is defined

def evaluate_student(model, val_loader, config, data_processor): # Added data_processor
    logger.info("Evaluating student model...")
    model.eval() # Set model to evaluation mode

    # Get student pad token ID for loss and metric calculation
    try:
        student_pad_id = data_processor.student_tokenizer.pad_token_id
        student_vocab_size = model.vocab_size # Assumes model has this correctly set
    except AttributeError as ae:
         logger.error(f"Error accessing attributes from data_processor or model in evaluate_student: {ae}")
         # Fallback or raise error - critical if pad_id/vocab_size are needed
         student_pad_id = -100 # Fallback ignore index
         try:
             student_vocab_size = model.config.vocab_size # Try getting from config
         except:
              logger.error("Cannot determine student vocab size for evaluation!")
              raise ValueError("Cannot determine student vocab size.")


    if student_pad_id is None:
        logger.warning("Student pad_token_id is None during evaluation. Using ignore_index=-100.")
        student_pad_id = -100
    else:
         logger.info(f"Evaluation using student pad_token_id: {student_pad_id}")


    total_loss = 0.0
    total_batches = 0
    all_preds_eval = []
    all_labels_eval = []

    # Define loss function with ignore_index for evaluation
    loss_fn = nn.CrossEntropyLoss(ignore_index=student_pad_id)

    with torch.no_grad(): # Disable gradient calculations
        for batch_idx, batch in enumerate(tqdm(val_loader, desc="Evaluating Student")):
            try:
                # --- 1. Data Loading and Basic Checks ---
                if "input_ids" not in batch or "attention_mask" not in batch or "labels" not in batch:
                    logger.warning(f"Skipping eval batch {batch_idx} due to missing keys.")
                    continue

                input_ids = batch["input_ids"].to(config.device)
                attention_mask = batch["attention_mask"].to(config.device)
                labels = batch["labels"].to(config.device)

                # --- 2. Shape Debugging and Validation ---
                logger.debug(f"Eval Batch {batch_idx}: input_ids shape {input_ids.shape}, attention_mask shape {attention_mask.shape}, labels shape {labels.shape}")

                # Check for consistent batch size and sequence length across inputs
                batch_size = input_ids.shape[0]
                seq_len = input_ids.shape[1]

                if batch_size == 0:
                    logger.warning(f"Skipping eval batch {batch_idx} due to zero batch size.")
                    continue

                if attention_mask.shape != (batch_size, seq_len):
                    logger.warning(f"Skipping eval batch {batch_idx} due to inconsistent attention_mask shape: {attention_mask.shape}")
                    continue
                # >>> CRITICAL CHECK for the ValueError <<<
                if labels.shape != (batch_size, seq_len):
                    logger.error(f"CRITICAL SHAPE MISMATCH in eval batch {batch_idx}: Labels shape is {labels.shape}, but expected ({batch_size}, {seq_len}).")
                    logger.error("This is likely the cause of the ValueError. Check FewShotDataset or DataLoader collation for 'labels'. Skipping batch.")
                    continue # Skip this batch as loss calculation will fail

                # --- 3. Forward Pass ---
                context = torch.amp.autocast(device_type=config.device.type, dtype=torch.float16, enabled=(config.device.type == 'cuda'))
                with context:
                    logits = model(input_ids, attention_mask)

                # Check logits shape
                expected_logit_shape = (batch_size, seq_len, student_vocab_size)
                if logits.shape != expected_logit_shape:
                     logger.error(f"Eval Batch {batch_idx}: Unexpected logits shape. Got {logits.shape}, Expected {expected_logit_shape}. Skipping.")
                     continue

                # --- 4. Loss Calculation ---
                # Reshape for loss - dimensions should now match if shapes were correct above
                logits_view = logits.view(-1, student_vocab_size) # Shape: (batch_size * seq_len, vocab_size)
                labels_view = labels.view(-1) # Shape: (batch_size * seq_len)

                # Double check before loss - this should not fail if previous check passed
                if logits_view.shape[0] != labels_view.shape[0]:
                     logger.error(f"UNEXPECTED shape mismatch right before loss in eval batch {batch_idx}: "
                                  f"Logits_view[0]={logits_view.shape[0]}, Labels_view[0]={labels_view.shape[0]}. Skipping loss.")
                     loss = torch.tensor(float('nan')) # Assign NaN
                else:
                     loss = loss_fn(logits_view, labels_view)


                # Accumulate loss (only if valid)
                if not torch.isnan(loss) and not torch.isinf(loss):
                    total_loss += loss.item()
                    total_batches += 1
                else:
                     logger.warning(f"NaN or Inf loss calculated during student evaluation batch {batch_idx}.")


                # --- 5. Store Predictions for Metrics ---
                preds = torch.argmax(logits, dim=-1)
                valid_mask = (labels != student_pad_id)
                all_preds_eval.extend(preds[valid_mask].cpu().tolist())
                all_labels_eval.extend(labels[valid_mask].cpu().tolist())

            except Exception as e:
                logger.error(f"Error during student evaluation batch {batch_idx}: {e}", exc_info=True)
                continue # Skip batch on error

    # --- Calculate final metrics (Keep as before) ---
    if total_batches > 0:
        avg_loss = total_loss / total_batches
        perplexity = np.exp(avg_loss) if not np.isnan(avg_loss) and not np.isinf(avg_loss) else float('inf')
    else:
        avg_loss = float('inf')
        perplexity = float('inf')
        logger.warning("No valid batches processed during student evaluation.")

    if all_labels_eval:
        accuracy = accuracy_score(all_labels_eval, all_preds_eval)
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_labels_eval, all_preds_eval, average='weighted', zero_division=0
        )
    else:
        accuracy, precision, recall, f1 = 0.0, 0.0, 0.0, 0.0
        logger.warning("No valid (non-padding) labels found during student evaluation to calculate metrics.")

    results = {
        "loss": avg_loss if not np.isinf(avg_loss) else float('inf'), # Ensure inf is returned if avg_loss is inf
        "perplexity": perplexity,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

    logger.info("========== Student Model Evaluation Results ==========")
    logger.info(f"  Loss       : {results['loss']:.4f}")
    logger.info(f"  Perplexity : {results['perplexity']:.4f}")
    logger.info(f"  Accuracy   : {results['accuracy']:.4f}")
    logger.info(f"  Precision  : {results['precision']:.4f}")
    logger.info(f"  Recall     : {results['recall']:.4f}")
    logger.info(f"  F1 Score   : {results['f1']:.4f}")
    logger.info("====================================================")

    return results


In [20]:
def plot_training_history(metrics_history, config):
    """Plot training and validation metrics history"""
    epochs = [m['epoch'] for m in metrics_history]
    
    # Create directory for plots
    plots_dir = os.path.join(config.output_dir, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Plot 1: Train vs Val Accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, [m['train_accuracy'] for m in metrics_history], 'b-', label='Train Accuracy')
    plt.plot(epochs, [m['val_accuracy'] for m in metrics_history], 'r-', label='Val Accuracy')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, 'accuracy.png'))
    
    # Plot 2: Train Loss and Reward
    fig, ax1 = plt.subplots(figsize=(10, 6))
    
    # Loss on left y-axis
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss', color='tab:red')
    ax1.plot(epochs, [m['train_loss'] for m in metrics_history], 'tab:red', label='Train Loss')
    ax1.tick_params(axis='y', labelcolor='tab:red')
    
    # Reward on right y-axis
    ax2 = ax1.twinx()
    ax2.set_ylabel('Reward', color='tab:blue')
    ax2.plot(epochs, [m['train_reward'] for m in metrics_history], 'tab:blue', label='Train Reward')
    ax2.tick_params(axis='y', labelcolor='tab:blue')
    
    fig.tight_layout()
    plt.title('Training Loss and Reward over Epochs')
    plt.savefig(os.path.join(plots_dir, 'loss_reward.png'))
    
    # Plot 3: Validation Metrics
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, [m['val_accuracy'] for m in metrics_history], 'g-', label='Accuracy')
    plt.plot(epochs, [m['val_precision'] for m in metrics_history], 'b-', label='Precision')
    plt.plot(epochs, [m['val_recall'] for m in metrics_history], 'r-', label='Recall')
    plt.plot(epochs, [m['val_f1'] for m in metrics_history], 'y-', label='F1')
    plt.title('Validation Metrics over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, 'val_metrics.png'))
    
    # Save metrics history as CSV
    pd.DataFrame(metrics_history).to_csv(os.path.join(config.output_dir, 'metrics_history.csv'), index=False)
    
    logger.info(f"Training history plots saved to {plots_dir}")


In [26]:
# Cell 233: main function (Rewritten - Fix All Identified Errors)

# --- Imports ---
# Ensure all necessary libraries and your custom modules/classes are imported
import logging
import torch
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup, AutoConfig, AutoModelForMaskedLM # Added Auto* imports if needed here
import os
import numpy as np
import traceback # For more detailed error logging if needed

# Import your classes/functions if they are not auto-imported
# (Assuming they are defined in previous cells or imported correctly)
# E.g.: from your_module import Config, DataProcessor, TeacherModel, StudentModel, train_teacher, evaluate_teacher, train_rl, evaluate_student, plot_training_history, NextWordPredictionDataset, FewShotDataset, validate_dataset

# --- Logger Setup ---
# Make sure logger is configured correctly in an earlier cell or here
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )

# --- Main Function Definition ---
def main():
    """
    Main function to orchestrate the teacher-student training pipeline.
    """
    config = Config()
    logger.info(f"Configuration loaded. Using device: {config.device}")
    logger.info(f"Teacher model: {config.teacher_model_name}, Student model: {config.student_model_name}")
    logger.info(f"Max length: {config.max_length}, Batch size: {config.batch_size}, Accumulation: {config.gradient_accumulation_steps}")

    # --- 1. Initialize Data Processor ---
    logger.info("Initializing DataProcessor...")
    try:
        # Ensure DataProcessor class is defined and handles tokenizer loading
        data_processor = DataProcessor(config)
        teacher_tokenizer = data_processor.teacher_tokenizer
        student_tokenizer = data_processor.student_tokenizer
        logger.info(f"Teacher vocab size: {data_processor.teacher_vocab_size}, Student vocab size: {data_processor.student_vocab_size}")
    except Exception as e:
        logger.error(f"Failed to initialize DataProcessor: {e}", exc_info=True)
        raise

    # --- 2. Initialize Teacher Model ON CPU FIRST ---
    logger.info("Initializing teacher model on CPU...")
    try:
        # Ensure TeacherModel class is defined
        # --- FIX: Pass correct arguments instead of Ellipsis ---
        teacher_model = TeacherModel(
            config.teacher_model_name,
            tokenizer_vocab_size=data_processor.teacher_vocab_size
        )
        # -----------------------------------------------------
        logger.info(f"Teacher model initialized on CPU. Model config vocab size: {teacher_model.config.vocab_size}")
        # --- Optional CPU Test ---
        # logger.info("Performing quick CPU forward pass test...")
        # ... (CPU test logic) ...
    except Exception as e:
        logger.error(f"Failed to initialize teacher model: {e}", exc_info=True)
        raise

    # --- 3. Move Teacher Model to GPU ---
    logger.info(f"Attempting to move teacher model to device: {config.device}...")
    try:
        teacher_model.to(config.device)
        actual_device = next(teacher_model.parameters()).device
        logger.info(f"Teacher model successfully moved. Device check: {actual_device}")
        if actual_device != config.device:
             logger.warning(f"Teacher model parameters are on {actual_device}, but config device is {config.device}.")
    except Exception as e:
        logger.error(f"Failed to move teacher model to {config.device}: {e}", exc_info=True)
        raise RuntimeError(f"Cannot proceed without moving teacher model to {config.device}")

    # --- 4. Load Datasets ---
    logger.info("Loading datasets...")
    try:
        hindi_samples, chhattisgarhi_samples = data_processor.load_dataset()
        if not hindi_samples or not chhattisgarhi_samples:
             raise ValueError("Dataset loading returned empty lists.")
        logger.info(f"Loaded {len(hindi_samples)} Hindi and {len(chhattisgarhi_samples)} Chhattisgarhi samples.")
    except Exception as e:
        logger.error(f"Failed to load datasets: {e}", exc_info=True)
        raise

    # --- 5. Teacher Data Setup & Validation ---
    logger.info("Creating teacher (Hindi) dataset and dataloaders...")
    try:
        # Ensure NextWordPredictionDataset class is defined
        hindi_dataset = NextWordPredictionDataset(hindi_samples, teacher_tokenizer, max_length=config.max_length)
        logger.info(f"Hindi dataset created with {len(hindi_dataset)} examples after filtering.")
        if len(hindi_dataset) == 0: raise ValueError("Hindi NextWordPredictionDataset is empty!")

        train_size_h = int((1.0 - config.train_test_split) * len(hindi_dataset))
        val_size_h = len(hindi_dataset) - train_size_h
        if train_size_h <= 0 or val_size_h <= 0: raise ValueError("Hindi train/val split resulted in empty dataset.")
        hindi_train_dataset, hindi_val_dataset = random_split(hindi_dataset, [train_size_h, val_size_h], generator=torch.Generator().manual_seed(42))
        logger.info(f"Hindi dataset sizes: Train={len(hindi_train_dataset)}, Val={len(hindi_val_dataset)}")

        # Ensure validate_dataset function is defined
        if not validate_dataset(hindi_train_dataset, teacher_tokenizer, name="Hindi Train"): raise ValueError("Validation failed for Hindi training data.")
        if not validate_dataset(hindi_val_dataset, teacher_tokenizer, name="Hindi Val"): raise ValueError("Validation failed for Hindi validation data.")

        # DataLoader instantiation with correct keyword arguments
        hindi_train_loader = DataLoader(
            hindi_train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=True if config.device.type == 'cuda' else False
        )
        hindi_val_loader = DataLoader(
            hindi_val_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=True if config.device.type == 'cuda' else False
        )
        logger.info("Hindi DataLoaders created.")
    except Exception as e:
        logger.error(f"Error setting up Hindi data pipeline: {e}", exc_info=True)
        raise

    # --- 6. Fine-tune Teacher Model ---
    logger.info("Starting teacher model fine-tuning...")
    try:
        # Ensure train_teacher is defined (Cell 228) and run
        if 'train_teacher' not in globals(): raise NameError("Function 'train_teacher' is not defined.")
        teacher_model = train_teacher(teacher_model, hindi_train_loader, hindi_val_loader, config, data_processor)
        logger.info("Teacher model fine-tuning complete.")
    except NameError as ne:
         logger.error(f"NameError during teacher training: {ne}. Have you run the cell defining 'train_teacher'?")
         raise ne
    except Exception as e:
        logger.error(f"Error during teacher model training: {e}", exc_info=True)
        raise # Stop if teacher training fails

    # --- 7. Student Data Setup & Validation ---
    logger.info("Creating student (Chhattisgarhi) dataset and dataloaders...")
    try:
        # Ensure FewShotDataset class is defined
        chhattisgarhi_dataset = FewShotDataset(chhattisgarhi_samples, student_tokenizer, few_shot_examples=config.few_shot_examples, max_length=config.max_length)
        logger.info(f"Chhattisgarhi dataset created with {len(chhattisgarhi_dataset)} examples after filtering.")
        if len(chhattisgarhi_dataset) == 0: raise ValueError("Chhattisgarhi FewShotDataset is empty!")

        train_size_c = int((1.0 - config.train_test_split) * len(chhattisgarhi_dataset))
        val_size_c = len(chhattisgarhi_dataset) - train_size_c
        if train_size_c <= 0 or val_size_c <= 0: raise ValueError("Chhattisgarhi train/val split resulted in empty dataset(s).")
        chhattisgarhi_train_dataset, chhattisgarhi_val_dataset = random_split(chhattisgarhi_dataset, [train_size_c, val_size_c], generator=torch.Generator().manual_seed(42))
        logger.info(f"Chhattisgarhi dataset sizes: Train={len(chhattisgarhi_train_dataset)}, Val={len(chhattisgarhi_val_dataset)}")

        # Validate datasets
        if not validate_dataset(chhattisgarhi_train_dataset, student_tokenizer, name="Chhattisgarhi Train"): raise ValueError("Validation failed for Chhattisgarhi training data.")
        if not validate_dataset(chhattisgarhi_val_dataset, student_tokenizer, name="Chhattisgarhi Val"): raise ValueError("Validation failed for Chhattisgarhi validation data.")

        # --- Create Chhattisgarhi DataLoaders FIRST ---
        chhattisgarhi_train_loader = DataLoader(
            chhattisgarhi_train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=True if config.device.type == 'cuda' else False
        )
        chhattisgarhi_val_loader = DataLoader(
            chhattisgarhi_val_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=True if config.device.type == 'cuda' else False
        )
        logger.info("Chhattisgarhi DataLoaders created.")

    except Exception as e:
        logger.error(f"Error setting up Chhattisgarhi data pipeline: {e}", exc_info=True)
        raise

    # --- Student Data Sanity Check (Moved After DataLoader Creation) ---
    logger.info("Sanity checking Chhattisgarhi training data labels...")
    try:
        student_pad_id_check = data_processor.student_tokenizer.pad_token_id
        if student_pad_id_check is None: student_pad_id_check = -100

        num_batches_to_check = min(5, len(chhattisgarhi_train_loader))
        if num_batches_to_check == 0:
            logger.warning("Student train loader is empty, skipping sanity check.")
        else:
            total_valid_labels = 0
            unique_labels_found = set()
            for i, batch in enumerate(chhattisgarhi_train_loader):
                if i >= num_batches_to_check: break
                if 'labels' not in batch:
                     logger.warning(f"Batch {i} missing 'labels' key during sanity check.")
                     continue
                labels = batch['labels']
                valid_mask = labels != student_pad_id_check
                valid_labels_in_batch = labels[valid_mask]
                total_valid_labels += len(valid_labels_in_batch)
                unique_in_batch, _ = torch.unique(valid_labels_in_batch, return_counts=True)
                unique_labels_found.update(unique_in_batch.tolist())

            if total_valid_labels == 0:
                logger.error("CRITICAL: No valid (non-padding) labels found in the first few batches of student training data! Check dataset creation.")
            else:
                 logger.info(f"Checked {num_batches_to_check} batches: Found {total_valid_labels} valid labels, {len(unique_labels_found)} unique label IDs.")
    except Exception as sanity_e:
        logger.error(f"Error during student data sanity check: {sanity_e}", exc_info=True)
    # --- END DATA SANITY CHECK ---


    # --- 8. Student Model Init, Resize, GPU Move ---
    logger.info(f"Initializing student model: {config.student_model_name} on CPU...")
    try:
        # Ensure StudentModel class is defined
        # --- Ensure no Ellipsis here either ---
        student_model = StudentModel(config.student_model_name)
        # ------------------------------------

        # Resize student model token embeddings *on CPU* if necessary
        student_tokenizer_len = len(data_processor.student_tokenizer)
        if student_model.model.config.vocab_size != student_tokenizer_len:
            logger.warning(f"Resizing student model token embeddings from {student_model.model.config.vocab_size} to {student_tokenizer_len}")
            student_model.model.resize_token_embeddings(student_tokenizer_len)
            # IMPORTANT: Update vocab_size attribute and re-initialize head layer if StudentModel class doesn't handle this internally
            if hasattr(student_model, 'vocab_size'):
                 student_model.vocab_size = student_model.model.config.vocab_size
            if hasattr(student_model, 'next_token_head'):
                 student_model.next_token_head = nn.Linear(student_model.hidden_size, student_model.model.config.vocab_size)
                 logger.info("Re-initialized student model head layer after resizing.")
            else:
                 logger.warning("StudentModel does not have 'vocab_size' or 'next_token_head' attributes to update after resize.")
            logger.info(f"Resized student model embeddings. New config vocab size: {student_model.model.config.vocab_size}")
        else:
             logger.info("Student model vocab size already matches tokenizer.")

        # Move student model to GPU
        student_model.to(config.device)
        actual_device_student = next(student_model.parameters()).device
        logger.info(f"Student model initialized and moved to device. Device check: {actual_device_student}")
        if actual_device_student != config.device:
             logger.warning(f"Student model parameters are on {actual_device_student}, expected {config.device}.")

    except Exception as e:
        logger.error(f"Failed to initialize or move student model: {e}", exc_info=True)
        raise

    # --- 9. KD+CE Training Setup ---
    logger.info("Setting up student optimizer and scheduler...")
    try:
        # Ensure AdamW and get_linear_schedule_with_warmup are imported
        optimizer = AdamW(student_model.parameters(), lr=config.rl_lr, weight_decay=config.weight_decay)
        # Robust calculation for total steps
        if len(chhattisgarhi_train_loader) == 0:
             logger.warning("Student train loader is empty. Setting RL total steps based on epochs only.")
             rl_total_steps = config.distillation_epochs
             num_rl_update_steps_per_epoch = 0
        else:
             num_rl_update_steps_per_epoch = max(1, len(chhattisgarhi_train_loader) // config.gradient_accumulation_steps)
             rl_total_steps = num_rl_update_steps_per_epoch * config.distillation_epochs

        rl_num_warmup_steps = min(config.warmup_steps, rl_total_steps // 10) if rl_total_steps > 0 else 0
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=rl_num_warmup_steps, num_training_steps=max(1, rl_total_steps)) # Ensure num_training_steps >= 1
        logger.info(f"Student (KD+CE) training setup: {rl_total_steps} total steps ({num_rl_update_steps_per_epoch} steps/epoch), {rl_num_warmup_steps} warmup steps.")
    except Exception as e:
        logger.error(f"Failed to setup student optimizer/scheduler: {e}", exc_info=True)
        raise

    # --- 10. KD+CE Training Loop ---
    logger.info("Starting student model training with KD + CE...")
    metrics_history = []
    best_val_accuracy = -1.0
    epochs_without_improvement = 0

    # Check if train loader has steps
    if num_rl_update_steps_per_epoch == 0 and config.distillation_epochs > 0:
         logger.error("Student training cannot proceed: train loader is empty or too small for gradient accumulation.")
    else:
        for epoch in range(config.distillation_epochs):
            logger.info(f"--- Starting Student Epoch {epoch+1}/{config.distillation_epochs} ---")
            try:
                # --- Ensure Cells 230 (train_rl) & 231 (evaluate_student) have been run ---
                if 'train_rl' not in globals(): raise NameError("Function 'train_rl' is not defined.")
                if 'evaluate_student' not in globals(): raise NameError("Function 'evaluate_student' is not defined.")

                # Call training step
                avg_loss, train_acc, train_prec, train_rec, train_f1 = train_rl(
                    teacher_model, student_model, chhattisgarhi_train_loader, optimizer, scheduler, config, data_processor
                )
                logger.info(f"Epoch {epoch+1} Train Stats: Loss={avg_loss:.4f}, Acc={train_acc:.4f}, F1={train_f1:.4f}")

                # Call evaluation step
                eval_metrics = evaluate_student(student_model, chhattisgarhi_val_loader, config, data_processor)

                # Collect metrics
                current_metrics = {
                    'epoch': epoch + 1,
                    'train_loss': avg_loss, 'train_accuracy': train_acc, 'train_precision': train_prec, 'train_recall': train_rec, 'train_f1': train_f1,
                    'val_loss': eval_metrics.get('loss', float('nan')), 'val_perplexity': eval_metrics.get('perplexity', float('nan')),
                    'val_accuracy': eval_metrics.get('accuracy', float('nan')), 'val_precision': eval_metrics.get('precision', float('nan')),
                    'val_recall': eval_metrics.get('recall', float('nan')), 'val_f1': eval_metrics.get('f1', float('nan'))
                }
                metrics_history.append(current_metrics)
                logger.info(f"Epoch {epoch+1} Val Stats: Loss={current_metrics['val_loss']:.4f}, Acc={current_metrics['val_accuracy']:.4f}, F1={current_metrics['val_f1']:.4f}")

                # --- Checkpointing & Early Stopping ---
                current_val_acc = current_metrics['val_accuracy']
                is_current_acc_valid = not (current_val_acc is None or np.isnan(current_val_acc))

                if is_current_acc_valid and current_val_acc > best_val_accuracy:
                    best_val_accuracy = current_val_acc
                    epochs_without_improvement = 0
                    best_model_path = os.path.join(config.model_dir, "best_student_model.pt")
                    torch.save(student_model.state_dict(), best_model_path)
                    logger.info(f"*** New best student model saved with validation accuracy: {best_val_accuracy:.4f} at epoch {epoch+1} ***")
                elif is_current_acc_valid:
                    epochs_without_improvement += 1
                    logger.info(f"Validation accuracy ({current_val_acc:.4f}) did not improve for {epochs_without_improvement} epoch(s). Best accuracy: {best_val_accuracy:.4f}")
                else:
                    epochs_without_improvement += 1
                    logger.warning(f"Validation accuracy is invalid (NaN/None) at epoch {epoch+1}. Treating as no improvement ({epochs_without_improvement}).")

                if epochs_without_improvement >= config.early_stopping_patience:
                    logger.info(f"Early stopping triggered after epoch {epoch+1} due to no improvement for {config.early_stopping_patience} epochs.")
                    break
                # --- End Checkpointing ---

            except NameError as ne:
                 logger.error(f"NameError during student training epoch {epoch+1}: {ne}")
                 logger.error("!!! Please ensure the cells defining 'train_rl' (Cell 230) and 'evaluate_student' (Cell 231) have been executed before running this cell. !!!")
                 raise ne # Stop execution
            except Exception as e:
                logger.error(f"Error during student training epoch {epoch+1}: {e}", exc_info=True)
                logger.warning("Error occurred, breaking student training loop.")
                break # Break loop on error

    # --- 11. Final Steps ---
    logger.info("Student training loop finished.")

    # Plotting history
    if metrics_history:
        logger.info("Plotting training history...")
        try:
            # Ensure plot_training_history function is defined
            if 'plot_training_history' in globals():
                plot_training_history(metrics_history, config)
            else:
                logger.warning("Function 'plot_training_history' not defined. Skipping plotting.")
        except Exception as plot_e:
            logger.error(f"Failed to plot training history: {plot_e}")
    else:
        logger.warning("No metrics history recorded, skipping plotting.")

    # Final Evaluation
    logger.info("Performing final evaluation...")
    final_metrics = {}
    if 'evaluate_student' not in globals():
         logger.error("Cannot perform final evaluation: 'evaluate_student' not defined.")
    else:
        best_student_model_path = os.path.join(config.model_dir, "best_student_model.pt")
        if os.path.exists(best_student_model_path):
            logger.info(f"Loading best student model from {best_student_model_path} for final evaluation...")
            try:
                map_location = config.device
                # Re-initialize a student model instance on CPU first for loading, then move
                final_student_model = StudentModel(config.student_model_name)
                # Handle potential resize mismatch if config changed
                # ... (resize logic similar to above if needed) ...
                final_student_model.load_state_dict(torch.load(best_student_model_path, map_location=torch.device('cpu'))) # Load to CPU first
                final_student_model.to(config.device) # Move to target device
                logger.info("Successfully loaded best student model.")
                final_metrics = evaluate_student(final_student_model, chhattisgarhi_val_loader, config, data_processor)
            except Exception as e:
                 logger.error(f"Failed to load or evaluate best student model: {e}", exc_info=True)
                 logger.warning("Evaluating model from last trained epoch state instead.")
                 final_metrics = evaluate_student(student_model, chhattisgarhi_val_loader, config, data_processor) # Use model in memory
        else:
            logger.warning("Best student model checkpoint not found. Evaluating model from last trained epoch state.")
            final_metrics = evaluate_student(student_model, chhattisgarhi_val_loader, config, data_processor)


    # Final logging
    logger.info("====================== Final Results ======================")
    logger.info(f"Best Student Validation Accuracy Achieved during training: {best_val_accuracy:.4f}")
    logger.info("Final Metrics (using best model if loaded, else last epoch):")
    if final_metrics:
         logger.info(f"  loss: {final_metrics.get('loss', float('nan')):.4f}")
         logger.info(f"  perplexity: {final_metrics.get('perplexity', float('nan')):.4f}")
         logger.info(f"  accuracy: {final_metrics.get('accuracy', float('nan')):.4f}")
         logger.info(f"  precision: {final_metrics.get('precision', float('nan')):.4f}")
         logger.info(f"  recall: {final_metrics.get('recall', float('nan')):.4f}")
         logger.info(f"  f1: {final_metrics.get('f1', float('nan')):.4f}")
    else:
         logger.warning("  Final metrics unavailable.")
    logger.info("===========================================================")
    logger.info("Training pipeline officially finished.")


    return teacher_model, student_model, metrics_history, final_metrics


# --- Execution Block (Cell 49) ---
# This block should ideally be separate from the main function definition
# if __name__ == "__main__": # Use this if running as a script
# Ensure logger is configured before this runs
# if 'main' in globals(): # Check if main is defined before calling in notebook context
#     logger.info(f"Starting main execution block. CUDA_LAUNCH_BLOCKING is set to: {os.environ.get('CUDA_LAUNCH_BLOCKING', 'Not Set')}")
#     try:
#         teacher_model, student_model, metrics_history, final_metrics = main()
#         logger.info("Main execution completed successfully.")
#     except Exception as e:
#         logger.error(f"An error occurred during the main execution: {e}", exc_info=True)
# else:
#     logger.error("The 'main' function is not defined. Please execute Cell 233 first.")



In [28]:
# Cell 18: Execution Block (Add CUDA_LAUNCH_BLOCKING info)
if __name__ == "__main__":
    logger.info(f"CUDA_LAUNCH_BLOCKING is set to: {os.environ.get('CUDA_LAUNCH_BLOCKING', 'Not Set')}")
    try:
        # ... (wandb init optional) ...
        teacher_model, student_model, metrics_history, final_metrics = main()
        # ... (Log final results) ...
        logger.info("====================== Final Results ======================")
        if metrics_history:
             best_acc = max(m['val_accuracy'] for m in metrics_history if not np.isnan(m['val_accuracy'])) if any(not np.isnan(m['val_accuracy']) for m in metrics_history) else float('nan')
             logger.info(f"Best Validation Accuracy Achieved: {best_acc:.4f}")
        logger.info(f"Final Metrics (using best model):")
        for k, v in final_metrics.items():
             logger.info(f"  {k}: {v:.4f}")
        logger.info("===========================================================")

    except Exception as e:
        logger.error(f"An error occurred during the main execution: {e}", exc_info=True)
        # ... (wandb finish optional) ...





2025-04-14 17:34:14,203 - __main__ - INFO - CUDA_LAUNCH_BLOCKING is set to: 1
2025-04-14 17:34:14,206 - __main__ - INFO - Configuration loaded. Using device: cuda
2025-04-14 17:34:14,206 - __main__ - INFO - Teacher model: ai4bharat/indic-bert, Student model: distilbert-base-multilingual-cased
2025-04-14 17:34:14,207 - __main__ - INFO - Max length: 64, Batch size: 32, Accumulation: 8
2025-04-14 17:34:14,208 - __main__ - INFO - Initializing DataProcessor...
2025-04-14 17:34:15,980 - __main__ - ERROR - Failed to initialize DataProcessor: 'DataProcessor' object has no attribute 'student_vocab_size'
Traceback (most recent call last):
  File "C:\Users\pc\AppData\Local\Temp\ipykernel_9168\548841237.py", line 45, in main
    logger.info(f"Teacher vocab size: {data_processor.teacher_vocab_size}, Student vocab size: {data_processor.student_vocab_size}")
                                                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError