In [1]:
from datasets import load_dataset
import random
import json

In [2]:
load_dataset("dbpedia_14")

DatasetDict({
    train: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 560000
    })
    test: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 70000
    })
})

In [3]:
dataset_train = load_dataset("dbpedia_14", split="train")
dataset_train

Dataset({
    features: ['label', 'title', 'content'],
    num_rows: 560000
})

In [4]:
dataset_train[0]

{'label': 0,
 'title': 'E. D. Abbott Ltd',
 'content': ' Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.'}

In [5]:
LABEL_TO_CATEGORY = {
    0: "Company",
    1: "Educational Institution",
    2: "Artist",
    3: "Athlete",
    4: "Office Holder",
    5: "Mean Of Transportation",
    6: "Building",
    7: "Natural Place",
    8: "Village",
    9: "Animal",
    10: "Plant",
    11: "Album",
    12: "Film",
    13: "Written Work"
}

In [6]:
category_to_samples = {}
for sample in dataset_train:
    category = LABEL_TO_CATEGORY[sample["label"]]
    if category not in category_to_samples:
        category_to_samples[category] = []
    category_to_samples[category].append(sample["content"])

In [7]:
category_negatives = {
    category: [desc for cat, descriptions in category_to_samples.items() if cat != category for desc in descriptions]
    for category in LABEL_TO_CATEGORY.values()
}

In [8]:
def preprocess(example, num_negatives=5):
    category = LABEL_TO_CATEGORY.get(example["label"], None)
    
    if category is None or category not in category_negatives:
        return None  

    query = f"Tell me about {category.lower()}."
    positive = example["content"]

    negatives = random.choices(category_negatives[category], k=num_negatives)

    return {
        "query": query,
        "positive": positive,
        "negatives": negatives
    }

In [9]:
processed_data_train = dataset_train.map(preprocess, remove_columns=dataset_train.column_names)

In [10]:
processed_data_train = processed_data_train.filter(lambda x: x['query'] is not None and x['positive'] is not None)

In [11]:
processed_data_train

Dataset({
    features: ['query', 'positive', 'negatives'],
    num_rows: 560000
})

In [12]:
processed_data_train[0]

