In [None]:
!pip install -r requirements.txt

In [1]:
import time
import os
import pandas as pd
import numpy as np

from dataloader import GraphTextDataset, GraphDataset, TextDataset

import torch
import torch.nn as nn
from torch import optim
from torch_geometric.data import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
from transformers import AutoTokenizer
from torchmetrics.functional import pairwise_cosine_similarity


from alignment import AlignmentModel,Discriminator, gradient_penalty
from moemodel import MOEModel
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import label_ranking_average_precision_score

In [2]:
def hard_triplet_loss(graph_embeddings, text_embeddings, margin = 0.3):
    cosine = pairwise_cosine_similarity(text_embeddings, graph_embeddings) # compute cosine similarity between each pairs (texts, graphs)
    positive_sample = cosine.diag() # get similarity between anchor and positive sample where anchor could be the text representation and positive sample the graph represention and vice versa
    cosine = cosine.fill_diagonal_(-2) # set diag val to a minimum possible value of similarity to get hard negetive example by argmax
    loss = torch.clamp(torch.max(cosine, axis = 1)[0] - positive_sample + margin,0)
    loss += torch.clamp(torch.max(cosine, axis = 0)[0] - positive_sample +  margin,0)
    loss = torch.mean(loss)
    return loss

In [3]:
nb_epochs = 5
batch_size = 32
learning_rate = 2e-5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'allenai/scibert_scivocab_uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]

val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer)
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='train', tokenizer=tokenizer)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)



In [6]:
model = AlignmentModel(in_channels=300, out_channels=300, graph_attention_head=6, type = 'Antisymmetric')
# expert_type = ['GPS', 'EGC', 'TransformerConv']

# model = MOEModel(300, 300, 6, expert_type)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)
model.to(device)

discriminator=Discriminator(300,300)
optimizer_discriminator = optim.AdamW(discriminator.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)
discriminator.to(device)



Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Discriminator(
  (model): Sequential(
    (0): Linear(in_features=300, out_features=300, bias=True)
    (1): ReLU()
    (2): Linear(in_features=300, out_features=1, bias=True)
    (3): Sigmoid()
  )
)

In [5]:
# checkpoint = torch.load('model8.pt')
# model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [None]:
model

In [7]:
epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000
best_validation_mrr = 0
discriminator_iteration = 3
lambda_gp = 1
gamma = 1e-2

In [14]:
optimizer.param_groups[0]['lr'] = 2e-7
optimizer_discriminator.param_groups[0]['lr'] = 2e-7

nb_epochs = 10

for i in range(nb_epochs):
    print('-----EPOCH{}-----'.format(i+1))
    model.train()
    for batch in train_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        
        graph_embeddings, text_embeddings = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        
        for _ in range(discriminator_iteration):
            
            gp = gradient_penalty(discriminator, text_embeddings, graph_embeddings)
            text_scores = discriminator(text_embeddings)
            graph_scores = discriminator(graph_embeddings)

            loss_discriminator = torch.mean(text_scores) - torch.mean(graph_scores) + lambda_gp*gp

            optimizer_discriminator.zero_grad()
            loss_discriminator.backward(retain_graph=True)
            optimizer_discriminator.step()            
          
        triplet_loss = hard_triplet_loss(graph_embeddings, text_embeddings)
        
        optimizer.zero_grad()
        triplet_loss.backward()        
        optimizer.step()
        
        loss += triplet_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 
    model.eval()       
    val_loss = 0   
    graphs = []
    texts = []     
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        graph_embeddings, text_embeddings = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = hard_triplet_loss(graph_embeddings, text_embeddings)  
        val_loss += current_loss.item()
        graphs.extend(graph_embeddings.tolist())
        texts.extend(text_embeddings.tolist())

    best_validation_loss = min(best_validation_loss, val_loss)
    print('-----EPOCH'+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    similarity = cosine_similarity(texts, graphs)
    y_true = np.eye(len(similarity))
    score = label_ranking_average_precision_score(y_true, similarity)
    print('-----EPOCH'+str(i+1)+'----- done.  Validation MRR: ', str(score) )
    best_validation_mrr = max(best_validation_mrr, score)
    if best_validation_mrr==score:
        current_directory = os.getcwd()
        files = os.listdir(current_directory)

        for file in files:
            if file.startswith('model'):
                file_path = os.path.join(current_directory, file)
                os.remove(file_path)
                
        print('validation loss improoved saving checkpoint...')
        save_path = os.path.join('./', 'model'+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        save_path = os.path.join('./', 'model_discriminator'+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': discriminator.state_dict(),
        'optimizer_state_dict': optimizer_discriminator.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        print('checkpoint saved to: {}'.format(save_path))

-----EPOCH1-----
Iteration: 24800, Time: 17727.4163 s, training loss: 0.0265
Iteration: 24850, Time: 17758.5371 s, training loss: 0.0309
Iteration: 24900, Time: 17789.7504 s, training loss: 0.0312
Iteration: 24950, Time: 17821.1303 s, training loss: 0.0241
Iteration: 25000, Time: 17852.0777 s, training loss: 0.0325
Iteration: 25050, Time: 17883.0199 s, training loss: 0.0314
Iteration: 25100, Time: 17914.1380 s, training loss: 0.0270
Iteration: 25150, Time: 17945.1350 s, training loss: 0.0357
Iteration: 25200, Time: 17976.3632 s, training loss: 0.0278
Iteration: 25250, Time: 18007.4518 s, training loss: 0.0311
Iteration: 25300, Time: 18038.1156 s, training loss: 0.0310
Iteration: 25350, Time: 18068.3183 s, training loss: 0.0332
Iteration: 25400, Time: 18098.5687 s, training loss: 0.0332
Iteration: 25450, Time: 18128.9218 s, training loss: 0.0300
Iteration: 25500, Time: 18159.1207 s, training loss: 0.0331
Iteration: 25550, Time: 18189.4228 s, training loss: 0.0364
Iteration: 25600, Time:

In [9]:
res = []
for _ in tqdm(range(10)):
    graphs, texts = [], []
    model.eval()
    for batch in val_loader:
        
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        
        graph_embeddings, text_embeddings = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        graphs.extend(graph_embeddings.tolist())
        texts.extend(text_embeddings.tolist())
    similarity = cosine_similarity(texts, graphs)
    y_true = np.eye(len(similarity))
    scores = label_ranking_average_precision_score(y_true, similarity)
    print(scores)
    res.append(scores)

 10%|█         | 1/10 [00:13<02:05, 13.95s/it]

0.8345661185417758


 20%|██        | 2/10 [00:27<01:49, 13.69s/it]

0.8347720445959282


 30%|███       | 3/10 [00:40<01:34, 13.43s/it]

0.8345126234488979


 40%|████      | 4/10 [00:53<01:19, 13.32s/it]

0.8349754461625541


 50%|█████     | 5/10 [01:07<01:06, 13.32s/it]

0.8352960560787419


 60%|██████    | 6/10 [01:20<00:53, 13.37s/it]

0.8344885807105265


 70%|███████   | 7/10 [01:34<00:40, 13.45s/it]

0.8347273251025582


 80%|████████  | 8/10 [01:47<00:26, 13.44s/it]

0.834865089993428


 90%|█████████ | 9/10 [02:00<00:13, 13.37s/it]

0.8359520621952138


100%|██████████| 10/10 [02:13<00:00, 13.38s/it]

0.834989511164502





In [10]:
sum(res)/len(res)

0.8349144857994126