In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [1]:
import json
read = json.load(open("/content/qrecc-training.json"))

In [None]:
read

In [1]:
# -----------------------------
# Step 1: Teacher Preparation with TopiOCQA (using load_dataset with plain_text config, no Pyserini)
# -----------------------------
# !pip install datasets transformers accelerate torch --quiet

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json # Import json to save results
from huggingface_hub import login
from google.colab import userdata

# Log in to Hugging Face Hub
try:
    login(token=userdata.get("HF_TOKEN"))
    print("Logged in to Hugging Face.")
except Exception as e:
    print(f"Error logging in to Hugging Face: {e}")


# -----------------------------
# 1. Load TopiOCQA dataset (using svakulenk0/qrecc as it was explored earlier)
# -----------------------------
# The error indicates that dataset scripts are no longer supported.
# We will load the dataset directly from the Hugging Face Hub without specifying a config that uses a script.
try:
    ds = load_dataset('json',data_files="/content/qrecc-training.json") # Load the train split directly
    print("Dataset loaded successfully.")
except Exception as e:
    print(f"Error loading dataset: {e}")
    ds = None # Set ds to None if loading fails


# -----------------------------
# 2. Initialize the teacher LLM
# -----------------------------
# llm_name = "mistralai/Mistral-7B-Instruct-v0.2"  # Use an open-access model
llm_name = "google/gemma-2b-it" # Using a smaller model that might fit in Colab's GPU memory
tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = AutoModelForCausalLM.from_pretrained(llm_name, device_map="cuda") # Use auto for device mapping

Logged in to Hugging Face.
Dataset loaded successfully.


tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [2]:
# torch.cuda.is_available()
# model.to('cuda')
# model.generate(**tokenizer("hello", return_tensors="pt").to('cuda'))

tensor([[     2,  17534, 235269,    496, 235303, 235262,   3648,    604,    476,
           1703,    577,   1501,    970,   4451,    978,  18516,    577,   1461,
            675,  44601, 235265,    109]], device='cuda:0')

In [5]:
import torch

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("CUDA cache cleared.")
else:
    print("CUDA is not available.")

CUDA cache cleared.


In [4]:
# -----------------------------
from tqdm import tqdm
# 3. Function to rewrite queries
# -----------------------------
def rewrite_query(turns):
    """
    Input: list of conversation turns like [{'role': 'user', 'text': '...'}, ...]
    Output: rewritten last query string
    """
    if not isinstance(turns, list) or not all(isinstance(t, dict) for t in turns):
        print("⚠️ Warning: Input to rewrite_query is not in the expected format.")
        return ""

    # Build conversation context
    context = "\n".join([f"{t['role']}: {t['text']}" for t in turns])
    prompt = f"Rewrite the last user question so it is self-contained.\n\nConversation:\n{context}\nRewritten query:"

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    # print(inputs)
    model.eval()
    output = model.generate(
            **inputs,
            # max_new_tokens=64,
            pad_token_id=tokenizer.eos_token_id
        )
    # print(output)

    rewritten = tokenizer.decode(output[0], skip_special_tokens=True)

    # Extract only rewritten query text
    if "Rewritten query:" in rewritten:
        rewritten = rewritten.split("Rewritten query:")[-1].strip()

    return rewritten.strip()


# -----------------------------
# 4. Placeholder Teacher Scores
# -----------------------------
def teacher_scores(turns):
    # print(turns)
    rewritten = rewrite_query(turns)
    scores = {}  # Placeholder for real retrieval scores
    return scores, rewritten


# -----------------------------
# 5. Process Samples
# -----------------------------
results = []
max_samples = 100  # adjust as needed

