In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import csv
import os
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import random_split
from sklearn.model_selection import GroupShuffleSplit
from scipy.stats import spearmanr
from tqdm import tqdm  # make sure you have tqdm installed
import geoopt
from datetime import datetime
import json
import numpy as np
# there was an error with geoopt scipy version compatibility, it was fixed by-
# nano /cronus_data/vraja/dysarthria/lib/python3.11/site-packages/geoopt/optim/rlinesearch.py
# added this to the import -
# try:
#     from scipy.optimize.linesearch import scalar_search_wolfe2, scalar_search_armijo
# except ImportError:
#     from scipy.optimize._linesearch import scalar_search_wolfe2, scalar_search_armijo
torch.manual_seed(42)  # Before random_split
base_path = ""

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# -----------------------------
# 1. Dataset Class
# -----------------------------

class TwitterURLDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.data = self.read_file(file_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.groups = self.group_data()
        self.group_ids = [i for i, _ in enumerate(self.groups)]
        
        # Pre-compute maximum group size
        self.max_group_size = max(len(group) for group in self.groups)  # Cap at 10
        print(f"Max group size: {self.max_group_size}")
        print(f"unique group sizes = {set([len(group) for group in self.groups])}")



    def read_file(self, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            reader = csv.reader(f, delimiter='\t')
            for row in reader:
                if len(row) != 4:
                    continue
                sentence1, sentence2, score_str, url = row
                try:
                    score = int(score_str.strip('()').split(',')[0])
                except:
                    continue
                data.append((sentence1.strip(), sentence2.strip(), score))
        return data


    def group_data(self):
        groups = []
        current_group = []
        last_sentence1 = None
        for item in self.data:
            sentence1, sentence2, score = item
            if sentence1 != last_sentence1 and current_group:
                groups.append(current_group)
                current_group = []
            current_group.append(item)
            last_sentence1 = sentence1
        if current_group:
            groups.append(current_group)
        return groups

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        group = self.groups[idx]
        anchor_text = group[0][0]
        
        # Pad group to max_group_size with dummy entries if needed
        padded_group = group + [("", "", 0)] * (self.max_group_size - len(group))
        
        sentences = [item[1] for item in padded_group]
        scores = torch.tensor([item[2] for item in padded_group], dtype=torch.float32)
        
        anchor_input = self.tokenizer(
            anchor_text, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors="pt",
            return_token_type_ids=False  # Add this line
        )
        
        sentence_inputs = [
            self.tokenizer(
                s, 
                padding='max_length', 
                truncation=True, 
                max_length=self.max_length, 
                return_tensors="pt",
                return_token_type_ids=False  # Add this line
            ) for s in sentences
        ]
        
        return {
            'anchor_input': {k: v.squeeze(0) for k, v in anchor_input.items()},
            'sentence_inputs': [{k: v.squeeze(0) for k, v in inp.items()} for inp in sentence_inputs],
            'scores': scores,
            'mask': torch.tensor([1]*len(group) + [0]*(self.max_group_size - len(group)))  # Mask for real vs padded
        }


# -----------------------------
# 2. Model
# -----------------------------    
class HyperbolicMapper(nn.Module):
    def __init__(self, sbert_model_name='sentence-transformers/all-MiniLM-L6-v2', output_dim=32):
        super(HyperbolicMapper, self).__init__()
        # Frozen SBERT
        self.sbert = AutoModel.from_pretrained(sbert_model_name)
        for param in self.sbert.parameters():
            param.requires_grad = False
        
        sbert_hidden_dim = self.sbert.config.hidden_size  # usually 384 for MiniLM
        self.curvature = nn.Parameter(torch.tensor(1.0))
        self.temperature = nn.Parameter(torch.tensor(1.0))
        # self.fc = nn.Linear(sbert_hidden_dim, output_dim)
        self.projection = nn.Sequential(
            # nn.LayerNorm(sbert_hidden_dim),
            # nn.Linear(sbert_hidden_dim, sbert_hidden_dim//2),
            # nn.GELU(),
            # nn.Linear(sbert_hidden_dim//2, output_dim))
            nn.Linear(sbert_hidden_dim, output_dim))
        print("initialized model")

    def poincare_project(self, x):
        x = x / self.temperature
        norm = torch.norm(x, p=2, dim=-1, keepdim=True)
        scale = (1 - 1e-5) / torch.clamp(norm * torch.sqrt(self.curvature), min=1e-5)
        return x * scale
        
    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            sbert_output = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embedding = sbert_output.last_hidden_state[:, 0]
        
        projected = self.projection(cls_embedding)
        return self.poincare_project(projected)

# -----------------------------
# 3. Margin Ranking Loss and Poincare distance
# -----------------------------

# margin_ranking_loss = nn.MarginRankingLoss(margin=0.2)
def hyperbolic_margin_loss(anchor, positive, negative, margin=0.2, 
                          curvature=1.0, temperature=1.0):
    # Scale margin by temperature
    effective_margin = margin * temperature

    # print("TRAIN:anchor.shape = ",anchor.shape)
    # print("TRAIN:negative.shape = ",negative.shape)
    # print("TRAIN:positive.shape = ",positive.shape)
    pos_dist = poincare_distance(anchor/temperature, 
                               positive/temperature, 
                               curvature)
    neg_dist = poincare_distance(anchor/temperature,
                               negative/temperature,
                               curvature)
    
    return F.relu(pos_dist - neg_dist + effective_margin).mean()

def poincare_distance(x, y, curvature=1.0, eps=1e-5):
    """Batch-supported Poincaré distance with curvature"""
    sqrt_c = torch.sqrt(curvature + eps)

    # Ensure same batch shape if needed
    if x.dim() == 2 and y.dim() == 3:
        # x: [B, D] -> [B, 1, D]
        x = x.unsqueeze(1)
    elif x.dim() == 2 and y.dim() == 2:
        # x: [B, D], y: [B, D] — no reshaping needed
        pass
    else:
        raise ValueError(f"Incompatible shapes: x {x.shape}, y {y.shape}")
    
    # Compute norms
    x_norm = torch.norm(x, p=2, dim=-1, keepdim=True) * sqrt_c  # [B, 1, 1]
    y_norm = torch.norm(y, p=2, dim=-1, keepdim=True) * sqrt_c  # [B, G, 1]
    # Pairwise distances
    pairwise_norm = torch.norm(x - y, p=2, dim=-1, keepdim=True) * sqrt_c  # [B, G, 1]
    
    # # Hyperbolic distance calculation
    # denominator = (1 - curvature * x_norm**2) * (1 - curvature * y_norm**2)
    # inside = 1 + 2 * curvature * pairwise_norm**2 / (denominator.clamp(min=eps))
    # return torch.acosh(torch.clamp(inside, min=1+eps)) / (sqrt_c + eps)  # [B, G, 1] -> [B, G]
    # Distance
    denominator = (1 - curvature * x_norm**2) * (1 - curvature * y_norm**2)
    inside = 1 + 2 * curvature * pairwise_norm**2 / (denominator.clamp(min=eps))
    return torch.acosh(torch.clamp(inside, min=1+eps)).squeeze(-1) / (sqrt_c + eps)
    
# -----------------------------
# 5. Training Step
# -----------------------------

def train_one_epoch(train_loader, model, optimizer, device):
    model.train()
    total_loss = 0
    total_pairs = 0
    
    # Wrap your loader with tqdm for a progress bar
    prog_bar = tqdm(train_loader, desc="Training", leave=False)
    
    for batch in prog_bar:
        optimizer.zero_grad()
        
        anchor_input = {k: v.to(device) for k, v in batch['anchor_input'].items()}
        anchor_embed = model(**anchor_input)
        
        sentence_inputs = {k: v.to(device) for k, v in batch['sentence_inputs'].items()}
        sentence_embeds = model(**sentence_inputs)
        
        batch_size, max_group_size = batch['scores'].shape
        sentence_embeds = sentence_embeds.view(batch_size, max_group_size, -1)
        scores = batch['scores'].to(device)
        mask = batch['mask'].to(device).bool()       
        
        # Create all valid pairs
        pos_mask = (scores.unsqueeze(2) > scores.unsqueeze(1)) & mask.unsqueeze(2) & mask.unsqueeze(1)
        
        if pos_mask.any():
            # Get indices of all valid pairs
            anchor_idx, pos_idx, neg_idx = torch.where(pos_mask)
            
            # Compute hyperbolic margin loss directly
            loss = hyperbolic_margin_loss(
                anchor_embed[anchor_idx],          # Anchor embeddings
                sentence_embeds[anchor_idx, pos_idx],  # Positive embeddings
                sentence_embeds[anchor_idx, neg_idx],  # Negative embeddings
                margin=0.2,
                curvature=model.curvature,
                temperature=model.temperature
            )
            
            loss.backward()
            optimizer.step()
            
            # Update tracking
            num_pairs = pos_mask.sum().item()
            total_loss += loss.item() * num_pairs
            total_pairs += num_pairs
            
        prog_bar.set_postfix({
            'loss': loss.item() if pos_mask.any() else 0,
            'pairs': pos_mask.sum().item() if pos_mask.any() else 0,
            'curv': model.curvature.item(),
            'temp': model.temperature.item()
        })
    
    # Return average loss per pair
    return total_loss / total_pairs if total_pairs > 0 else 0


def validate(loader, model, device):
    model.eval()
    # total_loss = 0
    # total_pairs = 0
    all_anchor_spearmans = []
    
    with torch.no_grad():
        for batch in loader:
            anchor_input = {k: v.to(device) for k, v in batch['anchor_input'].items()}
            anchor_embed = model(**anchor_input)
            
            # Process all sentences (with masking)
            sentence_inputs = {k: v.to(device) for k, v in batch['sentence_inputs'].items()}
            sentence_embeds = model(**sentence_inputs)

            # Reshape and mask
            batch_size, max_group_size = batch['scores'].shape
            sentence_embeds = sentence_embeds.view(batch_size, max_group_size, -1)
            mask = batch['mask'].to(device)
            scores = batch['scores'].to(device)

            # print("anchor_embed.shape = ",anchor_embed.shape)
            # print("sentence_embeds.shape = ",sentence_embeds.shape)
            # Compute similarities -> Now poincare distance
            dists = poincare_distance(
                anchor_embed, #.unsqueeze(1), #.expand_as(sentence_embeds),
                sentence_embeds,
                curvature=model.curvature
            ) * mask
            # Flip distance -> similarity
            sims = -dists

            for i in range(batch_size):
                valid_idx = torch.where(mask[i] == 1)[0]
                if len(valid_idx) < 2:  # Need at least 2 for correlation
                    continue

                # Calculate per-anchor Spearman
                anchor_scores = scores[i, valid_idx].cpu().numpy()
                anchor_preds = sims[i, valid_idx].cpu().numpy()

                if len(np.unique(anchor_scores)) > 1:  # Check for rankable scores
                    spear = spearmanr(anchor_scores, anchor_preds).correlation
                    all_anchor_spearmans.append(spear)

                # # Validation loss part. Im just adding cause it was also in training. Dont really need it
                # pos_mask = (scores[i].unsqueeze(1) > scores[i].unsqueeze(0)) & mask[i].unsqueeze(1) & mask[i].unsqueeze(0)
                # if pos_mask.any():
                #     a_idx, p_idx, n_idx = torch.where(pos_mask)
                #     loss = hyperbolic_margin_loss(
                #         anchor_embed[i].unsqueeze(0).expand(len(a_idx), -1),
                #         sentence_embeds[i, p_idx],
                #         sentence_embeds[i, n_idx],
                #         margin=0.2,
                #         curvature=model.curvature,
                #         temperature=model.temperature
                #     )
                #     total_loss += loss.item() * len(a_idx)
                #     total_pairs += len(a_idx)
    
    # Compute metrics
    mean_spearman = np.mean(all_anchor_spearmans) if all_anchor_spearmans else 0.0
    # avg_loss = total_loss / total_pairs if total_pairs > 0 else 0.0
    
    return mean_spearman


def collate_fn(batch):
    # Find actual max group size in this batch
    batch_max_len = max(len(item['mask'][item['mask'] == 1]) for item in batch)
    
    # Truncate all inputs to this length
    anchor_inputs = {
        k: torch.stack([item['anchor_input'][k] for item in batch])
        for k in batch[0]['anchor_input']
    }
    
    sentence_inputs = {
        k: torch.stack([
            inp[k] 
            for item in batch 
            for inp in item['sentence_inputs'][:batch_max_len]  # Only take needed elements
        ])
        for k in batch[0]['sentence_inputs'][0]
    }
    
    scores = torch.stack([item['scores'][:batch_max_len] for item in batch])
    masks = torch.stack([item['mask'][:batch_max_len] for item in batch])
    
    return {
        'anchor_input': anchor_inputs,
        'sentence_inputs': sentence_inputs,
        'scores': scores,
        'mask': masks
    }


tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')


def group_aware_split(dataset, test_size=0.1, random_state=42):
    splitter = GroupShuffleSplit(
        n_splits=1, 
        test_size=test_size, 
        random_state=random_state
    )
    train_idx, val_idx = next(splitter.split(
        range(len(dataset)), 
        groups=dataset.group_ids  # Critical: Group by anchor IDs
    ))
    return (
        torch.utils.data.Subset(dataset, train_idx),
        torch.utils.data.Subset(dataset, val_idx)
    )

# Load datasets
full_train_dataset = TwitterURLDataset(f'{base_path}paraphrase_dataset_emnlp2017/Twitter_URL_Corpus_train.txt', tokenizer)
test_dataset = TwitterURLDataset(f'{base_path}paraphrase_dataset_emnlp2017/Twitter_URL_Corpus_test.txt', tokenizer)

# Group-aware split (90% train, 10% val)
train_dataset, val_dataset = group_aware_split(full_train_dataset, test_size=0.1)

# For all splits (train/val/test), use the same collate_fn
train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    collate_fn=collate_fn  # Same as training
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=32, 
    shuffle=False,  # No shuffling for val/test
    collate_fn=collate_fn  # Same collate function!
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=32, 
    shuffle=False,
    collate_fn=collate_fn  # Consistency is key
)

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")




Max group size: 20
unique group sizes = {8, 9, 10, 20}
Max group size: 20
unique group sizes = {10, 20}
Train size: 4197
Val size: 467
Test size: 1010


In [2]:
num_epochs = 40
def train_model(model, label, optimizer):
    best_val_spearman = float('-inf')
    log_data = {
        'training_start': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        'config': {
            'model': label,
            'num_epochs': num_epochs,
            'optimizer': str(optimizer.__class__.__name__),
            'device': str(device)
        },
        'epochs': []
    }
    for epoch in range(num_epochs):
        # Initialize logging
        epoch_start = datetime.now()
        
        print(f"Epoch {epoch+1}/{num_epochs}")
    
        train_loss = train_one_epoch(train_loader, model, optimizer, device)
        val_spearman = validate(val_loader, model, device)

        # Track epoch data
        epoch_log = {
            'epoch': epoch + 1,
            'train_loss': float(train_loss),
            'val_spearman': float(val_spearman),
            'duration_seconds': (datetime.now() - epoch_start).total_seconds(),
            'is_best': False
        }
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Spearman (-ve):   {val_spearman:.4f}")
    
        if val_spearman > best_val_spearman:
            best_val_spearman = val_spearman
            torch.save(model.state_dict(), f'saved_models3/{label}.pt')
            epoch_log['is_best'] = True
            print(f"  -> Best model saved (val_loss improved to {best_val_spearman:.4f})")

        log_data['epochs'].append(epoch_log)
        
        # Save logs after each epoch (in case of crash)
        with open(f'model_logs3/{label}_logs.json', 'w') as f:
            json.dump(log_data, f, indent=2)
    # Add final metadata
    log_data['training_end'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    log_data['best_val_spearman'] = float(best_val_spearman)
    
    # Final save
    with open(f'model_logs3/{label}_logs.json', 'w') as f:
        json.dump(log_data, f, indent=2)

hyp_dims = [384,192]
for dim in hyp_dims:
    print("Initiliazing model")
    model = HyperbolicMapper(output_dim = dim).to(device)
    # RiemannianAdam optimizer:
    optimizer = geoopt.optim.RiemannianAdam(
        model.parameters(), 
        lr=1e-4,
        stabilize=1000  # Helps with numerical stability
    )
    print("Starting training")
    train_model(model, f"hyp_{dim}", optimizer)
    test_spearman = validate(test_loader, model, device)
    print(f"Test Spearman with {dim} dimensions: {test_spearman:.4f}")
    del model


Initiliazing model
initialized model
Starting training
Epoch 1/40


                                                                                                                                     

Epoch 1/40
  Train Loss: 0.1083
  Val Spearman (-ve):   0.6095
  -> Best model saved (val_loss improved to 0.6095)
Epoch 2/40


                                                                                                                                     

Epoch 2/40
  Train Loss: 0.0996
  Val Spearman (-ve):   0.6217
  -> Best model saved (val_loss improved to 0.6217)
Epoch 3/40


                                                                                                                                     

Epoch 3/40
  Train Loss: 0.0935
  Val Spearman (-ve):   0.6258
  -> Best model saved (val_loss improved to 0.6258)
Epoch 4/40


                                                                                                                                     

Epoch 4/40
  Train Loss: 0.0897
  Val Spearman (-ve):   0.6287
  -> Best model saved (val_loss improved to 0.6287)
Epoch 5/40


                                                                                                                                     

Epoch 5/40
  Train Loss: 0.0875
  Val Spearman (-ve):   0.6324
  -> Best model saved (val_loss improved to 0.6324)
Epoch 6/40


                                                                                                                                     

Epoch 6/40
  Train Loss: 0.0843
  Val Spearman (-ve):   0.6354
  -> Best model saved (val_loss improved to 0.6354)
Epoch 7/40


                                                                                                                                     

Epoch 7/40
  Train Loss: 0.0817
  Val Spearman (-ve):   0.6391
  -> Best model saved (val_loss improved to 0.6391)
Epoch 8/40


                                                                                                                                     

Epoch 8/40
  Train Loss: 0.0795
  Val Spearman (-ve):   0.6414
  -> Best model saved (val_loss improved to 0.6414)
Epoch 9/40


                                                                                                                                     

Epoch 9/40
  Train Loss: 0.0763
  Val Spearman (-ve):   0.6426
  -> Best model saved (val_loss improved to 0.6426)
Epoch 10/40


                                                                                                                                     

Epoch 10/40
  Train Loss: 0.0750
  Val Spearman (-ve):   0.6444
  -> Best model saved (val_loss improved to 0.6444)
Epoch 11/40


                                                                                                                                     

Epoch 11/40
  Train Loss: 0.0740
  Val Spearman (-ve):   0.6432
Epoch 12/40


                                                                                                                                     

Epoch 12/40
  Train Loss: 0.0716
  Val Spearman (-ve):   0.6445
  -> Best model saved (val_loss improved to 0.6445)
Epoch 13/40


                                                                                                                                     

Epoch 13/40
  Train Loss: 0.0685
  Val Spearman (-ve):   0.6450
  -> Best model saved (val_loss improved to 0.6450)
Epoch 14/40


                                                                                                                                     

Epoch 14/40
  Train Loss: 0.0673
  Val Spearman (-ve):   0.6442
Epoch 15/40


                                                                                                                                     

Epoch 15/40
  Train Loss: 0.0659
  Val Spearman (-ve):   0.6456
  -> Best model saved (val_loss improved to 0.6456)
Epoch 16/40


                                                                                                                                     

Epoch 16/40
  Train Loss: 0.0646
  Val Spearman (-ve):   0.6435
Epoch 17/40


                                                                                                                                     

Epoch 17/40
  Train Loss: 0.0626
  Val Spearman (-ve):   0.6449
Epoch 18/40


                                                                                                                                     

Epoch 18/40
  Train Loss: 0.0610
  Val Spearman (-ve):   0.6468
  -> Best model saved (val_loss improved to 0.6468)
Epoch 19/40


                                                                                                                                     

Epoch 19/40
  Train Loss: 0.0595
  Val Spearman (-ve):   0.6441
Epoch 20/40


                                                                                                                                     

Epoch 20/40
  Train Loss: 0.0586
  Val Spearman (-ve):   0.6474
  -> Best model saved (val_loss improved to 0.6474)
Epoch 21/40


                                                                                                                                     

Epoch 21/40
  Train Loss: 0.0570
  Val Spearman (-ve):   0.6475
  -> Best model saved (val_loss improved to 0.6475)
Epoch 22/40


                                                                                                                                     

Epoch 22/40
  Train Loss: 0.0553
  Val Spearman (-ve):   0.6485
  -> Best model saved (val_loss improved to 0.6485)
Epoch 23/40


                                                                                                                                     

Epoch 23/40
  Train Loss: 0.0541
  Val Spearman (-ve):   0.6490
  -> Best model saved (val_loss improved to 0.6490)
Epoch 24/40


                                                                                                                                     

Epoch 24/40
  Train Loss: 0.0530
  Val Spearman (-ve):   0.6491
  -> Best model saved (val_loss improved to 0.6491)
Epoch 25/40


                                                                                                                                     

Epoch 25/40
  Train Loss: 0.0518
  Val Spearman (-ve):   0.6488
Epoch 26/40


                                                                                                                                     

Epoch 26/40
  Train Loss: 0.0508
  Val Spearman (-ve):   0.6487
Epoch 27/40


                                                                                                                                     

Epoch 27/40
  Train Loss: 0.0495
  Val Spearman (-ve):   0.6502
  -> Best model saved (val_loss improved to 0.6502)
Epoch 28/40


                                                                                                                                     

Epoch 28/40
  Train Loss: 0.0486
  Val Spearman (-ve):   0.6513
  -> Best model saved (val_loss improved to 0.6513)
Epoch 29/40


                                                                                                                                     

Epoch 29/40
  Train Loss: 0.0470
  Val Spearman (-ve):   0.6508
Epoch 30/40


                                                                                                                                     

Epoch 30/40
  Train Loss: 0.0459
  Val Spearman (-ve):   0.6508
Epoch 31/40


                                                                                                                                     

Epoch 31/40
  Train Loss: 0.0449
  Val Spearman (-ve):   0.6521
  -> Best model saved (val_loss improved to 0.6521)
Epoch 32/40


                                                                                                                                     

Epoch 32/40
  Train Loss: 0.0438
  Val Spearman (-ve):   0.6501
Epoch 33/40


                                                                                                                                     

Epoch 33/40
  Train Loss: 0.0431
  Val Spearman (-ve):   0.6513
Epoch 34/40


                                                                                                                                     

Epoch 34/40
  Train Loss: 0.0420
  Val Spearman (-ve):   0.6518
Epoch 35/40


                                                                                                                                     

Epoch 35/40
  Train Loss: 0.0407
  Val Spearman (-ve):   0.6516
Epoch 36/40


                                                                                                                                     

Epoch 36/40
  Train Loss: 0.0398
  Val Spearman (-ve):   0.6502
Epoch 37/40


                                                                                                                                     

Epoch 37/40
  Train Loss: 0.0386
  Val Spearman (-ve):   0.6503
Epoch 38/40


                                                                                                                                     

Epoch 38/40
  Train Loss: 0.0372
  Val Spearman (-ve):   0.6502
Epoch 39/40


                                                                                                                                     

Epoch 39/40
  Train Loss: 0.0366
  Val Spearman (-ve):   0.6497
Epoch 40/40


                                                                                                                                     

Epoch 40/40
  Train Loss: 0.0356
  Val Spearman (-ve):   0.6503
Test Spearman with 384 dimensions: 0.6252
Initiliazing model
initialized model
Starting training
Epoch 1/40


                                                                                                                                     

Epoch 1/40
  Train Loss: 0.1114
  Val Spearman (-ve):   0.5902
  -> Best model saved (val_loss improved to 0.5902)
Epoch 2/40


                                                                                                                                     

Epoch 2/40
  Train Loss: 0.1032
  Val Spearman (-ve):   0.6039
  -> Best model saved (val_loss improved to 0.6039)
Epoch 3/40


                                                                                                                                     

Epoch 3/40
  Train Loss: 0.0972
  Val Spearman (-ve):   0.6115
  -> Best model saved (val_loss improved to 0.6115)
Epoch 4/40


                                                                                                                                     

Epoch 4/40
  Train Loss: 0.0931
  Val Spearman (-ve):   0.6168
  -> Best model saved (val_loss improved to 0.6168)
Epoch 5/40


                                                                                                                                     

Epoch 5/40
  Train Loss: 0.0904
  Val Spearman (-ve):   0.6213
  -> Best model saved (val_loss improved to 0.6213)
Epoch 6/40


                                                                                                                                     

Epoch 6/40
  Train Loss: 0.0869
  Val Spearman (-ve):   0.6216
  -> Best model saved (val_loss improved to 0.6216)
Epoch 7/40


                                                                                                                                     

Epoch 7/40
  Train Loss: 0.0843
  Val Spearman (-ve):   0.6231
  -> Best model saved (val_loss improved to 0.6231)
Epoch 8/40


                                                                                                                                     

Epoch 8/40
  Train Loss: 0.0817
  Val Spearman (-ve):   0.6223
Epoch 9/40


                                                                                                                                     

Epoch 9/40
  Train Loss: 0.0808
  Val Spearman (-ve):   0.6221
Epoch 10/40


                                                                                                                                     

Epoch 10/40
  Train Loss: 0.0771
  Val Spearman (-ve):   0.6266
  -> Best model saved (val_loss improved to 0.6266)
Epoch 11/40


                                                                                                                                     

Epoch 11/40
  Train Loss: 0.0749
  Val Spearman (-ve):   0.6256
Epoch 12/40


                                                                                                                                     

Epoch 12/40
  Train Loss: 0.0735
  Val Spearman (-ve):   0.6265
Epoch 13/40


                                                                                                                                     

Epoch 13/40
  Train Loss: 0.0712
  Val Spearman (-ve):   0.6302
  -> Best model saved (val_loss improved to 0.6302)
Epoch 14/40


                                                                                                                                     

Epoch 14/40
  Train Loss: 0.0701
  Val Spearman (-ve):   0.6302
Epoch 15/40


                                                                                                                                     

Epoch 15/40
  Train Loss: 0.0674
  Val Spearman (-ve):   0.6297
Epoch 16/40


                                                                                                                                     

Epoch 16/40
  Train Loss: 0.0667
  Val Spearman (-ve):   0.6319
  -> Best model saved (val_loss improved to 0.6319)
Epoch 17/40


                                                                                                                                     

Epoch 17/40
  Train Loss: 0.0644
  Val Spearman (-ve):   0.6331
  -> Best model saved (val_loss improved to 0.6331)
Epoch 18/40


                                                                                                                                     

Epoch 18/40
  Train Loss: 0.0623
  Val Spearman (-ve):   0.6323
Epoch 19/40


                                                                                                                                     

Epoch 19/40
  Train Loss: 0.0610
  Val Spearman (-ve):   0.6354
  -> Best model saved (val_loss improved to 0.6354)
Epoch 20/40


                                                                                                                                     

Epoch 20/40
  Train Loss: 0.0595
  Val Spearman (-ve):   0.6345
Epoch 21/40


                                                                                                                                     

Epoch 21/40
  Train Loss: 0.0586
  Val Spearman (-ve):   0.6343
Epoch 22/40


                                                                                                                                     

Epoch 22/40
  Train Loss: 0.0568
  Val Spearman (-ve):   0.6353
Epoch 23/40


                                                                                                                                     

Epoch 23/40
  Train Loss: 0.0557
  Val Spearman (-ve):   0.6354
Epoch 24/40


                                                                                                                                     

Epoch 24/40
  Train Loss: 0.0547
  Val Spearman (-ve):   0.6363
  -> Best model saved (val_loss improved to 0.6363)
Epoch 25/40


                                                                                                                                     

Epoch 25/40
  Train Loss: 0.0528
  Val Spearman (-ve):   0.6364
  -> Best model saved (val_loss improved to 0.6364)
Epoch 26/40


                                                                                                                                     

Epoch 26/40
  Train Loss: 0.0514
  Val Spearman (-ve):   0.6359
Epoch 27/40


                                                                                                                                     

Epoch 27/40
  Train Loss: 0.0506
  Val Spearman (-ve):   0.6358
Epoch 28/40


                                                                                                                                     

Epoch 28/40
  Train Loss: 0.0495
  Val Spearman (-ve):   0.6380
  -> Best model saved (val_loss improved to 0.6380)
Epoch 29/40


                                                                                                                                     

Epoch 29/40
  Train Loss: 0.0484
  Val Spearman (-ve):   0.6357
Epoch 30/40


                                                                                                                                     

Epoch 30/40
  Train Loss: 0.0469
  Val Spearman (-ve):   0.6374
Epoch 31/40


                                                                                                                                     

Epoch 31/40
  Train Loss: 0.0457
  Val Spearman (-ve):   0.6391
  -> Best model saved (val_loss improved to 0.6391)
Epoch 32/40


                                                                                                                                     

Epoch 32/40
  Train Loss: 0.0446
  Val Spearman (-ve):   0.6386
Epoch 33/40


                                                                                                                                     

Epoch 33/40
  Train Loss: 0.0440
  Val Spearman (-ve):   0.6397
  -> Best model saved (val_loss improved to 0.6397)
Epoch 34/40


                                                                                                                                     

Epoch 34/40
  Train Loss: 0.0421
  Val Spearman (-ve):   0.6392
Epoch 35/40


                                                                                                                                     

Epoch 35/40
  Train Loss: 0.0417
  Val Spearman (-ve):   0.6403
  -> Best model saved (val_loss improved to 0.6403)
Epoch 36/40


                                                                                                                                     

Epoch 36/40
  Train Loss: 0.0405
  Val Spearman (-ve):   0.6417
  -> Best model saved (val_loss improved to 0.6417)
Epoch 37/40


                                                                                                                                     

Epoch 37/40
  Train Loss: 0.0395
  Val Spearman (-ve):   0.6418
  -> Best model saved (val_loss improved to 0.6418)
Epoch 38/40


                                                                                                                                     

Epoch 38/40
  Train Loss: 0.0386
  Val Spearman (-ve):   0.6401
Epoch 39/40


                                                                                                                                     

Epoch 39/40
  Train Loss: 0.0370
  Val Spearman (-ve):   0.6407
Epoch 40/40


                                                                                                                                     

Epoch 40/40
  Train Loss: 0.0361
  Val Spearman (-ve):   0.6415
Test Spearman with 192 dimensions: 0.6217


In [4]:
# Final tests

hyp_dims = [384,192,96,64,48,32,24,16,12,8]
for dim in hyp_dims:
    model = HyperbolicMapper(output_dim = dim).to(device)
    model.load_state_dict(torch.load(f'saved_models3/hyp_{dim}.pt'))
    test_spearman = validate(test_loader, model, device)  # Flip sign back
    print(f"Test Spearman for {dim}: {test_spearman:.4f}")
    del model



  model.load_state_dict(torch.load(f'saved_models/hyp_{dim}.pt'))


Test Spearman for 384: 0.7059
Test Spearman for 192: 0.7067
Test Spearman for 96: 0.7023
Test Spearman for 64: 0.6971
Test Spearman for 48: 0.6910
Test Spearman for 32: 0.6853
Test Spearman for 24: 0.6734
Test Spearman for 16: 0.6487
Test Spearman for 12: 0.6423
Test Spearman for 8: 0.5944


  model.load_state_dict(torch.load('best_model.pt'))


Test Spearman: 0.6818
