## Installation

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install kaggle

In [None]:
import os
os.environ['KAGGLE_CONFIG_DIR'] = '/content/drive/MyDrive/Kaggle'

In [None]:
!kaggle datasets download -d adibhabbou/challenge-altegrad

In [None]:
!unzip challenge-altegrad.zip

In [None]:
!pip install torch torch-geometric transformers pandas numpy

## Imports

In [None]:
import os
import os.path as osp
import time
import numpy as np
import pandas as pd

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DataLoader as TorchDataLoader

from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data
from torch_geometric.nn import GCNConv, GATConv, GATv2Conv, SuperGATConv
from torch_geometric.nn import global_mean_pool, global_max_pool

from transformers import AutoModel, AutoTokenizer

from sklearn.metrics.pairwise import cosine_similarity

## Data Sets/Loaders

### Graph-Text

In [None]:
class GraphTextDataset(Dataset):
    def __init__(self, root, gt, split, tokenizer=None, transform=None, pre_transform=None):
        self.root = root
        self.gt = gt
        self.split = split
        self.tokenizer = tokenizer
        self.description = pd.read_csv(os.path.join(self.root, split+'.tsv'), sep='\t', header=None)
        self.description = self.description.set_index(0).to_dict()
        self.cids = list(self.description[1].keys())

        self.idx_to_cid = {}
        i = 0
        for cid in self.cids:
            self.idx_to_cid[i] = cid
            i += 1
        super(GraphTextDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return [str(cid) + ".graph" for cid in self.cids]

    @property
    def processed_file_names(self):
        return ['data_{}.pt'.format(cid) for cid in self.cids]

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, 'processed/', self.split)

    def download(self):
        pass

    def process_graph(self, raw_path):
      edge_index  = []
      x = []
      with open(raw_path, 'r') as f:
        next(f)
        for line in f:
          if line != "\n":
            edge = *map(int, line.split()),
            edge_index.append(edge)
          else:
            break
        next(f)
        for line in f: #get mol2vec features:
          substruct_id = line.strip().split()[-1]
          if substruct_id in self.gt.keys():
            x.append(self.gt[substruct_id])
          else:
            x.append(self.gt['UNK'])
        return torch.LongTensor(edge_index).T, torch.FloatTensor(x)

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            cid = int(raw_path.split('\\')[-1][:-6])
            text_input = self.tokenizer([self.description[1][cid]],
                                   return_tensors="pt",
                                   truncation=True,
                                   max_length=256,
                                   padding="max_length",
                                   add_special_tokens=True,)
            edge_index, x = self.process_graph(raw_path)
            data = Data(x=x, edge_index=edge_index, input_ids=text_input['input_ids'], attention_mask=text_input['attention_mask'])

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(cid)))
            i += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(self.idx_to_cid[idx])))
        return data

    def get_cid(self, cid):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(cid)))
        return data

### Graph

In [None]:
class GraphDataset(Dataset):
    def __init__(self, root, gt, split, transform=None, pre_transform=None):
        self.root = root
        self.gt = gt
        self.split = split
        self.description = pd.read_csv(os.path.join(self.root, split+'.txt'), sep='\t', header=None)
        self.cids = self.description[0].tolist()

        self.idx_to_cid = {}
        i = 0
        for cid in self.cids:
            self.idx_to_cid[i] = cid
            i += 1
        super(GraphDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return [str(cid) + ".graph" for cid in self.cids]

    @property
    def processed_file_names(self):
        return ['data_{}.pt'.format(cid) for cid in self.cids]

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, 'processed/', self.split)

    def download(self):
        pass

    def process_graph(self, raw_path):
      edge_index  = []
      x = []
      with open(raw_path, 'r') as f:
        next(f)
        for line in f:
          if line != "\n":
            edge = *map(int, line.split()),
            edge_index.append(edge)
          else:
            break
        next(f)
        for line in f:
          substruct_id = line.strip().split()[-1]
          if substruct_id in self.gt.keys():
            x.append(self.gt[substruct_id])
          else:
            x.append(self.gt['UNK'])
        return torch.LongTensor(edge_index).T, torch.FloatTensor(x)

    def process(self):
        os.makedirs(self.processed_dir, exist_ok=True)
        i = 0
        for raw_path in self.raw_paths:
            cid = int(raw_path.split('\\')[-1][:-6])
            edge_index, x = self.process_graph(raw_path)
            data = Data(x=x, edge_index=edge_index)
            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(cid)))
            i += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(self.idx_to_cid[idx])))
        return data

    def get_cid(self, cid):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(cid)))
        return data

    def get_idx_to_cid(self):
        return self.idx_to_cid