for i, sample in tqdm(enumerate(ds['train'])):
    if i >= max_samples:
        break

    try:
        context_list = sample.get("Context", [])
        conversation = []
        for j, text in enumerate(context_list):
            role = "user" if j % 2 == 0 else "assistant"
            conversation.append({"role": role, "text": text})

        current_question = sample.get("Question", "")
        if current_question:
            conversation.append({"role": "user", "text": current_question})
        # print('processing')
        scores, rewritten_q = teacher_scores(conversation)
        # print('processed')

        results.append({
            "conversation_id": sample.get("Conversation_no", i),
            "turn_no": sample.get("Turn_no", 0),
            "original_question": current_question,
            "rewritten_query": rewritten_q,
            "similarity_scores": scores
        })

        print(f"✅ Processed sample {i+1}/{max_samples}")

    except Exception as e:
        print(f"⚠️ Skipping sample {i} due to error: {e}")


# -----------------------------
# 6. Save Outputs
# -----------------------------
output_path = "/content/teacher_scores_qrecc.json"
with open(output_path, "w") as f:
    json.dump(results, f, indent=2)

print(f"\n✅ Saved {len(results)} rewritten queries to: {output_path}")

# Preview few examples
print("\n🔍 Example output:")
for r in results[:3]:
    print(json.dumps(r, indent=2))

1it [00:00,  1.64it/s]

✅ Processed sample 1/100


2it [00:01,  1.12it/s]

✅ Processed sample 2/100


3it [00:02,  1.02it/s]

✅ Processed sample 3/100


4it [00:03,  1.07s/it]

✅ Processed sample 4/100


5it [00:05,  1.14s/it]

✅ Processed sample 5/100


6it [00:06,  1.18s/it]

✅ Processed sample 6/100


7it [00:07,  1.04it/s]

✅ Processed sample 7/100


8it [00:08,  1.02it/s]

✅ Processed sample 8/100


9it [00:09,  1.02s/it]

✅ Processed sample 9/100


10it [00:10,  1.06s/it]

✅ Processed sample 10/100


11it [00:11,  1.12s/it]

✅ Processed sample 11/100


12it [00:12,  1.18s/it]

✅ Processed sample 12/100


13it [00:13,  1.11it/s]

✅ Processed sample 13/100


14it [00:14,  1.05it/s]

✅ Processed sample 14/100


15it [00:15,  1.09it/s]

✅ Processed sample 15/100


16it [00:15,  1.11it/s]

✅ Processed sample 16/100


17it [00:17,  1.03it/s]

✅ Processed sample 17/100


18it [00:18,  1.03s/it]

✅ Processed sample 18/100


19it [00:19,  1.06s/it]

✅ Processed sample 19/100


20it [00:20,  1.17s/it]

✅ Processed sample 20/100


21it [00:22,  1.25s/it]

✅ Processed sample 21/100


22it [00:22,  1.04it/s]

✅ Processed sample 22/100


23it [00:23,  1.01it/s]

✅ Processed sample 23/100


24it [00:24,  1.05s/it]

✅ Processed sample 24/100


25it [00:25,  1.09s/it]

✅ Processed sample 25/100


26it [00:27,  1.14s/it]

✅ Processed sample 26/100


27it [00:28,  1.17s/it]

✅ Processed sample 27/100


28it [00:29,  1.24s/it]

✅ Processed sample 28/100


29it [00:31,  1.29s/it]

✅ Processed sample 29/100


30it [00:31,  1.01it/s]

✅ Processed sample 30/100


31it [00:32,  1.16it/s]

✅ Processed sample 31/100


32it [00:32,  1.16it/s]

✅ Processed sample 32/100


33it [00:33,  1.17it/s]

✅ Processed sample 33/100


34it [00:34,  1.11it/s]

✅ Processed sample 34/100


35it [00:35,  1.06it/s]

✅ Processed sample 35/100


36it [00:36,  1.01it/s]

✅ Processed sample 36/100


37it [00:38,  1.10s/it]

✅ Processed sample 37/100


38it [00:39,  1.19s/it]

✅ Processed sample 38/100


39it [00:40,  1.07s/it]

✅ Processed sample 39/100


40it [00:41,  1.01it/s]

✅ Processed sample 40/100


41it [00:42,  1.01s/it]

✅ Processed sample 41/100


