In [73]:
from datasets import load_dataset

In [74]:
dataset = load_dataset('ms_marco', 'v1.1')

In [75]:
dataset

DatasetDict({
    validation: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 10047
    })
    train: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 82326
    })
    test: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 9650
    })
})

In [76]:
dataset_train = load_dataset('ms_marco', 'v1.1', split='train')

In [77]:
dataset_train

Dataset({
    features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
    num_rows: 82326
})

In [78]:
type(dataset_train)

datasets.arrow_dataset.Dataset

In [79]:
def preprocess(example):
    positive_passages=[]
    for p in enumerate(example['passages']["is_selected"]):
        if p[1]==1:
            positive_passages.append(example["passages"]["passage_text"][p[0]])
    negative_passages = []
    for p in enumerate(example['passages']["is_selected"]):
        if p[1]==0:
            negative_passages.append(example["passages"]["passage_text"][p[0]])
    if (len(positive_passages)>0 and len(negative_passages)>=5):
        positive = positive_passages[0]
        negatives = negative_passages[:5]
        return {
            "query": example["query"],
            "positive": positive,
            "negatives": negatives
        }
    else:
        return {"query": None, "positive": None, "negatives": None}

In [80]:
processed_data_train = dataset_train.map(preprocess, remove_columns=dataset_train.column_names)

In [81]:
processed_data_train = processed_data_train.filter(lambda x: x['query'] is not None and x['positive'] is not None)

In [82]:
contrastive_pairs_train = []
for item in processed_data_train:
    query = item["query"]
    positive = item["positive"]
    negatives = item["negatives"]
    contrastive_pairs_train.append({
        "anchor": query,
        "positive": positive,
        "negatives": negatives
    })

In [83]:
len(contrastive_pairs_train)

74538

In [84]:
from torch.utils.data import DataLoader

In [85]:
class ContrastiveDataset:
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        item = self.pairs[idx]
        return item["anchor"], item["positive"], item["negatives"]

In [86]:
contrastive_dataset_train = ContrastiveDataset(contrastive_pairs_train[0:500])

In [87]:
data_loader_train = DataLoader(contrastive_dataset_train, batch_size=16, shuffle=True)

In [88]:
len(data_loader_train)

32

In [89]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [90]:
model = AutoModel.from_pretrained("bert-base-uncased")

In [91]:
from peft import LoraConfig, get_peft_model

In [92]:
lora_config = LoraConfig(
    task_type= "FEATURE_EXTRACTION"  # Sequence-level task (e.g., contrastive learning)
)

In [93]:
lora_model = get_peft_model(model, lora_config)

In [94]:
lora_model.print_trainable_parameters()

trainable params: 294,912 || all params: 109,777,152 || trainable%: 0.2686


In [95]:
import torch
import torch.nn as nn
import torch.optim as optim

In [96]:
class HyperbolicEmbeddingLayer(nn.Module):
    def __init__(self, embed_dim):
        super(HyperbolicEmbeddingLayer, self).__init__()
        self.embed_dim = embed_dim
        self.fc = nn.Linear(768, embed_dim)  # Assuming 768-D input from BERT

    def forward(self, x):
        embedding = self.fc(x)
        return embedding / torch.norm(embedding, dim=-1, keepdim=True)

In [97]:
def lorentzian_distance(x, y):
    
    # Compute the dot product and the norms of the vectors
    dot_product = torch.sum(x * y, dim=-1)
    norm_x = torch.norm(x, dim=-1)
    norm_y = torch.norm(y, dim=-1)
    
    # Calculate the Lorentzian distance
    distance = torch.acosh(-dot_product + torch.sqrt((1 + norm_x**2) * (1 + norm_y**2)))
    return distance

In [98]:
def exterior_angle(x_space, y_space, c):
    norm_x_space = torch.norm(x_space, p=2, dim=-1)
    norm_y_space = torch.norm(y_space, p=2, dim=-1)
    x_time = torch.sqrt(1/c + norm_x_space**2)
    y_time = torch.sqrt(1/c + norm_y_space**2)
    dot_product = torch.sum(x_space * y_space, dim=-1)
    lorentz_inner_product =  dot_product - x_time * y_time
    numerator = y_time + x_time * c * lorentz_inner_product
    denominator = norm_x_space * torch.sqrt((c * lorentz_inner_product)**2 - 1)
    ext_angle = torch.acos(numerator / denominator)
    return ext_angle

In [99]:
def entailment_loss(x, y, c=1, K=0.1):
    # Compute half-aperture
    aperture = torch.asin(2 * K / (c * torch.norm(x, p=2, dim=-1)))  # Half-aperture formula
    
    # Compute exterior angle using the provided formula
    # Simplified for demonstration; the actual formula needs to handle Poincaré embeddings
    ext_angle = exterior_angle(x_space=x,y_space=y,c=c)
    
    # Compute the loss
    loss = torch.max(torch.zeros_like(ext_angle), ext_angle - aperture)
    return loss.mean()

