In [3]:
print("Installing dependencies...")
!pip install -q transformers torch sentence-transformers datasets pandas scikit-learn langdetect

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader, random_split
import json
import pandas as pd
import numpy as np
from datetime import datetime
from google.colab import drive, files
import os
from tqdm.auto import tqdm
from typing import List, Dict, Tuple
import random
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)




Installing dependencies...
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m981.5/981.5 kB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for langdetect (setup.py) ... [?25l[?25hdone


In [4]:
print("\n" + "="*60)
print("DATASET LOADING OPTIONS")
print("="*60)
print("\nChoose how to load your dataset:")
print("1. Upload CSV file")
print("2. Upload JSON file")
print("3. Load from Google Drive")
print("4. Use sample data for testing")
print("5. Generate synthetic multilingual data")

dataset_choice = input("\nEnter choice (1-5): ").strip()

def load_csv_dataset(file_path):
    """Load dataset from CSV"""
    df = pd.read_csv(file_path)
    print(f"\n Loaded CSV with {len(df)} rows")
    print(f"   Columns: {list(df.columns)}")
    return df

def load_json_dataset(file_path):
   
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    if isinstance(data, list):
        df = pd.DataFrame(data)
    else:
        df = pd.DataFrame([data])
    
    print(f"\n Loaded JSON with {len(df)} rows")
    print(f"   Columns: {list(df.columns)}")
    return df

def generate_synthetic_data(num_samples=1000):
   
    print(f"\nüìä Generating {num_samples} synthetic samples...")

    templates = {
        'travel': {
            'en': ['cheap hotels in {}', 'flights to {}', 'tourist attractions in {}', 
                   'restaurants in {}', 'hotels near {}'],
            'es': ['hoteles baratos en {}', 'vuelos a {}', 'atracciones tur√≠sticas en {}',
                   'restaurantes en {}', 'hoteles cerca de {}'],
            'fr': ['h√¥tels bon march√© √† {}', 'vols vers {}', 'attractions touristiques √† {}',
                   'restaurants √† {}', 'h√¥tels pr√®s de {}'],
            'de': ['g√ºnstige Hotels in {}', 'Fl√ºge nach {}', 'Touristenattraktionen in {}',
                   'Restaurants in {}', 'Hotels in der N√§he von {}'],
            'it': ['hotel economici a {}', 'voli per {}', 'attrazioni turistiche a {}',
                   'ristoranti a {}', 'hotel vicino a {}'],
            'zh': ['{}ÁöÑ‰æøÂÆúÈÖíÂ∫ó', 'È£ûÂæÄ{}', '{}ÁöÑÊóÖÊ∏∏ÊôØÁÇπ', '{}ÁöÑÈ§êÂéÖ', '{}ÈôÑËøëÁöÑÈÖíÂ∫ó'],
        },
        'shopping': {
            'en': ['buy {} online', 'best {} deals', '{} for sale', 'cheap {}', 'discount {}'],
            'es': ['comprar {} en l√≠nea', 'mejores ofertas de {}', '{} en venta', '{} barato', '{} con descuento'],
            'fr': ['acheter {} en ligne', 'meilleures offres {}', '{} √† vendre', '{} pas cher', '{} en promotion'],
            'de': ['{} online kaufen', 'beste {} Angebote', '{} zu verkaufen', 'g√ºnstige {}', '{} im Angebot'],
            'it': ['acquista {} online', 'migliori offerte {}', '{} in vendita', '{} economico', '{} in sconto'],
            'zh': ['Âú®Á∫øË¥≠‰π∞{}', 'ÊúÄÂ•ΩÁöÑ{}‰ºòÊÉ†', '{}Âá∫ÂîÆ', '‰æøÂÆúÁöÑ{}', '{}ÊäòÊâ£'],
        },
        'information': {
            'en': ['how to {}', 'what is {}', 'where to find {}', 'when to {}', 'why {}'],
            'es': ['c√≥mo {}', 'qu√© es {}', 'd√≥nde encontrar {}', 'cu√°ndo {}', 'por qu√© {}'],
            'fr': ['comment {}', 'qu\'est-ce que {}', 'o√π trouver {}', 'quand {}', 'pourquoi {}'],
            'de': ['wie man {}', 'was ist {}', 'wo finde ich {}', 'wann {}', 'warum {}'],
            'it': ['come {}', 'cos\'√® {}', 'dove trovare {}', 'quando {}', 'perch√© {}'],
            'zh': ['Â¶Ç‰Ωï{}', '‰ªÄ‰πàÊòØ{}', 'Âú®Âì™ÈáåÊâæÂà∞{}', '‰ªÄ‰πàÊó∂ÂÄô{}', '‰∏∫‰ªÄ‰πà{}'],
        }
    }
    
    cities = ['Paris', 'London', 'Tokyo', 'New York', 'Rome', 'Berlin', 'Madrid', 
              'Barcelona', 'Amsterdam', 'Dubai', 'Singapore', 'Mumbai', 'Beijing']
    
    products = ['laptop', 'phone', 'camera', 'shoes', 'watch', 'book', 'tablet', 
                'headphones', 'backpack', 'jacket']
    
    topics = ['cook', 'learn', 'travel', 'exercise', 'meditate', 'study', 'work', 
              'relax', 'save money', 'start business']
    
    data = []
    languages = list(templates['travel'].keys())
    
    for _ in tqdm(range(num_samples), desc="Generating data"):
        domain = random.choice(list(templates.keys()))
        
      
        if domain == 'travel':
            entity = random.choice(cities)
        elif domain == 'shopping':
            entity = random.choice(products)
        else:
            entity = random.choice(topics)
        
      
        lang1 = random.choice(languages)
        lang2 = random.choice(languages)
        
        template1 = random.choice(templates[domain][lang1])
        template2 = random.choice(templates[domain][lang2])
        
        query = template1.format(entity)
        

        if random.random() < 0.7:
            doc = template2.format(entity)
            label = 1
        else:
    
            if random.random() < 0.5:
       
                other_entity = random.choice(cities if domain == 'travel' else 
                                            products if domain == 'shopping' else topics)
                while other_entity == entity:
                    other_entity = random.choice(cities if domain == 'travel' else 
                                                products if domain == 'shopping' else topics)
                doc = template2.format(other_entity)
            else:
              
                other_domain = random.choice([d for d in templates.keys() if d != domain])
                other_template = random.choice(templates[other_domain][lang2])
                other_entity = random.choice(cities if other_domain == 'travel' else 
                                            products if other_domain == 'shopping' else topics)
                doc = other_template.format(other_entity)
            label = 0
        
        data.append({
            'query': query,
            'doc': doc,
            'label': label,
            'query_lang': lang1,
            'doc_lang': lang2,
            'domain': domain
        })
    
    df = pd.DataFrame(data)
    print(f"\n Generated {len(df)} samples")
    print(f"   Positive pairs: {(df['label']==1).sum()} ({(df['label']==1).sum()/len(df)*100:.1f}%)")
    print(f"   Negative pairs: {(df['label']==0).sum()} ({(df['label']==0).sum()/len(df)*100:.1f}%)")
    print(f"   Languages: {df['query_lang'].unique().tolist()}")
    print(f"   Domains: {df['domain'].unique().tolist()}")
    
    return df


if dataset_choice == '1':
    print("\n Upload your CSV file")
    print("Expected format: columns [query, doc, label]")
    print("  - query: search query text")
    print("  - doc: document text")
    print("  - label: 1 (similar) or 0 (not similar)")
    uploaded = files.upload()
    file_path = list(uploaded.keys())[0]
    df = load_csv_dataset(file_path)
    
elif dataset_choice == '2':
    print("\n Upload your JSON file")
    print("Expected format: list of {query, doc, label}")
    uploaded = files.upload()
    file_path = list(uploaded.keys())[0]
    df = load_json_dataset(file_path)
    
elif dataset_choice == '3':
    file_path = input("\nEnter Google Drive path (e.g., /content/drive/MyDrive/data.csv): ").strip()
    if file_path.endswith('.csv'):
        df = load_csv_dataset(file_path)
    elif file_path.endswith('.json'):
        df = load_json_dataset(file_path)
    else:
        print("Unsupported file format. Using sample data.")
        df = generate_synthetic_data(100)
        
elif dataset_choice == '5':
    num_samples = int(input("\nHow many samples to generate? (default: 1000): ").strip() or "1000")
    df = generate_synthetic_data(num_samples)
    
else: 
    print("\n Using sample data")
    df = generate_synthetic_data(100)


print("\n" + "="*60)
print("DATASET OVERVIEW")
print("="*60)
print(f"\nTotal samples: {len(df)}")
print(f"\nFirst few rows:")
print(df.head())

print(f"\nDataset statistics:")
print(df.describe(include='all'))


required_cols = ['query', 'doc', 'label']
missing_cols = [col for col in required_cols if col not in df.columns]

if missing_cols:
    print(f"\n  Missing columns: {missing_cols}")
    print("Attempting to auto-detect columns...")
    
  
    for col in missing_cols:
        possible_names = [c for c in df.columns if col.lower() in c.lower()]
        if possible_names:
            df.rename(columns={possible_names[0]: col}, inplace=True)
            print(f"   Renamed '{possible_names[0]}' ‚Üí '{col}'")


if 'label' in df.columns:
    unique_labels = df['label'].unique()
    print(f"\nLabel distribution:")
    print(df['label'].value_counts())
    

    if set(unique_labels).issubset({0, 1}):
        print(" Labels are binary (0/1)")
    else:
        print(f"Found labels: {unique_labels}")
        print("Converting to binary (0/1)...")
        df['label'] = df['label'].apply(lambda x: 1 if x > 0 else 0)


print("\nüßπ Cleaning data...")
initial_len = len(df)
df = df.dropna(subset=['query', 'doc', 'label'])
df = df[df['query'].str.strip() != '']
df = df[df['doc'].str.strip() != '']
print(f"   Removed {initial_len - len(df)} invalid rows")
print(f"   Final dataset size: {len(df)}")


cleaned_path = '/content/cleaned_dataset.csv'
df.to_csv(cleaned_path, index=False)
print(f"\n Cleaned dataset saved to: {cleaned_path}")




DATASET LOADING OPTIONS

Choose how to load your dataset:
1. Upload CSV file
2. Upload JSON file
3. Load from Google Drive
4. Use sample data for testing
5. Generate synthetic multilingual data

 Using sample data

üìä Generating 100 synthetic samples...


Generating data:   0%|          | 0/100 [00:00<?, ?it/s]


 Generated 100 samples
   Positive pairs: 70 (70.0%)
   Negative pairs: 30 (30.0%)
   Languages: ['en', 'zh', 'fr', 'it', 'es', 'de']
   Domains: ['information', 'travel', 'shopping']

DATASET OVERVIEW

Total samples: 100

First few rows:
                 query                       doc  label query_lang doc_lang  \
0  where to find learn                  ‰ªÄ‰πàÊòØlearn      1         en       zh   
1             ‰∏∫‰ªÄ‰πàlearn                   Â¶Ç‰Ωïlearn      1         zh       zh   
2     flights to Paris    hoteles cerca de Paris      1         en       es   
3          ‰∏∫‰ªÄ‰πàexercise              ‰ªÄ‰πàÊó∂ÂÄôexercise      1         zh       zh   
4     when to meditate  d√≥nde encontrar meditate      1         en       es   

        domain  
0  information  
1  information  
2       travel  
3  information  
4  information  

Dataset statistics:
                                 query           doc       label query_lang  \
count                              100           100

In [5]:
print("\n" + "="*60)
print("DATA SPLITTING")
print("="*60)


train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

print(f"\nSplit ratios:")
print(f"  Train: {train_ratio*100}%")
print(f"  Validation: {val_ratio*100}%")
print(f"  Test: {test_ratio*100}%")


train_df, temp_df = train_test_split(
    df, 
    test_size=(1 - train_ratio),
    stratify=df['label'],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=test_ratio/(test_ratio + val_ratio),
    stratify=temp_df['label'],
    random_state=42
)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_df)} samples")
print(f"  Validation: {len(val_df)} samples")
print(f"  Test: {len(test_df)} samples")


