In [None]:
!pip install -r requirements.txt

In [1]:
import time
import os
import pandas as pd
import numpy as np

from dataloader import GraphTextDataset, GraphDataset, TextDataset

import torch
from torch import optim
from torch_geometric.data import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
from transformers import AutoTokenizer
from torchmetrics.functional import pairwise_cosine_similarity
from torch_geometric.data import DataLoader, Data


from alignment import AlignmentModel,Discriminator, gradient_penalty, CombinedModel

from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import label_ranking_average_precision_score

from tqdm import tqdm

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


### **LOADING DATASET**

In [3]:
batch_size = 32

model_name = 'allenai/scibert_scivocab_uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer)
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='train', tokenizer=tokenizer)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)



### **LOADING FINETUNED MODELS**

In [4]:
path_to_model= 'supergat_efficient.pt'
checkpoint = torch.load(path_to_model)
model_supergat = AlignmentModel(in_channels=300, out_channels=768, graph_attention_head=6, type = 'SuperGat', n_layers = 5)
model_supergat.load_state_dict(checkpoint['model_state_dict'])
model_supergat.eval()
model_supergat.to(device)


path_to_model= 'anti_efficient.pt'
checkpoint = torch.load(path_to_model)
model_anti = AlignmentModel(in_channels=300, out_channels=768, graph_attention_head=6, type = 'Antisymmetric', n_layers = 5)
model_anti.load_state_dict(checkpoint['model_state_dict'])
model_anti.eval()
model_anti.to(device)


path_to_model= 'gps_efficient.pt'
checkpoint = torch.load(path_to_model)
model_gps = AlignmentModel(in_channels=300, out_channels=768, graph_attention_head=6, type = 'GPS', n_layers = 5)
model_gps.load_state_dict(checkpoint['model_state_dict'])
model_gps.eval()
model_gps.to(device)


path_to_model= 'transformer_efficient.pt'
checkpoint = torch.load(path_to_model)
model_trans = AlignmentModel(in_channels=300, out_channels=768, graph_attention_head=6, type = 'TransformerConv', n_layers = 5)
model_trans.load_state_dict(checkpoint['model_state_dict'])
model_trans.eval()
model_trans.to(device)

path_to_model= 'gat_efficient.pt'
checkpoint = torch.load(path_to_model)
model_gat = AlignmentModel(in_channels=300, out_channels=768, graph_attention_head=6, type = 'GATv2Conv', n_layers = 5)
model_gat.load_state_dict(checkpoint['model_state_dict'])
model_gat.eval()
model_gat.to(device)

path_to_model= 'gin_efficient.pt'
checkpoint = torch.load(path_to_model)
model_gin = AlignmentModel(in_channels=300, out_channels=768, graph_attention_head=6, type = 'GIN', n_layers = 5)
model_gin.load_state_dict(checkpoint['model_state_dict'])
model_gin.eval()
model_gin.to(device)


path_to_model= 'dir_supergat_efficient.pt'
checkpoint = torch.load(path_to_model)
model_dir_supergat = AlignmentModel(in_channels=300, out_channels=768, graph_attention_head=6, type = 'DirGNNConv_supergat', n_layers = 5)
model_dir_supergat.load_state_dict(checkpoint['model_state_dict'])
model_dir_supergat.eval()
model_dir_supergat.to(device)


path_to_model= 'gat_kv_plm.pt'
checkpoint = torch.load(path_to_model)
model_gat_kv_plm = AlignmentModel(in_channels=300, out_channels=768, graph_attention_head=6, type = 'GATv2Conv', n_layers = 5)
model_gat_kv_plm.load_state_dict(checkpoint['model_state_dict'])
model_gat_kv_plm.eval()
model_gat_kv_plm.to(device)

path_to_model= 'transformer_kv_plm.pt'
checkpoint = torch.load(path_to_model)
model_trans_kv_plm = AlignmentModel(in_channels=300, out_channels=768, graph_attention_head=6, type = 'TransformerConv', n_layers = 5)
model_trans_kv_plm.load_state_dict(checkpoint['model_state_dict'])
model_trans_kv_plm.eval()
model_trans_kv_plm.to(device)



Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--------------------loading pretrained--------------------


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--------------------loading pretrained--------------------


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--------------------loading pretrained--------------------


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--------------------loading pretrained--------------------


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--------------------loading pretrained--------------------


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--------------------loading pretrained--------------------


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--------------------loading pretrained--------------------


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--------------------loading pretrained--------------------


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--------------------loading pretrained--------------------


