In [None]:
!pip install --upgrade pip

Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.0.1


In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
Downloading fsspec-2024.12.0-py3-none-any.whl (183 kB)
Downloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
Installing collected packages: xxhash, fsspec, dill, multiprocess, datasets
  Attemptin

In [None]:
!pip install transformers



In [None]:
# Cell 1: Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import numpy as np
import os
from tqdm import tqdm
import logging
from contextlib import nullcontext

In [None]:
# Cell 2: Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("training.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)


In [None]:
class Config:
    def __init__(self):
        # Model parameters
        self.teacher_model_name = "ai4bharat/indic-bert"
        self.student_model_name = "gpt2"
        self.max_length = 64
        self.batch_size = 8  # Increased slightly
        self.learning_rate = 1e-5  # Reduced
        self.weight_decay = 0.01
        self.epochs = 2 # Increased for more training time
        self.warmup_steps = 1000

        # Fixed: Adding gradient_accumulation_steps attribute
        self.gradient_accumulation_steps = 8  # Increased

        # Check if CUDA is available, if not use CPU
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Language parameters
        self.teacher_lang = "hin_Deva"
        self.student_lang = "hne_Deva"

        # RL parameters
        self.gamma = 0.99
        self.entropy_coef = 0.05  # Increased to encourage exploration
        self.rl_lr = 1e-5  # Reduced
        self.teacher_weight = 0.5  # Balanced influence

        # Dataset
        self.use_synthetic_data = False  # Changed to use real data

        # Paths
        self.output_dir = "output/"

        # Evaluation
        self.eval_every = 50  # More frequent evaluation
        self.save_every = 500


In [None]:
class TeacherModel(nn.Module):
    def __init__(self, model_name):
        super(TeacherModel, self).__init__()
        # Load with lower precision to save memory
        self.bert_model = AutoModelForMaskedLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )

        # Add an adapter layer to convert masked LM outputs to next token prediction
        self.vocab_size = self.bert_model.config.vocab_size
        self.hidden_size = self.bert_model.config.hidden_size

        # Linear layer to adapt masked LM to next token prediction
        self.next_token_adapter = nn.Linear(self.hidden_size, self.vocab_size)

    def forward(self, input_ids, attention_mask=None):
        # Get hidden states from the masked LM with memory efficient settings
        with torch.amp.autocast('cuda') if torch.cuda.is_available() else nullcontext():
            outputs = self.bert_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )

        # Use the last hidden state
        last_hidden_state = outputs.hidden_states[-1]

        # Predict the next token
        next_token_logits = self.next_token_adapter(last_hidden_state)

        return next_token_logits


In [None]:
class StudentModel(nn.Module):
    def __init__(self, model_name):
        super(StudentModel, self).__init__()
        self.model = AutoModelForCausalLM.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits

   # In the StudentModel.get_action_and_value method:
# We need to make sure log_probs has the right shape
    def get_action_and_value(self, input_ids, attention_mask=None):
        # Get logits for the entire sequence
        logits = self.forward(input_ids, attention_mask)

        # Extract only the last token position logits for next token prediction
        # Shape: [batch_size, vocab_size]
        last_token_logits = logits[:, -1, :]

        # Get probabilities
        probs = torch.softmax(last_token_logits, dim=-1)

        # Sample from the distribution
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()  # Shape: [batch_size]

        # Get log probability of the action
        log_prob = dist.log_prob(action)  # Shape: [batch_size]

        # Calculate entropy for exploration encouragement
        entropy = dist.entropy().mean()  # Scalar

        return action, log_prob, entropy, probs