print(f"\nLabel distribution:")
print(f"  Train - Positive: {(train_df['label']==1).sum()}, Negative: {(train_df['label']==0).sum()}")
print(f"  Val   - Positive: {(val_df['label']==1).sum()}, Negative: {(val_df['label']==0).sum()}")
print(f"  Test  - Positive: {(test_df['label']==1).sum()}, Negative: {(test_df['label']==0).sum()}")


train_df.to_csv('/content/train.csv', index=False)
val_df.to_csv('/content/val.csv', index=False)
test_df.to_csv('/content/test.csv', index=False)

print("\n Data splits saved")


DATA SPLITTING

Split ratios:
  Train: 70.0%
  Validation: 15.0%
  Test: 15.0%

Dataset sizes:
  Train: 69 samples
  Validation: 15 samples
  Test: 16 samples

Label distribution:
  Train - Positive: 48, Negative: 21
  Val   - Positive: 11, Negative: 4
  Test  - Positive: 11, Negative: 5

 Data splits saved


In [6]:
class BiEncoderModel(nn.Module):
    
    def __init__(self, model_name='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 
                 pooling='mean'):
        super().__init__()
        
        print(f"Loading model: {model_name}")
        self.encoder = AutoModel.from_pretrained(model_name)
        self.pooling = pooling
        self.config = self.encoder.config
        
        print(f"‚úÖ Model loaded")
        print(f"   Hidden size: {self.config.hidden_size}")
        print(f"   Pooling strategy: {pooling}")
        
    def mean_pooling(self, token_embeddings, attention_mask):
        
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask
    
    def cls_pooling(self, token_embeddings):
     
        return token_embeddings[:, 0, :]
    
    def max_pooling(self, token_embeddings, attention_mask):
  
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        token_embeddings[input_mask_expanded == 0] = -1e9
        return torch.max(token_embeddings, 1)[0]
    
    def forward(self, input_ids, attention_mask):
     
        
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
       
        if self.pooling == 'mean':
            embeddings = self.mean_pooling(outputs.last_hidden_state, attention_mask)
        elif self.pooling == 'cls':
            embeddings = self.cls_pooling(outputs.last_hidden_state)
        elif self.pooling == 'max':
            embeddings = self.max_pooling(outputs.last_hidden_state, attention_mask)
        else:
            raise ValueError(f"Unknown pooling: {self.pooling}")
        
      
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        
        return embeddings

