<center> ALTeGraD 2023 Data Challenge  - MVA 2023/2024
<center> Molecule Retrieval with Natural Language Queries
<center> Lilian Hunout $~~$ lilian.hunout@ens-paris-saclay.fr
<center> Samy Hocine $~~$ samy.hocine@ens-paris-saclay.fr
<center> Lucas Haubert $~~$ lucas.haubert@ens-paris-saclay.fr
<center> January 4, 2024

## Installations & Imports

In [None]:
!pip install torch_geometric

In [None]:
from accelerate import Accelerator
import gc
import numpy as np
import os
import os.path as osp
import pandas as pd
from transformers import AutoConfig
import torch
from torch import nn, optim
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATv2Conv, GINConv, global_mean_pool
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset as TorchDataset
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics import label_ranking_average_precision_score
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# change directory
%cd /kaggle/input/altegrad

## Dataloader

In [None]:
class GraphTextDataset(Dataset):
    def __init__(
        self,
        root,
        gt,
        split,
        processed,
        tokenizer=None,
        transform=None,
        pre_transform=None,
    ):
        self.root = root
        self.gt = gt
        self.split = split
        self.processed = processed
        self.tokenizer = tokenizer
        self.description = (
            pd.read_csv(os.path.join(self.root, f"{split}.tsv"), sep="\t", header=None)
            .set_index(0)
            .to_dict()
        )
        self.cids = list(self.description[1].keys())
        self.idx_to_cid = {i: cid for i, cid in enumerate(self.cids)}
        super(GraphTextDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return [f"{cid}.graph" for cid in self.cids]

    @property
    def processed_file_names(self):
        return [f"data_{cid}.pt" for cid in self.cids]

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, "raw")

    @property
    def processed_dir(self) -> str:
        return osp.join(self.processed, "processed", self.split)

    def download(self):
        pass

    def process_graph(self, raw_path):
        edge_index = []
        x = []
        with open(raw_path, "r") as f:
            next(f)
            for line in f:
                if line != "\n":
                    edge = tuple(map(int, line.split()))
                    edge_index.append(edge)
                else:
                    break
            next(f)
            for line in f:
                substruct_id = line.strip().split()[-1]
                x.append(self.gt.get(substruct_id, self.gt["UNK"]))
        return torch.tensor(edge_index).t().long(), torch.tensor(x).float()

    def process(self):
        data_list = []
        for raw_path in self.raw_paths:
            cid = int(os.path.basename(raw_path)[:-6])
            text_input = self.tokenizer(
                [self.description[1][cid]],
                return_tensors="pt",
                truncation=True,
                max_length=256,
                padding="max_length",
                add_special_tokens=True,
            )
            edge_index, x = self.process_graph(raw_path)
            data = Data(
                x=x,
                edge_index=edge_index,
                input_ids=text_input["input_ids"],
                attention_mask=text_input["attention_mask"],
            )
            data_list.append((cid, data))
        os.makedirs(self.processed_dir, exist_ok=True)
        for cid, data in data_list:
            torch.save(data, osp.join(self.processed_dir, f"data_{cid}.pt"))

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(
            osp.join(self.processed_dir, f"data_{self.idx_to_cid[idx]}.pt")
        )
        return data

    def get_cid(self, cid):
        data = torch.load(osp.join(self.processed_dir, f"data_{cid}.pt"))
        return data


class GraphDataset(Dataset):
    def __init__(self, root, gt, split, processed, transform=None, pre_transform=None):
        self.root = root
        self.gt = gt
        self.split = split
        self.processed = processed
        self.description = pd.read_csv(
            os.path.join(self.root, f"{split}.txt"), sep="\t", header=None
        )
        self.cids = self.description[0].tolist()
        self.idx_to_cid = {i: cid for i, cid in enumerate(self.cids)}
        super(GraphDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return [f"{cid}.graph" for cid in self.cids]

    @property
    def processed_file_names(self):
        return [f"data_{cid}.pt" for cid in self.cids]

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, "raw")

    @property
    def processed_dir(self) -> str:
        return osp.join(self.processed, "processed", self.split)

    def download(self):
        pass

    def process_graph(self, raw_path):
        edge_index = []
        x = []
        with open(raw_path, "r") as f:
            next(f)
            for line in f:
                if line != "\n":
                    edge = tuple(map(int, line.split()))
                    edge_index.append(edge)
                else:
                    break
            next(f)
            for line in f:
                substruct_id = line.strip().split()[-1]
                x.append(self.gt.get(substruct_id, self.gt["UNK"]))
        return torch.tensor(edge_index).t().long(), torch.tensor(x).float()

    def process(self):
        data_list = []
        for raw_path in self.raw_paths:
            cid = int(os.path.basename(raw_path)[:-6])
            edge_index, x = self.process_graph(raw_path)
            data = Data(x=x, edge_index=edge_index)
            data_list.append((cid, data))
        os.makedirs(self.processed_dir, exist_ok=True)
        for cid, data in data_list:
            torch.save(data, osp.join(self.processed_dir, f"data_{cid}.pt"))

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(
            osp.join(self.processed_dir, f"data_{self.idx_to_cid[idx]}.pt")
        )
        return data

    def get_cid(self, cid):
        data = torch.load(osp.join(self.processed_dir, f"data_{cid}.pt"))
        return data

    def get_idx_to_cid(self):
        return self.idx_to_cid