In [None]:
def compute_reward(teacher_probs, student_action, target_tokens, alpha=0.7):
    """
    Compute reward for reinforcement learning

    Args:
        teacher_probs: Tensor of shape [batch_size, vocab_size]
        student_action: Tensor of shape [batch_size]
        target_tokens: Tensor of shape [batch_size, seq_len] - we'll use the first token of each sequence
        alpha: Weight for teacher confidence

    Returns:
        reward: Tensor of shape [batch_size]
    """
    batch_size = teacher_probs.shape[0]
    reward = torch.zeros(batch_size, device=teacher_probs.device)

    # Get the first token of each target sequence
    # For next token prediction, we want to compare with the first token of the target
    # (which corresponds to the token after the last input token)
    target = target_tokens[:, 0] if target_tokens.dim() > 1 else target_tokens

    for i in range(batch_size):
        # Check if student's action matches the target
        if student_action[i].item() == target[i].item():
            reward[i] += 1.0

        # Add smaller reward based on teacher probability of student's action
        token_idx = student_action[i].item()
        if token_idx < teacher_probs.shape[1]:  # Ensure index is valid
            teacher_confidence = teacher_probs[i, token_idx]
            reward[i] += alpha * teacher_confidence

    return reward  # Shape: [batch_size]

In [None]:
# Cell 6: Reward function
def compute_reward(teacher_probs, student_action, target_tokens, alpha=0.7):
    """Improved reward function with better learning signals"""
    batch_size = teacher_probs.shape[0]
    reward = torch.zeros(batch_size, device=teacher_probs.device)

    for i in range(batch_size):
        # Higher reward for matching the target (5.0 instead of 1.0)
        if student_action[i].item() == target_tokens[i].item():
            reward[i] += 5.0
        else:
            # Small penalty for wrong answers to speed up learning
            reward[i] -= 0.1

            # Add partial reward for being close (teacher had high probability for the correct token)
            if target_tokens[i].item() < teacher_probs.shape[2]:
                teacher_confidence_for_correct = teacher_probs[i, -1, target_tokens[i].item()]
                reward[i] += alpha * teacher_confidence_for_correct

        # Add smaller reward based on teacher probability of student's action
        token_idx = student_action[i].item()
        if token_idx < teacher_probs.shape[2]:  # Ensure index is valid
            teacher_confidence = teacher_probs[i, -1, token_idx]
            reward[i] += (alpha * 0.5) * teacher_confidence  # Reduced influence

    return reward

# Cell 7: Data Processor class

import json
class DataProcessor:
    def __init__(self, config):
        self.config = config
        self.teacher_tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_name)
        self.student_tokenizer = AutoTokenizer.from_pretrained(config.student_model_name)

        # Fix for the teacher tokenizer (BERT-based)
        if self.teacher_tokenizer.pad_token is None:
            self.teacher_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        # Fix for the student tokenizer (GPT-2)
        if self.student_tokenizer.pad_token is None:
            self.student_tokenizer.pad_token = self.student_tokenizer.eos_token

    def load_dataset(self):
        try:
            # Load dataset directly with language pair specification
            dataset = load_dataset("allenai/nllb", "hin_Deva-hne_Deva")

            # Split the dataset into train and validation
            train_val = dataset["train"].train_test_split(test_size=0.1, seed=42)
            train_data = train_val["train"]

            hindi_samples = []
            chhattisgarhi_samples = []

            for item in train_data:
                hindi_samples.append({
                    "lang": self.config.teacher_lang,
                    "text": item['translation']['hin_Deva']
                })
                chhattisgarhi_samples.append({
                    "lang": self.config.student_lang,
                    "text": item['translation']['hne_Deva']
                })

            logger.info(f"Loaded {len(hindi_samples)} Hindi samples")
            logger.info(f"Loaded {len(chhattisgarhi_samples)} Chhattisgarhi samples")

            return hindi_samples, chhattisgarhi_samples

        except Exception as e:
            logger.error(f"Error loading NLLB dataset: {e}")
            raise





    def _create_synthetic_data(self):
        # Your existing synthetic data generation code
        hindi_samples = [
            {"lang": "hin_Deva", "text": "नमस्ते, आप कैसे हैं?"},
            {"lang": "hin_Deva", "text": "आपका नाम क्या है?"},
            # ... other samples
        ] * 50  # Repeat to create more samples

        # Create Chhattisgarhi samples by transforming Hindi samples
        chhattisgarhi_words = {
            "है": "हे",
            "मैं": "मय",
            # ... other word mappings
        }

        chhattisgarhi_samples = []
        for sample in hindi_samples:
            text = sample["text"]
            for hindi_word, chhattisgarhi_word in chhattisgarhi_words.items():
                text = text.replace(hindi_word, chhattisgarhi_word)
            chhattisgarhi_samples.append({
                "lang": self.config.student_lang,
                "text": text
            })

    #     logger.info(f"Created {len(hindi_samples)} Hindi samples")
    #     logger.info(f"Created {len(chhattisgarhi_samples)} Chhattisgarhi samples")

    #     return hindi_samples, chhattisgarhi_samples