print("\n Model class defined")



 Model class defined


In [7]:
class MultilingualSearchDataset(Dataset):
   
    
    def __init__(self, dataframe, tokenizer, max_length=128):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
      
        query_encoding = self.tokenizer(
            str(row['query']),
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
       
        doc_encoding = self.tokenizer(
            str(row['doc']),
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'query_input_ids': query_encoding['input_ids'].squeeze(0),
            'query_attention_mask': query_encoding['attention_mask'].squeeze(0),
            'doc_input_ids': doc_encoding['input_ids'].squeeze(0),
            'doc_attention_mask': doc_encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(row['label'], dtype=torch.float32)
        }

In [25]:
print("\n" + "="*60)
print("TRAINING CONFIGURATION")
print("="*60)


print("\nAvailable models:")
print("1. paraphrase-multilingual-MiniLM-L12-v2 (Fast, 118M params)")
print("2. xlm-roberta-base (Balanced, 279M params)")
print("3. distilbert-base-multilingual-cased (Faster, 135M params)")

model_choice = input("\nChoose model (1-3, default: 1): ").strip() or "1"

model_names = {
    '1': 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
    '2': 'xlm-roberta-base',
    '3': 'distilbert-base-multilingual-cased'
}

CONFIG = {
    'model_name': model_names.get(model_choice, model_names['1']),
    'pooling': 'mean',
    'max_length': 128,
    'batch_size': 16,  
    'learning_rate': 2e-5,
    'num_epochs': 5,
    'warmup_steps': 5,
    'weight_decay': 0.01,
    'max_grad_norm': 1.0,
    'log_interval': 10,
    'save_steps': 500,
}

print("\nCurrent configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

customize = input("\nCustomize configuration? (y/n, default: n): ").strip().lower()

if customize == 'y':
    CONFIG['batch_size'] = int(input(f"Batch size (current: {CONFIG['batch_size']}): ") or CONFIG['batch_size'])
    CONFIG['learning_rate'] = float(input(f"Learning rate (current: {CONFIG['learning_rate']}): ") or CONFIG['learning_rate'])
    CONFIG['num_epochs'] = int(input(f"Number of epochs (current: {CONFIG['num_epochs']}): ") or CONFIG['num_epochs'])

print("\nConfiguration set")
print("\nFinal configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")



TRAINING CONFIGURATION

Available models:
1. paraphrase-multilingual-MiniLM-L12-v2 (Fast, 118M params)
2. xlm-roberta-base (Balanced, 279M params)
3. distilbert-base-multilingual-cased (Faster, 135M params)

Current configuration:
  model_name: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
  pooling: mean
  max_length: 128
  batch_size: 16
  learning_rate: 2e-05
  num_epochs: 5
  warmup_steps: 5
  weight_decay: 0.01
  max_grad_norm: 1.0
  log_interval: 10
  save_steps: 500

Configuration set

Final configuration:
  model_name: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
  pooling: mean
  max_length: 128
  batch_size: 16
  learning_rate: 2e-05
  num_epochs: 5
  warmup_steps: 5
  weight_decay: 0.01
  max_grad_norm: 1.0
  log_interval: 10
  save_steps: 500


In [26]:
# Initialize Training
print("\n" + "="*60)
print("INITIALIZING TRAINING")
print("="*60)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n Device: {device}")

print(f"\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])
print(" Tokenizer loaded")

print("\n Creating datasets...")
train_dataset = MultilingualSearchDataset(train_df, tokenizer, CONFIG['max_length'])
val_dataset = MultilingualSearchDataset(val_df, tokenizer, CONFIG['max_length'])
test_dataset = MultilingualSearchDataset(test_df, tokenizer, CONFIG['max_length'])

print(f" Datasets created")
print(f"   Train: {len(train_dataset)} samples")
print(f"   Val: {len(val_dataset)} samples")
print(f"   Test: {len(test_dataset)} samples")

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f" DataLoaders created")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")

print(f"\n Initializing model...")
model = BiEncoderModel(CONFIG['model_name'], CONFIG['pooling'])
model = model.to(device)

num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f" Model initialized")
print(f"   Total parameters: {num_params:,}")
print(f"   Trainable parameters: {num_trainable:,}")

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

num_training_steps = len(train_loader) * CONFIG['num_epochs']
num_warmup_steps = CONFIG['warmup_steps']

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=CONFIG['learning_rate'],
    total_steps=num_training_steps,
    pct_start=num_warmup_steps/num_training_steps,
    anneal_strategy='cos'
)

