In [1]:
import torch
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import wandb

In [2]:
!python3 -m wandb login eb7b1964fb84cd81de96b2a273ecf2bb6254aeac

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/alexeyorlov53/.netrc


In [3]:
device = torch.device("cuda", index=4) if torch.cuda.is_available() else torch.device('cpu')

In [100]:
batch_size = 64

In [32]:
model_name_bert = 'molberto_ecfp0_2M'

### Upload config

In [116]:
import yaml

config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
print(config)

{'batch_size': 512, 'warm_up': 10, 'epochs': 100, 'load_model': 'None', 'eval_every_n_epochs': 1, 'save_every_n_epochs': 5, 'log_every_n_steps': 50, 'fp16_precision': False, 'init_lr': 0.0005, 'weight_decay': '1e-5', 'gpu': 'cuda:0', 'model_type': 'gin', 'model': {'num_layer': 5, 'emb_dim': 300, 'feat_dim': 768, 'drop_ratio': 0, 'pool': 'mean'}, 'aug': 'node', 'dataset': {'num_workers': 12, 'valid_size': 0.05, 'data_path': 'data/pubchem-10m-clean.txt'}, 'loss': {'temperature': 0.1, 'use_cosine_similarity': True}}


### Upload and Split Dataset

In [61]:
dataframe = pd.read_csv("data_10k.csv")

In [62]:
dataframe = dataframe.drop(columns=['ecfp2', 'ecfp3', 'Molecular Weight', 'Bioactivities', 'AlogP', 'Polar Surface Area', 'CX Acidic pKa', 'CX Basic pKa'])

In [69]:
dataframe

