In [1]:
from datasets import load_dataset

In [2]:
dataset = load_dataset("CShorten/ML-ArXiv-Papers")

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0.1', 'Unnamed: 0', 'title', 'abstract'],
        num_rows: 117592
    })
})

In [4]:
split_datasets = dataset["train"].train_test_split(test_size=0.2)

In [5]:
split_datasets

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0.1', 'Unnamed: 0', 'title', 'abstract'],
        num_rows: 94073
    })
    test: Dataset({
        features: ['Unnamed: 0.1', 'Unnamed: 0', 'title', 'abstract'],
        num_rows: 23519
    })
})

In [6]:
dataset_train = split_datasets["train"]

In [7]:
dataset_train

Dataset({
    features: ['Unnamed: 0.1', 'Unnamed: 0', 'title', 'abstract'],
    num_rows: 94073
})

In [8]:
dataset_train = dataset_train.remove_columns(['Unnamed: 0','Unnamed: 0.1'])

In [9]:
dataset_train

Dataset({
    features: ['title', 'abstract'],
    num_rows: 94073
})

In [10]:
dataset_train[0]

{'title': 'Stable Long-Term Recurrent Video Super-Resolution',
 'abstract': '  Recurrent models have gained popularity in deep learning (DL) based video\nsuper-resolution (VSR), due to their increased computational efficiency,\ntemporal receptive field and temporal consistency compared to sliding-window\nbased models. However, when inferring on long video sequences presenting low\nmotion (i.e. in which some parts of the scene barely move), recurrent models\ndiverge through recurrent processing, generating high frequency artifacts. To\nthe best of our knowledge, no study about VSR pointed out this instability\nproblem, which can be critical for some real-world applications. Video\nsurveillance is a typical example where such artifacts would occur, as both the\ncamera and the scene stay static for a long time.\n  In this work, we expose instabilities of existing recurrent VSR networks on\nlong sequences with low motion. We demonstrate it on a new long sequence\ndataset Quasi-Static Video

In [11]:
dataset_train = dataset_train.map(lambda x, idx: { 'index': idx }, with_indices=True)

Map:   0%|          | 0/94073 [00:00<?, ? examples/s]

In [12]:
dataset_train

Dataset({
    features: ['title', 'abstract', 'index'],
    num_rows: 94073
})

In [13]:
dataset_train[4]

{'title': 'PAC Learning-Based Verification and Model Synthesis',
 'abstract': '  We introduce a novel technique for verification and model synthesis of\nsequential programs. Our technique is based on learning a regular model of the\nset of feasible paths in a program, and testing whether this model contains an\nincorrect behavior. Exact learning algorithms require checking equivalence\nbetween the model and the program, which is a difficult problem, in general\nundecidable. Our learning procedure is therefore based on the framework of\nprobably approximately correct (PAC) learning, which uses sampling instead and\nprovides correctness guarantees expressed using the terms error probability and\nconfidence. Besides the verification result, our procedure also outputs the\nmodel with the said correctness guarantees. Obtained preliminary experiments\nshow encouraging results, in some cases even outperforming mature software\nverifiers.\n',
 'index': 4}

In [14]:
dataset_train[20572]

{'title': 'A Compressive Classification Framework for High-Dimensional Data',
 'abstract': '  We propose a compressive classification framework for settings where the data\ndimensionality is significantly higher than the sample size. The proposed\nmethod, referred to as compressive regularized discriminant analysis (CRDA) is\nbased on linear discriminant analysis and has the ability to select significant\nfeatures by using joint-sparsity promoting hard thresholding in the\ndiscriminant rule. Since the number of features is larger than the sample size,\nthe method also uses state-of-the-art regularized sample covariance matrix\nestimators. Several analysis examples on real data sets, including image,\nspeech signal and gene expression data illustrate the promising improvements\noffered by the proposed CRDA classifier in practise. Overall, the proposed\nmethod gives fewer misclassification errors than its competitors, while at the\nsame time achieving accurate feature selection results. 

In [15]:
import numpy as np
all_indexes = np.array(dataset_train['index'])

In [16]:
import random

def generate_hard_negatives(example, num_negatives=10,dataset=dataset_train):
    query_index = example['index']
    anchor = example['title']
    positive = example['abstract']
    
    negatives = []
    
    negative_indexes = np.delete(all_indexes, np.where(all_indexes == query_index))
    
    sampled_negatives = random.sample(list(negative_indexes), num_negatives)
    
    for idx in sampled_negatives:
        negatives.append(dataset[int(idx)]['abstract'])
    
    return {
        "query": anchor,
        "positive": positive,
        "negatives": negatives
    }

In [17]:
processed_data_train = dataset_train.map(generate_hard_negatives, remove_columns=dataset_train.column_names)

Map:   0%|          | 0/94073 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [None]:
processed_data_train

In [None]:
processed_data_train[0]

In [None]:
contrastive_pairs_train = []
for item in processed_data_train:
    query = item["query"]
    positive = item["positive"]
    negatives = item["negatives"]
    contrastive_pairs_train.append({
        "anchor": query,
        "positive": positive,
        "negatives": negatives
    })

In [None]:
len(contrastive_pairs_train)

In [None]:
from torch.utils.data import DataLoader

In [None]:
class ContrastiveDataset:
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        item = self.pairs[idx]
        return item["anchor"], item["positive"], item["negatives"]

In [None]:
contrastive_dataset_train = ContrastiveDataset(contrastive_pairs_train)

In [None]:
data_loader_train = DataLoader(contrastive_dataset_train, batch_size=32, shuffle=True)

In [None]:
len(data_loader_train)

In [None]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
model = AutoModel.from_pretrained("bert-base-uncased")

In [None]:
from peft import LoraConfig, get_peft_model

In [None]:
lora_config = LoraConfig(
    task_type= "FEATURE_EXTRACTION"
)

In [None]:
lora_model = get_peft_model(model, lora_config)

In [None]:
def lorentzian_distance(x, y):
    
    dot_product = torch.sum(x * y, dim=-1)
    norm_x = torch.norm(x, dim=-1)
    norm_y = torch.norm(y, dim=-1)
    
    distance = torch.acosh(-dot_product + torch.sqrt((1 + norm_x**2) * (1 + norm_y**2)))
    return distance

In [None]:
def info_nce_loss(anchor_embedding, positive_embedding, negative_embedding, distance_fn):

    pos_dist = distance_fn(anchor_embedding, positive_embedding)
    neg_dist = torch.stack([distance_fn(anchor_embedding, neg) for neg in negative_embedding], dim=-1)
    
    logits = torch.cat([-pos_dist.unsqueeze(1), -neg_dist], dim=1)
    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

    loss = torch.nn.CrossEntropyLoss()(logits, labels)
    return loss

In [None]:
def exterior_angle(x_space, y_space, c):
    norm_x_space = torch.norm(x_space, p=2, dim=-1)
    norm_y_space = torch.norm(y_space, p=2, dim=-1)
    x_time = torch.sqrt(1/c + norm_x_space**2)
    y_time = torch.sqrt(1/c + norm_y_space**2)
    dot_product = torch.sum(x_space * y_space, dim=-1)
    lorentz_inner_product =  dot_product - x_time * y_time
    numerator = y_time + x_time * c * lorentz_inner_product
    denominator = norm_x_space * torch.sqrt((c * lorentz_inner_product)**2 - 1)
    ext_angle = torch.acos(numerator / denominator)
    return ext_angle

In [None]:
def entailment_loss(x, y, c=1, K=0.1):
    c = torch.tensor(c)
    K = torch.tensor(K)
    xspace = x
    yspace = y
    aperture = torch.asin(2 * K / (torch.sqrt(c) * torch.norm(xspace, p=2, dim=-1)))
    
    ext_angle = exterior_angle(xspace,yspace,c=c)
    
    loss = torch.max(torch.zeros_like(ext_angle), ext_angle - aperture)
    return loss.mean()

In [None]:
def expm_o(v, c=1.0):
    c = torch.tensor(c)
    vspace = v
    vnorm = torch.norm(v, p=2, dim=-1, keepdim=True)
    xspace = torch.sinh(torch.sqrt(c) * vnorm) * vspace / (torch.sqrt(c) * vnorm)
    batch_min = xspace.min(dim=1, keepdim=True).values
    batch_max = xspace.max(dim=1, keepdim=True).values
    xspace_scaled=(xspace - batch_min) / (batch_max - batch_min)
    return xspace_scaled

In [None]:
import torch
import torch.optim as optim
optimizer = torch.optim.AdamW(lora_model.parameters(), lr=5e-5)

In [None]:
num_epochs=3

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

In [None]:
torch.cuda.is_available()

In [None]:
lora_model = lora_model.to(device)

In [None]:
dataset_val = split_datasets["test"]

In [None]:
dataset_val

In [None]:
dataset_val = dataset_val.remove_columns(['Unnamed: 0','Unnamed: 0.1'])

In [None]:
dataset_val = dataset_val.map(lambda x, idx: { 'index': idx }, with_indices=True)

In [None]:
dataset_val

In [None]:
all_indexes_val = np.array(dataset_val['index'])

In [None]:
def generate_hard_negatives_val(example, num_negatives=10,dataset=dataset_val):
    query_index = example['index'] 
    anchor = example['title']
    positive = example['abstract']
    
    negatives = []
    
    negative_indexes = np.delete(all_indexes_val, np.where(all_indexes_val == query_index)) # Remove the query paper itself
    
    sampled_negatives = random.sample(list(negative_indexes), num_negatives)
    
    for idx in sampled_negatives:
        negatives.append(dataset[int(idx)]['abstract'])
    
    return {
        "query": anchor,
        "positive": positive,
        "negatives": negatives
    }

In [None]:
processed_data_val = dataset_val.map(generate_hard_negatives_val, remove_columns=dataset_val.column_names)

In [None]:
contrastive_pairs_val = []
for item in processed_data_val:
    query = item["query"]
    positive = item["positive"]
    negatives = item["negatives"]
    contrastive_pairs_val.append({
        "anchor": query,
        "positive": positive,
        "negatives": negatives
    })

In [None]:
contrastive_dataset_val = ContrastiveDataset(contrastive_pairs_val)

In [None]:
data_loader_val = DataLoader(contrastive_dataset_val, batch_size=32, shuffle=True)

In [None]:
len(data_loader_val)

In [None]:
def evaluate_mrr(model1, data_loader_val, distance_fn):
    model1.eval()
    
    total_rr = 0.0
    num_queries = 0

    with torch.no_grad():
        for batch in data_loader_val:
            anchor_text = batch[0]
            positive_text = batch[1]
            negative_texts = batch[2]

            anchor_input = tokenizer(anchor_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
            positive_input = tokenizer(positive_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

            anchor_embedding = expm_o(model1(**anchor_input).last_hidden_state[:, 0, :])
            positive_embedding = expm_o(model1(**positive_input).last_hidden_state[:, 0, :])
            negative_embedding = [expm_o(model1(**tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)).last_hidden_state[:, 0, :]) for neg in negative_texts]

            pos_dist = distance_fn(anchor_embedding, positive_embedding)
            neg_dist = torch.stack([distance_fn(anchor_embedding, neg) for neg in negative_embedding], dim=-1)
            all_similarities=torch.cat([-pos_dist.unsqueeze(1), -neg_dist], dim=1)

            sorted_similarities, sorted_indices = torch.sort(all_similarities, dim=1, descending=True)

            # Find the rank of the first relevant (positive) document
            positive_rank = (sorted_indices == 0).nonzero(as_tuple=True)[1] + 1  # +1 to make rank 1-based
            total_rr += torch.sum(1.0 / positive_rank.float()).item()  # Reciprocal rank
            num_queries += len(positive_rank)
            
    mrr = total_rr / num_queries
    return mrr

In [None]:
import os
save_dir ="/dss/dsshome1/07/ra65bex2/srawat/0.1hyperbolic"

In [None]:
import time
epoch_metrics = []

In [None]:
torch.cuda.empty_cache()

In [None]:
for epoch in range(num_epochs):
    start_time = time.time()
    lora_model.train()

    total_loss = 0.0
    entailment_loss_total=0.0
    contrastive_loss_total=0.0
    for batch in data_loader_train:

        anchor_texts = batch[0]
        positive_texts = batch[1]
        negative_texts = batch[2]

        anchor_inputs = tokenizer(anchor_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
        positive_inputs = tokenizer(positive_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
    
        anchor_embedding = expm_o(lora_model(**anchor_inputs).last_hidden_state[:, 0, :])
        positive_embedding = expm_o(lora_model(**positive_inputs).last_hidden_state[:, 0, :])
        negative_embedding = [expm_o(lora_model(**tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)).last_hidden_state[:, 0, :]) for neg in negative_texts]

        contrastive_loss_value = info_nce_loss(anchor_embedding, positive_embedding, negative_embedding, distance_fn=lorentzian_distance)
        
        entailment_loss_value = entailment_loss(anchor_embedding, positive_embedding)
        
        loss = contrastive_loss_value + 0.1*entailment_loss_value
  
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        entailment_loss_total+=entailment_loss_value.item()
        contrastive_loss_total+=contrastive_loss_value.item()
    save_path1 = os.path.join(save_dir, f"hyperbolic_lora_checkpoint_epoch_{epoch+1}.pth")
    torch.save(lora_model, save_path1)
    print(f"EPOCH {epoch+1}:")
    print(f"Checkpoint saved: {save_dir}")
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(data_loader_train)}")
    print(f"Epoch {epoch+1}/{num_epochs}, Contrastive Loss: {contrastive_loss_total / len(data_loader_train)}")
    print(f"Epoch {epoch+1}/{num_epochs}, Entailment Loss: {entailment_loss_total / len(data_loader_train)}")
    mrr_validation = evaluate_mrr(model1=lora_model, data_loader_val=data_loader_val,distance_fn=lorentzian_distance)
    #mrr_train = evaluate_mrr(lora_model, data_loader_train, lorentzian_distance)
    #print(f"Mean Reciprocal Rank (MRR) for training set: {mrr_train:.4f}")
    print(f"Mean Reciprocal Rank (MRR) for validation set: {mrr_validation:.4f}")
    end_time = time.time()
    print(f"Epoch {epoch+1} took {(end_time - start_time) / 60:.4f} minutes.")
    print(f"\n")
    epoch_metrics.append({
        'epoch': epoch + 1,
        'training_loss': total_loss / len(data_loader_train),
        'Contrastive_loss': contrastive_loss_total / len(data_loader_train),
        'Entailment_loss': entailment_loss_total / len(data_loader_train),
        'mrr_validation': mrr_validation,
        'time_taken_minutes': (end_time - start_time) / 60
    })

In [None]:
import json

In [None]:
with open(save_dir + '/epoch_metrics.json', 'w') as f:
    json.dump(epoch_metrics, f)