print(f" Optimizer: AdamW")
print(f" Scheduler: OneCycleLR")
print(f"   Total steps: {num_training_steps}")
print(f"   Warmup steps: {num_warmup_steps}")

criterion = nn.CosineEmbeddingLoss()
print(f" Loss: CosineEmbeddingLoss")


INITIALIZING TRAINING

 Device: cpu

Loading tokenizer...
 Tokenizer loaded

 Creating datasets...
 Datasets created
   Train: 69 samples
   Val: 15 samples
   Test: 16 samples
 DataLoaders created
   Train batches: 5
   Val batches: 1
   Test batches: 1

 Initializing model...
Loading model: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
‚úÖ Model loaded
   Hidden size: 384
   Pooling strategy: mean
 Model initialized
   Total parameters: 117,653,760
   Trainable parameters: 117,653,760
 Optimizer: AdamW
 Scheduler: OneCycleLR
   Total steps: 25
   Warmup steps: 5
 Loss: CosineEmbeddingLoss


In [14]:
def train_epoch(model, train_loader, optimizer, scheduler, criterion, device, epoch, config):
  
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        query_input_ids = batch['query_input_ids'].to(device)
        query_attention_mask = batch['query_attention_mask'].to(device)
        doc_input_ids = batch['doc_input_ids'].to(device)
        doc_attention_mask = batch['doc_attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        # Forward pass
        query_emb = model(query_input_ids, query_attention_mask)
        doc_emb = model(doc_input_ids, doc_attention_mask)
        
        # Compute loss
        targets = (labels * 2) - 1  # Convert 0/1 to -1/1
        loss = criterion(query_emb, doc_emb, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
        
        optimizer.step()
        scheduler.step()
        
        # Track metrics
        total_loss += loss.item()
        
        # Calculate accuracy
        similarities = torch.nn.functional.cosine_similarity(query_emb, doc_emb)
        predictions = (similarities > 0.5).float()
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
        
        # Update progress bar
        current_lr = scheduler.get_last_lr()[0]
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%',
            'lr': f'{current_lr:.2e}'
        })
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100 * correct / total
    
    return avg_loss, accuracy