Unnamed: 0,Smiles,ecfp1
0,COc1cc(C2(C)CCCc3nc(SCc4ncccn4)n(-c4ccc(F)cc4)...,2246728737 864674487 3217380708 3218693969 321...
1,COC(=O)c1sc(NC(=O)C2c3ccccc3Oc3ccccc32)c(C(=O)...,2246728737 864674487 2246699815 864942730 3217...
2,CC[C@H]1OC(=O)C[C@@H](O)[C@H](C)[C@@H](O[C@@H]...,2246728737 2245384272 2976033787 3189457552 32...
3,Cc1cccc(-n2cc(C(=O)N3CCC[C@@H]([n+]4cc[nH]c4)C...,2246728737 3217380708 3218693969 3218693969 32...
4,CCOC(=O)[C@H](C1CC1)N1C(=O)[C@@H](CC(=O)O)C[C@...,2246728737 2245384272 864674487 2246699815 864...
...,...,...
9995,CCN1CCN(CC(O)c2ccc(Br)cc2)CC1,2246728737 2245384272 2092489639 2968968094 29...
9996,O=C(O)CNC(=O)CNC(=O)CNC(=O)CSC(=O)c1ccccc1,864942730 2246699815 864662311 2245384272 8479...
9997,O=C(N[C@@]12CCC[C@@](C#Cc3ccccn3)(CC1)C2)c1ccc...,864942730 2246699815 847961216 2976816164 2968...
9998,CCOc1ccccc1-c1cc(C(=O)N2CCOCC2)c2ccccc2n1,2246728737 2245384272 864674487 3217380708 321...


In [64]:
# this because pandas thinks columns with arrays are strings
def preprocess_data_dataset(df, column):
    for row in tqdm(range(len(df))):
        str_ints = eval(df.iloc[row][column])
        str_fingerprint = ' '.join(str_ints)
        df.at[row, column] = str_fingerprint

In [65]:
preprocess_data_dataset(dataframe, 'ecfp1')

  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
from transformers import DataCollatorWithPadding
def CustomDataCollator():
    pass


In [36]:
from rdkit import Chem

ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
    Chem.rdchem.BondType.SINGLE, 
    Chem.rdchem.BondType.DOUBLE, 
    Chem.rdchem.BondType.TRIPLE, 
    Chem.rdchem.BondType.AROMATIC
]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]

In [91]:
import random
import math
from copy import deepcopy
import torch
from torch_geometric.data import Data, Dataset

class MoleculeDataset(Dataset):
    def __init__(self, dataset: pd.DataFrame, tokenizer, node_mask_percent=0.25, edge_mask_percent=0.25):
        super(Dataset, self).__init__()
        self.dataset = dataset
        self.node_mask_percent = node_mask_percent
        self.edge_mask_percent = edge_mask_percent

        self.tokenizer = tokenizer
        self.tokenizer.model_max_len = 512

    def get_graph_from_smiles(self, smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return torch.tensor([[], []], dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    0
    
        N = mol.GetNumAtoms()
        M = mol.GetNumBonds()
    
        type_idx = []
        chirality_idx = []
        atomic_number = []
        
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())
        
        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
        node_feat = torch.cat([x1, x2], dim=-1)
    
        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
    
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.long)
        num_nodes = N
        num_edges = M
        return node_feat, edge_index, edge_attr, num_nodes, num_edges

    def get_augmented_graph_copy(self, node_feat, edge_index, edge_attr, N, M):
        num_mask_nodes = max([1, math.floor(self.node_mask_percent * N)])
        num_mask_edges = max([0, math.floor(self.edge_mask_percent * M)])
        
        mask_nodes = random.sample(list(range(N)), num_mask_nodes)
        mask_edges_single = random.sample(list(range(M)), num_mask_edges)
        mask_edges = [2*i for i in mask_edges_single] + [2*i+1 for i in mask_edges_single]

        node_feat_new = deepcopy(node_feat)
        for atom_idx in mask_nodes:
            node_feat_new[atom_idx, :] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_new = torch.zeros((2, 2*(M - num_mask_edges)), dtype=torch.long)
        edge_attr_new = torch.zeros((2*(M - num_mask_edges), 2), dtype=torch.long)
        count = 0
        for bond_idx in range(2*M):
            if bond_idx not in mask_edges:
                edge_index_new[:, count] = edge_index[:, bond_idx]
                edge_attr_new[count, :] = edge_attr[bond_idx, :]
                count += 1
        return Data(x=node_feat_new, edge_index=edge_index_new, edge_attr=edge_attr_new)

    def tokenize(self, item):
        return self.tokenizer(item, truncation=True, max_length=512, padding='max_length')

    def mlm(self, tensor):
        rand = torch.rand(tensor.shape)
        # mask random 15% where token is not 0 <s>, 1 <pad>, or 2 <s/>
        mask_arr = (rand < .15) * (tensor != 0) * (tensor != 1) * (tensor != 2)
        selection = torch.flatten(mask_arr.nonzero()).tolist()
        # mask tensor, token == 4 is our mask token
        tensor[selection] = 4
        return tensor

    def apply_mlm(self, sample):
        labels = torch.tensor(sample.input_ids)
        mask = torch.tensor(sample.attention_mask)
        input_ids = self.mlm(labels.detach().clone())
        return Data(labels=labels, mask=mask, input_ids=input_ids)

    def __getitem__(self, index):
        node_feat, edge_index, edge_attr, num_nodes, num_edges = self.get_graph_from_smiles(self.dataset['Smiles'][index])

        data_i = self.get_augmented_graph_copy(node_feat, edge_index, edge_attr, num_nodes, num_edges)
        data_j = self.get_augmented_graph_copy(node_feat, edge_index, edge_attr, num_nodes, num_edges)

        ecfp = self.dataset['ecfp1'][index]
        data_for_bert = self.apply_mlm(self.tokenize(ecfp))
        return data_for_bert, data_i, data_j

    def __len__(self):
        return len(self.dataset)

In [95]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name_bert)
dataset = MoleculeDataset(dataframe, tokenizer)

In [104]:
from torch_geometric.loader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_train = len(dataset)
indices = list(range(num_train))
np.random.shuffle(indices)

split = int(np.floor(config['dataset']['valid_size'] * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_dataloader = DataLoader(
    dataset, batch_size = batch_size, sampler=valid_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)

eval_dataloader = DataLoader(
    dataset, batch_size = batch_size, sampler=valid_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)

### Create Transformer Model

In [114]:
from transformers import RobertaForMaskedLM
from transformers import RobertaConfig

if config['model_type'] == 'gin':
    from models.ginet_molclr import GINet as GraphModel
elif config['model_type'] == 'gcn':
    from models.gcn_molclr import GCN as GraphModel
from MolCLR.utils.nt_xent import NTXentLoss

class MolecularBertGraph(torch.nn.Module):
    def __init__(self):
        super(MolecularPropertiesClassification, self).__init__()

        config = RobertaConfig(
            vocab_size=30_522,
            max_position_embeddings=514,
            hidden_size=768,
            num_attention_heads=12,
            num_hidden_layers=6,
            type_vocab_size=1
        )
        self.bert = RobertaForMaskedLM(config)
        
        self.graph_model = GraphModel(**config["model"]).to(self.device)
        # self.graph_model = self._load_pre_trained_weights(self.graph_model)

        self.out_graph_linear = torch.nn.Linear(768 * 2, 768, bias=True)

        # contrastive loss for MolCLR
        self.nt_xent_criterion = NTXentLoss(device, config['batch_size'], **config['loss'])

    def forward(self, bert_batch, graph_batch1, graph_batch2):
        bert_output = self.bert(input_ids=bert_batch['input_ids'], 
                                 attention_mask=bert_batch['attention_mask'],
                                 labels=bert_batch['labels'])
        bert_loss = bert_output.loss
        bert_emb = bert_output.hidden_states[0]

        graph_loss, hidden_states_1, hidden_states_2 = self.graph_step(graph_batch1, graph_batch2)
        graph_emb = self.out_graph_linear(torch.cat(hidden_states_1, hidden_states_2, dim=-1))
        
        
        # first_linear_out = self.linear1( \
        #     torch.cat((last_hidden_state1[:, 0, : ], last_hidden_state2[:, 0, : ]), dim=-1).view(-1, 2 * 768))

        return bert_loss, bert_emb, graph_loss, graph_emb

    def graph_step(self, xis, xjs):
        # get the representations and the projections
        ris, zis = self.graph_model(xis)  # [N,C]
    
        # get the representations and the projections
        rjs, zjs = self.graph_model(xjs)  # [N,C]
    
        # normalize projection feature vectors
        zis = torch.nn.functional.normalize(zis, dim=1)
        zjs = torch.nn.functional.normalize(zjs, dim=1)
    
        loss = self.nt_xent_criterion(zis, zjs)
        return loss, ris, rjs
        

TypeError: 'RobertaConfig' object is not subscriptable

In [None]:
model = MolecularBertGraph().to(device)

In [None]:
model

### Define utils

In [None]:
from transformers import AdamW, get_scheduler

num_epoch = 1

optimizer = torch.optim.Adam(
    model.parameters(), config['init_lr'], 
    weight_decay=eval(config['weight_decay'])
)
scheduler = CosineAnnealingLR(
    optimizer, T_max=config['epochs']-config['warm_up'], 
    eta_min=0, last_epoch=-1
)

In [None]:
wandb.init(
    project="efcp_transformer",
    name="RobertaForMaskedLM + MolCLR (GCN)",
    config={}
)

### Training

In [None]:
# model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
# _save_config_file(model_checkpoints_folder)

n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf

for epoch_counter in range(num_epoch): # range(config['epochs']):
    model.train()
    for batch_counter, (bert_batch, graph_batch1, graph_batch2) in enumerate(train_loader):
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)

        loss_bert, loss_graph, bimodal_loss = model(bert_batch, graph_batch1, graph_batch2)

        if n_iter % self.config['log_every_n_steps'] == 0:
            # self.writer.add_scalar('train_loss', loss, global_step=n_iter)
            # self.writer.add_scalar('cosine_lr_decay', scheduler.get_last_lr()[0], global_step=n_iter)
            print(epoch_counter, batch_counter, loss.item())

        loss.backward()

        optimizer.step()
        n_iter += 1

    # validate the model if requested
    if epoch_counter % self.config['eval_every_n_epochs'] == 0:
        model.eval()
        valid_loss = self._validate(model, valid_loader)
        print(epoch_counter, bn, valid_loss, '(validation)')
        if valid_loss < best_valid_loss:
            # save the model weights
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
    
        self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
        valid_n_iter += 1
    
    if (epoch_counter+1) % self.config['save_every_n_epochs'] == 0:
        torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model_{}.pth'.format(str(epoch_counter))))

    # warmup for the first few epochs
    if epoch_counter >= config['warm_up']:
        scheduler.step()

In [None]:
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epoch * len(eval_dataloader)))

for epoch in range(num_epoch):
    model.train()
    total_pred_labels = []
    total_true_labels = []
    epoch_loss = 0
    for batch in train_dataloader:
        input_batch = { k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask'] }
        batch['target'] = batch['target'].to(device)
        
        logits = model(**input_batch)
        
        loss = loss_func(logits.view(-1, 2), batch['target'].view(-1))
        loss.backward()
        epoch_loss += loss.item()
        
        pred_labels = torch.argmax(logits, dim=-1)
        true_labels = batch['target']
        total_pred_labels.append(pred_labels)
        total_true_labels.append(true_labels)
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar_train.update(1)

    total_pred_labels = torch.cat(total_pred_labels).cpu().detach().numpy()
    total_true_labels = torch.cat(total_true_labels).cpu().detach().numpy()
    
    wandb.log({"loss/train": epoch_loss / len(train_dataloader)}, step=epoch)
    wandb.log({"accuracy/train": accuracy_score(total_true_labels, total_pred_labels)}, step=epoch)
    wandb.log({"f1/train": f1_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
    wandb.log({"precision/train": precision_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
    wandb.log({"recall/train": recall_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)

    model.eval()
    total_pred_labels = []
    total_true_labels = []
    epoch_loss = 0
    for batch in eval_dataloader:
        input_batch = { k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask'] }
        batch['target'] = batch['target'].to(device)
        
        with torch.no_grad():
            logits = model(**input_batch)
            loss = loss_func(logits.view(-1, 2), batch['target'].view(-1))
            epoch_loss += loss.item()

            pred_labels = torch.argmax(logits, dim=-1)
            true_labels = batch['target']
            total_pred_labels.append(pred_labels)
            total_true_labels.append(true_labels)
        
        progress_bar_eval.update(1)

    total_pred_labels = torch.cat(total_pred_labels).cpu().detach().numpy()
    total_true_labels = torch.cat(total_true_labels).cpu().detach().numpy()
    
    wandb.log({"loss/validation": epoch_loss / len(eval_dataloader)}, step=epoch)
    wandb.log({"accuracy/validation": accuracy_score(total_true_labels, total_pred_labels)}, step=epoch)
    wandb.log({"f1/validation": f1_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
    wandb.log({"precision/validation": precision_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
    wandb.log({"recall/validation": recall_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)

In [None]:
wandb.finish()

In [None]:
test_dataloader = DataLoader(
    tokenized_dataset['test'], batch_size = 64, collate_fn = data_collator
)

model.eval()
total_pred_labels = []
total_true_labels = []
epoch_loss = 0
for batch in tqdm(test_dataloader):
    input_batch = { k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask'] }
    batch['target'] = batch['target'].to(device)

    with torch.no_grad():
        logits = model(**input_batch)
        loss = loss_func(logits.view(-1, 2), batch['target'].view(-1))
        epoch_loss += loss.item()

        pred_labels = torch.argmax(logits, dim=-1)
        true_labels = batch['target']
        total_pred_labels.append(pred_labels)
        total_true_labels.append(true_labels)

total_pred_labels = torch.cat(total_pred_labels).cpu().detach().numpy()
total_true_labels = torch.cat(total_true_labels).cpu().detach().numpy()

wandb.log({"loss/validation": epoch_loss / len(eval_dataloader)}, step=epoch)
wandb.log({"accuracy/validation": accuracy_score(total_true_labels, total_pred_labels)}, step=epoch)
wandb.log({"f1/validation": f1_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
wandb.log({"precision/validation": precision_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
wandb.log({"recall/validation": recall_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)

In [None]:
torch.cuda.empty_cache()