In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import csv
import os
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import random_split
from sklearn.model_selection import GroupShuffleSplit
from scipy.stats import spearmanr
from tqdm import tqdm  # make sure you have tqdm installed
from datetime import datetime
import json
import numpy as np

torch.manual_seed(42)  # Before random_split
base_path = ""

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# -----------------------------
# 1. Dataset Class
# -----------------------------

class TwitterURLDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.data = self.read_file(file_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.groups = self.group_data()
        self.group_ids = [i for i, _ in enumerate(self.groups)]
        
        # Pre-compute maximum group size
        self.max_group_size = max(len(group) for group in self.groups)  # Cap at 10
        print(f"Max group size: {self.max_group_size}")
        print(f"unique group sizes = {set([len(group) for group in self.groups])}")



    def read_file(self, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            reader = csv.reader(f, delimiter='\t')
            for row in reader:
                if len(row) != 4:
                    continue
                sentence1, sentence2, score_str, url = row
                try:
                    score = int(score_str.strip('()').split(',')[0])
                except:
                    continue
                data.append((sentence1.strip(), sentence2.strip(), score))
        return data


    def group_data(self):
        groups = []
        current_group = []
        last_sentence1 = None
        for item in self.data:
            sentence1, sentence2, score = item
            if sentence1 != last_sentence1 and current_group:
                groups.append(current_group)
                current_group = []
            current_group.append(item)
            last_sentence1 = sentence1
        if current_group:
            groups.append(current_group)
        return groups

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        group = self.groups[idx]
        anchor_text = group[0][0]
        
        # Pad group to max_group_size with dummy entries if needed
        padded_group = group + [("", "", 0)] * (self.max_group_size - len(group))
        
        sentences = [item[1] for item in padded_group]
        scores = torch.tensor([item[2] for item in padded_group], dtype=torch.float32)
        
        anchor_input = self.tokenizer(
            anchor_text, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors="pt",
            return_token_type_ids=False  # Add this line
        )
        
        sentence_inputs = [
            self.tokenizer(
                s, 
                padding='max_length', 
                truncation=True, 
                max_length=self.max_length, 
                return_tensors="pt",
                return_token_type_ids=False  # Add this line
            ) for s in sentences
        ]
        
        return {
            'anchor_input': {k: v.squeeze(0) for k, v in anchor_input.items()},
            'sentence_inputs': [{k: v.squeeze(0) for k, v in inp.items()} for inp in sentence_inputs],
            'scores': scores,
            'mask': torch.tensor([1]*len(group) + [0]*(self.max_group_size - len(group)))  # Mask for real vs padded
        }


# -----------------------------
# 2. Model
# -----------------------------

class SentenceEncoder(nn.Module):
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
        super(SentenceEncoder, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        # Use CLS token representation
        cls_embedding = output.last_hidden_state[:, 0]
        cls_embedding = F.normalize(cls_embedding, p=2, dim=1)  # normalize for cosine similarity
        return cls_embedding

# -----------------------------
# 3. Margin Ranking Loss
# -----------------------------

margin_ranking_loss = nn.MarginRankingLoss(margin=0.2)

# -----------------------------
# 4. Example Training Step
# -----------------------------

def train_one_epoch(train_loader, model, optimizer, device):
    model.train()
    total_loss = 0
    total_pairs = 0
    
    # Wrap your loader with tqdm for a progress bar
    prog_bar = tqdm(train_loader, desc="Training", leave=False)
    
    for batch in prog_bar:
        optimizer.zero_grad()
        
        anchor_input = {k: v.to(device) for k, v in batch['anchor_input'].items()}
        anchor_embed = model(**anchor_input)
        
        sentence_inputs = {
            k: v.to(device) for k, v in batch['sentence_inputs'].items()
        }
        sentence_embeds = model(**sentence_inputs)
        
        batch_size, max_group_size = batch['scores'].shape
        sentence_embeds = sentence_embeds.view(batch_size, max_group_size, -1)
        
        sims = F.cosine_similarity(
            anchor_embed.unsqueeze(1).expand_as(sentence_embeds),
            sentence_embeds,
            dim=-1
        ) * batch['mask'].to(device)
        
        # Vectorized ranking loss calculation
        scores = batch['scores'].to(device)
        mask = batch['mask'].to(device).bool()
        
        # Create all valid pairs
        pos_mask = (scores.unsqueeze(2) > scores.unsqueeze(1)) & mask.unsqueeze(2) & mask.unsqueeze(1)
        
        if pos_mask.any():
            pos_sims = sims.unsqueeze(2).expand(-1, -1, max_group_size)[pos_mask]
            neg_sims = sims.unsqueeze(1).expand(-1, max_group_size, -1)[pos_mask]
            
            loss = F.margin_ranking_loss(
                pos_sims,
                neg_sims,
                torch.ones_like(pos_sims),
                margin=0.2,
                reduction='sum'
            )
            total_pairs += pos_mask.sum().item()
            total_loss += loss.item()
            
            loss = loss / pos_mask.sum()  # Normalize
            loss.backward()
            optimizer.step()
    
    return total_loss / total_pairs if total_pairs > 0 else 0


def validate2(loader, model, device):
    model.train()
    total_loss = 0
    total_pairs = 0
    
    # Wrap your loader with tqdm for a progress bar
    prog_bar = tqdm(loader, desc="Testing", leave=False)
    
    for batch in prog_bar:
        optimizer.zero_grad()
        
        anchor_input = {k: v.to(device) for k, v in batch['anchor_input'].items()}
        anchor_embed = model(**anchor_input)
        
        sentence_inputs = {
            k: v.to(device) for k, v in batch['sentence_inputs'].items()
        }
        sentence_embeds = model(**sentence_inputs)
        
        batch_size, max_group_size = batch['scores'].shape
        sentence_embeds = sentence_embeds.view(batch_size, max_group_size, -1)
        # Vectorized ranking loss calculation
        scores = batch['scores'].to(device)
        mask = batch['mask'].to(device).bool()
        
        # Create all valid pairs
        pos_mask = (scores.unsqueeze(2) > scores.unsqueeze(1)) & mask.unsqueeze(2) & mask.unsqueeze(1)
        
        if pos_mask.any():
            pos_sims = sims.unsqueeze(2).expand(-1, -1, max_group_size)[pos_mask]
            neg_sims = sims.unsqueeze(1).expand(-1, max_group_size, -1)[pos_mask]
            
            loss = F.margin_ranking_loss(
                pos_sims,
                neg_sims,
                torch.ones_like(pos_sims),
                margin=0.2,
                reduction='sum'
            )
            total_pairs += pos_mask.sum().item()
            total_loss += loss.item()
            
            loss = loss / pos_mask.sum()  # Normalize
            loss.backward()
            optimizer.step()
    
    return total_loss / total_pairs if total_pairs > 0 else 0


def validate(loader, model, device):
    model.eval()
    total_loss = 0
    all_scores = []
    all_sims = []
    all_anchor_spearmans = []
    
    with torch.no_grad():
        for batch in loader:
            # Get embeddings
            anchor_input = {k: v.to(device) for k, v in batch['anchor_input'].items()}
            anchor_embed = model(**anchor_input)
            
            sentence_inputs = {k: v.to(device) for k, v in batch['sentence_inputs'].items()}
            sentence_embeds = model(**sentence_inputs)
            
            # Reshape and mask
            batch_size = len(batch['scores'])
            max_group_size = batch['scores'].shape[1]
            sentence_embeds = sentence_embeds.view(batch_size, max_group_size, -1)
            mask = batch['mask'].to(device)
            
            # Compute similarities
            sims = F.cosine_similarity(
                anchor_embed.unsqueeze(1).expand_as(sentence_embeds),
                sentence_embeds,
                dim=-1
            ) * mask  # Mask out paddings
            
            # Collect non-padded scores/similarities for correlation
            scores = batch['scores'].to(device)
            for i in range(batch_size):
                # real_indices = torch.where(mask[i] == 1)[0]
                # all_scores.extend(scores[i, real_indices].cpu().numpy())
                # all_sims.extend(sims[i, real_indices].cpu().numpy())
                valid_idx = torch.where(mask[i] == 1)[0]
                if len(valid_idx) < 2:  # Need at least 2 for correlation
                    continue

                # Calculate per-anchor Spearman
                anchor_scores = scores[i, valid_idx].cpu().numpy()
                anchor_preds = sims[i, valid_idx].cpu().numpy()

                if len(np.unique(anchor_scores)) > 1:  # Check for rankable scores
                    spear = spearmanr(anchor_scores, anchor_preds).correlation
                    all_anchor_spearmans.append(spear)

    # Compute metrics
    mean_spearman = np.mean(all_anchor_spearmans) if all_anchor_spearmans else 0.0
    # avg_loss = total_loss / total_pairs if total_pairs > 0 else 0.0
    
    return mean_spearman
    

def collate_fn(batch):
    # Find actual max group size in this batch
    batch_max_len = max(len(item['mask'][item['mask'] == 1]) for item in batch)
    
    # Truncate all inputs to this length
    anchor_inputs = {
        k: torch.stack([item['anchor_input'][k] for item in batch])
        for k in batch[0]['anchor_input']
    }
    
    sentence_inputs = {
        k: torch.stack([
            inp[k] 
            for item in batch 
            for inp in item['sentence_inputs'][:batch_max_len]  # Only take needed elements
        ])
        for k in batch[0]['sentence_inputs'][0]
    }
    
    scores = torch.stack([item['scores'][:batch_max_len] for item in batch])
    masks = torch.stack([item['mask'][:batch_max_len] for item in batch])
    
    return {
        'anchor_input': anchor_inputs,
        'sentence_inputs': sentence_inputs,
        'scores': scores,
        'mask': masks
    }


tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')


def group_aware_split(dataset, test_size=0.1, random_state=42):
    splitter = GroupShuffleSplit(
        n_splits=1, 
        test_size=test_size, 
        random_state=random_state
    )
    train_idx, val_idx = next(splitter.split(
        range(len(dataset)), 
        groups=dataset.group_ids  # Critical: Group by anchor IDs
    ))
    return (
        torch.utils.data.Subset(dataset, train_idx),
        torch.utils.data.Subset(dataset, val_idx)
    )

# Load datasets
full_train_dataset = TwitterURLDataset(f'{base_path}paraphrase_dataset_emnlp2017/Twitter_URL_Corpus_train.txt', tokenizer)
test_dataset = TwitterURLDataset(f'{base_path}paraphrase_dataset_emnlp2017/Twitter_URL_Corpus_test.txt', tokenizer)

# Group-aware split (90% train, 10% val)
train_dataset, val_dataset = group_aware_split(full_train_dataset, test_size=0.1)

# For all splits (train/val/test), use the same collate_fn
train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    collate_fn=collate_fn  # Same as training
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=32, 
    shuffle=False,  # No shuffling for val/test
    collate_fn=collate_fn  # Same collate function!
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=32, 
    shuffle=False,
    collate_fn=collate_fn  # Consistency is key
)

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

Max group size: 20
unique group sizes = {8, 9, 10, 20}
Max group size: 20
unique group sizes = {10, 20}
Train size: 4197
Val size: 467
Test size: 1010


In [10]:
num_epochs = 40
def train_model(model, label, optimizer):
    best_val_spearman = float('-inf')
    log_data = {
        'training_start': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        'config': {
            'model': label,
            'num_epochs': num_epochs,
            'optimizer': str(optimizer.__class__.__name__),
            'device': str(device)
        },
        'epochs': []
    }
    
    for epoch in range(num_epochs):
        # Initialize logging
        epoch_start = datetime.now()
        
        print(f"Epoch {epoch+1}/{num_epochs}")
    
        train_loss = train_one_epoch(train_loader, model, optimizer, device)
        val_spearman = validate(val_loader, model, device)

        # Track epoch data
        epoch_log = {
            'epoch': epoch + 1,
            'train_loss': float(train_loss),
            'val_spearman': float(val_spearman),
            'duration_seconds': (datetime.now() - epoch_start).total_seconds(),
            'is_best': False
        }
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Spearman (-ve):   {val_spearman:.4f}")
    
        if val_spearman > best_val_spearman:
            best_val_spearman = val_spearman
            torch.save(model.state_dict(), f'saved_models3/{label}.pt')
            epoch_log['is_best'] = True
            print(f"  -> Best model saved (val_loss improved to {best_val_spearman:.4f})")

        log_data['epochs'].append(epoch_log)
        
        # Save logs after each epoch (in case of crash)
        with open(f'model_logs3/{label}_logs.json', 'w') as f:
            json.dump(log_data, f, indent=2)
    # Add final metadata
    log_data['training_end'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    log_data['best_val_spearman'] = float(best_val_spearman)
    
    # Final save
    with open(f'model_logs3/{label}_logs.json', 'w') as f:
        json.dump(log_data, f, indent=2)

print("Initiliazing model")
model = SentenceEncoder().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

test_spearman = validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman Initially: {test_spearman:.4f}")

print("Starting training")
train_model(model, f"euc_sbert", optimizer)
test_spearman = validate(test_loader, model, device)
print(f"Final Test Spearman with sbert: {test_spearman:.4f}")
del model

# Final test
print("Initiliazing model")
model = SentenceEncoder().to(device)
model.load_state_dict(torch.load("saved_models3/euc_sbert.pt"))
test_spearman = validate(test_loader, model, device)  # Flip sign back
print(f"Best Model - Test Spearman: {test_spearman:.4f}")

Initiliazing model
Test Spearman Initially: 0.5860
Starting training
Epoch 1/40


                                                                                                                                     

Epoch 1/40
  Train Loss: 0.1006
  Val Spearman (-ve):   0.6584
  -> Best model saved (val_loss improved to 0.6584)
Epoch 2/40


                                                                                                                                     

Epoch 2/40
  Train Loss: 0.0837
  Val Spearman (-ve):   0.6693
  -> Best model saved (val_loss improved to 0.6693)
Epoch 3/40


                                                                                                                                     

Epoch 3/40
  Train Loss: 0.0767
  Val Spearman (-ve):   0.6755
  -> Best model saved (val_loss improved to 0.6755)
Epoch 4/40


                                                                                                                                     

Epoch 4/40
  Train Loss: 0.0721
  Val Spearman (-ve):   0.6799
  -> Best model saved (val_loss improved to 0.6799)
Epoch 5/40


                                                                                                                                     

Epoch 5/40
  Train Loss: 0.0680
  Val Spearman (-ve):   0.6815
  -> Best model saved (val_loss improved to 0.6815)
Epoch 6/40


                                                                                                                                     

Epoch 6/40
  Train Loss: 0.0646
  Val Spearman (-ve):   0.6837
  -> Best model saved (val_loss improved to 0.6837)
Epoch 7/40


                                                                                                                                     

Epoch 7/40
  Train Loss: 0.0606
  Val Spearman (-ve):   0.6816
Epoch 8/40


                                                                                                                                     

Epoch 8/40
  Train Loss: 0.0579
  Val Spearman (-ve):   0.6800
Epoch 9/40


                                                                                                                                     

Epoch 9/40
  Train Loss: 0.0547
  Val Spearman (-ve):   0.6790
Epoch 10/40


                                                                                                                                     

Epoch 10/40
  Train Loss: 0.0521
  Val Spearman (-ve):   0.6791
Epoch 11/40


                                                                                                                                     

Epoch 11/40
  Train Loss: 0.0497
  Val Spearman (-ve):   0.6714
Epoch 12/40


                                                                                                                                     

Epoch 12/40
  Train Loss: 0.0475
  Val Spearman (-ve):   0.6740
Epoch 13/40


                                                                                                                                     

Epoch 13/40
  Train Loss: 0.0453
  Val Spearman (-ve):   0.6715
Epoch 14/40


                                                                                                                                     

Epoch 14/40
  Train Loss: 0.0433
  Val Spearman (-ve):   0.6695
Epoch 15/40


                                                                                                                                     

Epoch 15/40
  Train Loss: 0.0412
  Val Spearman (-ve):   0.6694
Epoch 16/40


                                                                                                                                     

Epoch 16/40
  Train Loss: 0.0399
  Val Spearman (-ve):   0.6602
Epoch 17/40


                                                                                                                                     

Epoch 17/40
  Train Loss: 0.0382
  Val Spearman (-ve):   0.6644
Epoch 18/40


                                                                                                                                     

Epoch 18/40
  Train Loss: 0.0370
  Val Spearman (-ve):   0.6571
Epoch 19/40


                                                                                                                                     

Epoch 19/40
  Train Loss: 0.0354
  Val Spearman (-ve):   0.6544
Epoch 20/40


                                                                                                                                     

Epoch 20/40
  Train Loss: 0.0339
  Val Spearman (-ve):   0.6553
Epoch 21/40


                                                                                                                                     

Epoch 21/40
  Train Loss: 0.0331
  Val Spearman (-ve):   0.6568
Epoch 22/40


                                                                                                                                     

Epoch 22/40
  Train Loss: 0.0317
  Val Spearman (-ve):   0.6564
Epoch 23/40


                                                                                                                                     

Epoch 23/40
  Train Loss: 0.0304
  Val Spearman (-ve):   0.6495
Epoch 24/40


                                                                                                                                     

Epoch 24/40
  Train Loss: 0.0297
  Val Spearman (-ve):   0.6524
Epoch 25/40


                                                                                                                                     

Epoch 25/40
  Train Loss: 0.0288
  Val Spearman (-ve):   0.6457
Epoch 26/40


                                                                                                                                     

Epoch 26/40
  Train Loss: 0.0281
  Val Spearman (-ve):   0.6503
Epoch 27/40


                                                                                                                                     

Epoch 27/40
  Train Loss: 0.0273
  Val Spearman (-ve):   0.6492
Epoch 28/40


                                                                                                                                     

Epoch 28/40
  Train Loss: 0.0268
  Val Spearman (-ve):   0.6434
Epoch 29/40


                                                                                                                                     

Epoch 29/40
  Train Loss: 0.0260
  Val Spearman (-ve):   0.6411
Epoch 30/40


                                                                                                                                     

Epoch 30/40
  Train Loss: 0.0250
  Val Spearman (-ve):   0.6424
Epoch 31/40


                                                                                                                                     

Epoch 31/40
  Train Loss: 0.0246
  Val Spearman (-ve):   0.6422
Epoch 32/40


                                                                                                                                     

Epoch 32/40
  Train Loss: 0.0237
  Val Spearman (-ve):   0.6293
Epoch 33/40


                                                                                                                                     

Epoch 33/40
  Train Loss: 0.0233
  Val Spearman (-ve):   0.6302
Epoch 34/40


                                                                                                                                     

Epoch 34/40
  Train Loss: 0.0229
  Val Spearman (-ve):   0.6321
Epoch 35/40


                                                                                                                                     

Epoch 35/40
  Train Loss: 0.0226
  Val Spearman (-ve):   0.6441
Epoch 36/40


                                                                                                                                     

Epoch 36/40
  Train Loss: 0.0218
  Val Spearman (-ve):   0.6382
Epoch 37/40


                                                                                                                                     

Epoch 37/40
  Train Loss: 0.0214
  Val Spearman (-ve):   0.6292
Epoch 38/40


                                                                                                                                     

Epoch 38/40
  Train Loss: 0.0213
  Val Spearman (-ve):   0.6297
Epoch 39/40


                                                                                                                                     

Epoch 39/40
  Train Loss: 0.0209
  Val Spearman (-ve):   0.6330
Epoch 40/40


                                                                                                                                     

Epoch 40/40
  Train Loss: 0.0202
  Val Spearman (-ve):   0.6294
Final Test Spearman with sbert: 0.6036
Initiliazing model


  model.load_state_dict(torch.load("saved_models3/euc_sbert.pt"))


Best Model - Test Spearman: 0.6720


In [6]:
# Rest of training logic (same as before)
model = SentenceEncoder().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
num_epochs = 40

test_spearman = validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman Initially: {test_spearman:.4f}")

best_val_spearman = float('-inf')
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    train_loss = train_one_epoch(train_loader, model, optimizer, device)
    val_spearman = validate(val_loader, model, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Spearman (-ve):   {val_spearman:.4f}")

    if val_spearman > best_val_spearman:
        best_val_spearman = val_spearman
        torch.save(model.state_dict(), 'best_model.pt')
        print(f"  -> Best model saved (val_spearman improved to {best_val_spearman:.4f})")

# Final test
model.load_state_dict(torch.load('best_model.pt'))
test_spearman = validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman: {test_spearman:.4f}")

Test Spearman Initially: 0.5860
Epoch 1/3


                                                                                                                                     

Epoch 1/3
  Train Loss: 0.1006
  Val Spearman (-ve):   0.6584
  -> Best model saved (val_loss improved to 0.6584)
Epoch 2/3


                                                                                                                                     

Epoch 2/3
  Train Loss: 0.0837
  Val Spearman (-ve):   0.6693
  -> Best model saved (val_loss improved to 0.6693)
Epoch 3/3


                                                                                                                                     

Epoch 3/3
  Train Loss: 0.0767
  Val Spearman (-ve):   0.6755
  -> Best model saved (val_loss improved to 0.6755)


  model.load_state_dict(torch.load('best_model.pt'))


Test Spearman: 0.6701


In [7]:
# Final test
model.load_state_dict(torch.load('best_model.pt'))
test_spearman = validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman: {test_spearman:.4f}")

  model.load_state_dict(torch.load('best_model.pt'))


Test Spearman: 0.6701
