In [1]:
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
import wandb
import os

In [2]:
wandb.require("service")

In [3]:
!python3 -m wandb login eb7b1964fb84cd81de96b2a273ecf2bb6254aeac

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/alexeyorlov53/.netrc


### Upload config

In [4]:
import yaml

config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
print(config)

{'batch_size': 32, 'warm_up': 10, 'epochs': 100, 'load_model': 'None', 'save_every_n_epochs': 5, 'fp16_precision': False, 'init_lr': 0.0005, 'weight_decay': '1e-5', 'gpu': 'cuda:1', 'model_type': 'gcn', 'model': {'num_layer': 5, 'emb_dim': 300, 'feat_dim': 768, 'drop_ratio': 0, 'pool': 'mean'}, 'aug': 'node', 'dataset': {'num_workers': 12, 'valid_size': 0.1, 'data_path': 'data/pubchem-10m-clean.txt'}, 'loss': {'temperature': 0.1, 'use_cosine_similarity': True}, 'loss_params': {'alpha': 1.0, 'beta': 1.0, 'gamma': 1.0}}


In [5]:
print('batch_size =', config['batch_size'])

batch_size = 32


In [6]:
print('running on device:', config['gpu'])
device = torch.device(config['gpu']) if torch.cuda.is_available() else torch.device('cpu')

running on device: cuda:1