In [100]:
def compute_laplacian(similarity_matrix):
    similarity_matrix = similarity_matrix.float()

    # Compute the degree matrix (diagonal matrix where each element is the sum of the row in the similarity matrix)
    degree_matrix = torch.diag(torch.sum(similarity_matrix, dim=1))

    # Compute the Laplacian matrix
    laplacian_matrix = degree_matrix - similarity_matrix

    return laplacian_matrix

In [101]:
def harmonic_distance(laplacian_matrix):
    laplacian_pseudo_inv = torch.pinverse(laplacian_matrix)
    harmonic_distances = []
    for anchor_idx in range(0, len(laplacian_matrix), 7):  # Step of 7 for each query
        # Extract the anchor, positive, and negative nodes for this query
        anchor_node=torch.zeros(len(laplacian_matrix), dtype=torch.float32, device=device)
        anchor_node[anchor_idx] = 1

        
        # Compute the harmonic distance for this particular anchor
        distances = []
        for i in range(7-1):
            node=torch.zeros(len(laplacian_matrix), dtype=torch.float32, device=device)
            node[anchor_idx+(i+1)] = 1
            diff = anchor_node - node
            dist = torch.matmul(torch.matmul(diff.T, laplacian_pseudo_inv), diff)
            distances.append(dist)
        
        harmonic_distances.append(torch.stack(distances))
    return(harmonic_distances)

In [102]:
def harmonic_loss(distances):
    logits=-torch.stack(distances, dim=0)
    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)  # Positive as class 0
    loss = torch.nn.CrossEntropyLoss()(logits, labels)
    return loss

In [103]:
optimizer = torch.optim.AdamW(lora_model.parameters(), lr=5e-5)

In [104]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [105]:
print(lora_model)

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): lora.Linear(
                  (base_layer): Linear(in_features=768, out_features=768, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Identity()
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=768, out_features=8, bias=False)
                  )
                  (lora_B): Modu

In [106]:
embed_dim=128

In [107]:
hyperbolic_model= HyperbolicEmbeddingLayer(embed_dim)

In [108]:
import torch.nn.functional as F

In [109]:
dataset_val = load_dataset('ms_marco', 'v1.1', split='validation')

In [110]:
processed_data_val = dataset_val.map(preprocess, remove_columns=dataset_val.column_names)

In [111]:
processed_data_val = processed_data_val.filter(lambda x: x['query'] is not None and x['positive'] is not None)

In [112]:
contrastive_pairs_val = []
for item in processed_data_val:
    query = item["query"]
    positive = item["positive"]
    negatives = item["negatives"]
    contrastive_pairs_val.append({
        "anchor": query,
        "positive": positive,
        "negatives": negatives
    })

In [113]:
contrastive_dataset_val = ContrastiveDataset(contrastive_pairs_val[0:500])

In [114]:
data_loader_val = DataLoader(contrastive_dataset_val, batch_size=16, shuffle=True)

In [115]:
len(data_loader_val)

32

In [116]:
lora_model = lora_model.to(device)
hyperbolic_model = hyperbolic_model.to(device)

