<a href="https://colab.research.google.com/github/arjan-hada/esm2-antibody-CLIP/blob/main/ESM2_Ab_CLIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ESM2AbCLIP: Antibody structure aware ESM2 model

## Setup

In [1]:
!pip install torch transformers accelerate &> /dev/null

In [2]:
import os
from pathlib import Path
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

In [3]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

path = Path("/content/gdrive/")
path_data = Path("/content/gdrive/MyDrive/data/proteinflow_esmif1_20240520-0899946")

Mounted at /content/gdrive


## Dataset

In [4]:
import pandas as pd
import torch
from torch.utils.data import Dataset

class AntibodyDataset(Dataset):
    """
    Initialize the dataset.

    Args:
        data_path (str): Path to the pickle file containing data.
        tokenizer (transformers.PreTrainedTokenizer): Tokenizer to process the sequences.
    """
    def __init__(self, data_path, tokenizer):
        self.data = pd.read_pickle(data_path)
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Get item by index.
        """
        row = self.data.iloc[idx]
        sequence = row['sequence']
        embedding = torch.tensor(row['embedding'], dtype=torch.float32)

        inputs = self.tokenizer(sequence, return_tensors='pt', padding=False,
                                truncation=False)

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': embedding
        }

Github copilot: The conversion of row['embedding'] to a tensor is done directly within __getitem__. This is generally fine, but if the dataset is large and this operation is costly, consider pre-processing steps or caching mechanisms.

In [5]:
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader

# Initialize tokenizer
model_ckpt = 'facebook/esm2_t33_650M_UR50D'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

# Initialize datasets
train_ds = AntibodyDataset(path_data/'train_data.pkl', tokenizer)
valid_ds = AntibodyDataset(path_data/'valid_data.pkl', tokenizer)
test_ds = AntibodyDataset(path_data/'test_data.pkl', tokenizer)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [6]:
# Initialize DataCollator
data_collator = DataCollatorWithPadding(tokenizer)

# Initialize DataLoader with DataCollator
batch_size=2
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

In [7]:
batch = next(iter(train_dl))
print(batch)
print(batch['input_ids'].shape)
print(batch['attention_mask'].shape)
print(batch['labels'].shape)

{'input_ids': tensor([[ 0, 13, 14,  ...,  1,  1,  1],
        [ 0,  9, 17,  ...,  7, 12,  2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[-0.0578, -0.0622, -0.1321,  ...,  0.2488, -0.0769,  0.2094],
        [-0.0398,  0.1967, -0.0542,  ...,  0.0323, -0.0520,  0.1241]])}
torch.Size([2, 700])
torch.Size([2, 700])
torch.Size([2, 512])


## Custom Model for Contrastive Pre-training

In [8]:
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PretrainedConfig
from torch.cuda.amp import autocast

class ESM2AbCLIPConfig(PretrainedConfig):
    model_type = "esm2_ab_clip"

    def __init__(self, projection_dim=512,
                 model_ckpt='facebook/esm2_t33_650M_UR50D',
                 **kwargs):
        super().__init__(**kwargs)
        self.projection_dim = projection_dim
        self.model_ckpt = model_ckpt

class ESM2AbCLIP(PreTrainedModel):
    config_class = ESM2AbCLIPConfig

    def __init__(self, config):
        super().__init__(config)
        self.sequence_model = AutoModel.from_pretrained(config.model_ckpt)

        # Projection layers for sequence embeddings with GeLU
        self.sequence_projection = nn.Sequential(
            nn.Linear(self.sequence_model.config.hidden_size, config.projection_dim),
            nn.GELU()
        )

        # Projection layer for structure embeddings with GeLU
        self.structure_projection = nn.Sequential(
            nn.Linear(config.projection_dim, config.projection_dim),
            nn.GELU()
        )

        # Learnable temperature parameter
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        # Load and initialize weights
        self.post_init()

    def forward(self, input_ids, attention_mask, labels):
        with autocast():  # Enable mixed-precision for the forward pass
            # Forward pass through the sequence model
            outputs = self.sequence_model(input_ids=input_ids, attention_mask=attention_mask)
            cls_hidden_state = outputs.last_hidden_state[:, 0, :]  # [CLS] token pooling

            # Project the sequence embeddings
            projected_sequence = self.sequence_projection(cls_hidden_state)
            projected_structure = self.structure_projection(labels)

            # Clamp the logit scale value to ensure it does not exceed log(100)
            self.logit_scale.data.clamp_(max=np.log(100.0))

            return {
                'logits': (projected_sequence, projected_structure)
            }

In [9]:
model_config = ESM2AbCLIPConfig(projection_dim=512,
                                model_ckpt='facebook/esm2_t6_8M_UR50D')
model = ESM2AbCLIP(model_config)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
outputs = model.forward(**batch)
logits = outputs['logits']
print(logits[0].shape)
print(logits[1].shape)

torch.Size([2, 512])
torch.Size([2, 512])


## Contrastive Loss

Scaled pairwise cosine similarities are a crucial component in many contrastive learning frameworks, including CLIP (Contrastive Language-Image Pre-training).

**Cosine similarity** measures the cosine of the angle between two vectors in an inner product space. It is a measure of similarity between two non-zero vectors, giving a value between -1 and 1.

For two vectors \( A \) and \( B \):
$$ \text{cosine_similarity}(A, B) = \frac{A \cdot B}{\|A\| \|B\|} $$

In PyTorch, cosine similarity between two sets of embeddings can be computed using:
```python
import torch.nn.functional as F

cos_sim = F.cosine_similarity(embedding1, embedding2, dim=-1)
```

**Pairwise cosine similarity** computes the cosine similarity between each pair of vectors from two sets of vectors. This is useful in comparing all possible pairs in a batch.

For sequence embeddings $ \text{seq_embeddings} $ and structure embeddings $ \text{struct_embeddings}$:

```python
cos_sim_matrix = torch.mm(seq_embeddings, struct_embeddings.t())
```

This gives a matrix of cosine similarities between each sequence and each structure embedding.

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    def __init__(self):
        super(ContrastiveLoss, self).__init__()

    def forward(self, seq_embeddings, struct_embeddings, logit_scale):
        # Normalize embeddings to unit vectors
        seq_embeddings = F.normalize(seq_embeddings, dim=1)
        struct_embeddings = F.normalize(struct_embeddings, dim=1)

        # Compute pairwise cosine similarities and scale with temperature
        temperature = torch.exp(logit_scale)
        logits_per_seq = torch.mm(seq_embeddings, struct_embeddings.t()) * temperature
        logits_per_struct = logits_per_seq.t()

        # Labels for contrastive loss
        labels = torch.arange(seq_embeddings.size(0)).to(seq_embeddings.device)

        # Contrastive loss as described in the paper
        loss_seq = F.cross_entropy(logits_per_seq, labels)
        loss_struct = F.cross_entropy(logits_per_struct, labels)
        loss = (loss_seq + loss_struct) / 2

        return loss

The `ContrastiveLoss` class is designed to implement a contrastive learning objective for sequence and structure embeddings, similar to how CLIP (Contrastive Language-Image Pre-training) works. The goal is to bring corresponding sequence and structure embeddings closer in the embedding space while pushing non-corresponding pairs further apart.

#### Key Components

1. **Normalization**: Both sequence and structure embeddings are normalized to unit vectors. This ensures that the cosine similarity is computed correctly.
   
2. **Cosine Similarity**: The similarity between the embeddings is computed using the dot product. The temperature parameter τ is used to scale the logits (cosine similarities) before applying the softmax function in contrastive learning. It controls the sharpness of the distribution:
    - High temperature (large τ) results in a smoother probability distribution.
    - Low temperature (small τ) results in a sharper distribution.

    Making τ a learnable parameter allows the model to adapt the scaling dynamically based on the data and the training process.

   
3. **Contrastive Loss**:
   - **Sequence-to-Structure**: Each sequence embedding is compared against all structure embeddings. The goal is to maximize the similarity with its corresponding structure embedding and minimize the similarity with others.
   - **Structure-to-Sequence**: Similarly, each structure embedding is compared against all sequence embeddings.

4. **Cross-Entropy Loss**: The loss is computed using cross-entropy, treating the problem as a classification task where the correct pair should have the highest similarity score.

The computation of temperature = torch.exp(logit_scale) could potentially lead to numerical instability if logit_scale is large. There is a clamp (torch.clamp(logit_scale, max=...)) in model def to ensure stability.

In [12]:
loss_fct = ContrastiveLoss()  # Instantiate the loss function
loss = loss_fct(logits[0], logits[1], model.logit_scale)  # Compute the loss
print(loss)

tensor(0.6715, grad_fn=<DivBackward0>)


## Performance measures

We will implement the Alignment and Uniformity metrics as proposed by [Wang & Isola (2020)](https://arxiv.org/abs/2005.10242), in addition to Contrastive Accuracy and Top-K Accuracy.

1. **Alignment** measures how close positive pairs are in the embedding space.

  $$
  \text{Alignment} = \mathbb{E}_{(x, y) \sim p_{\text{pos}}} \left[ \| f(x) - f(y) \|^2 \right]
  $$
  This is calculated as the average squared Euclidean distance between embeddings of positive pairs. A good alignment score is close to 0, indicating that positive pairs are nearly identical in the embedding space.


2. **Uniformity** measures how uniformly the embeddings are spread on the unit hypersphere.
  $$
  \text{Uniformity} = \log \mathbb{E}_{(x, y) \sim p_{\text{data}}} \left[ e^{-2 \| f(x) - f(y) \|^2} \right]
  $$
  This is calculated as the logarithm of the expected exponential of the negative squared Euclidean distance between all pairs of embeddings. High Uniformity indicates that embeddings are well spread out uniformly across the embedding space, which is desirable.

3. **Cosine Similarity** measures the cosine of the angle between two non-zero vectors.

4. **Contrastive Accuracy** measures how often the model correctly identifies the matching pair among a set of negatives.

5. **Top-K Accuracy** measures whether the true positive is within the top K closest predictions.


In [13]:
import torch
import torch.nn.functional as F

# Function to compute the Alignment metric
def compute_alignment(seq_embeddings, struct_embeddings):
    distances = (seq_embeddings - struct_embeddings).pow(2).sum(dim=1)
    alignment = distances.mean().item()
    return alignment

# Function to compute the Uniformity metric
# This function is computationally intensive and may take a while to run
def compute_uniformity(embeddings):
    pairwise_distances = torch.cdist(embeddings, embeddings, p=2).pow(2)
    uniformity = torch.log(torch.exp(-2 * pairwise_distances).mean()).item()
    return uniformity

# Function to compute the Cosine Similarity metric
def compute_cosine_similarity(seq_embeddings, struct_embeddings):
    """
    This function normalizes the sequence and structure embeddings to unit
    vectors and then computes the cosine similarity between each pair using
    matrix multiplication.
    """
    cosine_sim = torch.mm(seq_embeddings, struct_embeddings.t())
    return cosine_sim

# Function to compute the Contrastive Accuracy
def compute_contrastive_accuracy(cosine_sim):
    """
    This function finds the index of the maximum cosine similarity for each
    sequence embedding and compares it to the correct index.
    It then computes the mean accuracy.
    """
    correct_preds = cosine_sim.argmax(dim=1)
    correct = correct_preds == torch.arange(cosine_sim.size(0)).to(cosine_sim)
    return correct.float().mean().item()

# Function to compute the Top-K Accuracy
def compute_top_k_accuracy(cosine_sim, k=1):
    """
    This function finds the top K predictions for each sequence embedding and
    checks if the correct match is within these top K predictions.
    It then computes the mean accuracy.
    """
    top_k_preds = cosine_sim.topk(k, dim=1)[1]
    correct = torch.arange(cosine_sim.size(0)).unsqueeze(1).expand_as(top_k_preds)
    correct = correct == top_k_preds
    return correct.any(dim=1).float().mean().item()

In [14]:
def compute_metrics(eval_pred):
    seq_embeddings, struct_embeddings = eval_pred.predictions
    seq_embeddings = torch.tensor(seq_embeddings)
    struct_embeddings = torch.tensor(struct_embeddings)

    # Normalize embeddings
    seq_embeddings = F.normalize(seq_embeddings, dim=1)
    struct_embeddings = F.normalize(struct_embeddings, dim=1)

    # Compute metrics
    alignment = compute_alignment(seq_embeddings, struct_embeddings)
    combined_embeddings = torch.cat((seq_embeddings, struct_embeddings), dim=0)
    uniformity = compute_uniformity(combined_embeddings)
    cosine_sim = compute_cosine_similarity(seq_embeddings, struct_embeddings)
    contrastive_accuracy = compute_contrastive_accuracy(cosine_sim)
    top_5_accuracy = compute_top_k_accuracy(cosine_sim, k=5)
    top_10_accuracy = compute_top_k_accuracy(cosine_sim, k=10)

    metrics = {
        "alignment": alignment,
        "uniformity": uniformity,
        "contrastive_accuracy": contrastive_accuracy,
        "top_5_accuracy": top_5_accuracy,
        "top_10_accuracy": top_10_accuracy
    }
    return metrics

In [15]:
# Calculate metrics
seq_embeddings, struct_embeddings = logits[0], logits[1]
seq_embeddings = F.normalize(seq_embeddings, dim=1)
struct_embeddings = F.normalize(struct_embeddings, dim=1)

alignment = compute_alignment(seq_embeddings, struct_embeddings)
combined_embeddings = torch.cat((seq_embeddings, struct_embeddings), dim=0)
uniformity = compute_uniformity(combined_embeddings)
cosine_sim = compute_cosine_similarity(seq_embeddings, struct_embeddings)
contrastive_accuracy = compute_contrastive_accuracy(cosine_sim)
top_1_accuracy = compute_top_k_accuracy(cosine_sim, k=1)

# Print the results
print("Contrastive Loss:", loss.item())
print("Alignment:", alignment)
print("Uniformity:", uniformity)
print("Contrastive Accuracy:", contrastive_accuracy)
print("Top-1 Accuracy:", top_1_accuracy)

Contrastive Loss: 0.6715006232261658
Alignment: 2.0915002822875977
Uniformity: -0.947462797164917
Contrastive Accuracy: 0.5
Top-1 Accuracy: 0.5


## Custom Trainer for Contrastive Learning

In [16]:
from transformers import Trainer

class ContrastiveTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)  # Forward pass to get logits
        logits = outputs['logits']

        # Access logit_scale from the underlying model
        logit_scale = model.module.logit_scale if hasattr(model, 'module') else model.logit_scale

        loss_fct = ContrastiveLoss()  # Instantiate the loss function
        loss = loss_fct(logits[0], logits[1], model.logit_scale)  # Compute the loss

        return (loss, outputs) if return_outputs else loss

In [17]:
import gc # Python's garbage collection module

def clear_memory():
    gc.collect() # explicitly triggers garbage collection, free up memory
    if torch.cuda.is_available(): torch.cuda.empty_cache() # clears the PyTorch CUDA memory cache

## Training

In [18]:
# Original CLIP parameters from **Learning Transferable Visual Models From Natural Language Supervision**
original_dataset_size = 400 * 10**6  # 400 million pairs
original_batch_size = 32768
original_epochs = 32
original_warmup_steps = 2000

# Calculate total training steps for the original setup
original_total_training_steps = (original_dataset_size * original_epochs) / original_batch_size

# Calculate the warmup ratio
original_warmup_ratio = original_warmup_steps / original_total_training_steps

print(f"Original total training steps: {original_total_training_steps}")
print(f"Original warmup ratio: {original_warmup_ratio}")

# Your training setup parameters
our_dataset_size = 1571  # Your dataset size
our_batch_size = 8
our_epochs = 5

# Calculate total training steps for your setup
our_total_training_steps = (our_dataset_size * our_epochs) / our_batch_size

# Calculate your warmup steps using the original warmup ratio
our_warmup_steps = int(original_warmup_ratio * our_total_training_steps)
our_warmup_steps = max(1, our_warmup_steps)  # Ensure at least 1 warmup step

print(f"Our total training steps: {our_total_training_steps}")
print(f"Our warmup steps: {our_warmup_steps}")

Original total training steps: 390625.0
Original warmup ratio: 0.00512
Our total training steps: 981.875
Our warmup steps: 5


In [19]:
from transformers import TrainingArguments

num_epochs = 5
batch_size = 8
logging_steps = len(train_ds) // batch_size
model_name = f"{model_ckpt}-Ab-CLIP-v0"

# Training arguments
training_args = TrainingArguments(
    output_dir=model_name,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=1e-4,
    weight_decay=0.25,
    adam_beta1=0.9,
    adam_beta2=0.98,
    adam_epsilon=1e-6,
    fp16=True, # Mixed-precision training
    lr_scheduler_type="cosine",
    warmup_steps=3, # for stability during the initial phase of training
    load_best_model_at_end=True,
    disable_tqdm=False,
    logging_steps=logging_steps,
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
    save_strategy="epoch",  # Save the model at the end of each epoch
    push_to_hub=True,
)



In [20]:
# To share your model with the community
# First store your authentication token from the Hugging Face website and then execute this cell
# Make sure to get token with WRITE access
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [21]:
# Initialize model with custom configuration
model_config = ESM2AbCLIPConfig(projection_dim=512,
                                model_ckpt='facebook/esm2_t6_8M_UR50D')
model = ESM2AbCLIP(model_config)

# Initialize DataCollator
data_collator = DataCollatorWithPadding(tokenizer)

trainer = ContrastiveTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
# Optionally set the max_split_size_mb to avoid fragmentation issues
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

In [23]:
clear_memory()
# Train the model
print(torch.cuda.memory_summary())
trainer.train()
print(torch.cuda.memory_summary())
trainer.push_to_hub(commit_message="Training completed!")

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  32320 KiB |  32320 KiB |  32320 KiB |      0 B   |
|       from large pool |  20482 KiB |  20482 KiB |  20482 KiB |      0 B   |
|       from small pool |  11838 KiB |  11838 KiB |  11838 KiB |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |  32320 KiB |  32320 KiB |  32320 KiB |      0 B   |
|       from large pool |  20482 KiB |  20482 KiB |  20482 KiB |      0 B   |
|       from small pool |  11838 KiB |  11838 KiB |  11838 KiB |      0 B   |
|---------------------------------------------------------------

Epoch,Training Loss,Validation Loss,Alignment,Uniformity,Contrastive Accuracy,Top 5 Accuracy,Top 10 Accuracy
1,1.0791,1.242567,1.203027,-2.381756,0.068027,0.306122,0.469388
2,0.4833,1.138185,1.207227,-2.768368,0.115646,0.414966,0.598639
3,0.2918,1.099299,1.164177,-2.939466,0.129252,0.482993,0.666667
4,0.1979,1.071894,1.152017,-3.004821,0.14966,0.44898,0.707483
5,0.158,1.071362,1.165201,-3.049541,0.142857,0.469388,0.714286


|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 82        |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 111396 KiB |  19857 MiB |  37939 GiB |  37939 GiB |
|       from large pool |  76217 KiB |  19806 MiB |  37842 GiB |  37842 GiB |
|       from small pool |  35179 KiB |     57 MiB |     96 GiB |     96 GiB |
|---------------------------------------------------------------------------|
| Active memory         | 111396 KiB |  19857 MiB |  37939 GiB |  37939 GiB |
|       from large pool |  76217 KiB |  19806 MiB |  37842 GiB |  37842 GiB |
|       from small pool |  35179 KiB |     57 MiB |     96 GiB |     96 GiB |
|---------------------------------------------------------------

CommitInfo(commit_url='https://huggingface.co/arjan-hada/esm2_t33_650M_UR50D-Ab-CLIP-v0/commit/943b855dedd6c3e11e1a07181e7757b1e6bc9068', commit_message='Training completed!', commit_description='', oid='943b855dedd6c3e11e1a07181e7757b1e6bc9068', pr_url=None, pr_revision=None, pr_num=None)

In [24]:
# Saves the best model
trainer.save_model("models/esm2_t6_8M_UR50D-Ab-CLIP-v0")

In [25]:
# Evaluate the model on the test set
test_result = trainer.evaluate(eval_dataset=test_ds)

In [26]:
# Print the results
print(f"Test Loss: {test_result['eval_loss']}")
for key, value in test_result.items():
    if key != 'eval_loss':
        print(f"{key}: {value}")

Test Loss: 0.9265738725662231
eval_alignment: 1.0818055868148804
eval_uniformity: -3.0300042629241943
eval_contrastive_accuracy: 0.18666666746139526
eval_top_5_accuracy: 0.5866666436195374
eval_top_10_accuracy: 0.753333330154419
eval_runtime: 1.4383
eval_samples_per_second: 104.29
eval_steps_per_second: 13.21
epoch: 5.0


### Performance Analysis

The provided training and evaluation results show the progression of our model's performance across epochs. Here’s a detailed analysis and some insights:



**Training Loss and Validation Loss**
- **Training Loss**: Decreases steadily from epoch 1 to epoch 5, indicating that the model is learning from the training data.
- **Validation Loss**: Decreases initially but stabilizes around epoch 4 and 5. This suggests that the model's ability to generalize to the validation data is not improving significantly after the initial epochs.

**Alignment and Uniformity**
- **Alignment**: Decreases slightly, which could indicate better matching between the sequence and structure embeddings.
- **Uniformity**: Becomes more negative, indicating improved uniformity in the embedding space. This means that embeddings are spread out more evenly, which is generally good for contrastive learning.

**Contrastive Accuracy and Top-K Accuracy**
- **Contrastive Accuracy**: Shows a gradual improvement but remains relatively low. This metric indicates how well the model is at distinguishing between different sequences and structures.
- **Top 5 and Top 10 Accuracy**: Show significant improvement over the epochs. Top-K accuracies are better indicators for retrieval tasks and show that the model is increasingly able to rank the correct pairs higher in the list.



## Notable Todo for compute_metrics



In SBERT, embeddings for sentence pairs are used to compute a similarity score, and the embeddings are fine-tuned to improve this similarity measure.

1. **Sentence Embeddings**:
   - SBERT fine-tunes BERT to derive semantically meaningful sentence embeddings.
   - For a given pair of sentences, SBERT computes their [CLS] embeddings.

2. **Concatenation and Difference**:
   - For a pair of sentences $A$ and $B$, the [CLS] embeddings are obtained as $z_A$ and $z_B$.
   - The features used for similarity computation are:
     $$
     {concat}(z_A, z_B, |z_A - z_B|)
     $$
   - This concatenated vector is then fed into a regression or classification layer.

3. **Regression Head**:
   - A simple feed-forward neural network (often with one or two layers) is used to predict the similarity score or to classify the relationship between the sentences.