def evaluate(model, data_loader, criterion, device):
  
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    all_similarities = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            query_input_ids = batch['query_input_ids'].to(device)
            query_attention_mask = batch['query_attention_mask'].to(device)
            doc_input_ids = batch['doc_input_ids'].to(device)
            doc_attention_mask = batch['doc_attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            # Forward pass
            query_emb = model(query_input_ids, query_attention_mask)
            doc_emb = model(doc_input_ids, doc_attention_mask)
            
            # Compute loss
            targets = (labels * 2) - 1
            loss = criterion(query_emb, doc_emb, targets)
            
            total_loss += loss.item()
            
            # Calculate accuracy
            similarities = torch.nn.functional.cosine_similarity(query_emb, doc_emb)
            predictions = (similarities > 0.5).float()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
            
            # Store for analysis
            all_similarities.extend(similarities.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(data_loader)
    accuracy = 100 * correct / total
    
    # Calculate additional metrics
    similarities_np = np.array(all_similarities)
    labels_np = np.array(all_labels)
    
    # True positives, false positives, etc.
    tp = np.sum((similarities_np > 0.5) & (labels_np == 1))
    fp = np.sum((similarities_np > 0.5) & (labels_np == 0))
    tn = np.sum((similarities_np <= 0.5) & (labels_np == 0))
    fn = np.sum((similarities_np <= 0.5) & (labels_np == 1))
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision * 100,
        'recall': recall * 100,
        'f1': f1 * 100
    }


In [27]:
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_precision': [],
    'val_recall': [],
    'val_f1': [],
    'learning_rates': []
}

