In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import csv
import os
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import random_split
from sklearn.model_selection import GroupShuffleSplit
from scipy.stats import spearmanr
from tqdm import tqdm  # make sure you have tqdm installed
import geoopt
from datetime import datetime
import json
# there was an error with geoopt scipy version compatibility, it was fixed by-
# nano /cronus_data/vraja/dysarthria/lib/python3.11/site-packages/geoopt/optim/rlinesearch.py
# added this to the import -
# try:
#     from scipy.optimize.linesearch import scalar_search_wolfe2, scalar_search_armijo
# except ImportError:
#     from scipy.optimize._linesearch import scalar_search_wolfe2, scalar_search_armijo
torch.manual_seed(42)  # Before random_split
base_path = ""

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# -----------------------------
# 1. Dataset Class
# -----------------------------

class TwitterURLDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.data = self.read_file(file_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.groups = self.group_data()
        self.group_ids = [i for i, _ in enumerate(self.groups)]
        
        # Pre-compute maximum group size
        self.max_group_size = max(len(group) for group in self.groups)  # Cap at 10
        print(f"Max group size: {self.max_group_size}")
        print(f"unique group sizes = {set([len(group) for group in self.groups])}")



    def read_file(self, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            reader = csv.reader(f, delimiter='\t')
            for row in reader:
                if len(row) != 4:
                    continue
                sentence1, sentence2, score_str, url = row
                try:
                    score = int(score_str.strip('()').split(',')[0])
                except:
                    continue
                data.append((sentence1.strip(), sentence2.strip(), score))
        return data


    def group_data(self):
        groups = []
        current_group = []
        last_sentence1 = None
        for item in self.data:
            sentence1, sentence2, score = item
            if sentence1 != last_sentence1 and current_group:
                groups.append(current_group)
                current_group = []
            current_group.append(item)
            last_sentence1 = sentence1
        if current_group:
            groups.append(current_group)
        return groups

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        group = self.groups[idx]
        anchor_text = group[0][0]
        
        # Pad group to max_group_size with dummy entries if needed
        padded_group = group + [("", "", 0)] * (self.max_group_size - len(group))
        
        sentences = [item[1] for item in padded_group]
        scores = torch.tensor([item[2] for item in padded_group], dtype=torch.float32)
        
        anchor_input = self.tokenizer(
            anchor_text, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors="pt",
            return_token_type_ids=False  # Add this line
        )
        
        sentence_inputs = [
            self.tokenizer(
                s, 
                padding='max_length', 
                truncation=True, 
                max_length=self.max_length, 
                return_tensors="pt",
                return_token_type_ids=False  # Add this line
            ) for s in sentences
        ]
        
        return {
            'anchor_input': {k: v.squeeze(0) for k, v in anchor_input.items()},
            'sentence_inputs': [{k: v.squeeze(0) for k, v in inp.items()} for inp in sentence_inputs],
            'scores': scores,
            'mask': torch.tensor([1]*len(group) + [0]*(self.max_group_size - len(group)))  # Mask for real vs padded
        }


# -----------------------------
# 2. Model
# -----------------------------
def poincare_project(x, eps=1e-5):
    """Projects vectors into Poincaré ball with proper scaling"""
    norm = torch.norm(x, p=2, dim=-1, keepdim=True)
    scale = (1 - eps) / torch.clamp(norm, min=eps)
    return x * scale
    
class HyperbolicMapper(nn.Module):
    def __init__(self, sbert_model_name='sentence-transformers/all-MiniLM-L6-v2', output_dim=32):
        super(HyperbolicMapper, self).__init__()
        # Frozen SBERT
        self.sbert = AutoModel.from_pretrained(sbert_model_name)
        for param in self.sbert.parameters():
            param.requires_grad = False
        
        sbert_hidden_dim = self.sbert.config.hidden_size  # usually 384 for MiniLM
        self.fc = nn.Linear(sbert_hidden_dim, output_dim)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            sbert_output = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embedding = sbert_output.last_hidden_state[:, 0]
        
        mapped = self.fc(cls_embedding)
        return poincare_project(mapped)  # Project after mapping

# -----------------------------
# 3. Margin Ranking Loss and Poincare distance
# -----------------------------

# margin_ranking_loss = nn.MarginRankingLoss(margin=0.2)
def hyperbolic_margin_loss(anchor, positive, negative, margin=0.2):
    pos_dist = poincare_distance(anchor, positive)
    neg_dist = poincare_distance(anchor, negative)
    return F.relu(pos_dist - neg_dist + margin).mean()

def poincare_distance(x, y, eps=1e-5):
    # x: [batch_size, dim]
    # y: [batch_size, dim]
    x2 = torch.sum(x * x, dim=-1, keepdim=True)  # [batch_size, 1]
    y2 = torch.sum(y * y, dim=-1, keepdim=True)  # [batch_size, 1]
    xy = torch.sum(x * y, dim=-1, keepdim=True)  # [batch_size, 1]

    num = torch.norm(x - y, p=2, dim=-1, keepdim=True) ** 2
    denom = (1 - x2) * (1 - y2) + eps

    inside = 1 + 2 * num / denom
    dist = torch.acosh(torch.clamp(inside, min=1 + eps))
    return dist.squeeze(-1)  # [batch_size]
    
# -----------------------------
# 5. Training Step
# -----------------------------

def train_one_epoch(train_loader, model, optimizer, device):
    model.train()
    total_loss = 0
    total_pairs = 0
    
    # Wrap your loader with tqdm for a progress bar
    prog_bar = tqdm(train_loader, desc="Training", leave=False)
    
    for batch in prog_bar:
        optimizer.zero_grad()
        
        anchor_input = {k: v.to(device) for k, v in batch['anchor_input'].items()}
        anchor_embed = model(**anchor_input)
        
        sentence_inputs = {
            k: v.to(device) for k, v in batch['sentence_inputs'].items()
        }
        sentence_embeds = model(**sentence_inputs)
        
        batch_size, max_group_size = batch['scores'].shape
        sentence_embeds = sentence_embeds.view(batch_size, max_group_size, -1)
        
        dists = poincare_distance(
            anchor_embed.unsqueeze(1).expand_as(sentence_embeds),
            sentence_embeds
        ) * batch['mask'].to(device)
        
        # # Flip distance -> similarity
        # sims = -dists
        
        # Vectorized ranking loss calculation
        scores = batch['scores'].to(device)
        mask = batch['mask'].to(device).bool()
        
        # Create all valid pairs
        pos_mask = (scores.unsqueeze(2) > scores.unsqueeze(1)) & mask.unsqueeze(2) & mask.unsqueeze(1)
        
        if pos_mask.any():
            # pos_sims = sims.unsqueeze(2).expand(-1, -1, max_group_size)[pos_mask]
            # neg_sims = sims.unsqueeze(1).expand(-1, max_group_size, -1)[pos_mask]
            
            # loss = F.margin_ranking_loss(
            #     pos_sims,
            #     neg_sims,
            #     torch.ones_like(pos_sims),
            #     margin=0.2,
            #     reduction='mean'
            # )
            
            # loss.backward()
            # optimizer.step()
            # total_loss += loss.item() * pos_mask.sum().item()  # Weight by batch size
            # Get indices of all valid pairs
            anchor_idx, pos_idx, neg_idx = torch.where(pos_mask)
            
            # Compute hyperbolic margin loss directly
            loss = hyperbolic_margin_loss(
                anchor_embed[anchor_idx],          # Anchor embeddings
                sentence_embeds[anchor_idx, pos_idx],  # Positive embeddings
                sentence_embeds[anchor_idx, neg_idx],  # Negative embeddings
                margin=0.2
            )
            
            loss.backward()
            optimizer.step()
            
            # Update tracking
            num_pairs = pos_mask.sum().item()
            total_loss += loss.item() * num_pairs
            total_pairs += num_pairs
            
        prog_bar.set_postfix({
            'loss': loss.item() if pos_mask.any() else 0,
            'pairs': pos_mask.sum().item() if pos_mask.any() else 0
        })
    
    total_samples = len(train_loader.dataset)
    return total_loss / total_samples if total_samples > 0 else 0


def validate(loader, model, device):
    model.eval()
    total_loss = 0
    all_scores = []
    all_sims = []
    
    with torch.no_grad():
        for batch in loader:
            anchor_input = {k: v.to(device) for k, v in batch['anchor_input'].items()}
            anchor_embed = model(**anchor_input)
            
            # Process all sentences (with masking)
            sentence_inputs = {
                k: v.to(device) for k, v in batch['sentence_inputs'].items()
            }
            sentence_embeds = model(**sentence_inputs)
            
            # Reshape and mask
            batch_size = len(batch['scores'])
            max_group_size = batch['scores'].shape[1]
            sentence_embeds = sentence_embeds.view(batch_size, max_group_size, -1)
            mask = batch['mask'].to(device)
            
            # Compute similarities -> Now poincare distance
            dists = poincare_distance(
                anchor_embed.unsqueeze(1).expand_as(sentence_embeds),
                sentence_embeds
            ) * batch['mask'].to(device)
            # Flip distance -> similarity
            sims = -dists
            
            # Collect non-padded scores/similarities for correlation
            scores = batch['scores'].to(device)
            for i in range(batch_size):
                real_indices = torch.where(mask[i] == 1)[0]
                all_scores.extend(scores[i, real_indices].cpu().numpy())
                all_sims.extend(sims[i, real_indices].cpu().numpy())
    
    # Compute Spearman correlation
    if len(all_scores) > 0:
        spearman = spearmanr(all_scores, all_sims).correlation
    else:
        spearman = 0.0
    
    return -spearman  # Lower is better (for consistency with loss)


def collate_fn(batch):
    # Find actual max group size in this batch
    batch_max_len = max(len(item['mask'][item['mask'] == 1]) for item in batch)
    
    # Truncate all inputs to this length
    anchor_inputs = {
        k: torch.stack([item['anchor_input'][k] for item in batch])
        for k in batch[0]['anchor_input']
    }
    
    sentence_inputs = {
        k: torch.stack([
            inp[k] 
            for item in batch 
            for inp in item['sentence_inputs'][:batch_max_len]  # Only take needed elements
        ])
        for k in batch[0]['sentence_inputs'][0]
    }
    
    scores = torch.stack([item['scores'][:batch_max_len] for item in batch])
    masks = torch.stack([item['mask'][:batch_max_len] for item in batch])
    
    return {
        'anchor_input': anchor_inputs,
        'sentence_inputs': sentence_inputs,
        'scores': scores,
        'mask': masks
    }


tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')


def group_aware_split(dataset, test_size=0.1, random_state=42):
    splitter = GroupShuffleSplit(
        n_splits=1, 
        test_size=test_size, 
        random_state=random_state
    )
    train_idx, val_idx = next(splitter.split(
        range(len(dataset)), 
        groups=dataset.group_ids  # Critical: Group by anchor IDs
    ))
    return (
        torch.utils.data.Subset(dataset, train_idx),
        torch.utils.data.Subset(dataset, val_idx)
    )

# Load datasets
full_train_dataset = TwitterURLDataset(f'{base_path}paraphrase_dataset_emnlp2017/Twitter_URL_Corpus_train.txt', tokenizer)
test_dataset = TwitterURLDataset(f'{base_path}paraphrase_dataset_emnlp2017/Twitter_URL_Corpus_test.txt', tokenizer)

# Group-aware split (90% train, 10% val)
train_dataset, val_dataset = group_aware_split(full_train_dataset, test_size=0.1)

# For all splits (train/val/test), use the same collate_fn
train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    collate_fn=collate_fn  # Same as training
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=32, 
    shuffle=False,  # No shuffling for val/test
    collate_fn=collate_fn  # Same collate function!
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=32, 
    shuffle=False,
    collate_fn=collate_fn  # Consistency is key
)

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")




Max group size: 20
unique group sizes = {8, 9, 10, 20}
Max group size: 20
unique group sizes = {10, 20}
Train size: 4197
Val size: 467
Test size: 1010


In [None]:
num_epochs = 50
def train_model(model, label, optimizer):
    best_val_spearman = float('inf')
    log_data = {
        'training_start': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        'config': {
            'model': label,
            'num_epochs': num_epochs,
            'optimizer': str(optimizer.__class__.__name__),
            'device': str(device)
        },
        'epochs': []
    }
    for epoch in range(num_epochs):
        # Initialize logging
        epoch_start = datetime.now()
        
        print(f"Epoch {epoch+1}/{num_epochs}")
    
        train_loss = train_one_epoch(train_loader, model, optimizer, device)
        val_spearman = validate(val_loader, model, device)

        # Track epoch data
        epoch_log = {
            'epoch': epoch + 1,
            'train_loss': float(train_loss),
            'val_spearman': float(val_spearman),
            'duration_seconds': (datetime.now() - epoch_start).total_seconds(),
            'is_best': False
        }
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Spearman (-ve):   {val_spearman:.4f}")
    
        if val_spearman < best_val_spearman:
            best_val_spearman = val_spearman
            torch.save(model.state_dict(), f'saved_models2/{label}.pt')
            epoch_log['is_best'] = True
            print(f"  -> Best model saved (val_loss improved to {best_val_spearman:.4f})")

        log_data['epochs'].append(epoch_log)
        
        # Save logs after each epoch (in case of crash)
        with open(f'model_logs2/{label}_logs.json', 'w') as f:
            json.dump(log_data, f, indent=2)
    # Add final metadata
    log_data['training_end'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    log_data['best_val_spearman'] = float(-best_val_spearman)
    
    # Final save
    with open(f'model_logs2/{label}_logs.json', 'w') as f:
        json.dump(log_data, f, indent=2)

hyp_dims = [384,192]
for dim in hyp_dims:
    model = HyperbolicMapper(output_dim = dim).to(device)
    # RiemannianAdam optimizer:
    optimizer = geoopt.optim.RiemannianAdam(
        model.parameters(), 
        lr=1e-4,
        stabilize=1000  # Helps with numerical stability
    )
    train_model(model, f"hyp_{dim}", optimizer)
    test_spearman = -validate(test_loader, model, device) 
    print(f"Test Spearman with {dim} dimensions: {test_spearman:.4f}")
    del model


Epoch 1/50


                                                                                                                                     

Epoch 1/50
  Train Loss: 3.4604
  Val Spearman (-ve):   -0.7376
  -> Best model saved (val_loss improved to -0.7376)
Epoch 2/50


                                                                                                                                     

Epoch 2/50
  Train Loss: 3.2210
  Val Spearman (-ve):   -0.7514
  -> Best model saved (val_loss improved to -0.7514)
Epoch 3/50


                                                                                                                                     

Epoch 3/50
  Train Loss: 3.0613
  Val Spearman (-ve):   -0.7561
  -> Best model saved (val_loss improved to -0.7561)
Epoch 4/50


                                                                                                                                     

Epoch 4/50
  Train Loss: 2.9748
  Val Spearman (-ve):   -0.7588
  -> Best model saved (val_loss improved to -0.7588)
Epoch 5/50


                                                                                                                                     

Epoch 5/50
  Train Loss: 2.9417
  Val Spearman (-ve):   -0.7607
  -> Best model saved (val_loss improved to -0.7607)
Epoch 6/50


                                                                                                                                     

Epoch 6/50
  Train Loss: 2.8734
  Val Spearman (-ve):   -0.7617
  -> Best model saved (val_loss improved to -0.7617)
Epoch 7/50


                                                                                                                                     

Epoch 7/50
  Train Loss: 2.8243
  Val Spearman (-ve):   -0.7637
  -> Best model saved (val_loss improved to -0.7637)
Epoch 8/50


                                                                                                                                     

Epoch 8/50
  Train Loss: 2.7903
  Val Spearman (-ve):   -0.7644
  -> Best model saved (val_loss improved to -0.7644)
Epoch 9/50


                                                                                                                                     

Epoch 9/50
  Train Loss: 2.7194
  Val Spearman (-ve):   -0.7650
  -> Best model saved (val_loss improved to -0.7650)
Epoch 10/50


                                                                                                                                     

Epoch 10/50
  Train Loss: 2.7137
  Val Spearman (-ve):   -0.7661
  -> Best model saved (val_loss improved to -0.7661)
Epoch 11/50


                                                                                                                                     

Epoch 11/50
  Train Loss: 2.7196
  Val Spearman (-ve):   -0.7681
  -> Best model saved (val_loss improved to -0.7681)
Epoch 12/50


                                                                                                                                     

Epoch 12/50
  Train Loss: 2.6753
  Val Spearman (-ve):   -0.7683
  -> Best model saved (val_loss improved to -0.7683)
Epoch 13/50


                                                                                                                                     

Epoch 13/50
  Train Loss: 2.6007
  Val Spearman (-ve):   -0.7675
Epoch 14/50


                                                                                                                                     

Epoch 14/50
  Train Loss: 2.5958
  Val Spearman (-ve):   -0.7685
  -> Best model saved (val_loss improved to -0.7685)
Epoch 15/50


                                                                                                                                     

Epoch 15/50
  Train Loss: 2.5842
  Val Spearman (-ve):   -0.7697
  -> Best model saved (val_loss improved to -0.7697)
Epoch 16/50


                                                                                                                                     

Epoch 16/50
  Train Loss: 2.5757
  Val Spearman (-ve):   -0.7686
Epoch 17/50


                                                                                                                                     

Epoch 17/50
  Train Loss: 2.5401
  Val Spearman (-ve):   -0.7695
Epoch 18/50


                                                                                                                                     

Epoch 18/50
  Train Loss: 2.5200
  Val Spearman (-ve):   -0.7709
  -> Best model saved (val_loss improved to -0.7709)
Epoch 19/50


                                                                                                                                     

Epoch 19/50
  Train Loss: 2.4976
  Val Spearman (-ve):   -0.7700
Epoch 20/50


                                                                                                                                     

Epoch 20/50
  Train Loss: 2.5061
  Val Spearman (-ve):   -0.7706
Epoch 21/50


                                                                                                                                     

Epoch 21/50
  Train Loss: 2.4809
  Val Spearman (-ve):   -0.7707
Epoch 22/50


                                                                                                                                     

Epoch 22/50
  Train Loss: 2.4472
  Val Spearman (-ve):   -0.7715
  -> Best model saved (val_loss improved to -0.7715)
Epoch 23/50


                                                                                                                                     

Epoch 23/50
  Train Loss: 2.4398
  Val Spearman (-ve):   -0.7716
  -> Best model saved (val_loss improved to -0.7716)
Epoch 24/50


                                                                                                                                     

Epoch 24/50
  Train Loss: 2.4362
  Val Spearman (-ve):   -0.7716
Epoch 25/50


                                                                                                                                     

Epoch 25/50
  Train Loss: 2.4223
  Val Spearman (-ve):   -0.7709
Epoch 26/50


                                                                                                                                     

Epoch 26/50
  Train Loss: 2.4202
  Val Spearman (-ve):   -0.7703
Epoch 27/50


                                                                                                                                     

Epoch 27/50
  Train Loss: 2.4073
  Val Spearman (-ve):   -0.7716
Epoch 28/50


                                                                                                                                     

Epoch 28/50
  Train Loss: 2.4113
  Val Spearman (-ve):   -0.7719
  -> Best model saved (val_loss improved to -0.7719)
Epoch 29/50


                                                                                                                                     

Epoch 29/50
  Train Loss: 2.3804
  Val Spearman (-ve):   -0.7722
  -> Best model saved (val_loss improved to -0.7722)
Epoch 30/50


                                                                                                                                     

Epoch 30/50
  Train Loss: 2.3722
  Val Spearman (-ve):   -0.7721
Epoch 31/50


                                                                                                                                     

Epoch 31/50
  Train Loss: 2.3652
  Val Spearman (-ve):   -0.7725
  -> Best model saved (val_loss improved to -0.7725)
Epoch 32/50


                                                                                                                                     

Epoch 32/50
  Train Loss: 2.3597
  Val Spearman (-ve):   -0.7734
  -> Best model saved (val_loss improved to -0.7734)
Epoch 33/50


                                                                                                                                     

Epoch 33/50
  Train Loss: 2.3715
  Val Spearman (-ve):   -0.7732
Epoch 34/50


                                                                                                                                     

Epoch 34/50
  Train Loss: 2.3679
  Val Spearman (-ve):   -0.7728
Epoch 35/50


                                                                                                                                     

Epoch 35/50
  Train Loss: 2.3442
  Val Spearman (-ve):   -0.7733
Epoch 36/50


                                                                                                                                     

Epoch 36/50
  Train Loss: 2.3448
  Val Spearman (-ve):   -0.7738
  -> Best model saved (val_loss improved to -0.7738)
Epoch 37/50


                                                                                                                                     

Epoch 37/50
  Train Loss: 2.3292
  Val Spearman (-ve):   -0.7731
Epoch 38/50


                                                                                                                                     

Epoch 38/50
  Train Loss: 2.2948
  Val Spearman (-ve):   -0.7734
Epoch 39/50


                                                                                                                                     

Epoch 39/50
  Train Loss: 2.3191
  Val Spearman (-ve):   -0.7733
Epoch 40/50


                                                                                                                                     

Epoch 40/50
  Train Loss: 2.3089
  Val Spearman (-ve):   -0.7739
  -> Best model saved (val_loss improved to -0.7739)
Epoch 41/50


                                                                                                                                     

Epoch 41/50
  Train Loss: 2.2856
  Val Spearman (-ve):   -0.7744
  -> Best model saved (val_loss improved to -0.7744)
Epoch 42/50


                                                                                                                                     

Epoch 42/50
  Train Loss: 2.3001
  Val Spearman (-ve):   -0.7739
Epoch 43/50


                                                                                                                                     

Epoch 43/50
  Train Loss: 2.2929
  Val Spearman (-ve):   -0.7741
Epoch 44/50


                                                                                                                                     

Epoch 44/50
  Train Loss: 2.2816
  Val Spearman (-ve):   -0.7744
Epoch 45/50


                                                                                                                                     

Epoch 45/50
  Train Loss: 2.2919
  Val Spearman (-ve):   -0.7740
Epoch 46/50


                                                                                                                                     

Epoch 46/50
  Train Loss: 2.2771
  Val Spearman (-ve):   -0.7723
Epoch 47/50


                                                                                                                                     

Epoch 47/50
  Train Loss: 2.2769
  Val Spearman (-ve):   -0.7735
Epoch 48/50


                                                                                                                                     

Epoch 48/50
  Train Loss: 2.2843
  Val Spearman (-ve):   -0.7725
Epoch 49/50


                                                                                                                                     

Epoch 49/50
  Train Loss: 2.2536
  Val Spearman (-ve):   -0.7730
Epoch 50/50


                                                                                                                                     

Epoch 50/50
  Train Loss: 2.2683
  Val Spearman (-ve):   -0.7737
Test Spearman with 384 dimensions: 0.7056
Epoch 1/50


                                                                                                                                     

Epoch 1/50
  Train Loss: 3.5880
  Val Spearman (-ve):   -0.7361
  -> Best model saved (val_loss improved to -0.7361)
Epoch 2/50


                                                                                                                                     

Epoch 2/50
  Train Loss: 3.3029
  Val Spearman (-ve):   -0.7497
  -> Best model saved (val_loss improved to -0.7497)
Epoch 3/50


                                                                                                                                     

Epoch 3/50
  Train Loss: 3.1907
  Val Spearman (-ve):   -0.7542
  -> Best model saved (val_loss improved to -0.7542)
Epoch 4/50


                                                                                                                                     

Epoch 4/50
  Train Loss: 3.1207
  Val Spearman (-ve):   -0.7561
  -> Best model saved (val_loss improved to -0.7561)
Epoch 5/50


                                                                                                                                     

Epoch 5/50
  Train Loss: 3.0124
  Val Spearman (-ve):   -0.7588
  -> Best model saved (val_loss improved to -0.7588)
Epoch 6/50


                                                                                                                                     

Epoch 6/50
  Train Loss: 2.9925
  Val Spearman (-ve):   -0.7599
  -> Best model saved (val_loss improved to -0.7599)
Epoch 7/50


                                                                                                                                     

Epoch 7/50
  Train Loss: 2.9343
  Val Spearman (-ve):   -0.7617
  -> Best model saved (val_loss improved to -0.7617)
Epoch 8/50


                                                                                                                                     

Epoch 8/50
  Train Loss: 2.8594
  Val Spearman (-ve):   -0.7632
  -> Best model saved (val_loss improved to -0.7632)
Epoch 9/50


                                                                                                                                     

Epoch 9/50
  Train Loss: 2.8081
  Val Spearman (-ve):   -0.7645
  -> Best model saved (val_loss improved to -0.7645)
Epoch 10/50


                                                                                                                                     

Epoch 10/50
  Train Loss: 2.7777
  Val Spearman (-ve):   -0.7650
  -> Best model saved (val_loss improved to -0.7650)
Epoch 11/50


Training:  95%|███████████████████████████████████████████████████████▊   | 125/132 [01:02<00:02,  2.53it/s, loss=0.0827, pairs=1085]

In [3]:
# Final tests

hyp_dims = [384,192,96,64,48,32,24,16,12,8]
for dim in hyp_dims:
    model = HyperbolicMapper(output_dim = dim).to(device)
    model.load_state_dict(torch.load(f'saved_models2/hyp_{dim}.pt'))
    test_spearman = -validate(test_loader, model, device)  # Flip sign back
    print(f"Test Spearman for {dim}: {test_spearman:.4f}")
    del model



  model.load_state_dict(torch.load(f'saved_models2/hyp_{dim}.pt'))


Test Spearman for 384: 0.7062
Test Spearman for 192: 0.7050
Test Spearman for 96: 0.7007
Test Spearman for 64: 0.6937
Test Spearman for 48: 0.6866
Test Spearman for 32: 0.6760
Test Spearman for 24: 0.6663
Test Spearman for 16: 0.6450
Test Spearman for 12: 0.6252
Test Spearman for 8: 0.5927