class TextDataset(TorchDataset):
    def __init__(self, file_path, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.sentences = self.load_sentences(file_path)

    def load_sentences(self, file_path):
        with open(file_path, "r", encoding="utf-8") as file:
            lines = file.readlines()
        return [line.strip() for line in lines]

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]

        encoding = self.tokenizer.encode_plus(
            sentence,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
        }

## Models

### Graph Encoders

In [None]:
# first model
class GraphEncoder(nn.Module):
    def __init__(self, num_node_features, nout, nhid, graph_hidden_channels):
        super(GraphEncoder, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.relu = nn.ReLU()
        self.ln = nn.LayerNorm((nout))
        self.conv1 = GCNConv(num_node_features, graph_hidden_channels)
        self.conv2 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.conv3 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.mol_hidden1 = nn.Linear(graph_hidden_channels, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nout)

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.mol_hidden1(x).relu()
        x = self.mol_hidden2(x)
        return x

In [None]:
class GCNEncoder(nn.Module):
    def __init__(self, num_node_features, nout, nhid, graph_hidden_channels):
        super(GCNEncoder, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.conv1 = GCNConv(num_node_features, graph_hidden_channels)
        self.conv2 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.conv3 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.bn1 = nn.BatchNorm1d(graph_hidden_channels)
        self.bn2 = nn.BatchNorm1d(graph_hidden_channels)
        self.bn3 = nn.BatchNorm1d(graph_hidden_channels)
        self.dropout = nn.Dropout(p=0.2)
        self.mol_hidden1 = nn.Linear(graph_hidden_channels, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nout)
        self.ln = nn.LayerNorm(nout)
        self.leaky_relu = nn.LeakyReLU()
        self.relu = nn.ReLU()

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        x = self.relu(x)
        x = self.bn3(x)
        x = self.dropout(x)
        x = global_mean_pool(x, batch)
        x = self.mol_hidden1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.mol_hidden2(x)
        x = self.ln(x)
        return x


class GATEncoder(nn.Module):
    def __init__(self, num_node_features, nout, nhid, graph_hidden_channels, heads=2):
        super(GATEncoder, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.conv1 = GATv2Conv(num_node_features, graph_hidden_channels, heads)
        self.conv2 = GATv2Conv(
            graph_hidden_channels * heads, graph_hidden_channels, heads
        )
        self.conv3 = GATv2Conv(
            graph_hidden_channels * heads, graph_hidden_channels, heads
        )
        self.bn1 = nn.BatchNorm1d(graph_hidden_channels * heads)
        self.bn2 = nn.BatchNorm1d(graph_hidden_channels * heads)
        self.bn3 = nn.BatchNorm1d(graph_hidden_channels * heads)
        self.dropout = nn.Dropout(p=0.2)
        self.mol_hidden1 = nn.Linear(graph_hidden_channels * heads, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nout)
        self.ln = nn.LayerNorm(nout)
        self.leaky_relu = nn.LeakyReLU()
        self.relu = nn.ReLU()

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        x = self.relu(x)
        x = self.bn3(x)
        x = self.dropout(x)
        x = global_mean_pool(x, batch)
        x = self.mol_hidden1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.mol_hidden2(x)
        x = self.ln(x)
        return x


class GINEncoder(nn.Module):
    def __init__(self, num_node_features, nout, nhid, graph_hidden_channels):
        super(GINEncoder, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.conv1 = GINConv(nn.Linear(num_node_features, graph_hidden_channels))
        self.conv2 = GINConv(nn.Linear(graph_hidden_channels, graph_hidden_channels))
        self.conv3 = GINConv(nn.Linear(graph_hidden_channels, graph_hidden_channels))
        self.bn1 = nn.BatchNorm1d(graph_hidden_channels)
        self.bn2 = nn.BatchNorm1d(graph_hidden_channels)
        self.bn3 = nn.BatchNorm1d(graph_hidden_channels)
        self.dropout = nn.Dropout(p=0.2)
        self.mol_hidden1 = nn.Linear(graph_hidden_channels, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nout)
        self.ln = nn.LayerNorm(nout)
        self.leaky_relu = nn.LeakyReLU()
        self.relu = nn.ReLU()

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        x = self.conv1(x, edge_index)
        x = self.leaky_relu(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        x = self.relu(x)
        x = self.bn3(x)
        x = self.dropout(x)
        x = global_mean_pool(x, batch)
        x = self.mol_hidden1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.mol_hidden2(x)
        x = self.ln(x)
        return x

### Text Encoder

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, model_name):
        super(TextEncoder, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        encoded_text = self.bert(input_ids, attention_mask=attention_mask)
        # print(encoded_text.last_hidden_state.size())
        return encoded_text.last_hidden_state[:, 0, :]


# To freeze part of the model

# distilbert
# class TextEncoder(nn.Module):
#     def __init__(self, model_name, freeze_blocks=5):
#         super(TextEncoder, self).__init__()
#         self.bert = AutoModel.from_pretrained(model_name)
#         self.freeze_blocks = freeze_blocks

#         # Freeze layers
#         for layer in self.bert.transformer.layer[:self.freeze_blocks]:
#             for param in layer.parameters():
#                 param.requires_grad = False

#     def forward(self, input_ids, attention_mask):
#         encoded_text = self.bert(input_ids, attention_mask=attention_mask)
#         return encoded_text.last_hidden_state[:,0,:]


# scibert
# class TextEncoder(nn.Module):
#     def __init__(self, model_name, freeze_layers=False):
#         super(TextEncoder, self).__init__()
#         self.bert = AutoModel.from_pretrained(model_name)
#         if freeze_layers:
#             for param in self.bert.parameters():
#                 param.requires_grad = False
#             for param in self.bert.encoder.layer[-1].parameters():
#                 param.requires_grad = True

#     def forward(self, input_ids, attention_mask):
#         encoded_text = self.bert(input_ids, attention_mask=attention_mask)
#         return encoded_text.last_hidden_state[:,0,:]

### Global Model

In [None]:
class Model(nn.Module):
    def __init__(
        self, model_name, num_node_features, nout, nhid, graph_hidden_channels
    ):
        super(Model, self).__init__()
        self.graph_encoder = GCNEncoder(
            num_node_features, nout, nhid, graph_hidden_channels
        )
        self.text_encoder = TextEncoder(model_name)

    def forward(self, graph_batch, input_ids, attention_mask):
        graph_encoded = self.graph_encoder(graph_batch)
        text_encoded = self.text_encoder(input_ids, attention_mask)
        return graph_encoded, text_encoded

    def get_text_encoder(self):
        return self.text_encoder

    def get_graph_encoder(self):
        return self.graph_encoder

## Training Loop

### Loading Data

In [None]:
models = [
    ["distilbert-base-uncased", 768],
    ["WhereIsAI/UAE-Large-V1", 1024],
    ["allenai/scibert_scivocab_uncased", 768],
    ["GT4SD/multitask-text-and-chemistry-t5-base-augm", 768],
]
model_name, nout = models[0]

tokenizer = AutoTokenizer.from_pretrained(model_name)

config = AutoConfig.from_pretrained(model_name)
# Access the hidden size
hidden_size = config.hidden_size
print(f"Hidden Size: {hidden_size}")

gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphTextDataset(
    root="./data/",
    gt=gt,
    split="val",
    processed="/kaggle/working/data/",
    tokenizer=tokenizer,
)
train_dataset = GraphTextDataset(
    root="./data/",
    gt=gt,
    split="train",
    processed="/kaggle/working/data/",
    tokenizer=tokenizer,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

### Loss

In [None]:
CE = torch.nn.CrossEntropyLoss()


def contrastive_loss(v1, v2):
    logits = torch.matmul(v1, torch.transpose(v2, 0, 1))
    labels = torch.arange(logits.shape[0], device=v1.device)
    return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)

### Loading Model

In [None]:
model = Model(
    model_name=model_name,
    num_node_features=300,
    nout=nout,
    nhid=300,
    graph_hidden_channels=300,
)  # nout = bert model hidden dim

# Check if multiple GPUs are available
# if torch.cuda.device_count() > 1:
#     print("Using", torch.cuda.device_count(), "GPUs!")
#     model = nn.DataParallel(model)

save_path = "/kaggle/input/altegrad-weights/best_model_lrap_gcn.pt"
print("loading best model...")
model.load_state_dict(torch.load(save_path)["model_state_dict"])

model.to(device)

In [None]:
# Retrieving model parameters
graph_encoder_params = list(model.graph_encoder.parameters())
text_encoder_params = list(model.text_encoder.parameters())

In [None]:
best_validation_loss = float("inf")
best_lrap_score = 0

In [None]:
# Hyperparameters
nb_epochs = 10
text_lr = 3e-5
init_lr = 1e-4
print_every = 25


# Initialize Accelerator and GradScaler
accelerator = Accelerator()
device = accelerator.device
scaler = GradScaler()

optimizer = optim.AdamW(
    [{"params": graph_encoder_params}, {"params": text_encoder_params, "lr": text_lr}],
    lr=init_lr,
    betas=(0.9, 0.999),
    weight_decay=0.01,
)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# Move model and optimizer to device
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
    model, optimizer, train_loader, val_loader, scheduler
)


# Define function for training loop
def train(model, optimizer, train_loader, scaler):
    model.train()
    loss = 0
    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        input_ids = batch.input_ids.to(accelerator.device)
        batch.pop("input_ids")
        attention_mask = batch.attention_mask.to(accelerator.device)
        batch.pop("attention_mask")
        graph_batch = batch.to(accelerator.device)
        with autocast():
            x_graph, x_text = model(graph_batch, input_ids, attention_mask)
            current_loss = contrastive_loss(x_graph, x_text)
        scaler.scale(current_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loss += current_loss.item()
        if (batch_idx + 1) % print_every == 0:
            print(
                "Iteration: {}, Training loss: {:.4f}".format(
                    batch_idx + 1, loss / print_every
                )
            )
            loss = 0


# Define function for validation loop
def validate(model, val_loader):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        graph_embeddings = []
        text_embeddings = []
        for batch in val_loader:
            input_ids = batch.input_ids.to(accelerator.device)
            batch.pop("input_ids")
            attention_mask = batch.attention_mask.to(accelerator.device)
            batch.pop("attention_mask")
            graph_batch = batch.to(accelerator.device)
            with autocast():
                x_graph, x_text = model(graph_batch, input_ids, attention_mask)
                current_loss = contrastive_loss(x_graph, x_text)
            val_loss += current_loss.item()

            graph_embeddings.extend(x.tolist() for x in x_graph)
            text_embeddings.extend(x.tolist() for x in x_text)
        #         If you prefer to use the dot product directly with PyTorch
        #         text_tensor = torch.tensor(text_embeddings)
        #         graph_tensor = torch.tensor(graph_embeddings)
        #         similarity = torch.matmul(text_tensor, torch.transpose(graph_tensor, 0, 1))
        similarity = cosine_similarity(text_embeddings, graph_embeddings)
        gt = np.identity(len(similarity))
        lrap_score = label_ranking_average_precision_score(gt, similarity)
    return val_loss / len(val_loader), lrap_score


# Training loop
for epoch in range(nb_epochs):
    print("-----EPOCH {}-----".format(epoch + 1))
    train(model, optimizer, train_loader, scaler)
    val_loss, lrap_score = validate(model, val_loader)
    print("Validation loss:", val_loss)
    print("LRAP: ", lrap_score)
    scheduler.step()
    if val_loss < best_validation_loss:
        best_validation_loss = val_loss
        print("validation loss improoved saving checkpoint...")
        save_path = os.path.join("/kaggle/working/", "best_model_loss.pt")
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "validation_loss": val_loss,
                "validation_lrap": lrap_score,
            },
            save_path,
        )
        print("checkpoint saved to: {}".format(save_path))
    if lrap_score > best_lrap_score:
        best_lrap_score = lrap_score
        print("validation loss improoved saving checkpoint...")
        save_path = os.path.join("/kaggle/working/", "best_model_lrap.pt")
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "validation_loss": val_loss,
                "validation_lrap": lrap_score,
            },
            save_path,
        )
        print("checkpoint saved to: {}".format(save_path))

## Evaluation

In [None]:
# Force to recover GPU memory
gc.collect()
torch.cuda.empty_cache()

In [None]:
# save_path = "/kaggle/working/best_model_lrap.pt"
# print("loading best model...")
# model.load_state_dict(torch.load(save_path)["model_state_dict"])
model.eval()

graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

test_cids_dataset = GraphDataset(
    root="./data/", gt=gt, split="test_cids", processed="/kaggle/working/data/"
)
test_text_dataset = TextDataset(file_path="./data/test_text.txt", tokenizer=tokenizer)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

batch_size_test = 32
test_loader = DataLoader(test_cids_dataset, batch_size=batch_size_test, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    for output in graph_model(batch.to(device)):
        graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(
    test_text_dataset, batch_size=batch_size_test, shuffle=False
)
text_embeddings = []
for batch in test_text_loader:
    for output in text_model(
        batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device)
    ):
        text_embeddings.append(output.tolist())

similarity = cosine_similarity(text_embeddings, graph_embeddings)

solution = pd.DataFrame(similarity)
solution["ID"] = solution.index
solution = solution[["ID"] + [col for col in solution.columns if col != "ID"]]
solution.to_csv("/kaggle/working/submission.csv", index=False)