In [7]:
def _save_config_file(config, log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    with open(os.path.join(log_dir, 'config.yml'), 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False, sort_keys=False)

### Upload and Split Dataset

In [8]:
dataframe = pd.read_csv("data_10k.csv")

In [9]:
dataframe = dataframe.drop(columns=['ecfp2', 'ecfp3', 'Molecular Weight', 'Bioactivities', 'AlogP', 'Polar Surface Area', 'CX Acidic pKa', 'CX Basic pKa'])

In [10]:
dataframe

Unnamed: 0,Smiles,ecfp1
0,COc1cc(C2(C)CCCc3nc(SCc4ncccn4)n(-c4ccc(F)cc4)...,"['2246728737', '864674487', '3217380708', '321..."
1,COC(=O)c1sc(NC(=O)C2c3ccccc3Oc3ccccc32)c(C(=O)...,"['2246728737', '864674487', '2246699815', '864..."
2,CC[C@H]1OC(=O)C[C@@H](O)[C@H](C)[C@@H](O[C@@H]...,"['2246728737', '2245384272', '2976033787', '31..."
3,Cc1cccc(-n2cc(C(=O)N3CCC[C@@H]([n+]4cc[nH]c4)C...,"['2246728737', '3217380708', '3218693969', '32..."
4,CCOC(=O)[C@H](C1CC1)N1C(=O)[C@@H](CC(=O)O)C[C@...,"['2246728737', '2245384272', '864674487', '224..."
...,...,...
9995,CCN1CCN(CC(O)c2ccc(Br)cc2)CC1,"['2246728737', '2245384272', '2092489639', '29..."
9996,O=C(O)CNC(=O)CNC(=O)CNC(=O)CSC(=O)c1ccccc1,"['864942730', '2246699815', '864662311', '2245..."
9997,O=C(N[C@@]12CCC[C@@](C#Cc3ccccn3)(CC1)C2)c1ccc...,"['864942730', '2246699815', '847961216', '2976..."
9998,CCOc1ccccc1-c1cc(C(=O)N2CCOCC2)c2ccccc2n1,"['2246728737', '2245384272', '864674487', '321..."


In [11]:
# this because pandas thinks columns with arrays are strings
def preprocess_data_dataset(df, column):
    for row in tqdm(range(len(df))):
        str_ints = eval(df.iloc[row][column])
        str_fingerprint = ' '.join(str_ints)
        df.at[row, column] = str_fingerprint

In [12]:
preprocess_data_dataset(dataframe, 'ecfp1')

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 8716.30it/s]


### Create Molecule Dataset
##### It will generate torch_geometric.data.Data objects for both bert and GIN/GCN models.

In [13]:
from rdkit import Chem

ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
    Chem.rdchem.BondType.SINGLE, 
    Chem.rdchem.BondType.DOUBLE, 
    Chem.rdchem.BondType.TRIPLE, 
    Chem.rdchem.BondType.AROMATIC
]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]

In [14]:
import random
import math
from copy import deepcopy
from torch_geometric.data import Data, Dataset

class MoleculeDataset(Dataset):
    def __init__(self, dataset: pd.DataFrame, tokenizer, node_mask_percent=0.25, edge_mask_percent=0.25):
        super(Dataset, self).__init__()
        self.dataset = dataset
        self.node_mask_percent = node_mask_percent
        self.edge_mask_percent = edge_mask_percent

        self.tokenizer = tokenizer
        self.tokenizer.model_max_len = 512

    def get_graph_from_smiles(self, smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return torch.tensor([[], []], dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    0
    
        N = mol.GetNumAtoms()
        M = mol.GetNumBonds()
    
        type_idx = []
        chirality_idx = []
        atomic_number = []
        
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())
        
        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
        node_feat = torch.cat([x1, x2], dim=-1)
    
        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
    
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.long)
        num_nodes = N
        num_edges = M
        return node_feat, edge_index, edge_attr, num_nodes, num_edges

    def get_augmented_graph_copy(self, node_feat, edge_index, edge_attr, N, M):
        num_mask_nodes = max([1, math.floor(self.node_mask_percent * N)])
        num_mask_edges = max([0, math.floor(self.edge_mask_percent * M)])
        
        mask_nodes = random.sample(list(range(N)), num_mask_nodes)
        mask_edges_single = random.sample(list(range(M)), num_mask_edges)
        mask_edges = [2*i for i in mask_edges_single] + [2*i+1 for i in mask_edges_single]

        node_feat_new = deepcopy(node_feat)
        for atom_idx in mask_nodes:
            node_feat_new[atom_idx, :] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_new = torch.zeros((2, 2*(M - num_mask_edges)), dtype=torch.long)
        edge_attr_new = torch.zeros((2*(M - num_mask_edges), 2), dtype=torch.long)
        count = 0
        for bond_idx in range(2*M):
            if bond_idx not in mask_edges:
                edge_index_new[:, count] = edge_index[:, bond_idx]
                edge_attr_new[count, :] = edge_attr[bond_idx, :]
                count += 1
        return Data(x=node_feat_new, edge_index=edge_index_new, edge_attr=edge_attr_new)

    def tokenize(self, item):
        return self.tokenizer(item, truncation=True, max_length=512, padding='max_length')

    def mlm(self, tensor):
        rand = torch.rand(tensor.shape)
        # mask random 15% where token is not 0 <s>, 1 <pad>, or 2 <s/>
        mask_arr = (rand < .15) * (tensor != 0) * (tensor != 1) * (tensor != 2)
        selection = torch.flatten(mask_arr.nonzero()).tolist()
        # mask tensor, token == 4 is our mask token
        tensor[selection] = 4
        return tensor

    def apply_mlm(self, sample):
        labels = torch.tensor(sample.input_ids)
        attention_mask = torch.tensor(sample.attention_mask)
        input_ids = self.mlm(labels.detach().clone())
        return Data(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def __getitem__(self, index):
        node_feat, edge_index, edge_attr, num_nodes, num_edges = self.get_graph_from_smiles(self.dataset['Smiles'][index])

        data_i = self.get_augmented_graph_copy(node_feat, edge_index, edge_attr, num_nodes, num_edges)
        data_j = self.get_augmented_graph_copy(node_feat, edge_index, edge_attr, num_nodes, num_edges)

        ecfp = self.dataset['ecfp1'][index]
        data_for_bert = self.apply_mlm(self.tokenize(ecfp))
        return data_for_bert, data_i, data_j

    def __len__(self):
        return len(self.dataset)

    def get(self):
        pass
    def len(self):
        pass

In [15]:
from transformers import AutoTokenizer

model_name_bert = 'molberto_ecfp0_2M'
tokenizer = AutoTokenizer.from_pretrained(model_name_bert)
dataset = MoleculeDataset(dataframe, tokenizer)

In [16]:
from torch_geometric.loader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_train = len(dataset)
indices = list(range(num_train))
np.random.shuffle(indices)

split = int(np.floor(config['dataset']['valid_size'] * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_dataloader = DataLoader(
    dataset, batch_size=config['batch_size'], sampler=train_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)

eval_dataloader = DataLoader(
    dataset, batch_size=config['batch_size'], sampler=valid_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)

### Create Transformer Model

In [17]:
from transformers import RobertaForMaskedLM
from transformers import RobertaConfig

if config['model_type'] == 'gin':
    from MolCLR.models.ginet_molclr import GINet as GraphModel
elif config['model_type'] == 'gcn':
    from MolCLR.models.gcn_molclr import GCN as GraphModel
else:
    raise ValueError('GNN model is not defined in config.')
from MolCLR.utils.nt_xent import NTXentLoss

class MolecularBertGraph(torch.nn.Module):
    def __init__(self):
        super(MolecularBertGraph, self).__init__()
        self.batch_size = config['batch_size']

        roberta_config = RobertaConfig(
            vocab_size=30_522,
            max_position_embeddings=514,
            hidden_size=768,
            num_attention_heads=12,
            num_hidden_layers=6,
            type_vocab_size=1
        )
        self.bert = RobertaForMaskedLM(roberta_config)
        
        self.graph_model = GraphModel(**config['model'])
        # self.graph_model = self._load_pre_trained_weights(self.graph_model)

        self.out_graph_linear = torch.nn.Linear(768 * 2, 768, bias=True)

        # contrastive loss for MolCLR
        self.nt_xent_criterion = NTXentLoss(device, self.batch_size, **config['loss'])
        # cosine distance as loss between models
        self.cosine_sim = torch.nn.CosineSimilarity(dim=-1)

    def forward(self, bert_batch, graph_batch1, graph_batch2):
        bert_output = self.bert(input_ids=bert_batch['input_ids'].view(self.batch_size, -1), 
                                 attention_mask=bert_batch['attention_mask'].view(self.batch_size, -1),
                                 labels=bert_batch['labels'].view(self.batch_size, -1), output_hidden_states=True)
        bert_loss = bert_output.loss
        bert_emb = bert_output.hidden_states[0][:, 0, :] # take emb for CLS token

        graph_loss, hidden_states_1, hidden_states_2 = self.graph_step(graph_batch1, graph_batch2)
        graph_emb = self.out_graph_linear(torch.cat((hidden_states_1, hidden_states_2), dim=-1))

        bimodal_loss = ((1 - self.cosine_sim(bert_emb, graph_emb))**2).mean()
        return bert_loss, graph_loss, bimodal_loss

    def graph_step(self, xis, xjs):
        # get the representations and the projections
        ris, zis = self.graph_model(xis)  # [N,C]
    
        # get the representations and the projections
        rjs, zjs = self.graph_model(xjs)  # [N,C]
    
        # normalize projection feature vectors
        zis = torch.nn.functional.normalize(zis, dim=1)
        zjs = torch.nn.functional.normalize(zjs, dim=1)
    
        loss = self.nt_xent_criterion(zis, zjs)
        return loss, ris, rjs
        

In [18]:
model = MolecularBertGraph().to(device)

### Define utils

In [19]:
num_epoch = config['epochs']

optimizer = torch.optim.Adam(
    model.parameters(), config['init_lr'], 
    weight_decay=eval(config['weight_decay'])
)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     optimizer, T_max=config['epochs']-config['warm_up'], 
#     eta_min=0, last_epoch=-1
# )

In [20]:
wandb.init(
    project="efcp_transformer",
    name="RobertaForMaskedLM + MolCLR (GCN) equal_coeffs",
    config={}
)

[34m[1mwandb[0m: Currently logged in as: [33morlov-aleksei53[0m ([33mmoleculary-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Training (with validation)

In [21]:
alpha = config['loss_params']['alpha']
beta = config['loss_params']['beta']
gamma = config['loss_params']['gamma']

In [22]:
epoch_counter = 0

In [23]:
def train_loop():
    train_tqdm = tqdm(train_dataloader, unit="batch")
    train_tqdm.set_description(f'Epoch {epoch_counter}')
    bert_loss_sum, graph_model_loss_sum, bimodal_loss_sum, loss_sum = 0, 0, 0, 0
    
    model.train()
    for (bert_batch, graph_batch1, graph_batch2) in train_tqdm:
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)

        bert_loss, graph_loss, bimodal_loss = model(bert_batch, graph_batch1, graph_batch2)

        loss = alpha * bert_loss + beta * graph_loss + gamma * bimodal_loss
        loss.backward()

        bert_loss_sum += bert_loss.item()
        graph_model_loss_sum += graph_loss.item()
        bimodal_loss_sum += bimodal_loss.item()
        loss_sum += loss.item()

        optimizer.step()
        train_tqdm.set_postfix(loss=loss.item(), bert_loss=bert_loss.item(), graph_loss=graph_loss.item(), bimodal_loss=bimodal_loss.item())
    return bert_loss_sum / len(train_dataloader), graph_model_loss_sum / len(train_dataloader), bimodal_loss_sum / len(train_dataloader), loss_sum / len(train_dataloader)

In [24]:
def eval_loop():
    eval_tqdm = tqdm(eval_dataloader, unit="batch")
    eval_tqdm.set_description(f'Epoch {epoch_counter}')
    bert_loss_sum, graph_model_loss_sum, bimodal_loss_sum, loss_sum = 0, 0, 0, 0
    
    model.eval()
    for (bert_batch, graph_batch1, graph_batch2) in eval_tqdm:
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)

        with torch.no_grad():
            bert_loss, graph_loss, bimodal_loss = model(bert_batch, graph_batch1, graph_batch2)

        loss = alpha * bert_loss + beta * graph_loss + gamma * bimodal_loss

        bert_loss_sum += bert_loss.item()
        graph_model_loss_sum += graph_loss.item()
        bimodal_loss_sum += bimodal_loss.item()
        loss_sum += loss.item()

        eval_tqdm.set_postfix(loss=loss.item(), bert_loss=bert_loss.item(), graph_loss=graph_loss.item(), bimodal_loss=bimodal_loss.item())
    return bert_loss_sum / len(eval_dataloader), graph_model_loss_sum / len(eval_dataloader), bimodal_loss_sum / len(eval_dataloader), loss_sum / len(eval_dataloader)

In [25]:
bert_loss, graph_loss, bimodal_loss, loss = eval_loop()

Epoch 0: 100%|███████████████████████████| 31/31 [00:08<00:00,  3.67batch/s, bert_loss=10.5, bimodal_loss=1.07, graph_loss=4.14, loss=15.7]


In [26]:
print('bert_loss =', bert_loss)
print('graph_loss = ', graph_loss)
print('bimodal_loss =', bimodal_loss)
print('sum of losses =', loss)

bert_loss = 10.479248785203502
graph_loss =  4.141095822857272
bimodal_loss = 1.068702132471146
sum of losses = 15.689046705922772


In [27]:
from datetime import datetime

model_checkpoints_folder = os.path.join('ckpts')
dir_name = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(model_checkpoints_folder, dir_name)
_save_config_file(config, log_dir)

In [28]:
n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf

for epoch_counter in range(num_epoch):
    bert_loss, graph_loss, bimodal_loss, loss = train_loop()

    wandb.log({"bert_loss/train": bert_loss}, step=epoch_counter)
    wandb.log({"graph_loss/train": graph_loss}, step=epoch_counter)
    wandb.log({"bimodal_loss/train": bimodal_loss}, step=epoch_counter)
    wandb.log({"loss/train": loss}, step=epoch_counter)

    bert_loss, graph_loss, bimodal_loss, loss = eval_loop()

    wandb.log({"bert_loss/eval": bert_loss}, step=epoch_counter)
    wandb.log({"graph_loss/eval": graph_loss}, step=epoch_counter)
    wandb.log({"bimodal_loss/eval": bimodal_loss}, step=epoch_counter)
    wandb.log({"loss/eval": loss}, step=epoch_counter)
    
    if loss < best_valid_loss:
        best_valid_loss = loss
        torch.save(model.state_dict(), os.path.join(log_dir, 'model.pth'))
    
    if (epoch_counter + 1) % config['save_every_n_epochs'] == 0:
        torch.save(model.state_dict(), os.path.join(log_dir, 'model_{}.pth'.format(str(epoch_counter))))

    # # warmup for the first few epochs
    # if epoch_counter >= config['warm_up']:
        # wandb.log({"cosine_lr_decay": scheduler.get_last_lr()[0]}, step=epoch_counter)
        # scheduler.step()

Epoch 0: 100%|█████████████████████| 281/281 [02:40<00:00,  1.75batch/s, bert_loss=0.403, bimodal_loss=0.00326, graph_loss=1.43, loss=1.83]
Epoch 0: 100%|███████████████████████| 31/31 [00:07<00:00,  4.28batch/s, bert_loss=0.374, bimodal_loss=0.000325, graph_loss=1.33, loss=1.7]
Epoch 1: 100%|█████████████████████| 281/281 [02:40<00:00,  1.75batch/s, bert_loss=0.356, bimodal_loss=0.00301, graph_loss=1.01, loss=1.37]
Epoch 1: 100%|███████████████████████| 31/31 [00:07<00:00,  4.32batch/s, bert_loss=0.366, bimodal_loss=9.13e-6, graph_loss=1.14, loss=1.51]
Epoch 2: 100%|█████████████████████| 281/281 [02:40<00:00,  1.76batch/s, bert_loss=0.383, bimodal_loss=0.00281, graph_loss=1.42, loss=1.81]
Epoch 2: 100%|███████████████████████| 31/31 [00:07<00:00,  4.17batch/s, bert_loss=0.402, bimodal_loss=6.76e-6, graph_loss=1.44, loss=1.84]
Epoch 3: 100%|████████████████████| 281/281 [02:40<00:00,  1.75batch/s, bert_loss=0.396, bimodal_loss=0.00269, graph_loss=0.561, loss=0.96]
Epoch 3: 100%|██████

In [29]:
wandb.finish()

VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
bert_loss/eval,▇▇▆▅▆█▅▃▃▄▃▆▄▆▄▄▄▂▁▂▃▄▂▃▃▄▂▃▁▃▂▂▂▃▂▃▂▄▃▂
bert_loss/train,█▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
bimodal_loss/eval,█▁▁▁▁▁▁▁▁▁▂▁▁▁▂▂▁▁▁▁▁▂▁▁▁▂▁▁▂▁▂▁▂▁▁▁▁▁▁▁
bimodal_loss/train,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
graph_loss/eval,█▆▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
graph_loss/train,█▅▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval,█▆▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/train,█▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
bert_loss/eval,0.39652
bert_loss/train,0.39203
bimodal_loss/eval,0.0
bimodal_loss/train,0.00285
graph_loss/eval,0.17174
graph_loss/train,0.14427
loss/eval,0.56826
loss/train,0.53914


In [30]:
# bert_batch = bert_batch.to('cpu')
# graph_batch1 = graph_batch1.to('cpu')
# graph_batch2 = graph_batch2.to('cpu')
# del bert_batch, graph_batch1, graph_batch2
# torch.cuda.empty_cache()