42it [00:43,  1.06s/it]

✅ Processed sample 42/100


43it [00:44,  1.11s/it]

✅ Processed sample 43/100


44it [00:46,  1.18s/it]

✅ Processed sample 44/100


45it [00:47,  1.23s/it]

✅ Processed sample 45/100


46it [00:48,  1.27s/it]

✅ Processed sample 46/100


47it [00:50,  1.34s/it]

✅ Processed sample 47/100


48it [00:50,  1.03s/it]

✅ Processed sample 48/100


49it [00:51,  1.10it/s]

✅ Processed sample 49/100


50it [00:51,  1.17it/s]

✅ Processed sample 50/100


51it [00:52,  1.24it/s]

✅ Processed sample 51/100


52it [00:53,  1.08it/s]

✅ Processed sample 52/100


53it [00:55,  1.02s/it]

✅ Processed sample 53/100


54it [00:56,  1.12s/it]

✅ Processed sample 54/100


55it [00:57,  1.19s/it]

✅ Processed sample 55/100


56it [00:58,  1.15s/it]

✅ Processed sample 56/100


57it [01:00,  1.27s/it]

✅ Processed sample 57/100


58it [01:00,  1.04s/it]

✅ Processed sample 58/100


59it [01:01,  1.02s/it]

✅ Processed sample 59/100


60it [01:03,  1.05s/it]

✅ Processed sample 60/100


61it [01:04,  1.05s/it]

✅ Processed sample 61/100


62it [01:05,  1.04s/it]

✅ Processed sample 62/100


63it [01:06,  1.13s/it]

✅ Processed sample 63/100


64it [01:07,  1.15s/it]

✅ Processed sample 64/100


65it [01:08,  1.13s/it]

✅ Processed sample 65/100


66it [01:10,  1.19s/it]

✅ Processed sample 66/100


67it [01:11,  1.21s/it]

✅ Processed sample 67/100


68it [01:11,  1.04it/s]

✅ Processed sample 68/100


69it [01:12,  1.04it/s]

✅ Processed sample 69/100


70it [01:13,  1.13it/s]

✅ Processed sample 70/100


71it [01:14,  1.19it/s]

✅ Processed sample 71/100


72it [01:15,  1.05it/s]

✅ Processed sample 72/100


73it [01:16,  1.04s/it]

✅ Processed sample 73/100


74it [01:17,  1.14s/it]

✅ Processed sample 74/100


75it [01:19,  1.18s/it]

✅ Processed sample 75/100


76it [01:20,  1.29s/it]

✅ Processed sample 76/100


77it [01:22,  1.36s/it]

✅ Processed sample 77/100


78it [01:23,  1.41s/it]

✅ Processed sample 78/100


79it [01:24,  1.07s/it]

✅ Processed sample 79/100


80it [01:25,  1.06s/it]

✅ Processed sample 80/100


81it [01:26,  1.07s/it]

✅ Processed sample 81/100


82it [01:27,  1.10s/it]

✅ Processed sample 82/100


83it [01:28,  1.08s/it]

✅ Processed sample 83/100


84it [01:29,  1.02s/it]

✅ Processed sample 84/100


85it [01:30,  1.13s/it]

✅ Processed sample 85/100


86it [01:32,  1.21s/it]

✅ Processed sample 86/100


87it [01:33,  1.32s/it]

✅ Processed sample 87/100


88it [01:35,  1.38s/it]

✅ Processed sample 88/100


89it [01:36,  1.48s/it]

✅ Processed sample 89/100


90it [01:38,  1.57s/it]

✅ Processed sample 90/100


91it [01:39,  1.25s/it]

✅ Processed sample 91/100


92it [01:40,  1.14s/it]

✅ Processed sample 92/100


93it [01:40,  1.05s/it]

✅ Processed sample 93/100


94it [01:41,  1.01s/it]

✅ Processed sample 94/100


95it [01:43,  1.09s/it]

✅ Processed sample 95/100


96it [01:44,  1.08s/it]