AlignmentModel(
  (text_encoder): TextEncoder(
    (text_encoder): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(31090, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bi

### **PREDICTIONS ON THE VALIDATION TESTS**

In [5]:
graph_anti = []
text_anti = [] 

graph_supergat = []
text_supergat = [] 

graph_gps = []
text_gps = [] 

graph_trans = []
text_trans = []

graph_gat = []
text_gat = []

graph_gin = []
text_gin = []

graph_dir_supergat = []
text_dir_supergat = [] 

graph_trans_kv_plm = []
text_trans_kv_plm = []

graph_gat_kv_plm = []
text_gat_kv_plm = []

for batch in tqdm(val_loader):
    output = model_anti.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_anti.extend(output.tolist())

    output = model_supergat.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_supergat.extend(output.tolist())

    output = model_gps.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_gps.extend(output.tolist())

    output = model_trans.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_trans.extend(output.tolist())

    output = model_gat.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_gat.extend(output.tolist())

    output = model_gin.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_gin.extend(output.tolist())

    output = model_dir_supergat.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_dir_supergat.extend(output.tolist())

    output = model_trans_kv_plm.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_trans_kv_plm.extend(output.tolist())

    output = model_gat_kv_plm.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_gat_kv_plm.extend(output.tolist())

    
#for batch in val_loader:

    output = model_anti.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_anti.extend(output.tolist())

    output = model_supergat.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_supergat.extend(output.tolist())

    output = model_gps.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_gps.extend(output.tolist())

    output = model_trans.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_trans.extend(output.tolist())

    output = model_gat.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_gat.extend(output.tolist())

    output = model_gin.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_gin.extend(output.tolist())

    output = model_dir_supergat.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_dir_supergat.extend(output.tolist())

    output = model_trans_kv_plm.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_trans_kv_plm.extend(output.tolist())

    output = model_gat_kv_plm.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_gat_kv_plm.extend(output.tolist())

100%|██████████| 104/104 [01:21<00:00,  1.28it/s]


### **EVALUATION ON THE VALIDATION TESTS**

In [6]:
similarity_anti = cosine_similarity(text_anti, graph_anti)
similarity_supergat = cosine_similarity(text_supergat, graph_supergat)
similarity_gps = cosine_similarity(text_gps, graph_gps)
similarity_trans = cosine_similarity(text_trans, graph_trans)
similarity_gat = cosine_similarity(text_gat, graph_gat)
similarity_gin = cosine_similarity(text_gin, graph_gin)
similarity_dir_supergat = cosine_similarity(text_dir_supergat, graph_dir_supergat)
similarity_trans_kv_plm = cosine_similarity(text_trans_kv_plm , graph_trans_kv_plm )
similarity_gat_kv_plm = cosine_similarity(text_gat_kv_plm , graph_gat_kv_plm )

In [21]:
similarity =  similarity_trans + similarity_gat + (similarity_supergat + similarity_anti + similarity_gps + similarity_gin + similarity_dir_supergat) /2 + (similarity_trans_kv_plm + similarity_gat_kv_plm)/3

In [27]:
y_true = np.eye(len(similarity))
label_ranking_average_precision_score(y_true, similarity)

0.9380643621440066

### **PREDICTIONS ON TEST SET**

In [23]:
test_cids_dataset = GraphDataset(root='./data/', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='./data/test_text.txt', tokenizer=tokenizer)

idx_to_cid = test_cids_dataset.get_idx_to_cid()


In [24]:
graph_anti = []
text_anti = [] 

graph_supergat = []
text_supergat = [] 

graph_gps = []
text_gps = [] 

graph_trans = []
text_trans = []

graph_gat = []
text_gat = []

graph_gin = []
text_gin = []

graph_dir_supergat = []
text_dir_supergat = [] 

graph_trans_kv_plm = []
text_trans_kv_plm = []

graph_gat_kv_plm = []
text_gat_kv_plm = []

test_graph_loader = DataLoader(test_cids_dataset, batch_size=batch_size, shuffle=False)

for batch in tqdm(test_graph_loader):
    output = model_anti.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_anti.extend(output.tolist())

    output = model_supergat.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_supergat.extend(output.tolist())

    output = model_gps.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_gps.extend(output.tolist())

    output = model_trans.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_trans.extend(output.tolist())

    output = model_gat.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_gat.extend(output.tolist())

    output = model_gin.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_gin.extend(output.tolist())

    output = model_dir_supergat.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_dir_supergat.extend(output.tolist())

    output = model_trans_kv_plm.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_trans_kv_plm.extend(output.tolist())

    output = model_gat_kv_plm.forward_graph(x = batch.x.to(device), edge_index = batch.edge_index.to(device), batch = batch.batch.to(device)) 
    graph_gat_kv_plm.extend(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size, shuffle=False)

for batch in tqdm(test_text_loader):

    output = model_anti.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_anti.extend(output.tolist())

    output = model_supergat.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_supergat.extend(output.tolist())

    output = model_gps.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_gps.extend(output.tolist())

    output = model_trans.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_trans.extend(output.tolist())

    output = model_gat.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_gat.extend(output.tolist())

    output = model_gin.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_gin.extend(output.tolist())

    output = model_dir_supergat.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_dir_supergat.extend(output.tolist())

    output = model_trans_kv_plm.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_trans_kv_plm.extend(output.tolist())

    output = model_gat_kv_plm.forward_text(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    text_gat_kv_plm.extend(output.tolist())

100%|██████████| 104/104 [00:10<00:00, 10.23it/s]
100%|██████████| 104/104 [01:13<00:00,  1.42it/s]


In [25]:
similarity_anti = cosine_similarity(text_anti, graph_anti)
similarity_supergat = cosine_similarity(text_supergat, graph_supergat)
similarity_gps = cosine_similarity(text_gps, graph_gps)
similarity_trans = cosine_similarity(text_trans, graph_trans)
similarity_gat = cosine_similarity(text_gat, graph_gat)
similarity_gin = cosine_similarity(text_gin, graph_gin)
similarity_dir_supergat = cosine_similarity(text_dir_supergat, graph_dir_supergat)
similarity_trans_kv_plm = cosine_similarity(text_trans_kv_plm , graph_trans_kv_plm )
similarity_gat_kv_plm = cosine_similarity(text_gat_kv_plm , graph_gat_kv_plm )

In [26]:
similarity =  similarity_trans + (similarity_supergat + similarity_anti + similarity_gps + similarity_gin + similarity_dir_supergat) /2 + similarity_gat + (similarity_trans_kv_plm + similarity_gat_kv_plm)/3

In [27]:
from sklearn.metrics.pairwise import cosine_similarity

solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission_combined_efficient_kv_plm.csv', index=False)