### Text

In [None]:
class TextDataset(TorchDataset):
    def __init__(self, file_path, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.sentences = self.load_sentences(file_path)

    def load_sentences(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            lines = file.readlines()
        return [line.strip() for line in lines]

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]

        encoding = self.tokenizer.encode_plus(
            sentence,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze()
        }

## Models

### Graph Encoder Model

In [None]:
class GraphEncoderBaseline(nn.Module):
    def __init__(self, num_node_features, nout, nhid, graph_hidden_channels):
        super(GraphEncoderBaseline, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.relu = nn.ReLU()
        self.ln = nn.LayerNorm((nout))
        self.conv1 = GCNConv(num_node_features, graph_hidden_channels)
        self.conv2 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.conv3 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.mol_hidden1 = nn.Linear(graph_hidden_channels, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nout)

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.mol_hidden1(x).relu()
        x = self.mol_hidden2(x)
        return x

In [None]:
class GraphEncoderAttention(nn.Module):
    def __init__(self, num_node_features, nout, nhid, graph_hidden_channels, num_heads=25, dropout=0.2):
        super(GraphEncoderAttention, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout) 
        self.conv1 = GATConv(num_node_features, graph_hidden_channels, heads=num_heads)
        self.conv2 = GATConv(graph_hidden_channels * num_heads, graph_hidden_channels, heads=num_heads)
        self.conv3 = GATConv(graph_hidden_channels * num_heads, graph_hidden_channels, heads=num_heads)
        self.mol_hidden1 = nn.Linear(graph_hidden_channels * num_heads, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nout)

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.conv3(x, edge_index)
        x = self.relu(x)
        x = global_max_pool(x, batch)
        x = self.dropout(x)
        x = self.mol_hidden1(x)
        x = self.relu(x)
        x = self.mol_hidden2(x)
        return x

In [None]:
class GraphEncoderAttentionV2(nn.Module):
    def __init__(self, num_node_features, nout, nhid, graph_hidden_channels, num_heads=20, dropout=0.2):
        super(GraphEncoderAttentionV2, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout) 
        self.conv1 = GATv2Conv(num_node_features, graph_hidden_channels, heads=num_heads)
        self.conv2 = GATv2Conv(graph_hidden_channels * num_heads, graph_hidden_channels, heads=num_heads)
        self.conv3 = GATv2Conv(graph_hidden_channels * num_heads, graph_hidden_channels, heads=num_heads)
        self.mol_hidden1 = nn.Linear(graph_hidden_channels * num_heads, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nout)

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.conv3(x, edge_index)
        x = self.relu(x)
        x = global_max_pool(x, batch)
        x = self.dropout(x)
        x = self.mol_hidden1(x)
        x = self.relu(x)
        x = self.mol_hidden2(x)
        return x

In [None]:
class GraphEncoderSuperAttention(nn.Module):
    def __init__(self, num_node_features, nout, nhid, graph_hidden_channels, num_heads=25, dropout=0.2):
        super(GraphEncoderSuperAttention, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout) 
        self.conv1 = SuperGATConv(num_node_features, graph_hidden_channels, heads=num_heads)
        self.conv2 = SuperGATConv(graph_hidden_channels * num_heads, graph_hidden_channels, heads=num_heads)
        self.conv3 = SuperGATConv(graph_hidden_channels * num_heads, graph_hidden_channels, heads=num_heads)
        self.mol_hidden1 = nn.Linear(graph_hidden_channels * num_heads, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nout)

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.conv3(x, edge_index)
        x = self.relu(x)
        x = global_max_pool(x, batch)
        x = self.dropout(x)
        x = self.mol_hidden1(x)
        x = self.relu(x)
        x = self.mol_hidden2(x)
        return x

### Text Encoder Model

In [None]:
class TextEncoderBaseline(nn.Module):
    def __init__(self, model_name):
        super(TextEncoderBaseline, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        encoded_text = self.bert(input_ids, attention_mask=attention_mask)
        return encoded_text.last_hidden_state[:,0,:]

### Full Model

In [None]:
class Model(nn.Module):
    def __init__(self, model_name, num_node_features, nout, nhid, graph_hidden_channels):
        super(Model, self).__init__()
        self.graph_encoder = GraphEncoderAttention(num_node_features, nout, nhid, graph_hidden_channels)
        self.text_encoder = TextEncoderBaseline(model_name)

    def forward(self, graph_batch, input_ids, attention_mask):
        graph_encoded = self.graph_encoder(graph_batch)
        text_encoded = self.text_encoder(input_ids, attention_mask)
        return graph_encoded, text_encoded

    def get_text_encoder(self):
        return self.text_encoder

    def get_graph_encoder(self):
        return self.graph_encoder

## Pipeline

### Loss

In [None]:
CE = torch.nn.CrossEntropyLoss()

def contrastive_loss(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0], device=v1.device)
  return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)

### Text Tokenization

In [None]:
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphTextDataset(root='data', gt=gt, split='val', tokenizer=tokenizer)
train_dataset = GraphTextDataset(root='data', gt=gt, split='train', tokenizer=tokenizer)

### Train and Validation Pipeline

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nb_epochs = 40
batch_size = 16
learning_rate = 1e-5

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

model = Model(model_name=model_name, num_node_features=300, nout=768, nhid=500, graph_hidden_channels=500)
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-06)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)

epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000

In [None]:
for i in range(nb_epochs):
    print('------------------------- EPOCH {} -------------------------'.format(i+1))
    model.train()
    for batch in train_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch

        x_graph, x_text = model(graph_batch.to(device), input_ids.to(device), attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        scheduler.step()
        loss += current_loss.item()

        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, Training Loss: {2:.4f}".format(count_iter, time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0

    model.eval()
    val_loss = 0

    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        x_graph, x_text = model(graph_batch.to(device), input_ids.to(device), attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)
        val_loss += current_loss.item()

    best_validation_loss = min(best_validation_loss, val_loss)

    print('\nEpoch ' + str(i+1) + ' done.  Validation Loss: ', str(val_loss/len(val_loader)))

    if best_validation_loss==val_loss:
        print('Validation Loss Improoved')
        print('Saving Checkpoint...')
        save_path = os.path.join('./', 'model.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        print('Checkpoint Saved To: {}\n'.format(save_path))
    else:
        print('Validation Loss Not Improoved\n')

### Best Model

In [None]:
print('Loading Best Model...\n')

checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

print('Loading Best Model Done!')

### Test Pipeline

In [None]:
print('Text Embeddings...\n')

test_cids_dataset = GraphDataset(root='/content/data', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='/content/data/test_text.txt', tokenizer=tokenizer)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

test_loader = DataLoader(test_cids_dataset, batch_size=batch_size, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    for output in graph_model(batch.to(device)):
        graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
    for output in text_model(batch['input_ids'].to(device),
                             attention_mask=batch['attention_mask'].to(device)):
        text_embeddings.append(output.tolist())

print('Text Embeddings Done!')

### Sumbission File Generation

In [None]:
print('Creating Submission File...\n')

similarity = cosine_similarity(text_embeddings, graph_embeddings)

solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission.csv', index=False)

print('Submission File Ready!')