In [None]:
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
import wandb
import os

In [None]:
wandb.require("service")

### Upload config

In [None]:
import yaml

config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
print(config)

In [None]:
print('batch_size =', config['batch_size'])

In [None]:
print('running on device:', config['gpu'])
device = torch.device(config['gpu']) if torch.cuda.is_available() else torch.device('cpu')

In [None]:
def _save_config_file(config, log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    with open(os.path.join(log_dir, 'config.yml'), 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False, sort_keys=False)

### Upload and Split Dataset

In [None]:
dataframe = pd.read_csv("data_10k.csv")

In [None]:
dataframe = dataframe.drop(columns=['ecfp2', 'ecfp3', 'Molecular Weight', 'Bioactivities', 'AlogP', 'Polar Surface Area', 'CX Acidic pKa', 'CX Basic pKa'])

In [None]:
dataframe

In [None]:
# this because pandas thinks columns with arrays are strings
def preprocess_data_dataset(df, column):
    for row in tqdm(range(len(df))):
        str_ints = eval(df.iloc[row][column])
        str_fingerprint = ' '.join(str_ints)
        df.at[row, column] = str_fingerprint

In [None]:
preprocess_data_dataset(dataframe, 'ecfp1')

### Create Molecule Dataset
##### It will generate torch_geometric.data.Data objects for both bert and GIN/GCN models.

In [None]:
from rdkit import Chem

ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
    Chem.rdchem.BondType.SINGLE, 
    Chem.rdchem.BondType.DOUBLE, 
    Chem.rdchem.BondType.TRIPLE, 
    Chem.rdchem.BondType.AROMATIC
]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]

In [None]:
import random
import math
from copy import deepcopy
from torch_geometric.data import Data, Dataset

