In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import csv
import os
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import random_split
from sklearn.model_selection import GroupShuffleSplit
from scipy.stats import spearmanr
from tqdm import tqdm  # make sure you have tqdm installed
import geoopt
# there was an error with geoopt scipy version compatibility, it was fixed by-
# nano /cronus_data/vraja/dysarthria/lib/python3.11/site-packages/geoopt/optim/rlinesearch.py
# added this to the import -
# try:
#     from scipy.optimize.linesearch import scalar_search_wolfe2, scalar_search_armijo
# except ImportError:
#     from scipy.optimize._linesearch import scalar_search_wolfe2, scalar_search_armijo
torch.manual_seed(42)  # Before random_split
base_path = ""

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# -----------------------------
# 1. Dataset Class
# -----------------------------

class TwitterURLDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.data = self.read_file(file_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.groups = self.group_data()
        self.group_ids = [i for i, _ in enumerate(self.groups)]
        
        # Pre-compute maximum group size
        self.max_group_size = max(len(group) for group in self.groups)  # Cap at 10
        print(f"Max group size: {self.max_group_size}")
        print(f"unique group sizes = {set([len(group) for group in self.groups])}")



    def read_file(self, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            reader = csv.reader(f, delimiter='\t')
            for row in reader:
                if len(row) != 4:
                    continue
                sentence1, sentence2, score_str, url = row
                try:
                    score = int(score_str.strip('()').split(',')[0])
                except:
                    continue
                data.append((sentence1.strip(), sentence2.strip(), score))
        return data


    def group_data(self):
        groups = []
        current_group = []
        last_sentence1 = None
        for item in self.data:
            sentence1, sentence2, score = item
            if sentence1 != last_sentence1 and current_group:
                groups.append(current_group)
                current_group = []
            current_group.append(item)
            last_sentence1 = sentence1
        if current_group:
            groups.append(current_group)
        return groups

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        group = self.groups[idx]
        anchor_text = group[0][0]
        
        # Pad group to max_group_size with dummy entries if needed
        padded_group = group + [("", "", 0)] * (self.max_group_size - len(group))
        
        sentences = [item[1] for item in padded_group]
        scores = torch.tensor([item[2] for item in padded_group], dtype=torch.float32)
        
        anchor_input = self.tokenizer(
            anchor_text, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors="pt",
            return_token_type_ids=False  # Add this line
        )
        
        sentence_inputs = [
            self.tokenizer(
                s, 
                padding='max_length', 
                truncation=True, 
                max_length=self.max_length, 
                return_tensors="pt",
                return_token_type_ids=False  # Add this line
            ) for s in sentences
        ]
        
        return {
            'anchor_input': {k: v.squeeze(0) for k, v in anchor_input.items()},
            'sentence_inputs': [{k: v.squeeze(0) for k, v in inp.items()} for inp in sentence_inputs],
            'scores': scores,
            'mask': torch.tensor([1]*len(group) + [0]*(self.max_group_size - len(group)))  # Mask for real vs padded
        }


# -----------------------------
# 2. Model
# -----------------------------
def poincare_project(x, eps=1e-5):
    """Projects vectors into Poincaré ball with proper scaling"""
    norm = torch.norm(x, p=2, dim=-1, keepdim=True)
    scale = (1 - eps) / torch.clamp(norm, min=eps)
    return x * scale
    
class HyperbolicMapper(nn.Module):
    def __init__(self, sbert_model_name='sentence-transformers/all-MiniLM-L6-v2', output_dim=32):
        super(HyperbolicMapper, self).__init__()
        # Frozen SBERT
        self.sbert = AutoModel.from_pretrained(sbert_model_name)
        for param in self.sbert.parameters():
            param.requires_grad = False
        
        sbert_hidden_dim = self.sbert.config.hidden_size  # usually 384 for MiniLM
        self.fc = nn.Linear(sbert_hidden_dim, output_dim)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            sbert_output = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embedding = sbert_output.last_hidden_state[:, 0]
        
        mapped = self.fc(cls_embedding)
        return poincare_project(mapped)  # Project after mapping

# -----------------------------
# 3. Margin Ranking Loss and Poincare distance
# -----------------------------

# margin_ranking_loss = nn.MarginRankingLoss(margin=0.2)
def hyperbolic_margin_loss(anchor, positive, negative, margin=0.2):
    pos_dist = poincare_distance(anchor, positive)
    neg_dist = poincare_distance(anchor, negative)
    return F.relu(pos_dist - neg_dist + margin).mean()

def poincare_distance(x, y, eps=1e-5):
    # x: [batch_size, dim]
    # y: [batch_size, dim]
    x2 = torch.sum(x * x, dim=-1, keepdim=True)  # [batch_size, 1]
    y2 = torch.sum(y * y, dim=-1, keepdim=True)  # [batch_size, 1]
    xy = torch.sum(x * y, dim=-1, keepdim=True)  # [batch_size, 1]

    num = torch.norm(x - y, p=2, dim=-1, keepdim=True) ** 2
    denom = (1 - x2) * (1 - y2) + eps

    inside = 1 + 2 * num / denom
    dist = torch.acosh(torch.clamp(inside, min=1 + eps))
    return dist.squeeze(-1)  # [batch_size]
    
# -----------------------------
# 5. Training Step
# -----------------------------

def train_one_epoch(train_loader, model, optimizer, device):
    model.train()
    total_loss = 0
    
    # Wrap your loader with tqdm for a progress bar
    prog_bar = tqdm(train_loader, desc="Training", leave=False)
    
    for batch in prog_bar:
        optimizer.zero_grad()
        
        anchor_input = {k: v.to(device) for k, v in batch['anchor_input'].items()}
        anchor_embed = model(**anchor_input)
        
        sentence_inputs = {
            k: v.to(device) for k, v in batch['sentence_inputs'].items()
        }
        sentence_embeds = model(**sentence_inputs)
        
        batch_size, max_group_size = batch['scores'].shape
        sentence_embeds = sentence_embeds.view(batch_size, max_group_size, -1)
        
        dists = poincare_distance(
            anchor_embed.unsqueeze(1).expand_as(sentence_embeds),
            sentence_embeds
        ) * batch['mask'].to(device)
        
        # Flip distance -> similarity
        sims = -dists
        
        # Vectorized ranking loss calculation
        scores = batch['scores'].to(device)
        mask = batch['mask'].to(device).bool()
        
        # Create all valid pairs
        pos_mask = (scores.unsqueeze(2) > scores.unsqueeze(1)) & mask.unsqueeze(2) & mask.unsqueeze(1)
        
        if pos_mask.any():
            pos_sims = sims.unsqueeze(2).expand(-1, -1, max_group_size)[pos_mask]
            neg_sims = sims.unsqueeze(1).expand(-1, max_group_size, -1)[pos_mask]
            
            loss = F.margin_ranking_loss(
                pos_sims,
                neg_sims,
                torch.ones_like(pos_sims),
                margin=0.2,
                reduction='mean'
            )
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * pos_mask.sum().item()  # Weight by batch size
    
    total_samples = len(train_loader.dataset)
    return total_loss / total_samples if total_samples > 0 else 0


def validate(loader, model, device):
    model.eval()
    total_loss = 0
    all_scores = []
    all_sims = []
    
    with torch.no_grad():
        for batch in loader:
            anchor_input = {k: v.to(device) for k, v in batch['anchor_input'].items()}
            anchor_embed = model(**anchor_input)
            
            # Process all sentences (with masking)
            sentence_inputs = {
                k: v.to(device) for k, v in batch['sentence_inputs'].items()
            }
            sentence_embeds = model(**sentence_inputs)
            
            # Reshape and mask
            batch_size = len(batch['scores'])
            max_group_size = batch['scores'].shape[1]
            sentence_embeds = sentence_embeds.view(batch_size, max_group_size, -1)
            mask = batch['mask'].to(device)
            
            # Compute similarities -> Now poincare distance
            dists = poincare_distance(
                anchor_embed.unsqueeze(1).expand_as(sentence_embeds),
                sentence_embeds
            ) * batch['mask'].to(device)
            # Flip distance -> similarity
            sims = -dists
            
            # Collect non-padded scores/similarities for correlation
            scores = batch['scores'].to(device)
            for i in range(batch_size):
                real_indices = torch.where(mask[i] == 1)[0]
                all_scores.extend(scores[i, real_indices].cpu().numpy())
                all_sims.extend(sims[i, real_indices].cpu().numpy())
    
    # Compute Spearman correlation
    if len(all_scores) > 0:
        spearman = spearmanr(all_scores, all_sims).correlation
    else:
        spearman = 0.0
    
    return -spearman  # Lower is better (for consistency with loss)


def collate_fn(batch):
    # Find actual max group size in this batch
    batch_max_len = max(len(item['mask'][item['mask'] == 1]) for item in batch)
    
    # Truncate all inputs to this length
    anchor_inputs = {
        k: torch.stack([item['anchor_input'][k] for item in batch])
        for k in batch[0]['anchor_input']
    }
    
    sentence_inputs = {
        k: torch.stack([
            inp[k] 
            for item in batch 
            for inp in item['sentence_inputs'][:batch_max_len]  # Only take needed elements
        ])
        for k in batch[0]['sentence_inputs'][0]
    }
    
    scores = torch.stack([item['scores'][:batch_max_len] for item in batch])
    masks = torch.stack([item['mask'][:batch_max_len] for item in batch])
    
    return {
        'anchor_input': anchor_inputs,
        'sentence_inputs': sentence_inputs,
        'scores': scores,
        'mask': masks
    }


tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')


def group_aware_split(dataset, test_size=0.1, random_state=42):
    splitter = GroupShuffleSplit(
        n_splits=1, 
        test_size=test_size, 
        random_state=random_state
    )
    train_idx, val_idx = next(splitter.split(
        range(len(dataset)), 
        groups=dataset.group_ids  # Critical: Group by anchor IDs
    ))
    return (
        torch.utils.data.Subset(dataset, train_idx),
        torch.utils.data.Subset(dataset, val_idx)
    )

# Load datasets
full_train_dataset = TwitterURLDataset(f'{base_path}paraphrase_dataset_emnlp2017/Twitter_URL_Corpus_train.txt', tokenizer)
test_dataset = TwitterURLDataset(f'{base_path}paraphrase_dataset_emnlp2017/Twitter_URL_Corpus_test.txt', tokenizer)

# Group-aware split (90% train, 10% val)
train_dataset, val_dataset = group_aware_split(full_train_dataset, test_size=0.1)

# For all splits (train/val/test), use the same collate_fn
train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    collate_fn=collate_fn  # Same as training
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=32, 
    shuffle=False,  # No shuffling for val/test
    collate_fn=collate_fn  # Same collate function!
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=32, 
    shuffle=False,
    collate_fn=collate_fn  # Consistency is key
)

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

# Rest of training logic (same as before)
model = HyperbolicMapper().to(device)
# Replace your optimizer with:
optimizer = geoopt.optim.RiemannianAdam(
    model.parameters(), 
    lr=1e-4,
    stabilize=1000  # Helps with numerical stability
)
num_epochs = 6

test_spearman = -validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman Initially: {test_spearman:.4f}")

best_val_loss = float('inf')
# for epoch in range(num_epochs):
#     print(f"Epoch {epoch+1}/{num_epochs}")

#     train_loss = train_one_epoch(train_loader, model, optimizer, device)
#     val_loss = validate(val_loader, model, device)

#     print(f"Epoch {epoch+1}/{num_epochs}")
#     print(f"  Train Loss: {train_loss:.4f}")
#     print(f"  Val Spearman (-ve):   {val_loss:.4f}")

#     if val_loss < best_val_loss:
#         best_val_loss = val_loss
#         torch.save(model.state_dict(), 'best_model.pt')
#         print(f"  -> Best model saved (val_loss improved to {best_val_loss:.4f})")

# Final test
model.load_state_dict(torch.load('best_model.pt'))
test_spearman = -validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman: {test_spearman:.4f}")

Max group size: 20
unique group sizes = {8, 9, 10, 20}
Max group size: 20
unique group sizes = {10, 20}
Train size: 4197
Val size: 467
Test size: 1010
Test Spearman Initially: 0.5887


  model.load_state_dict(torch.load('best_model.pt'))


Test Spearman: 0.6650


In [4]:
# Final test
best_val_loss = -0.7197
model.load_state_dict(torch.load('best_model.pt'))
test_spearman = -validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman: {test_spearman:.4f}")

for I in range(6):
    for epoch in range(num_epochs):
        print(f"{I}: Epoch {epoch+1}/{num_epochs}")
    
        train_loss = train_one_epoch(train_loader, model, optimizer, device)
        val_loss = validate(val_loader, model, device)
    
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Spearman (-ve):   {val_loss:.4f}")
    
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pt')
            print(f"  -> Best model saved (val_loss improved to {best_val_loss:.4f})")

    # Final test
    model.load_state_dict(torch.load('best_model.pt'))
    test_spearman = -validate(test_loader, model, device)  # Flip sign back
    print(f"Test Spearman: {test_spearman:.4f}")

  model.load_state_dict(torch.load('best_model.pt'))


Test Spearman: 0.6650
0: Epoch 1/6


                                                                                                                                     

Epoch 1/6
  Train Loss: 3.7808
  Val Spearman (-ve):   -0.7214
  -> Best model saved (val_loss improved to -0.7214)
0: Epoch 2/6


                                                                                                                                     

Epoch 2/6
  Train Loss: 3.7311
  Val Spearman (-ve):   -0.7237
  -> Best model saved (val_loss improved to -0.7237)
0: Epoch 3/6


                                                                                                                                     

Epoch 3/6
  Train Loss: 3.6744
  Val Spearman (-ve):   -0.7242
  -> Best model saved (val_loss improved to -0.7242)
0: Epoch 4/6


                                                                                                                                     

Epoch 4/6
  Train Loss: 3.6385
  Val Spearman (-ve):   -0.7251
  -> Best model saved (val_loss improved to -0.7251)
0: Epoch 5/6


                                                                                                                                     

Epoch 5/6
  Train Loss: 3.6196
  Val Spearman (-ve):   -0.7273
  -> Best model saved (val_loss improved to -0.7273)
0: Epoch 6/6


                                                                                                                                     

Epoch 6/6
  Train Loss: 3.5654
  Val Spearman (-ve):   -0.7272


  model.load_state_dict(torch.load('best_model.pt'))


Test Spearman: 0.6712
1: Epoch 1/6


                                                                                                                                     

Epoch 1/6
  Train Loss: 3.6165
  Val Spearman (-ve):   -0.7274
  -> Best model saved (val_loss improved to -0.7274)
1: Epoch 2/6


                                                                                                                                     

Epoch 2/6
  Train Loss: 3.5368
  Val Spearman (-ve):   -0.7299
  -> Best model saved (val_loss improved to -0.7299)
1: Epoch 3/6


                                                                                                                                     

Epoch 3/6
  Train Loss: 3.5419
  Val Spearman (-ve):   -0.7307
  -> Best model saved (val_loss improved to -0.7307)
1: Epoch 4/6


                                                                                                                                     

Epoch 4/6
  Train Loss: 3.4543
  Val Spearman (-ve):   -0.7318
  -> Best model saved (val_loss improved to -0.7318)
1: Epoch 5/6


                                                                                                                                     

Epoch 5/6
  Train Loss: 3.4758
  Val Spearman (-ve):   -0.7313
1: Epoch 6/6


                                                                                                                                     

Epoch 6/6
  Train Loss: 3.4031
  Val Spearman (-ve):   -0.7328
  -> Best model saved (val_loss improved to -0.7328)
Test Spearman: 0.6743
2: Epoch 1/6


                                                                                                                                     

Epoch 1/6
  Train Loss: 3.3601
  Val Spearman (-ve):   -0.7332
  -> Best model saved (val_loss improved to -0.7332)
2: Epoch 2/6


                                                                                                                                     

Epoch 2/6
  Train Loss: 3.3466
  Val Spearman (-ve):   -0.7330
2: Epoch 3/6


                                                                                                                                     

Epoch 3/6
  Train Loss: 3.3898
  Val Spearman (-ve):   -0.7337
  -> Best model saved (val_loss improved to -0.7337)
2: Epoch 4/6


                                                                                                                                     

Epoch 4/6
  Train Loss: 3.3286
  Val Spearman (-ve):   -0.7342
  -> Best model saved (val_loss improved to -0.7342)
2: Epoch 5/6


                                                                                                                                     

Epoch 5/6
  Train Loss: 3.2905
  Val Spearman (-ve):   -0.7348
  -> Best model saved (val_loss improved to -0.7348)
2: Epoch 6/6


                                                                                                                                     

Epoch 6/6
  Train Loss: 3.2764
  Val Spearman (-ve):   -0.7349
  -> Best model saved (val_loss improved to -0.7349)
Test Spearman: 0.6762
3: Epoch 1/6


                                                                                                                                     

Epoch 1/6
  Train Loss: 3.2371
  Val Spearman (-ve):   -0.7354
  -> Best model saved (val_loss improved to -0.7354)
3: Epoch 2/6


                                                                                                                                     

Epoch 2/6
  Train Loss: 3.2320
  Val Spearman (-ve):   -0.7367
  -> Best model saved (val_loss improved to -0.7367)
3: Epoch 3/6


                                                                                                                                     

Epoch 3/6
  Train Loss: 3.2431
  Val Spearman (-ve):   -0.7374
  -> Best model saved (val_loss improved to -0.7374)
3: Epoch 4/6


                                                                                                                                     

Epoch 4/6
  Train Loss: 3.2484
  Val Spearman (-ve):   -0.7363
3: Epoch 5/6


                                                                                                                                     

Epoch 5/6
  Train Loss: 3.1820
  Val Spearman (-ve):   -0.7368
3: Epoch 6/6


                                                                                                                                     

Epoch 6/6
  Train Loss: 3.1958
  Val Spearman (-ve):   -0.7360
Test Spearman: 0.6756
4: Epoch 1/6


                                                                                                                                     

Epoch 1/6
  Train Loss: 3.2269
  Val Spearman (-ve):   -0.7372
4: Epoch 2/6


                                                                                                                                     

Epoch 2/6
  Train Loss: 3.2172
  Val Spearman (-ve):   -0.7365
4: Epoch 3/6


                                                                                                                                     

Epoch 3/6
  Train Loss: 3.2118
  Val Spearman (-ve):   -0.7361
4: Epoch 4/6


                                                                                                                                     

Epoch 4/6
  Train Loss: 3.2158
  Val Spearman (-ve):   -0.7360
4: Epoch 5/6


                                                                                                                                     

Epoch 5/6
  Train Loss: 3.1290
  Val Spearman (-ve):   -0.7362
4: Epoch 6/6


                                                                                                                                     

Epoch 6/6
  Train Loss: 3.1328
  Val Spearman (-ve):   -0.7357
Test Spearman: 0.6756
5: Epoch 1/6


                                                                                                                                     

Epoch 1/6
  Train Loss: 3.2644
  Val Spearman (-ve):   -0.7368
5: Epoch 2/6


                                                                                                                                     

Epoch 2/6
  Train Loss: 3.2053
  Val Spearman (-ve):   -0.7365
5: Epoch 3/6


                                                                                                                                     

Epoch 3/6
  Train Loss: 3.1543
  Val Spearman (-ve):   -0.7360
5: Epoch 4/6


                                                                                                                                     

Epoch 4/6
  Train Loss: 3.1802
  Val Spearman (-ve):   -0.7371
5: Epoch 5/6


                                                                                                                                     

Epoch 5/6
  Train Loss: 3.1451
  Val Spearman (-ve):   -0.7368
5: Epoch 6/6


                                                                                                                                     

Epoch 6/6
  Train Loss: 3.1531
  Val Spearman (-ve):   -0.7372
Test Spearman: 0.6756


In [5]:
num_epochs = 18
for epoch in range(num_epochs):
    print(f"{I}: Epoch {epoch+1}/{num_epochs}")

    train_loss = train_one_epoch(train_loader, model, optimizer, device)
    val_loss = validate(val_loader, model, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Spearman (-ve):   {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pt')
        print(f"  -> Best model saved (val_loss improved to {best_val_loss:.4f})")

5: Epoch 1/18


                                                                                                                                     

Epoch 1/18
  Train Loss: 3.2227
  Val Spearman (-ve):   -0.7352
5: Epoch 2/18


                                                                                                                                     

Epoch 2/18
  Train Loss: 3.2131
  Val Spearman (-ve):   -0.7360
5: Epoch 3/18


                                                                                                                                     

Epoch 3/18
  Train Loss: 3.1717
  Val Spearman (-ve):   -0.7360
5: Epoch 4/18


                                                                                                                                     

Epoch 4/18
  Train Loss: 3.2003
  Val Spearman (-ve):   -0.7363
5: Epoch 5/18


                                                                                                                                     

Epoch 5/18
  Train Loss: 3.1738
  Val Spearman (-ve):   -0.7358
5: Epoch 6/18


                                                                                                                                     

Epoch 6/18
  Train Loss: 3.1661
  Val Spearman (-ve):   -0.7367
5: Epoch 7/18


                                                                                                                                     

Epoch 7/18
  Train Loss: 3.1481
  Val Spearman (-ve):   -0.7379
  -> Best model saved (val_loss improved to -0.7379)
5: Epoch 8/18


                                                                                                                                     

Epoch 8/18
  Train Loss: 3.1318
  Val Spearman (-ve):   -0.7377
5: Epoch 9/18


                                                                                                                                     

Epoch 9/18
  Train Loss: 3.1192
  Val Spearman (-ve):   -0.7365
5: Epoch 10/18


                                                                                                                                     

Epoch 10/18
  Train Loss: 3.1073
  Val Spearman (-ve):   -0.7359
5: Epoch 11/18


                                                                                                                                     

Epoch 11/18
  Train Loss: 3.0807
  Val Spearman (-ve):   -0.7371
5: Epoch 12/18


                                                                                                                                     

Epoch 12/18
  Train Loss: 3.0904
  Val Spearman (-ve):   -0.7373
5: Epoch 13/18


                                                                                                                                     

Epoch 13/18
  Train Loss: 3.0618
  Val Spearman (-ve):   -0.7382
  -> Best model saved (val_loss improved to -0.7382)
5: Epoch 14/18


                                                                                                                                     

Epoch 14/18
  Train Loss: 3.1289
  Val Spearman (-ve):   -0.7380
5: Epoch 15/18


                                                                                                                                     

Epoch 15/18
  Train Loss: 3.0577
  Val Spearman (-ve):   -0.7374
5: Epoch 16/18


                                                                                                                                     

Epoch 16/18
  Train Loss: 3.0397
  Val Spearman (-ve):   -0.7378
5: Epoch 17/18


                                                                                                                                     

Epoch 17/18
  Train Loss: 3.0289
  Val Spearman (-ve):   -0.7367
5: Epoch 18/18


                                                                                                                                     

Epoch 18/18
  Train Loss: 3.0285
  Val Spearman (-ve):   -0.7375


In [6]:
test_spearman = -validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman: {test_spearman:.4f}")
num_epochs = 100
for epoch in range(num_epochs):
    print(f"{I}: Epoch {epoch+1}/{num_epochs}")

    train_loss = train_one_epoch(train_loader, model, optimizer, device)
    val_loss = validate(val_loader, model, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Spearman (-ve):   {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pt')
        print(f"  -> Best model saved (val_loss improved to {best_val_loss:.4f})")
    if epoch%10==0:
        test_spearman = -validate(test_loader, model, device)  # Flip sign back
        print(f"Test Spearman: {test_spearman:.4f}")

Test Spearman: 0.6783
5: Epoch 1/100


                                                                                                                                     

Epoch 1/100
  Train Loss: 3.0042
  Val Spearman (-ve):   -0.7368
Test Spearman: 0.6787
5: Epoch 2/100


                                                                                                                                     

Epoch 2/100
  Train Loss: 2.9925
  Val Spearman (-ve):   -0.7371
5: Epoch 3/100


                                                                                                                                     

Epoch 3/100
  Train Loss: 2.9761
  Val Spearman (-ve):   -0.7364
5: Epoch 4/100


                                                                                                                                     

Epoch 4/100
  Train Loss: 2.9997
  Val Spearman (-ve):   -0.7362
5: Epoch 5/100


                                                                                                                                     

Epoch 5/100
  Train Loss: 2.9472
  Val Spearman (-ve):   -0.7375
5: Epoch 6/100


                                                                                                                                     

Epoch 6/100
  Train Loss: 3.0298
  Val Spearman (-ve):   -0.7368
5: Epoch 7/100


                                                                                                                                     

Epoch 7/100
  Train Loss: 2.9733
  Val Spearman (-ve):   -0.7371
5: Epoch 8/100


                                                                                                                                     

Epoch 8/100
  Train Loss: 3.0119
  Val Spearman (-ve):   -0.7376
5: Epoch 9/100


                                                                                                                                     

Epoch 9/100
  Train Loss: 2.9781
  Val Spearman (-ve):   -0.7384
  -> Best model saved (val_loss improved to -0.7384)
5: Epoch 10/100


                                                                                                                                     

Epoch 10/100
  Train Loss: 2.9941
  Val Spearman (-ve):   -0.7377
5: Epoch 11/100


                                                                                                                                     

Epoch 11/100
  Train Loss: 2.9555
  Val Spearman (-ve):   -0.7381
Test Spearman: 0.6808
5: Epoch 12/100


                                                                                                                                     

Epoch 12/100
  Train Loss: 2.9179
  Val Spearman (-ve):   -0.7388
  -> Best model saved (val_loss improved to -0.7388)
5: Epoch 13/100


                                                                                                                                     

Epoch 13/100
  Train Loss: 2.9163
  Val Spearman (-ve):   -0.7369
5: Epoch 14/100


                                                                                                                                     

Epoch 14/100
  Train Loss: 2.9383
  Val Spearman (-ve):   -0.7376
5: Epoch 15/100


                                                                                                                                     

Epoch 15/100
  Train Loss: 2.9515
  Val Spearman (-ve):   -0.7386
5: Epoch 16/100


                                                                                                                                     

Epoch 16/100
  Train Loss: 2.8999
  Val Spearman (-ve):   -0.7379
5: Epoch 17/100


                                                                                                                                     

Epoch 17/100
  Train Loss: 2.9351
  Val Spearman (-ve):   -0.7385
5: Epoch 18/100


                                                                                                                                     

Epoch 18/100
  Train Loss: 2.9076
  Val Spearman (-ve):   -0.7381
5: Epoch 19/100


                                                                                                                                     

Epoch 19/100
  Train Loss: 2.8981
  Val Spearman (-ve):   -0.7376
5: Epoch 20/100


                                                                                                                                     

Epoch 20/100
  Train Loss: 2.8706
  Val Spearman (-ve):   -0.7375
5: Epoch 21/100


                                                                                                                                     

Epoch 21/100
  Train Loss: 2.8879
  Val Spearman (-ve):   -0.7378
Test Spearman: 0.6806
5: Epoch 22/100


                                                                                                                                     

Epoch 22/100
  Train Loss: 2.8773
  Val Spearman (-ve):   -0.7374
5: Epoch 23/100


                                                                                                                                     

Epoch 23/100
  Train Loss: 2.8596
  Val Spearman (-ve):   -0.7383
5: Epoch 24/100


                                                                                                                                     

Epoch 24/100
  Train Loss: 2.8694
  Val Spearman (-ve):   -0.7384
5: Epoch 25/100


                                                                                                                                     

Epoch 25/100
  Train Loss: 2.8802
  Val Spearman (-ve):   -0.7383
5: Epoch 26/100


                                                                                                                                     

Epoch 26/100
  Train Loss: 2.8718
  Val Spearman (-ve):   -0.7389
  -> Best model saved (val_loss improved to -0.7389)
5: Epoch 27/100


                                                                                                                                     

Epoch 27/100
  Train Loss: 2.8664
  Val Spearman (-ve):   -0.7388
5: Epoch 28/100


                                                                                                                                     

Epoch 28/100
  Train Loss: 2.8622
  Val Spearman (-ve):   -0.7395
  -> Best model saved (val_loss improved to -0.7395)
5: Epoch 29/100


                                                                                                                                     

Epoch 29/100
  Train Loss: 2.8899
  Val Spearman (-ve):   -0.7401
  -> Best model saved (val_loss improved to -0.7401)
5: Epoch 30/100


                                                                                                                                     

Epoch 30/100
  Train Loss: 2.8253
  Val Spearman (-ve):   -0.7388
5: Epoch 31/100


                                                                                                                                     

Epoch 31/100
  Train Loss: 2.8440
  Val Spearman (-ve):   -0.7390
Test Spearman: 0.6820
5: Epoch 32/100


                                                                                                                                     

Epoch 32/100
  Train Loss: 2.8281
  Val Spearman (-ve):   -0.7386
5: Epoch 33/100


                                                                                                                                     

Epoch 33/100
  Train Loss: 2.8327
  Val Spearman (-ve):   -0.7374
5: Epoch 34/100


                                                                                                                                     

Epoch 34/100
  Train Loss: 2.8510
  Val Spearman (-ve):   -0.7393
5: Epoch 35/100


                                                                                                                                     

Epoch 35/100
  Train Loss: 2.8230
  Val Spearman (-ve):   -0.7377
5: Epoch 36/100


                                                                                                                                     

Epoch 36/100
  Train Loss: 2.8454
  Val Spearman (-ve):   -0.7388
5: Epoch 37/100


                                                                                                                                     

Epoch 37/100
  Train Loss: 2.8828
  Val Spearman (-ve):   -0.7385
5: Epoch 38/100


                                                                                                                                     

Epoch 38/100
  Train Loss: 2.8074
  Val Spearman (-ve):   -0.7385
5: Epoch 39/100


                                                                                                                                     

Epoch 39/100
  Train Loss: 2.8068
  Val Spearman (-ve):   -0.7385
5: Epoch 40/100


                                                                                                                                     

Epoch 40/100
  Train Loss: 2.7946
  Val Spearman (-ve):   -0.7393
5: Epoch 41/100


                                                                                                                                     

Epoch 41/100
  Train Loss: 2.8209
  Val Spearman (-ve):   -0.7404
  -> Best model saved (val_loss improved to -0.7404)
Test Spearman: 0.6821
5: Epoch 42/100


                                                                                                                                     

Epoch 42/100
  Train Loss: 2.8501
  Val Spearman (-ve):   -0.7406
  -> Best model saved (val_loss improved to -0.7406)
5: Epoch 43/100


                                                                                                                                     

Epoch 43/100
  Train Loss: 2.8089
  Val Spearman (-ve):   -0.7401
5: Epoch 44/100


                                                                                                                                     

Epoch 44/100
  Train Loss: 2.8131
  Val Spearman (-ve):   -0.7403
5: Epoch 45/100


                                                                                                                                     

Epoch 45/100
  Train Loss: 2.8199
  Val Spearman (-ve):   -0.7408
  -> Best model saved (val_loss improved to -0.7408)
5: Epoch 46/100


                                                                                                                                     

Epoch 46/100
  Train Loss: 2.7765
  Val Spearman (-ve):   -0.7407
5: Epoch 47/100


                                                                                                                                     

Epoch 47/100
  Train Loss: 2.8051
  Val Spearman (-ve):   -0.7400
5: Epoch 48/100


                                                                                                                                     

Epoch 48/100
  Train Loss: 2.7750
  Val Spearman (-ve):   -0.7399
5: Epoch 49/100


                                                                                                                                     

Epoch 49/100
  Train Loss: 2.7985
  Val Spearman (-ve):   -0.7388
5: Epoch 50/100


                                                                                                                                     

Epoch 50/100
  Train Loss: 2.7706
  Val Spearman (-ve):   -0.7395
5: Epoch 51/100


                                                                                                                                     

Epoch 51/100
  Train Loss: 2.7728
  Val Spearman (-ve):   -0.7408
Test Spearman: 0.6827
5: Epoch 52/100


                                                                                                                                     

Epoch 52/100
  Train Loss: 2.7861
  Val Spearman (-ve):   -0.7405
5: Epoch 53/100


                                                                                                                                     

Epoch 53/100
  Train Loss: 2.7565
  Val Spearman (-ve):   -0.7402
5: Epoch 54/100


                                                                                                                                     

Epoch 54/100
  Train Loss: 2.7652
  Val Spearman (-ve):   -0.7404
5: Epoch 55/100


                                                                                                                                     

Epoch 55/100
  Train Loss: 2.7654
  Val Spearman (-ve):   -0.7403
5: Epoch 56/100


                                                                                                                                     

Epoch 56/100
  Train Loss: 2.7598
  Val Spearman (-ve):   -0.7412
  -> Best model saved (val_loss improved to -0.7412)
5: Epoch 57/100


                                                                                                                                     

Epoch 57/100
  Train Loss: 2.7380
  Val Spearman (-ve):   -0.7416
  -> Best model saved (val_loss improved to -0.7416)
5: Epoch 58/100


                                                                                                                                     

Epoch 58/100
  Train Loss: 2.7878
  Val Spearman (-ve):   -0.7414
5: Epoch 59/100


                                                                                                                                     

Epoch 59/100
  Train Loss: 2.7682
  Val Spearman (-ve):   -0.7398
5: Epoch 60/100


                                                                                                                                     

Epoch 60/100
  Train Loss: 2.7478
  Val Spearman (-ve):   -0.7412
5: Epoch 61/100


                                                                                                                                     

Epoch 61/100
  Train Loss: 2.7545
  Val Spearman (-ve):   -0.7408
Test Spearman: 0.6827
5: Epoch 62/100


                                                                                                                                     

Epoch 62/100
  Train Loss: 2.7704
  Val Spearman (-ve):   -0.7404
5: Epoch 63/100


                                                                                                                                     

Epoch 63/100
  Train Loss: 2.7677
  Val Spearman (-ve):   -0.7411
5: Epoch 64/100


                                                                                                                                     

Epoch 64/100
  Train Loss: 2.8002
  Val Spearman (-ve):   -0.7413
5: Epoch 65/100


                                                                                                                                     

Epoch 65/100
  Train Loss: 2.7091
  Val Spearman (-ve):   -0.7427
  -> Best model saved (val_loss improved to -0.7427)
5: Epoch 66/100


                                                                                                                                     

Epoch 66/100
  Train Loss: 2.7785
  Val Spearman (-ve):   -0.7419
5: Epoch 67/100


                                                                                                                                     

Epoch 67/100
  Train Loss: 2.7324
  Val Spearman (-ve):   -0.7429
  -> Best model saved (val_loss improved to -0.7429)
5: Epoch 68/100


                                                                                                                                     

Epoch 68/100
  Train Loss: 2.7359
  Val Spearman (-ve):   -0.7433
  -> Best model saved (val_loss improved to -0.7433)
5: Epoch 69/100


                                                                                                                                     

Epoch 69/100
  Train Loss: 2.7707
  Val Spearman (-ve):   -0.7437
  -> Best model saved (val_loss improved to -0.7437)
5: Epoch 70/100


                                                                                                                                     

Epoch 70/100
  Train Loss: 2.7152
  Val Spearman (-ve):   -0.7424
5: Epoch 71/100


                                                                                                                                     

Epoch 71/100
  Train Loss: 2.7589
  Val Spearman (-ve):   -0.7432
Test Spearman: 0.6819
5: Epoch 72/100


                                                                                                                                     

Epoch 72/100
  Train Loss: 2.7117
  Val Spearman (-ve):   -0.7426
5: Epoch 73/100


                                                                                                                                     

Epoch 73/100
  Train Loss: 2.7339
  Val Spearman (-ve):   -0.7420
5: Epoch 74/100


                                                                                                                                     

Epoch 74/100
  Train Loss: 2.7166
  Val Spearman (-ve):   -0.7408
5: Epoch 75/100


                                                                                                                                     

Epoch 75/100
  Train Loss: 2.7735
  Val Spearman (-ve):   -0.7423
5: Epoch 76/100


                                                                                                                                     

Epoch 76/100
  Train Loss: 2.7286
  Val Spearman (-ve):   -0.7423
5: Epoch 77/100


                                                                                                                                     

Epoch 77/100
  Train Loss: 2.7308
  Val Spearman (-ve):   -0.7415
5: Epoch 78/100


                                                                                                                                     

Epoch 78/100
  Train Loss: 2.7526
  Val Spearman (-ve):   -0.7415
5: Epoch 79/100


                                                                                                                                     

Epoch 79/100
  Train Loss: 2.7198
  Val Spearman (-ve):   -0.7408
5: Epoch 80/100


                                                                                                                                     

Epoch 80/100
  Train Loss: 2.7295
  Val Spearman (-ve):   -0.7410
5: Epoch 81/100


                                                                                                                                     

Epoch 81/100
  Train Loss: 2.7152
  Val Spearman (-ve):   -0.7412
Test Spearman: 0.6819
5: Epoch 82/100


                                                                                                                                     

Epoch 82/100
  Train Loss: 2.7212
  Val Spearman (-ve):   -0.7428
5: Epoch 83/100


                                                                                                                                     

Epoch 83/100
  Train Loss: 2.6852
  Val Spearman (-ve):   -0.7425
5: Epoch 84/100


                                                                                                                                     

Epoch 84/100
  Train Loss: 2.7074
  Val Spearman (-ve):   -0.7414
5: Epoch 85/100


                                                                                                                                     

Epoch 85/100
  Train Loss: 2.7421
  Val Spearman (-ve):   -0.7419
5: Epoch 86/100


                                                                                                                                     

Epoch 86/100
  Train Loss: 2.6879
  Val Spearman (-ve):   -0.7422
5: Epoch 87/100


                                                                                                                                     

Epoch 87/100
  Train Loss: 2.7176
  Val Spearman (-ve):   -0.7425
5: Epoch 88/100


                                                                                                                                     

Epoch 88/100
  Train Loss: 2.7184
  Val Spearman (-ve):   -0.7428
5: Epoch 89/100


                                                                                                                                     

Epoch 89/100
  Train Loss: 2.7037
  Val Spearman (-ve):   -0.7416
5: Epoch 90/100


                                                                                                                                     

Epoch 90/100
  Train Loss: 2.7315
  Val Spearman (-ve):   -0.7410
5: Epoch 91/100


                                                                                                                                     

Epoch 91/100
  Train Loss: 2.7086
  Val Spearman (-ve):   -0.7418
Test Spearman: 0.6819
5: Epoch 92/100


                                                                                                                                     

Epoch 92/100
  Train Loss: 2.7000
  Val Spearman (-ve):   -0.7415
5: Epoch 93/100


                                                                                                                                     

Epoch 93/100
  Train Loss: 2.7070
  Val Spearman (-ve):   -0.7413
5: Epoch 94/100


                                                                                                                                     

Epoch 94/100
  Train Loss: 2.7132
  Val Spearman (-ve):   -0.7413
5: Epoch 95/100


                                                                                                                                     

Epoch 95/100
  Train Loss: 2.7046
  Val Spearman (-ve):   -0.7431
5: Epoch 96/100


                                                                                                                                     

Epoch 96/100
  Train Loss: 2.6893
  Val Spearman (-ve):   -0.7434
5: Epoch 97/100


                                                                                                                                     

Epoch 97/100
  Train Loss: 2.7316
  Val Spearman (-ve):   -0.7425
5: Epoch 98/100


                                                                                                                                     

Epoch 98/100
  Train Loss: 2.7011
  Val Spearman (-ve):   -0.7418
5: Epoch 99/100


                                                                                                                                     

Epoch 99/100
  Train Loss: 2.6789
  Val Spearman (-ve):   -0.7414
5: Epoch 100/100


                                                                                                                                     

Epoch 100/100
  Train Loss: 2.7157
  Val Spearman (-ve):   -0.7418


In [7]:
test_spearman = -validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman: {test_spearman:.4f}")

Test Spearman: 0.6825


In [8]:
model.load_state_dict(torch.load('best_model.pt'))
test_spearman = -validate(test_loader, model, device)  # Flip sign back
print(f"Test Spearman: {test_spearman:.4f}")

  model.load_state_dict(torch.load('best_model.pt'))


Test Spearman: 0.6818
