In [2]:
from dataloader import GraphTextDataset, GraphDataset, TextDataset
from torch_geometric.data import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
#from Model import Model
import numpy as np
from transformers import AutoTokenizer
import torch
from torch import optim
import time
import os
import pandas as pd
from torch import nn
from torch_geometric.nn import GCNConv, GAT
from transformers import AutoModel
from torch_geometric.nn import global_mean_pool
#from subgraph import subgraph_random_walk

# Model

In [3]:
class GraphEncoder(nn.Module):
    def __init__(self, num_node_features, nout, nhid, graph_hidden_channels):
        super(GraphEncoder, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.relu = nn.ReLU()
        self.ln = nn.LayerNorm((nout))
        self.conv1 = GCNConv(num_node_features, graph_hidden_channels)
        self.conv2 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.conv3 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.mol_hidden1 = nn.Linear(graph_hidden_channels, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nout)

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.mol_hidden1(x).relu()
        x = self.mol_hidden2(x)
        return x
    
    
class GATEncoder(nn.Module):
    def __init__(self, nout, nhid, attention_hidden, n_in, dropout):
        super(GATEncoder, self).__init__()
        self.dropout = dropout
        self.n_in = n_in
        self.attention_hidden = attention_hidden
        self.n_hidden = nhid
        self.n_out = nout
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(self.attention_hidden, self.n_out)
        self.GATEnc = GAT(in_channels=self.n_in, hidden_channels = self.attention_hidden, out_channels=self.n_hidden, dropout=self.dropout, num_layers=4, v2=True)

    def forward(self, gr):
        x = gr.x
        x = self.GATEnc(x, gr.edge_index)
        x = self.relu(x)
        x = global_mean_pool(x, gr.batch)
        x = self.fc1(x)
        x = self.relu(x)
        return x
    
    
class TextEncoder(nn.Module):
    def __init__(self, model_name):
        super(TextEncoder, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        
    def forward(self, input_ids, attention_mask):
        encoded_text = self.bert(input_ids, attention_mask=attention_mask)
        return encoded_text.last_hidden_state[:,0,:]
    
class Model(nn.Module):
    def __init__(self, model_name, num_node_features, nout, nhid, graph_hidden_channels):
        super(Model, self).__init__()
        self.graph_encoder = GraphEncoder(num_node_features, nout, nhid, graph_hidden_channels)
        self.text_encoder = TextEncoder(model_name)
        
    def forward(self, graph_batch, input_ids, attention_mask):
        graph_encoded = self.graph_encoder(graph_batch)
        text_encoded = self.text_encoder(input_ids, attention_mask)
        return graph_encoded, text_encoded
    
    def get_text_encoder(self):
        return self.text_encoder
    
    def get_graph_encoder(self):
        return self.graph_encoder
    
class ModelGAT(nn.Module):
    def __init__(self, model_name, n_in, nout, nhid, attention_hidden, dropout):
        super(ModelGAT, self).__init__()
        self.graph_encoder = GATEncoder(nout, nhid, attention_hidden, n_in, dropout)
        self.text_encoder = TextEncoder(model_name)
        
    def forward(self, graph_batch, input_ids, attention_mask):
        graph_encoded = self.graph_encoder(graph_batch)
        text_encoded = self.text_encoder(input_ids, attention_mask)
        return graph_encoded, text_encoded
    
    def get_text_encoder(self):
        return self.text_encoder
    
    def get_graph_encoder(self):
        return self.graph_encoder

In [4]:
CE = torch.nn.CrossEntropyLoss()
def nt_xent_loss(v1, v2, temp = 1):
    logits = torch.matmul(v1,torch.transpose(v2, 0, 1)) / temp 	
    labels = torch.arange(logits.shape[0], device=v1.device)
    return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)    
    

In [5]:
def contrastive_loss(vg1, vg2, vt1, vt2, temp = 1):
    return (nt_xent_loss(vg1, vt1, temp) + nt_xent_loss(vg2, vt2, temp) + nt_xent_loss(vg1, vt2, temp) + nt_xent_loss(vg2, vt1, temp) + nt_xent_loss(vg1, vg2, temp))/5

In [6]:

model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("../../Public/data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphTextDataset(root='../../Public/data/', gt=gt, split='val', tokenizer=tokenizer)
train_dataset = GraphTextDataset(root='../../Public/data/', gt=gt, split='train', tokenizer=tokenizer)
train_drop_dataset = GraphTextDataset(root='../../Public/data/', gt=gt, split='train_drop', tokenizer=tokenizer, drop=True)
train_subgraph_dataset = GraphTextDataset(root='../../Public/data/', gt=gt, split='train_subgraph', tokenizer=tokenizer, subgraph=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
nb_epochs = 5
batch_size = 32
learning_rate = 2e-5

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


TypeError: GraphTextDataset.__init__() got an unexpected keyword argument 'drop'

In [65]:

model = ModelGAT(model_name=model_name, n_in=300, nout=768, nhid=1000, attention_hidden=1000, dropout=0.3)
#Model(model_name=model_name, num_node_features=300, nout=768, nhid=300, graph_hidden_channels=300) # nout = bert model hidden dim
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)


In [56]:

epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000
temp = 1
rate = 0.8

for i in range(nb_epochs):
    print('-----EPOCH{}-----'.format(i+1))
    model.train()
    for batch, batch_drop, batch_subgraph in zip(train_loader, train_drop_dataset, train_subgraph_dataset):
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch

        input_ids_drop = batch_drop.input_ids
        batch_drop.pop('input_ids')
        attention_mask_drop = batch_drop.attention_mask
        batch_drop.pop('attention_mask')
        graph_batch_drop = batch_drop

        input_ids_subgraph = batch_subgraph.input_ids
        batch_subgraph.pop('input_ids')
        attention_mask_subgraph = batch_subgraph.attention_mask
        batch_subgraph.pop('attention_mask')
        graph_batch_subgraph = batch_subgraph

        
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        
        x_graph_drop, x_text_drop = model(graph_batch_drop.to(device),
                                            input_ids_drop.to(device),
                                            attention_mask_drop.to(device))
        
        x_graph_subgraph, x_text_subgraph = model(graph_batch_subgraph.to(device),
                                            input_ids_subgraph.to(device),
                                            attention_mask_subgraph.to(device))
        
       
        current_loss = contrastive_loss(x_graph_drop, x_graph_subgraph, x_text_drop, x_text_subgraph, temp)
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 
    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_graph, x_text, x_text, temp)
        val_loss += current_loss.item()
    best_validation_loss = min(best_validation_loss, val_loss)
    print('-----EPOCH'+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    if best_validation_loss==val_loss:
        print('validation loss improoved saving checkpoint...')
        save_path = os.path.join('./', 'model'+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        print('checkpoint saved to: {}'.format(save_path))



-----EPOCH1-----
Iteration: 50, Time: 24.5009 s, training loss: 5.0209
Iteration: 100, Time: 48.3899 s, training loss: 4.5705
Iteration: 150, Time: 72.9599 s, training loss: 3.9515
Iteration: 200, Time: 97.3548 s, training loss: 3.3941
Iteration: 250, Time: 121.5333 s, training loss: 3.0376
Iteration: 300, Time: 145.8680 s, training loss: 2.7337
Iteration: 350, Time: 170.1914 s, training loss: 2.5004
Iteration: 400, Time: 194.4660 s, training loss: 2.2695
Iteration: 450, Time: 219.1673 s, training loss: 2.1912
Iteration: 500, Time: 243.6148 s, training loss: 2.0244
Iteration: 550, Time: 268.4169 s, training loss: 1.9669
Iteration: 600, Time: 293.0709 s, training loss: 1.8904
Iteration: 650, Time: 316.7455 s, training loss: 1.7849
Iteration: 700, Time: 341.0838 s, training loss: 1.7672
Iteration: 750, Time: 367.1084 s, training loss: 1.6040
Iteration: 800, Time: 399.7868 s, training loss: 1.5464
-----EPOCH1----- done.  Validation loss:  2.234593241260602
validation loss improoved saving

In [59]:

print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

test_cids_dataset = GraphDataset(root='../../Public/data/', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='../../Public/data/test_text.txt', tokenizer=tokenizer)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

test_loader = DataLoader(test_cids_dataset, batch_size=batch_size, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    for output in graph_model(batch.to(device)):
        graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
    for output in text_model(batch['input_ids'].to(device), 
                             attention_mask=batch['attention_mask'].to(device)):
        text_embeddings.append(output.tolist())


from sklearn.metrics.pairwise import cosine_similarity

similarity = cosine_similarity(text_embeddings, graph_embeddings)

solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission.csv', index=False)

loading best model...


Processing...
  return torch.LongTensor(edge_index).T, torch.FloatTensor(x)
  return torch.LongTensor(edge_index).T, torch.FloatTensor(x)
Done!