In [None]:
# Cell 8: Dataset for next word prediction
class NextWordPredictionDataset(Dataset):
    def __init__(self, samples, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Ensure pad token is set before using the tokenizer
        if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'):
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Process samples
        self.examples = []

        for sample in tqdm(samples, desc="Processing samples"):
            # Use text for next word prediction
            text = sample["text"]

            if len(text) > 10:  # Filter out very short texts
                self.examples.append(text)

        logger.info(f"Created dataset with {len(self.examples)} examples")

    def __len__(self):
        # Return the number of examples
        return len(self.examples)

    def __getitem__(self, idx):
        if idx >= len(self.examples):
            raise IndexError(f"Index {idx} out of bounds for dataset with {len(self.examples)} examples")

        text = self.examples[idx]

        # Tokenize for next word prediction
        encoding = self.tokenizer(text, max_length=self.max_length, padding="max_length",
                                 truncation=True, return_tensors="pt")

        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        # For next word prediction: inputs are all tokens except last, labels are all tokens except first
        if input_ids.dim() == 0 or input_ids.size(0) <= 1:  # Handle very short sequences
            # Create a minimal valid sample
            if hasattr(self.tokenizer, 'cls_token_id') and self.tokenizer.cls_token_id is not None:
                input_ids = torch.tensor([self.tokenizer.cls_token_id, self.tokenizer.eos_token_id], dtype=torch.long)
            else:
                # For GPT-2 which doesn't have cls_token
                input_ids = torch.tensor([self.tokenizer.bos_token_id if self.tokenizer.bos_token_id is not None
                                         else self.tokenizer.eos_token_id,
                                         self.tokenizer.eos_token_id], dtype=torch.long)
            attention_mask = torch.ones(2, dtype=torch.long)

        # Make sure we have at least 2 tokens
        if input_ids.size(0) <= 1:
            input_ids = torch.cat([input_ids, torch.tensor([self.tokenizer.pad_token_id])])
            attention_mask = torch.cat([attention_mask, torch.tensor([0])])

        return {
            "input_ids": input_ids[:-1],
            "attention_mask": attention_mask[:-1],
            "labels": input_ids[1:]
        }

In [None]:
def train_rl(teacher_model, student_model, train_loader, optimizer, scheduler, config):
    teacher_model.eval()  # Teacher model is frozen
    student_model.train()

    epoch_rewards = []
    epoch_losses = []

    # For tracking metrics
    all_predictions = []
    all_targets = []

    # Clear GPU cache if using CUDA
    if config.device.type == 'cuda':
        torch.cuda.empty_cache()

    # Initialize accumulated gradients
    optimizer.zero_grad()
    accumulated_loss = 0

    for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
        try:
            input_ids = batch["input_ids"].to(config.device)
            attention_mask = batch["attention_mask"].to(config.device)
            target_tokens = batch["labels"].to(config.device)

            # Skip batch if dimensions are incompatible
            if input_ids.shape[0] != attention_mask.shape[0] or input_ids.shape[0] != target_tokens.shape[0]:
                logger.warning(f"Skipping batch {batch_idx} due to dimension mismatch")
                continue

            # Skip very small batches
            if input_ids.size(0) < 2:
                logger.warning(f"Skipping batch {batch_idx} - batch too small")
                continue

            # Get teacher predictions - use no_grad to save memory
            with torch.no_grad():
                teacher_logits = teacher_model(input_ids, attention_mask)

                # Extract last token predictions for next token prediction
                if teacher_logits.dim() == 3:  # [batch, seq_len, vocab]
                    last_token_logits = teacher_logits[:, -1, :]
                else:
                    last_token_logits = teacher_logits  # Already the right shape

                teacher_probs = torch.softmax(last_token_logits, dim=-1)

                # Free memory
                del teacher_logits, last_token_logits

            # Get student model outputs
            logits = student_model(input_ids, attention_mask)

            # Extract last token predictions
            if logits.dim() == 3:  # [batch, seq_len, vocab]
                last_token_logits = logits[:, -1, :]
            else:
                last_token_logits = logits  # Already the right shape

            # Get probabilities
            probs = torch.softmax(last_token_logits, dim=-1)

            # Sample from the distribution
            dist = torch.distributions.Categorical(probs)
            actions = dist.sample()  # Shape: [batch_size]

            # Get log probability of the action
            log_probs = dist.log_prob(actions)  # Shape: [batch_size]

            # Calculate entropy
            entropy = dist.entropy().mean()  # Scalar

            # Free memory
            del logits, last_token_logits, probs, dist

            # Extract target for next token prediction
            if target_tokens.dim() > 1:
                next_tokens = target_tokens[:, 0]
            else:
                next_tokens = target_tokens

            # Compute reward - simplified version
            batch_size = actions.size(0)
            rewards = torch.zeros(batch_size, device=config.device)

            # Simple reward function: 1 for matching prediction, 0 otherwise
            for i in range(batch_size):
                if actions[i].item() == next_tokens[i].item():
                    rewards[i] = 1.0

                # Add teacher confidence as additional reward component
                if actions[i].item() < teacher_probs.shape[1]:  # Ensure index is valid
                    teacher_confidence = teacher_probs[i, actions[i].item()]
                    rewards[i] += config.teacher_weight * teacher_confidence

            # FIX: Ensure log_probs and rewards have compatible shapes
            # Both should be [batch_size] at this point
            assert log_probs.shape == rewards.shape, f"Shape mismatch: log_probs {log_probs.shape} vs rewards {rewards.shape}"

            # Compute RL loss
            loss = -(log_probs * rewards).mean() - config.entropy_coef * entropy

            # Normalize loss for gradient accumulation
            loss = loss / config.gradient_accumulation_steps

            # Backward pass
            loss.backward()
            accumulated_loss += loss.item()

            # Track metrics
            epoch_rewards.append(rewards.mean().item())
            epoch_losses.append(loss.item() * config.gradient_accumulation_steps)  # Re-scale for reporting

            # Track predictions for accuracy calculation
            all_predictions.extend(actions.detach().cpu().numpy())
            all_targets.extend(next_tokens.detach().cpu().numpy())

            # Free memory
            del input_ids, attention_mask, target_tokens, actions, log_probs, entropy, rewards, loss

            # Update parameters every gradient_accumulation_steps batches
            if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

                # More aggressive memory cleanup
                if config.device.type == 'cuda':
                    torch.cuda.empty_cache()

            # Evaluate periodically
            if (batch_idx + 1) % config.eval_every == 0:
                # Calculate accuracy
                if len(all_predictions) > 0 and len(all_targets) > 0:
                    # Use numpy for comparison
                    recent_preds = np.array(all_predictions[-len(all_predictions) % config.eval_every:])
                    recent_targets = np.array(all_targets[-len(all_targets) % config.eval_every:])
                    matching = recent_preds == recent_targets
                    accuracy = matching.mean() if len(matching) > 0 else 0.0
                else:
                    accuracy = 0.0

                logger.info(f"Batch {batch_idx+1}/{len(train_loader)}, Loss: {accumulated_loss:.4f}, "
                           f"Reward: {np.mean(epoch_rewards[-config.eval_every:]) if epoch_rewards else 0:.4f}, "
                           f"Accuracy: {accuracy:.4f}")

                accumulated_loss = 0

            # Save model periodically
            if (batch_idx + 1) % config.save_every == 0:
                os.makedirs(config.output_dir, exist_ok=True)
                torch.save(student_model.state_dict(), f"{config.output_dir}/student_model_step{batch_idx+1}.pt")
                logger.info(f"Model saved at step {batch_idx+1}")

        except Exception as e:
            logger.error(f"Error in batch {batch_idx}: {e}")
            import traceback
            logger.error(traceback.format_exc())
            # In case of error, try to clear memory
            if config.device.type == 'cuda':
                torch.cuda.empty_cache()
            continue

    # Perform final update for any remaining accumulated gradients
    if (len(train_loader) % config.gradient_accumulation_steps) != 0:
        optimizer.step()
        optimizer.zero_grad()

    # Convert lists to numpy arrays with safe handling of empty lists
    if all_predictions and all_targets:
        all_predictions = np.array(all_predictions)
        all_targets = np.array(all_targets)
    else:
        all_predictions = np.array([])
        all_targets = np.array([])

    return np.mean(epoch_rewards) if epoch_rewards else 0.0, np.mean(epoch_losses) if epoch_losses else 0.0, all_predictions, all_targets


In [None]:
# # Cell 9: Training function with RL
# def train_rl(teacher_model, student_model, train_loader, optimizer, scheduler, config):
#     teacher_model.eval()  # Teacher model is frozen
#     student_model.train()

#     epoch_rewards = []
#     epoch_losses = []

#     # For tracking metrics
#     all_predictions = []
#     all_targets = []

#     # Clear GPU cache if using CUDA
#     if config.device.type == 'cuda':
#         torch.cuda.empty_cache()

#     # Initialize accumulated gradients
#     optimizer.zero_grad()
#     accumulated_loss = 0

#     for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
#         try:
#             input_ids = batch["input_ids"].to(config.device)
#             attention_mask = batch["attention_mask"].to(config.device)
#             target_tokens = batch["labels"].to(config.device)

#             # Skip batch if dimensions are incompatible
#             if input_ids.shape[0] != attention_mask.shape[0] or input_ids.shape[0] != target_tokens.shape[0]:
#                 logger.warning(f"Skipping batch {batch_idx} due to dimension mismatch")
#                 continue

#             # Skip very small batches
#             if input_ids.size(0) < 2:
#                 logger.warning(f"Skipping batch {batch_idx} - batch too small")
#                 continue

#             # Get teacher predictions - use no_grad to save memory
#             # Get teacher predictions - use no_grad to save memory
#             with torch.no_grad():
#                 teacher_logits = teacher_model(input_ids, attention_mask)
#                 # Extract logits for the last position (to predict next token)
#                 teacher_last_logits = teacher_logits[:, -1, :]  # [batch_size, vocab_size]
#                 teacher_probs = torch.softmax(teacher_last_logits, dim=-1)
#                 del teacher_logits, teacher_last_logits  # Free memory


#             # Student makes predictions using RL approach
#             actions, log_probs, entropy, student_probs = student_model.get_action_and_value(input_ids, attention_mask)
#             # Free memory
#             del student_probs

#             # Compute rewards - shape: [batch_size]
#             rewards = compute_reward(teacher_probs, actions, target_tokens, config.teacher_weight)
#             del teacher_probs  # Free memory

#             # RL loss: negative log probability of action multiplied by reward
#             # Element-wise multiplication, then mean across batch
#             loss = -(log_probs * rewards).mean() - config.entropy_coef * entropy

#             # Normalize loss for gradient accumulation
#             loss = loss / config.gradient_accumulation_steps

#             # Backward pass
#             loss.backward()
#             accumulated_loss += loss.item()

#             # Track metrics
#             epoch_rewards.append(rewards.mean().item())
#             epoch_losses.append(loss.item() * config.gradient_accumulation_steps)  # Re-scale for reporting

#             # Track predictions for accuracy calculation
#             predictions = actions.detach().cpu().numpy()
#             targets = target_tokens.detach().cpu().numpy()
#             all_predictions.extend(predictions.flatten())
#             all_targets.extend(targets.flatten())

#             # Free memory
#             del input_ids, attention_mask, target_tokens, actions, log_probs, entropy, rewards, loss

#             # Update parameters every gradient_accumulation_steps batches
#             if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
#                 optimizer.step()
#                 scheduler.step()
#                 optimizer.zero_grad()

#                 # More aggressive memory cleanup
#                 if config.device.type == 'cuda':
#                     torch.cuda.empty_cache()

#             # Evaluate periodically
#             if (batch_idx + 1) % config.eval_every == 0:
#                 # Calculate accuracy
#                 if len(all_predictions) > 0 and len(all_targets) > 0:
#                     # Use numpy for comparison to avoid tensor boolean ambiguity
#                     matching = np.array(all_predictions[-len(predictions)*config.eval_every:]) == np.array(all_targets[-len(targets)*config.eval_every:])
#                     accuracy = matching.mean() if len(matching) > 0 else 0.0
#                 else:
#                     accuracy = 0.0

#                 logger.info(f"Batch {batch_idx+1}/{len(train_loader)}, Loss: {accumulated_loss:.4f}, "
#                       f"Reward: {np.mean(epoch_rewards[-config.eval_every:]) if epoch_rewards else 0:.4f}, "
#                       f"Accuracy: {accuracy:.4f}")

#                 accumulated_loss = 0

#             # Save model periodically
#             if (batch_idx + 1) % config.save_every == 0:
#                 os.makedirs(config.output_dir, exist_ok=True)
#                 torch.save(student_model.state_dict(), f"{config.output_dir}/student_model_step{batch_idx+1}.pt")
#                 logger.info(f"Model saved at step {batch_idx+1}")

#         except Exception as e:
#             logger.error(f"Error in batch {batch_idx}: {e}")
#             # In case of error, try to clear memory
#             if config.device.type == 'cuda':
#                 torch.cuda.empty_cache()
#             continue

#     # Perform final update for any remaining accumulated gradients
#     if (len(train_loader) % config.gradient_accumulation_steps) != 0:
#         optimizer.step()
#         optimizer.zero_grad()

#     # Convert lists to numpy arrays with safe handling of empty lists
#     if len(all_predictions) > 0 and len(all_targets) > 0:
#         all_predictions = np.array(all_predictions)
#         all_targets = np.array(all_targets)
#     else:
#         all_predictions = np.array([])
#         all_targets = np.array([])

#     return np.mean(epoch_rewards) if epoch_rewards else 0.0, np.mean(epoch_losses) if epoch_losses else 0.0, all_predictions, all_targets



In [None]:
import torch
import numpy as np
import logging
from tqdm import tqdm
import torch.nn as nn

logger = logging.getLogger(__name__)

def evaluate(model, eval_loader, config):
    model.eval()
    all_predictions = []
    all_targets = []
    total_loss = 0.0
    total_batches = 0
    loss_fct = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(eval_loader, desc="Evaluating")):
            try:
                input_ids = batch["input_ids"].to(config.device)
                attention_mask = batch["attention_mask"].to(config.device)
                target_tokens = batch["labels"].to(config.device)

                if input_ids.size(0) < 2:
                    continue

                logits = model(input_ids, attention_mask)

                if logits.size(-1) != target_tokens.max() + 1:
                    logger.warning(f"Skipping loss calculation in eval batch {batch_idx} - vocab size mismatch")
                    continue

                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, logits.size(-1)), target_tokens.view(-1))
                total_loss += loss.item()
                total_batches += 1

                predictions = torch.argmax(logits, dim=-1).detach().cpu().numpy()
                targets = target_tokens.detach().cpu().numpy()

                all_predictions.extend(predictions.flatten())
                all_targets.extend(targets.flatten())

            except Exception as e:
                logger.error(f"Error in evaluation batch {batch_idx}: {e}")
                continue

    avg_loss = total_loss / total_batches if total_batches > 0 else float('inf')
    perplexity = np.exp(avg_loss) if avg_loss < float('inf') else float('inf')

    accuracy = (np.array(all_predictions) == np.array(all_targets)).mean() if len(all_predictions) > 0 else 0.0

    print("\n========== Model Evaluation Results ==========")
    print(f"Average Loss      : {avg_loss:.4f}")
    print(f"Perplexity        : {perplexity:.4f}")
    print(f"Accuracy          : {accuracy * 100:.2f}%")
    print("==============================================\n")

    return {
        "loss": avg_loss,
        "perplexity": perplexity,
        "accuracy": accuracy
    }