✅ Processed sample 96/100


97it [01:45,  1.13s/it]

✅ Processed sample 97/100


98it [01:45,  1.13it/s]

✅ Processed sample 98/100


99it [01:46,  1.36it/s]

✅ Processed sample 99/100


100it [01:46,  1.07s/it]

✅ Processed sample 100/100

✅ Saved 100 rewritten queries to: /content/teacher_scores_qrecc.json

🔍 Example output:
{
  "conversation_id": 1,
  "turn_no": 1,
  "original_question": "What can you tell me about Gary Cherone?",
  "rewritten_query": "What can you tell me about Gary Cherone?",
  "similarity_scores": {}
}
{
  "conversation_id": 1,
  "turn_no": 2,
  "original_question": "Did Gary sing well?",
  "rewritten_query": "Sure, here is the rewritten question:\n\nDid Gary Cherone sing well in his career?",
  "similarity_scores": {}
}
{
  "conversation_id": 1,
  "turn_no": 3,
  "original_question": "What significant fact can you tell me about Gary that you liked?",
  "rewritten_query": "Sure, here's the rewritten question:\n\nWhat is Gary Cherone's most significant",
  "similarity_scores": {}
}





In [9]:
# -----------------------------
# Step 2: Student Model Training (DiSCo)
# -----------------------------
# Run this first in Colab:
# !pip install datasets transformers accelerate torch --quiet

import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import numpy as np
import gc

