In [None]:
import pickle
import networkx as nx
import random
from torch.utils.data import Dataset, DataLoader
from transformers import DataCollatorWithPadding, DefaultDataCollator
# import itertools
import torch
import numpy as np
from tqdm import tqdm
import os
import pandas as pd

import logging
from typing import Any, Dict, List, Literal, Union, Tuple
import wandb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import os
from copy import deepcopy
import pandas as pd
from transformers import BertForSequenceClassification, AdamW, BertTokenizer, BertModel, get_linear_schedule_with_warmup
from torch.optim.lr_scheduler import CosineAnnealingLR
from random import sample
from sentence_transformers import SentenceTransformer

In [None]:
G = nx.read_edgelist('/raid/rabikov/contrastive_data/all_final.edgelist', create_using=nx.DiGraph)

In [None]:
class TripletDataset(Dataset):
    def __init__(self, G, tokenizer, max_length=32):
        self.graph = G
        self.triplets = self.sample_triplets()
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Tokenize all triplets
        self.tokenized_triplets = self.tokenize_triplets()

    def sample_triplets(self):
        triplets = []
        for node in tqdm(self.graph.nodes()):
            preds = list(self.graph.predecessors(node))
            if preds:
        
                # print(preds)
                positive = sample(preds, 1)[0].split('.')[0]
            
                flag = True
                while flag:
                    negative = sample(list(self.graph.nodes), 1)
                    if negative not in preds:
                        negative = negative[0].split('.')[0]
                        flag = False
                        
                # node_new = node.split('.')[0]
                # triplet = (f'Concept: {node_new}', f'Concept: {positive}', f'Concept: {negative}')
                triplet = (node.split('.')[0], positive, negative)
                triplets.append(triplet)
            else:
                continue
        return triplets

    def tokenize_triplets(self):
        anchor_texts = [triplet[0] for triplet in self.triplets]
        positive_texts = [triplet[1] for triplet in self.triplets]
        negative_texts = [triplet[2] for triplet in self.triplets]

        anchor_texts = self.tokenizer(anchor_texts, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        positive_texts = self.tokenizer(positive_texts, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        negative_texts = self.tokenizer(negative_texts, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')

            

        return anchor_texts, positive_texts, negative_texts

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        return {
            'anchor_input_ids': self.tokenized_triplets[0]['input_ids'][idx],
            'anchor_attention_mask': self.tokenized_triplets[0]['attention_mask'][idx],
            'positive_input_ids': self.tokenized_triplets[1]['input_ids'][idx],
            'positive_attention_mask': self.tokenized_triplets[1]['attention_mask'][idx],
            'negative_input_ids': self.tokenized_triplets[2]['input_ids'][idx],
            'negative_attention_mask': self.tokenized_triplets[2]['attention_mask'][idx],
        }


In [None]:
model_name = 'microsoft/MiniLM-L12-H384-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
# bert = BertModel.from_pretrained('microsoft/MiniLM-L12-H384-uncased')
# model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device='cuda:7')

# device = torch.device('cuda:7') if torch.cuda.is_available() else torch.device('cpu')


In [None]:
bert = BertModel.from_pretrained('microsoft/MiniLM-L12-H384-uncased')

In [None]:
dataset = TripletDataset(G, tokenizer)

In [None]:
train_size = int(0.7 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [None]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=128,shuffle=False)

In [None]:
class SentenceBERT(nn.Module):
    def __init__(self, model_name):
        super(SentenceBERT, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.pooling = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),
            nn.Tanh()
        )

    def forward(self, input_ids, attention_mask):
        
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state  # Extract the last hidden state
        
        # Mean pooling
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        mean_pooled_output = sum_embeddings / sum_mask
        return mean_pooled_output

sbert_model = SentenceBERT(model_name)
device = torch.device('cuda:7') if torch.cuda.is_available() else torch.device('cpu')
sbert_model.to(device)

In [None]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_similarity = F.pairwise_distance(anchor, positive, p=2)
        neg_similarity = F.pairwise_distance(anchor, negative, p=2)
        loss = torch.mean(F.relu(self.margin + pos_similarity - neg_similarity))
        return loss

triplet_loss = TripletLoss(margin=1.0)

In [None]:
optimizer = torch.optim.AdamW(sbert_model.parameters(), lr=2e-5, weight_decay=1e-4)

num_epochs = 10
num_training_steps = len(train_dataloader) * num_epochs
num_warmup_steps = int(0.1 * num_training_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)


wandb.login()

wandb.init(project="contrastive_thesis_sbert_all_miniLM")

best_val_loss = float('inf')
patience = 3
epochs_no_improve = 0
best_model_wts = ''

for epoch in range(num_epochs):
    print(f'Epoch: {epoch + 1}')
    sbert_model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader):
        batch_final = {k: v.to(device) for k, v in batch.items()}
        # print('All on device')
        optimizer.zero_grad()

        # Split the embeddings into anchor, positive, and negative embeddings
        anchor_embeddings = sbert_model(batch_final['anchor_input_ids'], batch_final['anchor_attention_mask'])
        positive_embeddings = sbert_model(batch_final['positive_input_ids'], batch_final['positive_attention_mask'])
        negative_embeddings = sbert_model(batch_final['negative_input_ids'], batch_final['negative_attention_mask'])

        loss = triplet_loss(anchor_embeddings, positive_embeddings, negative_embeddings)
        loss.backward()
        optimizer.step()
        scheduler.step()
        wandb.log({'Loss': loss})

        total_loss += loss.item()
    
    print(f'Loss: {total_loss / len(train_dataloader)}')
    wandb.log({'Average Training Loss': total_loss / len(train_dataloader)})
    
    sbert_model.eval()
    total_eval_loss = 0
    for batch in tqdm(val_dataloader):
        batch_final = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            # all_embeddings = sbert_model(**batch_final)

            anchor_embeddings = sbert_model(batch_final['anchor_input_ids'], batch_final['anchor_attention_mask'])
            positive_embeddings = sbert_model(batch_final['positive_input_ids'], batch_final['positive_attention_mask'])
            negative_embeddings = sbert_model(batch_final['negative_input_ids'], batch_final['negative_attention_mask'])

            loss = triplet_loss(anchor_embeddings, positive_embeddings, negative_embeddings)

            total_eval_loss += loss.item()

    print(f'Validation Loss: {total_eval_loss / len(val_dataloader)}')
    wandb.log({'Average Validation Loss': total_eval_loss / len(val_dataloader)})
    
    avg_eval = total_eval_loss / len(val_dataloader)
    if avg_eval < best_val_loss:
        best_val_loss = avg_eval
        best_model_wts = sbert_model.state_dict()
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print('Early stopping')
            break


In [None]:
wandb.finish()

In [None]:
torch.save({"model": sbert_model.state_dict()}, f'/raid/rabikov/contrastive_data/model/best_sbert_checkpoint_ALL_MIMLM.pth')

## Test and evaluation

In [None]:
class CustomDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        item = {
            'input_ids': encoding['input_ids'].squeeze(),  # Remove the extra dimension
            'attention_mask': encoding['attention_mask'].squeeze(),  # Remove the extra dimension # Remove the extra dimension
        }
        return item


In [None]:
G_mag=nx.read_edgelist('./all.edgelist', create_using=nx.DiGraph, delimiter='\t')
all_nodes = list(G_mag.nodes())
new_nodes = []

for node in all_nodes:
    new_nodes.append(f'Concept: {node}')

tokenized = CustomDataset(new_nodes, tokenizer)

dataset_mag = DataLoader(tokenized, batch_size=32)

all_embedds = []
for batch in tqdm(dataset_mag):
    with torch.no_grad():
        batch_final = {k: v.to(device) for k, v in batch.items()}
        logits = sbert_model(**batch_final)


        all_embedds.extend(logits)


In [None]:
test_nodes = pickle.load(open('./test_nodes.pickle', 'rb'))

indices = {}
for i, node in enumerate(all_nodes):
    indices[node] = i

matrix = {}
for node, ans in test_nodes:
  matrix[node] = []

In [None]:
embs = [t.cpu() for t in all_embedds]

In [None]:
for node, answer in tqdm(test_nodes):
    node_embedding = embs[indices[node]]
    for embedding in embs:
        similarity = cosine_similarity(node_embedding.reshape(1,-1), embedding.reshape(1,-1))[0][0]
        matrix[node].append(similarity)

In [None]:
pickle.dump(matrix, open('sim_sbert_psy.pickle', 'wb'))

In [None]:
from collections import defaultdict

def default_structure():
    return {
        'args': np.array([]),  # Default empty numpy array
        'nodes': [],
        'scores': torch.tensor # Default empty list
    }

def find_top_k_indices(array, k=10):
    # Convert the input list to a numpy array if it isn't already
    array = np.array(array)
    
    # Find the indices of the elements sorted in descending order
    sorted_indices = np.argsort(-array)
    
    # Take the first k elements from the sorted indices
    top_k_indices = sorted_indices[:k]
    
    return top_k_indices

def extract_elements(array, indices, node=None):
    array = np.array(array)
    extracted_elements = []
    for idx in indices:
        if idx < len(array):
            if node:
              extracted_elements.append((node, array[idx]))
            else:
              extracted_elements.append(array[idx])
        else:
            extracted_elements.append(None)  # or handle it in a way you prefer

    if node:
      return extracted_elements
    else:
      return torch.tensor(extracted_elements)

In [None]:
from collections import defaultdict

data = defaultdict(default_structure)

for node, gold in tqdm(test_nodes):
  args = find_top_k_indices(matrix[node], k=len(matrix[node]))[1:]

  data[node]['args'] = args
  data[node]['nodes'] = extract_elements(all_nodes, args, node)
  data[node]['scores'] = extract_elements(matrix[node], args)

In [None]:
def rearrange(energy_scores, candidate_position_idx, true_position_idx):
    tmp = np.array([[x==y for x in candidate_position_idx] for y in true_position_idx]).any(0)
    correct = np.where(tmp)[0]
    incorrect = np.where(~tmp)[0]
    labels = torch.cat((torch.ones(len(correct)), torch.zeros(len(incorrect)))).int()
    energy_scores = torch.cat((energy_scores[correct], energy_scores[incorrect]))
    return energy_scores, labels

import re

def calculate_ranks_from_distance(all_distances, positive_relations):
    """
    all_distances: a np array
    positive_relations: a list of array indices

    return a list
    """
    # positive_relation_distance = all_distances[positive_relations]
    # negative_relation_distance = np.ma.array(all_distances, mask=False)
    # negative_relation_distance.mask[positive_relations] = True
    # ranks = list((negative_relation_distance < positive_relation_distance[:, np.newaxis]).sum(axis=1) + 1)
    # ranks = list((all_distances < positive_relation_distance[:, np.newaxis]).sum(axis=1) + 1)
    ranks = list(np.argsort(np.argsort(all_distances))[positive_relations]+1)
    return ranks

def obtain_ranks(outputs, targets):
    """
    outputs : tensor of size (batch_size, 1), required_grad = False, model predictions
    targets : tensor of size (batch_size, ), required_grad = False, labels
        Assume to be of format [1, 0, ..., 0, 1, 0, ..., 0, ..., 0]
    mode == 0: rank from distance (smaller is preferred)
    mode == 1: rank from similarity (larger is preferred)
    """
    calculate_ranks = calculate_ranks_from_distance
    all_ranks = []
    prediction = outputs.cpu().numpy().squeeze()
    label = targets.cpu().numpy()
    sep = np.array([0, 1], dtype=label.dtype)

    # fast way to find subarray indices in a large array, c.f. https://stackoverflow.com/questions/14890216/return-the-indexes-of-a-sub-array-in-an-array
    end_indices = [(m.start() // label.itemsize)+1 for m in re.finditer(sep.tostring(), label.tostring())]
    end_indices.append(len(label)+1)
    start_indices = [0] + end_indices[:-1]
    for start_idx, end_idx in zip(start_indices, end_indices):
        distances = prediction[start_idx: end_idx]
        labels = label[start_idx:end_idx]
        positive_relations = list(np.where(labels == 1)[0])
        ranks = calculate_ranks(distances, positive_relations)
        all_ranks.append(ranks)
    return all_ranks

In [None]:
import itertools

def macro_mr(all_ranks):
    macro_mr = np.array([np.array(all_rank).mean() for all_rank in all_ranks]).mean()
    return macro_mr

def micro_mr(all_ranks):
    micro_mr = np.array(list(itertools.chain(*all_ranks))).mean()
    return micro_mr

def hit_at_1(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 1)
    return 1.0 * hits / len(rank_positions)

def hit_at_3(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 3)
    return 1.0 * hits / len(rank_positions)

def hit_at_5(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 5)
    return 1.0 * hits / len(rank_positions)

def hit_at_10(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 10)
    return 1.0 * hits / len(rank_positions)

def precision_at_1(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 1)
    return 1.0 * hits / len(all_ranks)

def precision_at_3(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 3)
    return 1.0 * hits / (len(all_ranks)*3)

def precision_at_5(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 5)
    return 1.0 * hits / (len(all_ranks)*5)

def precision_at_10(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 10)
    return 1.0 * hits / (len(all_ranks)*10)

def mrr_scaled_10(all_ranks):
    """ Scaled MRR score, check eq. (2) in the PinSAGE paper: https://arxiv.org/pdf/1806.01973.pdf
    """
    rank_positions = np.array(list(itertools.chain(*all_ranks)))

    scaled_rank_positions = np.ceil(rank_positions / 10)

 #   print(scaled_rank_positions, (1.0 / scaled_rank_positions).mean())
    return (1.0 / scaled_rank_positions).mean()

In [None]:
metric_names = {
    'mrr': mrr_scaled_10,
    'p1': precision_at_1,
    'p5': precision_at_5,
    'r1': hit_at_1,
    'r5': hit_at_5
}

metrics = {}
for name in metric_names.keys():
    metrics[name] = []

missing = 0
for gold in tqdm(test_nodes):
    query = gold[0]

    gold_new = []
    for gold_node in gold[1]:
        gold_new.append((query, gold_node))

    scores = data[query]['scores']
    potential_nodes = data[query]['nodes']


    # scores = torch.tensor(scores)
    batched_energy_scores, labels = rearrange(scores, potential_nodes, gold_new)
    # print(batched_energy_scores)
    all_ranks = obtain_ranks(-batched_energy_scores, labels)

    
    for name, func in metric_names.items():
        cur_metric = np.nan_to_num(func(all_ranks))
        metrics[name].append(cur_metric)

In [None]:
for name, v in metrics.items():
    print(name, np.mean(v))