In [117]:
def evaluate_mrr(model1, model2, data_loader_val):
    model1.eval()  # Set the model to evaluation mode
    model2.eval()
    
    total_rr = 0.0
    num_queries = 0

    with torch.no_grad():  # Disable gradient calculations
        for batch in data_loader_val:
            anchor_text = batch[0]
            positive_text = batch[1]
            negative_texts = batch[2]

            # Tokenize inputs
            anchor_input = tokenizer(anchor_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
            positive_input = tokenizer(positive_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

            # Get embeddings for anchor (query), positive, and negative texts
            anchor_embedding = model2(model1(**anchor_input).last_hidden_state[:, 0, :])
            positive_embedding = model2(model1(**positive_input).last_hidden_state[:, 0, :])
            negative_embedding = [model2(model1(**tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)).last_hidden_state[:, 0, :]) for neg in negative_texts]

            batch_embeddings = []
        # Loop through the batch and extract embeddings for each sample
            for i in range(len(anchor_embedding)):  # Loop through the batch size
                sample_embeddings = torch.cat(
                    [anchor_embedding[i:i+1],  # Query for the current sample
                    positive_embedding[i:i+1],  # Positive for the current sample
                    torch.stack([neg[i] for neg in negative_embedding])],  # Negatives for the current sample
                    dim=0
                )
                batch_embeddings.append(sample_embeddings)
            
            all_embeddings = torch.cat(batch_embeddings, dim=0)
            sim=F.cosine_similarity(all_embeddings.unsqueeze(1), all_embeddings.unsqueeze(0), dim=2)
            laplacian = compute_laplacian(sim)
            distances=harmonic_distance(laplacian)
            all_similarities=-torch.stack(distances, dim=0)

            sorted_similarities, sorted_indices = torch.sort(all_similarities, dim=1, descending=True)

            # Find the rank of the first relevant (positive) document
            positive_rank = (sorted_indices == 0).nonzero(as_tuple=True)[1] + 1  # +1 to make rank 1-based
            total_rr += torch.sum(1.0 / positive_rank.float()).item()  # Reciprocal rank
            num_queries += len(positive_rank)

    # Compute the average MRR across all queries
    mrr = total_rr / num_queries
    return mrr

In [118]:
import os
save_dir ="/dss/dsshome1/07/ra65bex2/srawat"
import time
epoch_metrics = []
num_epochs=3

In [119]:
for epoch in range(num_epochs):
    start_time = time.time()
    lora_model.train()  # Set the model to training mode
    hyperbolic_model.train()

    total_loss = 0.0
    entailment_loss_total=0.0
    contrastive_loss_total=0.0
    for batch in data_loader_train:
        # Extract the anchor, positive, and negative pairs from the batch
        anchor_texts = batch[0]
        positive_texts = batch[1]
        negative_texts = batch[2]
        # Tokenize the text pairs and move them to the GPU
        anchor_inputs = tokenizer(anchor_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
        positive_inputs = tokenizer(positive_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

        # Get embeddings (CLS token representation) from BERT for each input
    
        anchor_embedding = hyperbolic_model(lora_model(**anchor_inputs).last_hidden_state[:, 0, :])
        positive_embedding = hyperbolic_model(lora_model(**positive_inputs).last_hidden_state[:, 0, :])
        negative_embedding = [hyperbolic_model(lora_model(**tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)).last_hidden_state[:, 0, :]) for neg in negative_texts]

        batch_embeddings = []
        # Loop through the batch and extract embeddings for each sample
        for i in range(len(anchor_embedding)):  # Loop through the batch size
            sample_embeddings = torch.cat(
                [anchor_embedding[i:i+1],  # Query for the current sample
                positive_embedding[i:i+1],  # Positive for the current sample
                torch.stack([neg[i] for neg in negative_embedding])],  # Negatives for the current sample
                dim=0
            )
            batch_embeddings.append(sample_embeddings)
        all_embeddings = torch.cat(batch_embeddings, dim=0)
        sim=F.cosine_similarity(all_embeddings.unsqueeze(1), all_embeddings.unsqueeze(0), dim=2)
        laplacian = compute_laplacian(sim)
        distances=harmonic_distance(laplacian)
        
        # Compute the InfoNCE loss
        contrastive_loss_value = harmonic_loss(distances)
        
        # Compute the entailment loss
        entailment_loss_value = entailment_loss(anchor_embedding, positive_embedding)
        
        loss = contrastive_loss_value + 0.1*entailment_loss_value

        # Backpropagation  
        optimizer.zero_grad()  # Clear previous gradients
        loss.backward()  # Compute gradients
        optimizer.step()  # Update model parameters
        
        total_loss += loss.item()
        entailment_loss_total+=entailment_loss_value.item()
        contrastive_loss_total+=contrastive_loss_value.item()
    print(f"EPOCH {epoch+1}:")
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(data_loader_train)}")
    print(f"Epoch {epoch+1}/{num_epochs}, Contrastive Loss: {contrastive_loss_total / len(data_loader_train)}")
    print(f"Epoch {epoch+1}/{num_epochs}, Entailment Loss: {entailment_loss_total / len(data_loader_train)}")
    mrr_validation = evaluate_mrr(model1=lora_model, model2=hyperbolic_model, data_loader_val=data_loader_val)
    #mrr_train = evaluate_mrr(lora_model, data_loader_train, lorentzian_distance)
    #print(f"Mean Reciprocal Rank (MRR) for training set: {mrr_train:.4f}")
    print(f"Mean Reciprocal Rank (MRR) for validation set: {mrr_validation:.4f}")
    end_time = time.time()
    print(f"Epoch {epoch+1} took {(end_time - start_time) / 60:.4f} minutes.")
    print(f"\n")
    epoch_metrics.append({
    'epoch': epoch + 1,
    'training_loss': total_loss / len(data_loader_train),
    'mrr_validation': mrr_validation,
    'time_taken_minutes': (end_time - start_time) / 60
    })

EPOCH 1:
Epoch 1/3, Loss: 1.99159587174654
Epoch 1/3, Contrastive Loss: 1.7920981496572495
Epoch 1/3, Entailment Loss: 1.9949772544205189
Mean Reciprocal Rank (MRR) for validation set: 0.4051
Epoch 1 took 1.2040 minutes.




NameError: name 'data_loader' is not defined

In [None]:
import json
with open(save_dir + '/combined_epoch_metrics.json', 'w') as f:
    json.dump(epoch_metrics, f)