# Clear GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()
    print(f"✅ GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"✅ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ No GPU available")

# -----------------------------
# 1. Load Teacher Outputs
# -----------------------------
try:
    teacher_output_path = "/content/teacher_scores_qrecc.json"
    with open(teacher_output_path, "r") as f:
        teacher_data = json.load(f)
    print(f"✅ Loaded {len(teacher_data)} teacher samples")
except FileNotFoundError:
    print("❌ Error: teacher_scores_qrecc.json not found!")
    print("Please run the teacher code first to generate this file.")
    teacher_data = []

# -----------------------------
# 2. Load Document Collection (simplified)
# -----------------------------
# In a real scenario, you'd load the full document corpus
# For now, we'll create dummy document embeddings
# You should replace this with actual document retrieval

class DocumentIndex:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        # In practice, you'd pre-compute and store all document embeddings
        self.doc_cache = {}

    def encode_text(self, text, max_length=256):
        """Encode text into sparse representation"""
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
            padding=True
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            # Use MLM head to create sparse representation
            embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token

        return embeddings

    def get_hard_negatives(self, query, k=5):
        """Sample hard negative documents"""
        # In practice, use BM25 or teacher model to mine hard negatives
        # For now, return random document IDs
        return [f"doc_{i}" for i in range(k)]


# -----------------------------
# 3. Student Dataset
# -----------------------------
class ConversationalDataset(Dataset):
    def __init__(self, teacher_data, tokenizer, max_length=256):
        self.data = teacher_data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Build full conversation context
        conv_text = item.get("original_question", "")
        rewritten_text = item.get("rewritten_query", "")

        # Tokenize conversation and rewritten query
        conv_inputs = self.tokenizer(
            conv_text,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_length,
            padding="max_length"
        )

        rewritten_inputs = self.tokenizer(
            rewritten_text,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_length,
            padding="max_length"
        )

        return {
            "conv_input_ids": conv_inputs["input_ids"].squeeze(0),
            "conv_attention_mask": conv_inputs["attention_mask"].squeeze(0),
            "rewritten_input_ids": rewritten_inputs["input_ids"].squeeze(0),
            "rewritten_attention_mask": rewritten_inputs["attention_mask"].squeeze(0),
            "similarity_scores": item.get("similarity_scores", {}),
            "conversation_id": item.get("conversation_id", idx)
        }


# -----------------------------
# 4. Student Model (SPLADE-like architecture)
# -----------------------------
class SPLADEStudent(nn.Module):
    def __init__(self, model_name="distilbert-base-uncased"):  # Use DistilBERT for Colab
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.vocab_size = self.encoder.config.vocab_size
        self.mlm_head = nn.Linear(self.encoder.config.hidden_size, self.vocab_size)

    def forward(self, input_ids, attention_mask):
        """
        Encode input and produce sparse representation
        """
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

        # Get token representations
        token_embeddings = outputs.last_hidden_state

        # Apply MLM head to get logits for each vocab token
        logits = self.mlm_head(token_embeddings)

        # Apply log(1 + ReLU) and max pooling over sequence
        sparse_repr = torch.log1p(F.relu(logits))
        sparse_repr = torch.max(sparse_repr, dim=1)[0]  # Max pool over sequence

        return sparse_repr

    def compute_similarity(self, query_repr, doc_repr):
        """Compute dot product similarity"""
        return torch.sum(query_repr * doc_repr, dim=-1)


# -----------------------------
# 5. DiSCo Loss (KL Divergence on similarity scores)
# -----------------------------
def disco_loss(student_scores, teacher_scores, temperature=1.0):
    """
    DiSCo distillation loss using KL divergence
    Args:
        student_scores: tensor of shape (batch_size, num_docs)
        teacher_scores: tensor of shape (batch_size, num_docs)
    """
    # Convert to distributions with softmax
    student_dist = F.log_softmax(student_scores / temperature, dim=-1)
    teacher_dist = F.softmax(teacher_scores / temperature, dim=-1)

    # KL divergence loss
    loss = F.kl_div(student_dist, teacher_dist, reduction="batchmean")

    return loss * (temperature ** 2)


# -----------------------------
# 6. FLOPS Regularization (for sparsity control)
# -----------------------------
def flops_regularization(query_repr, doc_repr, lambda_q=1e-3, lambda_d=5e-4):
    """L1 regularization for sparsity"""
    q_reg = torch.mean(torch.sum(torch.abs(query_repr), dim=-1))
    d_reg = torch.mean(torch.sum(torch.abs(doc_repr), dim=-1))
    return lambda_q * q_reg + lambda_d * d_reg


# -----------------------------
# 7. Training Function
# -----------------------------
def train_student(
    model,
    dataloader,
    doc_index,
    optimizer,
    device,
    num_negatives=3,
    temperature=1.0,
    lambda_q=1e-3,
    lambda_d=5e-4
):
    model.train()
    total_loss = 0
    num_batches = 0

    for batch in tqdm(dataloader, desc="Training"):
        try:
            # Move batch to device
            conv_input_ids = batch["conv_input_ids"].to(device)
            conv_attention_mask = batch["conv_attention_mask"].to(device)
            rewritten_input_ids = batch["rewritten_input_ids"].to(device)
            rewritten_attention_mask = batch["rewritten_attention_mask"].to(device)

            optimizer.zero_grad()

            # Encode conversation with student
            conv_repr = model(conv_input_ids, conv_attention_mask)

            # Encode rewritten query (teacher representations)
            with torch.no_grad():
                teacher_repr = model(rewritten_input_ids, rewritten_attention_mask)

            # Sample documents (1 positive + negatives)
            batch_size = conv_repr.size(0)
            num_docs = num_negatives + 1

            # Create dummy document representations (replace with actual docs)
            # Using smaller random docs to save memory
            doc_reprs = torch.abs(torch.randn(batch_size, num_docs, model.vocab_size, device=device)) * 0.1

            # Compute student scores
            student_scores = []
            for i in range(batch_size):
                scores = model.compute_similarity(
                    conv_repr[i:i+1].expand(num_docs, -1),
                    doc_reprs[i]
                )
                student_scores.append(scores)
            student_scores = torch.stack(student_scores)

            # Compute teacher scores
            teacher_scores = []
            for i in range(batch_size):
                scores = model.compute_similarity(
                    teacher_repr[i:i+1].expand(num_docs, -1),
                    doc_reprs[i]
                )
                teacher_scores.append(scores)
            teacher_scores = torch.stack(teacher_scores)

            # DiSCo loss (KL divergence on similarity distributions)
            loss_kld = disco_loss(student_scores, teacher_scores, temperature)

            # FLOPS regularization for sparsity
            loss_reg = flops_regularization(conv_repr, doc_reprs.mean(dim=1), lambda_q, lambda_d)

            # Total loss
            loss = loss_kld + loss_reg

            # Check for NaN
            if torch.isnan(loss):
                print("⚠️ NaN loss detected, skipping batch")
                continue

            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

            # Free memory
            del conv_repr, teacher_repr, doc_reprs, student_scores, teacher_scores, loss

        except RuntimeError as e:
            if "out of memory" in str(e):
                print("⚠️ OOM error, skipping batch and clearing cache...")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                continue
            else:
                raise e

    return total_loss / max(num_batches, 1)


# -----------------------------
# 8. Main Training Loop
# -----------------------------
def main():
    # Check if teacher data is available
    if not teacher_data:
        print("❌ No teacher data available. Exiting.")
        return

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🔧 Using device: {device}")

    # Initialize tokenizer and model (using smaller model for Colab)
    model_name = "distilbert-base-uncased"  # Smaller and faster than BERT
    print(f"📦 Loading tokenizer and model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Create student model
    print("🏗️ Building student model...")
    student_model = SPLADEStudent(model_name=model_name).to(device)

    # Limit dataset size for Colab (adjust based on your needs)
    max_train_samples = min(100, len(teacher_data))
    limited_data = teacher_data[:max_train_samples]
    print(f"📊 Training on {max_train_samples} samples")

    # Create dataset and dataloader
    dataset = ConversationalDataset(limited_data, tokenizer)
    dataloader = DataLoader(
        dataset,
        batch_size=4,  # Small batch size for Colab free tier
        shuffle=True,
        num_workers=0  # Important for Colab
    )

    # Initialize document index
    doc_index = DocumentIndex(student_model.encoder, tokenizer, device)

    # Optimizer
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)

    # Training
    num_epochs = 3  # Reduced for faster training in Colab
    print(f"\n🚀 Starting training for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"{'='*50}")

        avg_loss = train_student(
            student_model,
            dataloader,
            doc_index,
            optimizer,
            device,
            num_negatives=3,  # Reduced for memory efficiency
            temperature=1.0
        )

        print(f"✅ Average Loss: {avg_loss:.4f}")

        # Clear cache after each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()

    # Save model
    output_dir = "/content/disco_student_model"
    print(f"\n💾 Saving model to {output_dir}...")
    student_model.encoder.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    torch.save(student_model.mlm_head.state_dict(), f"{output_dir}/mlm_head.pt")

    print(f"\n✅ Training complete! Model saved to {output_dir}")
    print(f"📊 Total parameters: {sum(p.numel() for p in student_model.parameters()):,}")


# -----------------------------
# 9. Run Training
# -----------------------------
if __name__ == "__main__":
    print("🎯 DiSCo Student Training - Colab Version")
    print("="*50)
    main()

✅ GPU Available: Tesla T4
✅ GPU Memory: 15.83 GB
✅ Loaded 100 teacher samples
🎯 DiSCo Student Training - Colab Version
🔧 Using device: cuda
📦 Loading tokenizer and model: distilbert-base-uncased


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

🏗️ Building student model...


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

📊 Training on 100 samples

🚀 Starting training for 3 epochs...

Epoch 1/3


Training: 100%|██████████| 25/25 [00:06<00:00,  3.98it/s]


✅ Average Loss: 5.4576

Epoch 2/3


Training: 100%|██████████| 25/25 [00:05<00:00,  4.37it/s]


✅ Average Loss: 2.9434

Epoch 3/3


Training: 100%|██████████| 25/25 [00:05<00:00,  4.38it/s]


✅ Average Loss: 2.6038

💾 Saving model to /content/disco_student_model...

✅ Training complete! Model saved to /content/disco_student_model
📊 Total parameters: 89,834,298