best_val_loss = float('inf')
best_model_state = None


for epoch in range(1, CONFIG['num_epochs'] + 1):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch}/{CONFIG['num_epochs']}")
    print('='*60)
# Train
train_loss, train_acc = train_epoch(
    model, train_loader, optimizer, scheduler, criterion, 
    device, epoch, CONFIG
)


print("\nValidating...")
val_metrics = evaluate(model, val_loader, criterion, device)


history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_metrics['loss'])
history['val_acc'].append(val_metrics['accuracy'])
history['val_precision'].append(val_metrics['precision'])
history['val_recall'].append(val_metrics['recall'])
history['val_f1'].append(val_metrics['f1'])
history['learning_rates'].append(optimizer.param_groups[0]['lr'])

print(f"\nüìä Epoch {epoch} Results:")
print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
print(f"  Val Loss: {val_metrics['loss']:.4f} | Val Acc: {val_metrics['accuracy']:.2f}%")
print(f"  Val Precision: {val_metrics['precision']:.2f}%")
print(f"  Val Recall: {val_metrics['recall']:.2f}%")
print(f"  Val F1: {val_metrics['f1']:.2f}%")


if val_metrics['loss'] < best_val_loss:
    best_val_loss = val_metrics['loss']
    best_model_state = model.state_dict().copy()
    print(f" New best model! (Val Loss: {best_val_loss:.4f})")


STARTING TRAINING

Epoch 1/5

Epoch 2/5

Epoch 3/5

Epoch 4/5

Epoch 5/5


Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]


Validating...


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]


üìä Epoch 5 Results:
  Train Loss: 0.2554 | Train Acc: 94.20%
  Val Loss: 0.2168 | Val Acc: 86.67%
  Val Precision: 90.91%
  Val Recall: 90.91%
  Val F1: 90.91%
 New best model! (Val Loss: 0.2168)