In [None]:
# Cell 11: Main function - Data loading and preprocessing
config = Config()
def main():
    logger.info(f"Using device: {config.device}")

    # Load data processor
    data_processor = DataProcessor(config)
    hindi_samples, chhattisgarhi_samples = data_processor.load_dataset()

    # Create datasets with smaller max_length to save memory
    logger.info("Creating teacher (Hindi) dataset")
    teacher_dataset = NextWordPredictionDataset(hindi_samples, data_processor.teacher_tokenizer, config.max_length)

    logger.info("Creating student (Chhattisgarhi) dataset")
    student_dataset = NextWordPredictionDataset(chhattisgarhi_samples, data_processor.student_tokenizer, config.max_length)

    # Check if datasets were created properly
    if len(teacher_dataset) == 0 or len(student_dataset) == 0:
        logger.error("One of the datasets is empty. Please check your data processing.")
        return

    # Split student dataset into train/validation
    train_size = int(0.9 * len(student_dataset))
    val_size = len(student_dataset) - train_size

    # Use PyTorch's random_split for more reliable dataset splitting
    train_dataset, val_dataset = torch.utils.data.random_split(
        student_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    logger.info(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")

    # Create data loaders with smaller batch sizes to save memory
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        pin_memory=False  # Set to True if you have enough system memory
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        pin_memory=False  # Set to True if you have enough system memory
    )

    return train_loader, val_loader, data_processor

In [None]:
# Cell 12: Model loading function
def load_models(config, data_processor):
    # Try to clear memory before loading models
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    logger.info(f"Loading teacher model: {config.teacher_model_name}")
    teacher_model = TeacherModel(config.teacher_model_name).to(config.device)

    logger.info(f"Loading student model: {config.student_model_name}")
    student_model = StudentModel(config.student_model_name).to(config.device)

    # Resize token embeddings for the student model
    student_model.model.resize_token_embeddings(len(data_processor.student_tokenizer))

    return teacher_model, student_model

In [None]:
from torch.optim import AdamW  # Use PyTorch's implementation

def setup_optimizer_scheduler(student_model, config, train_loader):
    # Lower learning rate - current is too high
    config.rl_lr = 1e-5  # Reduced from 1e-4

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": config.weight_decay,
        },
        {
            "params": [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=config.rl_lr)

    # Increase warmup steps
    config.warmup_steps = 1000  # Increased from current setting

    # Calculate total steps accounting for gradient accumulation
    total_steps = (len(train_loader) // config.gradient_accumulation_steps) * config.epochs

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=min(config.warmup_steps, total_steps // 5),  # More warmup
        num_training_steps=total_steps
    )

    return optimizer, scheduler


In [None]:
# Cell 14: Training loop function
def train_model(teacher_model, student_model, train_loader, val_loader, optimizer, scheduler, config):
    best_accuracy = 0
    try:
        for epoch in range(config.epochs):
            logger.info(f"\nEpoch {epoch+1}/{config.epochs}")

            # Train with RL
            avg_reward, avg_loss, all_preds, all_targets = train_rl(
                teacher_model, student_model, train_loader, optimizer, scheduler, config
            )

            # Calculate training accuracy
            if len(all_preds) > 0 and len(all_targets) > 0:
                # Use numpy for comparison to avoid tensor boolean ambiguity
                matching = all_preds == all_targets
                train_accuracy = matching.mean() if len(matching) > 0 else 0.0
            else:
                train_accuracy = 0.0

            logger.info(f"Epoch {epoch+1} complete. Avg Reward: {avg_reward:.4f}, "
                  f"Avg Loss: {avg_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")

            # Clear memory before evaluation
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Evaluate
            eval_metrics = evaluate(student_model, val_loader, config)
            logger.info(f"Epoch {epoch+1} validation metrics: Loss: {eval_metrics['loss']:.4f}, "
                        f"Perplexity: {eval_metrics['perplexity']:.4f}, "
                        f"Accuracy: {eval_metrics['accuracy']:.4f}")

            # Save best model
            if eval_metrics['accuracy'] > best_accuracy:
                best_accuracy = eval_metrics['accuracy']
                os.makedirs(config.output_dir, exist_ok=True)
                torch.save(student_model.state_dict(), f"{config.output_dir}/best_student_model.pt")
                logger.info(f"New best model saved with accuracy: {best_accuracy:.4f}")

            # Clear memory after each epoch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    except KeyboardInterrupt:
        logger.info("Training interrupted by user")
    except Exception as e:
        logger.error(f"Training error: {e}")
    finally:
        # Make sure we save the model even if training is interrupted
        logger.info("Saving final model...")
        os.makedirs(config.output_dir, exist_ok=True)
        torch.save(student_model.state_dict(), f"{config.output_dir}/final_student_model.pt")

    logger.info("Training complete!")

    # Final evaluation and analysis
    logger.info("Performing final evaluation...")
    final_metrics = evaluate(student_model, val_loader, config)

    # Output final results
    logger.info("Final Results:")
    logger.info(f"Best validation accuracy: {best_accuracy:.4f}")
    logger.info(f"Final validation perplexity: {final_metrics['perplexity']:.4f}")

    return student_model, final_metrics, best_accuracy

In [None]:
def run_training_pipeline():
    # Step 1: Load and preprocess data
    train_loader, val_loader, data_processor = main()

    # Step 2: Load models
    teacher_model, student_model = load_models(config, data_processor)

    # Step 3: Setup optimizer and scheduler
    optimizer, scheduler = setup_optimizer_scheduler(student_model, config, train_loader)

    # Step 4: Train the model
    student_model, final_metrics, best_accuracy = train_model(
        teacher_model, student_model, train_loader, val_loader, optimizer, scheduler, config
    )

    return student_model, final_metrics, best_accuracy


In [None]:
# Cell 16: Execute training (uncomment to run)
student_model, final_metrics, best_accuracy = run_training_pipeline()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/507 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/5.65M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/38.6k [00:00<?, ?B/s]

nllb.py:   0%|          | 0.00/9.49k [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/5.05M [00:00<?, ?B/s]

nllb_lang_pairs.py:   0%|          | 0.00/81.9k [00:00<?, ?B/s]

The repository for allenai/nllb contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/allenai/nllb.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Repo card metadata block was not found. Setting CardData to empty.


Downloading data:   0%|          | 0.00/13.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/83150 [00:00<?, ? examples/s]

Processing samples: 100%|██████████| 74835/74835 [00:00<00:00, 2809429.84it/s]
Processing samples: 100%|██████████| 74835/74835 [00:00<00:00, 2916971.70it/s]


pytorch_model.bin:   0%|          | 0.00/135M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Training: 100%|██████████| 8404/8404 [28:55<00:00,  4.84it/s]
Evaluating: 100%|██████████| 934/934 [00:44<00:00, 20.78it/s]



Average Loss      : 10.5046
Perplexity        : 36483.1658
Accuracy          : 0.23%



Training: 100%|██████████| 8404/8404 [30:39<00:00,  4.57it/s]
Evaluating: 100%|██████████| 934/934 [00:41<00:00, 22.56it/s]



Average Loss      : 10.5918
Perplexity        : 39807.3435
Accuracy          : 0.14%



Evaluating: 100%|██████████| 934/934 [00:41<00:00, 22.33it/s]



Average Loss      : 10.5918
Perplexity        : 39807.3435
Accuracy          : 0.14%