class MoleculeDataset(Dataset):
    def __init__(self, dataset: pd.DataFrame, tokenizer, node_mask_percent=0.25, edge_mask_percent=0.25):
        super(Dataset, self).__init__()
        self.dataset = dataset
        self.node_mask_percent = node_mask_percent
        self.edge_mask_percent = edge_mask_percent

        self.tokenizer = tokenizer
        self.tokenizer.model_max_len = 512

    def get_graph_from_smiles(self, smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return torch.tensor([[], []], dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    0
    
        N = mol.GetNumAtoms()
        M = mol.GetNumBonds()
    
        type_idx = []
        chirality_idx = []
        atomic_number = []
        
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())
        
        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
        node_feat = torch.cat([x1, x2], dim=-1)
    
        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
    
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.long)
        num_nodes = N
        num_edges = M
        return node_feat, edge_index, edge_attr, num_nodes, num_edges

    def get_augmented_graph_copy(self, node_feat, edge_index, edge_attr, N, M):
        num_mask_nodes = max([1, math.floor(self.node_mask_percent * N)])
        num_mask_edges = max([0, math.floor(self.edge_mask_percent * M)])
        
        mask_nodes = random.sample(list(range(N)), num_mask_nodes)
        mask_edges_single = random.sample(list(range(M)), num_mask_edges)
        mask_edges = [2*i for i in mask_edges_single] + [2*i+1 for i in mask_edges_single]

        node_feat_new = deepcopy(node_feat)
        for atom_idx in mask_nodes:
            node_feat_new[atom_idx, :] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_new = torch.zeros((2, 2*(M - num_mask_edges)), dtype=torch.long)
        edge_attr_new = torch.zeros((2*(M - num_mask_edges), 2), dtype=torch.long)
        count = 0
        for bond_idx in range(2*M):
            if bond_idx not in mask_edges:
                edge_index_new[:, count] = edge_index[:, bond_idx]
                edge_attr_new[count, :] = edge_attr[bond_idx, :]
                count += 1
        return Data(x=node_feat_new, edge_index=edge_index_new, edge_attr=edge_attr_new)

    def tokenize(self, item):
        return self.tokenizer(item, truncation=True, max_length=512, padding='max_length')

    def mlm(self, tensor):
        rand = torch.rand(tensor.shape)
        # mask random 15% where token is not 0 <s>, 1 <pad>, or 2 <s/>
        mask_arr = (rand < .15) * (tensor != 0) * (tensor != 1) * (tensor != 2)
        selection = torch.flatten(mask_arr.nonzero()).tolist()
        # mask tensor, token == 4 is our mask token
        tensor[selection] = 4
        return tensor

    def apply_mlm(self, sample):
        labels = torch.tensor(sample.input_ids)
        attention_mask = torch.tensor(sample.attention_mask)
        input_ids = self.mlm(labels.detach().clone())
        return Data(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def __getitem__(self, index):
        node_feat, edge_index, edge_attr, num_nodes, num_edges = self.get_graph_from_smiles(self.dataset['Smiles'][index])

        data_i = self.get_augmented_graph_copy(node_feat, edge_index, edge_attr, num_nodes, num_edges)
        data_j = self.get_augmented_graph_copy(node_feat, edge_index, edge_attr, num_nodes, num_edges)

        ecfp = self.dataset['ecfp1'][index]
        data_for_bert = self.apply_mlm(self.tokenize(ecfp))
        return data_for_bert, data_i, data_j

    def __len__(self):
        return len(self.dataset)

    def get(self):
        pass
    def len(self):
        pass

In [None]:
from transformers import AutoTokenizer

model_name_bert = 'molberto_ecfp0_2M'
tokenizer = AutoTokenizer.from_pretrained(model_name_bert)
dataset = MoleculeDataset(dataframe, tokenizer)

In [None]:
from torch_geometric.loader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_train = len(dataset)
indices = list(range(num_train))
np.random.shuffle(indices)

split = int(np.floor(config['dataset']['valid_size'] * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_dataloader = DataLoader(
    dataset, batch_size=config['batch_size'], sampler=train_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)

eval_dataloader = DataLoader(
    dataset, batch_size=config['batch_size'], sampler=valid_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)

### Create Transformer Model

In [None]:
import torch
import numpy as np


class NTXentLoss(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        representations = torch.cat([zjs, zis], dim=0)

        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits = logits.abs() + 0.0001
        logits = torch.log(logits)
        logits /= self.temperature
        
        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        loss = self.criterion(logits, labels)

        return loss / (2 * self.batch_size)

In [None]:
from transformers import RobertaForMaskedLM
from transformers import RobertaConfig

if config['graph_model_type'] == 'gin':
    from MolCLR.models.ginet_molclr import GINet as GraphModel
elif config['graph_model_type'] == 'gcn':
    from MolCLR.models.gcn_molclr import GCN as GraphModel
else:
    raise ValueError('GNN model is not defined in config.')

class MolecularBertGraph(torch.nn.Module):
    def __init__(self):
        super(MolecularBertGraph, self).__init__()
        self.batch_size = config['batch_size']

        roberta_config = RobertaConfig(**config['roberta_model'])
        self.bert = RobertaForMaskedLM(roberta_config).from_pretrained(config['pretrained_roberta_name'], 
                                                                       config=roberta_config)
        self.graph_model = GraphModel(**config['graph_model'])
        # self.graph_model = self._load_graph_pretrained_weights(self.graph_model)

        self.out_graph_linear = torch.nn.Linear(2 * config['graph_model']['feat_dim'], 
                                                config['roberta_model']['hidden_size'], bias=True)
        # contrastive loss for MolCLR
        self.nt_xent_criterion = NTXentLoss(device, self.batch_size, **config['ntxent_loss'])
        # cosine distance as loss between models
        self.cosine_sim = torch.nn.CosineSimilarity(dim=-1)

    def forward(self, bert_batch, graph_batch1, graph_batch2):
        bert_output = self.bert(input_ids=bert_batch['input_ids'].view(self.batch_size, -1), 
                                 attention_mask=bert_batch['attention_mask'].view(self.batch_size, -1),
                                 labels=bert_batch['labels'].view(self.batch_size, -1), output_hidden_states=True)
        bert_loss = bert_output.loss
        bert_emb = bert_output.hidden_states[0][:, 0, :] # take emb for CLS token

        graph_loss, hidden_states_1, hidden_states_2 = self.graph_step(graph_batch1, graph_batch2)
        graph_emb = self.out_graph_linear(torch.cat((hidden_states_1, hidden_states_2), dim=-1))

        # bimodal_loss = ((1 - self.cosine_sim(bert_emb, graph_emb))**2).mean()
        bimodal_loss = self.nt_xent_criterion(bert_emb, graph_emb)
        return bert_loss, graph_loss, bimodal_loss

    def graph_step(self, xis, xjs):
        # get the representations and the projections
        ris, zis = self.graph_model(xis)  # [N,C]
    
        # get the representations and the projections
        rjs, zjs = self.graph_model(xjs)  # [N,C]
    
        # normalize projection feature vectors
        zis = torch.nn.functional.normalize(zis, dim=1)
        zjs = torch.nn.functional.normalize(zjs, dim=1)

        loss = self.nt_xent_criterion(zis, zjs)
        return loss, ris, rjs
        
    def _load_graph_pretrained_weights(self, model):
        try:
            checkpoints_folder = os.path.join('MolCLR', 'ckpt', config['load_graph_model'], 'checkpoints')
            print(os.path.join(checkpoints_folder, 'model.pth'))
            state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
            
            model.load_state_dict(state_dict)
            print("Loaded pre-trained model with success.")
        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")

        return model

In [None]:
model = MolecularBertGraph().to(device)

In [None]:
print(model)

### Define utils

In [None]:
num_epoch = config['epochs']

optimizer = torch.optim.Adam(
    model.parameters(), config['init_lr'], 
    weight_decay=eval(config['weight_decay'])
)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     optimizer, T_max=config['epochs']-config['warm_up'], 
#     eta_min=0, last_epoch=-1
# )

In [None]:
wandb.init(
    project="efcp_transformer",
    name="Pretrained RobertaForMaskedLM + pretrained MolCLR (GCN) 10k",
    config=config
)

### Training (with validation)

In [None]:
alpha = config['loss_params']['alpha']
beta = config['loss_params']['beta']
gamma = config['loss_params']['gamma']

In [None]:
epoch_counter = 0

In [None]:
def train_loop():
    train_tqdm = tqdm(train_dataloader, unit="batch")
    train_tqdm.set_description(f'Epoch {epoch_counter}')
    bert_loss_sum, graph_model_loss_sum, bimodal_loss_sum, loss_sum = 0, 0, 0, 0
    
    model.train()
    for (bert_batch, graph_batch1, graph_batch2) in train_tqdm:
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)

        bert_loss, graph_loss, bimodal_loss = model(bert_batch, graph_batch1, graph_batch2)

        loss = alpha * bert_loss + beta * graph_loss + gamma * bimodal_loss
        loss.backward()

        bert_loss_sum += bert_loss.item()
        graph_model_loss_sum += graph_loss.item()
        bimodal_loss_sum += bimodal_loss.item()
        loss_sum += loss.item()

        optimizer.step()
        train_tqdm.set_postfix(loss=loss.item(), bert_loss=bert_loss.item(), graph_loss=graph_loss.item(), bimodal_loss=bimodal_loss.item())
    return bert_loss_sum / len(train_dataloader), graph_model_loss_sum / len(train_dataloader), bimodal_loss_sum / len(train_dataloader), loss_sum / len(train_dataloader)

In [None]:
def eval_loop():
    eval_tqdm = tqdm(eval_dataloader, unit="batch")
    eval_tqdm.set_description(f'Epoch {epoch_counter}')
    bert_loss_sum, graph_model_loss_sum, bimodal_loss_sum, loss_sum = 0, 0, 0, 0
    
    model.eval()
    for (bert_batch, graph_batch1, graph_batch2) in eval_tqdm:
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)

        with torch.no_grad():
            bert_loss, graph_loss, bimodal_loss = model(bert_batch, graph_batch1, graph_batch2)

        loss = alpha * bert_loss + beta * graph_loss + gamma * bimodal_loss

        bert_loss_sum += bert_loss.item()
        graph_model_loss_sum += graph_loss.item()
        bimodal_loss_sum += bimodal_loss.item()
        loss_sum += loss.item()

        eval_tqdm.set_postfix(loss=loss.item(), bert_loss=bert_loss.item(), graph_loss=graph_loss.item(), bimodal_loss=bimodal_loss.item())
    return bert_loss_sum / len(eval_dataloader), graph_model_loss_sum / len(eval_dataloader), bimodal_loss_sum / len(eval_dataloader), loss_sum / len(eval_dataloader)

In [None]:
bert_loss, graph_loss, bimodal_loss, loss = eval_loop()

In [None]:
print('bert_loss =', bert_loss)
print('graph_loss = ', graph_loss)
print('bimodal_loss =', bimodal_loss)
print('sum of losses =', loss)

In [None]:
from datetime import datetime

model_checkpoints_folder = os.path.join('ckpts')
dir_name = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(model_checkpoints_folder, dir_name)
_save_config_file(config, log_dir)

In [None]:
n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf

for epoch_counter in range(num_epoch):
    bert_loss, graph_loss, bimodal_loss, loss = train_loop()
    print('train', bert_loss, graph_loss, bimodal_loss, loss)
    
    wandb.log({"bert_loss/train": bert_loss}, step=epoch_counter)
    wandb.log({"graph_loss/train": graph_loss}, step=epoch_counter)
    wandb.log({"bimodal_loss/train": bimodal_loss}, step=epoch_counter)
    wandb.log({"loss/train": loss}, step=epoch_counter)

    bert_loss, graph_loss, bimodal_loss, loss = eval_loop()
    print('eval', bert_loss, graph_loss, bimodal_loss, loss)

    wandb.log({"bert_loss/eval": bert_loss}, step=epoch_counter)
    wandb.log({"graph_loss/eval": graph_loss}, step=epoch_counter)
    wandb.log({"bimodal_loss/eval": bimodal_loss}, step=epoch_counter)
    wandb.log({"loss/eval": loss}, step=epoch_counter)
    
    if loss < best_valid_loss:
        best_valid_loss = loss
        torch.save(model.state_dict(), os.path.join(log_dir, 'model.pth'))
    
    if (epoch_counter + 1) % config['save_every_n_epochs'] == 0:
        torch.save(model.state_dict(), os.path.join(log_dir, 'model_{}.pth'.format(str(epoch_counter))))

    # # warmup for the first few epochs
    # if epoch_counter >= config['warm_up']:
        # wandb.log({"cosine_lr_decay": scheduler.get_last_lr()[0]}, step=epoch_counter)
        # scheduler.step()

In [None]:
wandb.finish()