In [None]:
import os
import sys
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import numpy as np
from pathlib import Path
import json
from sklearn.metrics import f1_score
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=384, hidden_dim=512, output_dim=256, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return F.normalize(x, p=2, dim=1)

In [None]:
class SiameseModel(nn.Module):
    def __init__(self, projection_dim=256):
        super().__init__()
        self.sbert = SentenceTransformer('all-MiniLM-L6-v2')
        self.projection_head = ProjectionHead(output_dim=projection_dim)
        self.projection_dim = projection_dim
        
        for param in self.sbert.parameters():
            param.requires_grad = True
    
    def encode(self, texts):
        embeddings = self.sbert.encode(texts, convert_to_tensor=True, show_progress_bar=False)
        return embeddings
    
    def forward(self, text_a, text_b):
        emb_a = self.encode(text_a)
        emb_b = self.encode(text_b)
        fv_a = self.projection_head(emb_a)
        fv_b = self.projection_head(emb_b)
        similarity = F.cosine_similarity(fv_a, fv_b)
        return similarity, fv_a, fv_b

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5):
        super().__init__()
        self.margin = margin
    
    def forward(self, similarities, labels):
        labels_float = labels.float()
        loss_positive = labels_float * (1 - similarities)
        loss_negative = (1 - labels_float) * torch.clamp(similarities - self.margin, min=0.0)
        loss = loss_positive + loss_negative
        return loss.mean()

In [None]:
class ParaphraseDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

In [None]:
data_path = os.path.join('data', 'quora_siamese_train.csv')
epochs = 20
batch_size = 32
learning_rate = 5e-5
projection_dim = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
print(f"Loading dataset from {data_path}")
df = pd.read_csv(data_path)

train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))
test_size = len(df) - train_size - val_size

train_df = df[:train_size]
val_df = df[train_size:train_size + val_size]
test_df = df[train_size + val_size:]

print(f"Dataset split: {len(train_df)} train, {len(val_df)} validation, {len(test_df)} test")

In [None]:
train_examples = [
    {'text_a': row['text_1'], 'text_b': row['text_2'], 'label': int(row['label'])}
    for _, row in train_df.iterrows()
]

val_examples = [
    {'text_a': row['text_1'], 'text_b': row['text_2'], 'label': int(row['label'])}
    for _, row in val_df.iterrows()
]

test_examples = [
    {'text_a': row['text_1'], 'text_b': row['text_2'], 'label': int(row['label'])}
    for _, row in test_df.iterrows()
]

train_dataset = ParaphraseDataset(train_examples)
val_dataset = ParaphraseDataset(val_examples)
test_dataset = ParaphraseDataset(test_examples)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
model = SiameseModel(projection_dim=projection_dim).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total: {total_params:,}")
print(f"Trainable: {trainable_params:,}")

In [None]:
criterion = ContrastiveLoss(margin=0.5)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scaler = GradScaler()

In [None]:
def find_optimal_threshold(similarities, labels):
    sims = similarities.cpu().numpy()
    labs = labels.cpu().numpy()
    best_f1 = 0
    best_thresh = 0.5
    for thresh in np.arange(0.3, 0.95, 0.05):
        preds = (sims >= thresh).astype(int)
        f1 = f1_score(labs, preds)
        if f1 > best_f1:
            best_f1 = f1
            best_thresh = thresh
    return best_thresh, best_f1

In [None]:
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'test_acc': []}
best_val_loss = float('inf')
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    
    pbar = tqdm(train_loader, total=len(train_loader))
    for batch in pbar:
        text_a = batch['text_a']
        text_b = batch['text_b']
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        with autocast():
            similarities, _, _ = model(text_a, text_b)
            loss = criterion(similarities, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item()
        predictions = (similarities >= 0.5).long()
        train_correct += (predictions == labels).sum().item()
        train_total += labels.size(0)
        
        pbar.set_postfix({'loss': loss.item()})
    
    avg_train_loss = train_loss / len(train_loader)
    train_acc = train_correct / train_total
    
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for batch in val_loader:
            text_a = batch['text_a']
            text_b = batch['text_b']
            labels = batch['label'].to(device)
            
            similarities, _, _ = model(text_a, text_b)
            loss = criterion(similarities, labels)
            
            val_loss += loss.item()
            predictions = (similarities >= 0.5).long()
            val_correct += (predictions == labels).sum().item()
            val_total += labels.size(0)
    
    avg_val_loss = val_loss / len(val_loader)
    val_acc = val_correct / val_total
    
    test_correct = 0
    test_total = 0
    
    with torch.no_grad():
        for batch in test_loader:
            text_a = batch['text_a']
            text_b = batch['text_b']
            labels = batch['label'].to(device)
            
            similarities, _, _ = model(text_a, text_b)
            predictions = (similarities >= 0.5).long()
            test_correct += (predictions == labels).sum().item()
            test_total += labels.size(0)
    
    test_acc = test_correct / test_total
    
    history['train_loss'].append(avg_train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(avg_val_loss)
    history['val_acc'].append(val_acc)
    history['test_acc'].append(test_acc)
    
    print(f"Training - Loss: {avg_train_loss:.4f}, Accuracy: {train_acc:.4f}")
    print(f"Validation - Loss: {avg_val_loss:.4f}, Accuracy: {val_acc:.4f}")
    print(f"Test - Accuracy: {test_acc:.4f}")
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best_model.pt'))
        print(f"New best model saved! (val_loss: {avg_val_loss:.4f})")

In [None]:
torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'final_model.pt'))
with open(os.path.join(checkpoint_dir, 'training_history.json'), 'w') as f:
    json.dump(history, f)

print(f"\nFinal model saved to {checkpoint_dir}/final_model.pt")
print(f"Training history saved to {checkpoint_dir}/training_history.json")