{'query': 'Tell me about company.',
 'positive': ' Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.',
 'negatives': [' The Ford Model 15-P flying wing was the last aircraft developed by the Stout Metal Airplane Division of the Ford Motor Company. After several flights resulting in a crash the program was halted. Ford eventually re-entered the aviation market producing Consolidated B-24 Liberators under license from Consolidated Aircraft.',
  ' Nellie Stockbridge (ca. 1868 – May 22 1965) was an early Idaho frontier mining district photographer. Her career spanned over 60 years. She was the oldest living member of the Zonta International club for advancement of women when she died in 1965.',
  ' Shurikeh (Persian: شوريكه\u200e also Romanized as Shūrīḵeh) is a village in Darbqazi Rural Distric

In [13]:
contrastive_pairs_train = []
for item in processed_data_train:
    query = item["query"]
    positive = item["positive"]
    negatives = item["negatives"]
    contrastive_pairs_train.append({
        "anchor": query,
        "positive": positive,
        "negatives": negatives
    })

In [14]:
contrastive_pairs_train[500000]

{'anchor': 'Tell me about film.',
 'positive': ' Rojo Amanecer (Red Dawn) is a 1989 Silver Ariel Award-winning Mexican film directed by Jorge Fons. It is a film about the Tlatelolco Massacre in the section of Tlatelolco in Mexico City in the evening of October 2 1968. It focuses on the day of a middle-class Mexican family living in one of the apartment buildings surrounding the Plaza de Tlatelolco and is based on testimonials from witnesses and victims. It stars Héctor Bonilla María Rojo the Bichir Brothers Eduardo Palomo and others.',
 'negatives': [' Karschiola is a genus of moths in the family Arctiidae. It contains the single species Karschiola holoclera which is found in Malawi Tanzania and Zimbabwe.',
  ' Billie Fulford (21 August 1914 – 28 May 1987) was a New Zealand cricketer. She played in one Test match in 1948.',
  ' Discovery or Discoverie was a small 20-ton 38 foot (12 m) long fly-boat of the British East India Company launched before 1602. The ship was one of three that p

In [15]:
len(contrastive_pairs_train)

560000

In [16]:
from torch.utils.data import DataLoader

In [17]:
class ContrastiveDataset:
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        item = self.pairs[idx]
        return item["anchor"], item["positive"], item["negatives"]

In [18]:
contrastive_dataset_train = ContrastiveDataset(contrastive_pairs_train)

In [19]:
data_loader_train = DataLoader(contrastive_dataset_train, batch_size=32, shuffle=True)

In [20]:
len(data_loader_train)

17500

In [21]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [22]:
model = AutoModel.from_pretrained("bert-base-uncased")

In [23]:
from peft import LoraConfig, get_peft_model

In [24]:
lora_config = LoraConfig(
    task_type= "FEATURE_EXTRACTION"
)

In [25]:
lora_model = get_peft_model(model, lora_config)

In [26]:
import torch

In [27]:
def lorentzian_distance(x, y):
    
    dot_product = torch.sum(x * y, dim=-1)
    norm_x = torch.norm(x, dim=-1)
    norm_y = torch.norm(y, dim=-1)
    
    distance = torch.acosh(-dot_product + torch.sqrt((1 + norm_x**2) * (1 + norm_y**2)))
    return distance

In [28]:
def info_nce_loss(anchor_embedding, positive_embedding, negative_embedding, distance_fn):

    pos_dist = distance_fn(anchor_embedding, positive_embedding)
    neg_dist = torch.stack([distance_fn(anchor_embedding, neg) for neg in negative_embedding], dim=-1)
    
    logits = torch.cat([-pos_dist.unsqueeze(1), -neg_dist], dim=1)
    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

    loss = torch.nn.CrossEntropyLoss()(logits, labels)
    return loss

In [29]:
def exterior_angle(x_space, y_space, c):
    norm_x_space = torch.norm(x_space, p=2, dim=-1)
    norm_y_space = torch.norm(y_space, p=2, dim=-1)
    x_time = torch.sqrt(1/c + norm_x_space**2)
    y_time = torch.sqrt(1/c + norm_y_space**2)
    dot_product = torch.sum(x_space * y_space, dim=-1)
    lorentz_inner_product =  dot_product - x_time * y_time
    numerator = y_time + x_time * c * lorentz_inner_product
    denominator = norm_x_space * torch.sqrt((c * lorentz_inner_product)**2 - 1)
    ext_angle = torch.acos(numerator / denominator)
    return ext_angle

In [30]:
def entailment_loss(x, y, c=1, K=0.1):
    c = torch.tensor(c)
    K = torch.tensor(K)
    xspace = x
    yspace = y
    aperture = torch.asin(2 * K / (torch.sqrt(c) * torch.norm(xspace, p=2, dim=-1)))
    
    ext_angle = exterior_angle(xspace,yspace,c=c)
    
    loss = torch.max(torch.zeros_like(ext_angle), ext_angle - aperture)
    return loss.mean()

In [31]:
def expm_o(v, c=1.0):
    c = torch.tensor(c)
    vspace = v
    vnorm = torch.norm(v, p=2, dim=-1, keepdim=True)
    xspace = torch.sinh(torch.sqrt(c) * vnorm) * vspace / (torch.sqrt(c) * vnorm)
    batch_min = xspace.min(dim=1, keepdim=True).values
    batch_max = xspace.max(dim=1, keepdim=True).values
    xspace_scaled=(xspace - batch_min) / (batch_max - batch_min)
    return xspace_scaled

In [32]:
optimizer = torch.optim.AdamW(lora_model.parameters(), lr=5e-5)

In [33]:
num_epochs=3

In [34]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [35]:
device

device(type='cuda')

In [36]:
lora_model = lora_model.to(device)

In [37]:
dataset_val = load_dataset("dbpedia_14", split="test")
dataset_val

Dataset({
    features: ['label', 'title', 'content'],
    num_rows: 70000
})

In [38]:
category_to_samples_val = {}
for sample in dataset_val:
    category = LABEL_TO_CATEGORY[sample["label"]]
    if category not in category_to_samples_val:
        category_to_samples_val[category] = []
    category_to_samples_val[category].append(sample["content"])

In [39]:
category_negatives_val = {
    category: [desc for cat, descriptions in category_to_samples_val.items() if cat != category for desc in descriptions]
    for category in LABEL_TO_CATEGORY.values()
}

In [40]:
def preprocess_val(example, num_negatives=5):
    category = LABEL_TO_CATEGORY.get(example["label"], None)
    
    if category is None or category not in category_negatives_val:
        return None

    query = f"Tell me about {category.lower()}."
    positive = example["content"]

    negatives = random.choices(category_negatives_val[category], k=num_negatives)

    return {
        "query": query,
        "positive": positive,
        "negatives": negatives
    }

In [41]:
processed_data_val = dataset_val.map(preprocess_val, remove_columns=dataset_val.column_names)

In [42]:
processed_data_val = processed_data_val.filter(lambda x: x['query'] is not None and x['positive'] is not None)

In [43]:
contrastive_pairs_val = []
for item in processed_data_val:
    query = item["query"]
    positive = item["positive"]
    negatives = item["negatives"]
    contrastive_pairs_val.append({
        "anchor": query,
        "positive": positive,
        "negatives": negatives
    })

In [44]:
contrastive_dataset_val = ContrastiveDataset(contrastive_pairs_val)

In [45]:
data_loader_val = DataLoader(contrastive_dataset_val, batch_size=32, shuffle=True)

In [46]:
len(data_loader_val)

2188

In [47]:
def evaluate_mrr(model1, data_loader_val, distance_fn):
    model1.eval()
    
    total_rr = 0.0
    num_queries = 0

    with torch.no_grad():
        for batch in data_loader_val:
            anchor_text = batch[0]
            positive_text = batch[1]
            negative_texts = batch[2]

            anchor_input = tokenizer(anchor_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
            positive_input = tokenizer(positive_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

            anchor_embedding = expm_o(model1(**anchor_input).last_hidden_state[:, 0, :])
            positive_embedding = expm_o(model1(**positive_input).last_hidden_state[:, 0, :])
            negative_embedding = [expm_o(model1(**tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)).last_hidden_state[:, 0, :]) for neg in negative_texts]

            pos_dist = distance_fn(anchor_embedding, positive_embedding)
            neg_dist = torch.stack([distance_fn(anchor_embedding, neg) for neg in negative_embedding], dim=-1)
            all_similarities=torch.cat([-pos_dist.unsqueeze(1), -neg_dist], dim=1)

            sorted_similarities, sorted_indices = torch.sort(all_similarities, dim=1, descending=True)

            # Find the rank of the first relevant (positive) document
            positive_rank = (sorted_indices == 0).nonzero(as_tuple=True)[1] + 1  # +1 to make rank 1-based
            total_rr += torch.sum(1.0 / positive_rank.float()).item()  # Reciprocal rank
            num_queries += len(positive_rank)
            
    mrr = total_rr / num_queries
    return mrr

In [48]:
import os
save_dir ="/dss/dsshome1/07/ra65bex2/srawat/wiki/0.1hyperbolic"

In [49]:
import time
epoch_metrics = []

In [50]:
for epoch in range(num_epochs):
    start_time = time.time()
    lora_model.train()

    total_loss = 0.0
    entailment_loss_total=0.0
    contrastive_loss_total=0.0
    for batch in data_loader_train:

        anchor_texts = batch[0]
        positive_texts = batch[1]
        negative_texts = batch[2]

        anchor_inputs = tokenizer(anchor_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
        positive_inputs = tokenizer(positive_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
    
        anchor_embedding = expm_o(lora_model(**anchor_inputs).last_hidden_state[:, 0, :])
        positive_embedding = expm_o(lora_model(**positive_inputs).last_hidden_state[:, 0, :])
        negative_embedding = [expm_o(lora_model(**tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)).last_hidden_state[:, 0, :]) for neg in negative_texts]

        contrastive_loss_value = info_nce_loss(anchor_embedding, positive_embedding, negative_embedding, distance_fn=lorentzian_distance)
        
        entailment_loss_value = entailment_loss(anchor_embedding, positive_embedding)
        
        loss = contrastive_loss_value + 0.1*entailment_loss_value
  
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        entailment_loss_total+=entailment_loss_value.item()
        contrastive_loss_total+=contrastive_loss_value.item()
    save_path1 = os.path.join(save_dir, f"hyperbolic_lora_checkpoint_epoch_{epoch+1}.pth")
    torch.save(lora_model, save_path1)
    print(f"EPOCH {epoch+1}:")
    print(f"Checkpoint saved: {save_dir}")
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(data_loader_train)}")
    print(f"Epoch {epoch+1}/{num_epochs}, Contrastive Loss: {contrastive_loss_total / len(data_loader_train)}")
    print(f"Epoch {epoch+1}/{num_epochs}, Entailment Loss: {entailment_loss_total / len(data_loader_train)}")
    mrr_validation = evaluate_mrr(model1=lora_model, data_loader_val=data_loader_val,distance_fn=lorentzian_distance)
    #mrr_train = evaluate_mrr(lora_model, data_loader_train, lorentzian_distance)
    #print(f"Mean Reciprocal Rank (MRR) for training set: {mrr_train:.4f}")
    print(f"Mean Reciprocal Rank (MRR) for validation set: {mrr_validation:.4f}")
    end_time = time.time()
    print(f"Epoch {epoch+1} took {(end_time - start_time) / 60:.4f} minutes.")
    print(f"\n")
    epoch_metrics.append({
        'epoch': epoch + 1,
        'training_loss': total_loss / len(data_loader_train),
        'Contrastive_loss': contrastive_loss_total / len(data_loader_train),
        'Entailment_loss': entailment_loss_total / len(data_loader_train),
        'mrr_validation': mrr_validation,
        'time_taken_minutes': (end_time - start_time) / 60
    })

EPOCH 1:
Checkpoint saved: /dss/dsshome1/07/ra65bex2/srawat/wiki/0.1hyperbolic
Epoch 1/3, Loss: 0.6994711348703929
Epoch 1/3, Contrastive Loss: 0.47783132351636887
Epoch 1/3, Entailment Loss: 2.2163980818339755


Mean Reciprocal Rank (MRR) for validation set: 0.9977
Epoch 1 took 183.2153 minutes.




EPOCH 2:
Checkpoint saved: /dss/dsshome1/07/ra65bex2/srawat/wiki/0.1hyperbolic
Epoch 2/3, Loss: 0.5216460899574415
Epoch 2/3, Contrastive Loss: 0.2997675705305168
Epoch 2/3, Entailment Loss: 2.2187851621627805


Mean Reciprocal Rank (MRR) for validation set: 0.9978
Epoch 2 took 182.8705 minutes.




EPOCH 3:
Checkpoint saved: /dss/dsshome1/07/ra65bex2/srawat/wiki/0.1hyperbolic
Epoch 3/3, Loss: 0.49189796753440584
Epoch 3/3, Contrastive Loss: 0.2716789664387703
Epoch 3/3, Entailment Loss: 2.2021899812834604


Mean Reciprocal Rank (MRR) for validation set: 0.9977
Epoch 3 took 182.8639 minutes.




In [51]:
import json
with open(save_dir + '/hyperbolic_epoch_metrics.json', 'w') as f:
    json.dump(epoch_metrics, f)