In [14]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import timm
import torch
from torch import nn
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch.backends.cudnn as cudnn
import wandb
import pytorch_lightning as pl
from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence
from transformers import AutoTokenizer,T5ForConditionalGeneration
from sklearn.model_selection import train_test_split
from rdkit import RDLogger,Chem
from rdkit.Chem import AllChem,DataStructs
import torch.nn.functional as F
import os
from rdkit.DataStructs import TanimotoSimilarity
from Levenshtein import distance as levenshtein_distance
pl.seed_everything(56)

Seed set to 56


56

In [15]:
RDLogger.DisableLog('rdApp.*')
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [18]:
class CFG:
    wandb=True
    model='sagawa/CompoundT5'
    train_path = 'train_t5_df.csv'
    train_folder = './train/'
    betas=(0.9, 0.999)
    img_size = 512
    max_pred_len = 128
    val_split_size = 0.2
    scheduler = None
    eps=1e-6
    num_workers = 12
    batch_size = 64
    lr=2e-5
    weight_decay = 0.01
    fine_tune_encoder = False
    max_epoches=20
    seed=56

In [19]:
class PLDataset(Dataset):
    def __init__(self, df, tokenizer):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        input_l = self.df.iloc[idx]['refs']
        label = self.df.iloc[idx]['smiles']
        input_enc = self.tokenizer.encode_plus(input_l, padding='max_length',max_length=128, truncation=True, return_tensors='pt')
        label_enc = self.tokenizer.encode_plus(label, padding='max_length',max_length=128, truncation=True, return_tensors='pt')
        return {'input_ids':input_enc.input_ids.squeeze(0),
                'attention_mask':input_enc.attention_mask.squeeze(0),
                'label':label_enc.input_ids.squeeze(0),}


In [20]:
class PLDataModule(pl.LightningDataModule):
    def __init__(self,tokenizer):
        super().__init__()
        self.cfg = CFG()
        self.is_setup = False
        self.tokenizer = tokenizer
        
    def prepare_data(self):
        self.train_data = pd.read_csv(CFG.train_path)
        
    def setup(self, stage: str):
        self.train_df, self.val_df = train_test_split(self.train_data, test_size=self.cfg.val_split_size,random_state=self.cfg.seed)
        self.train_df = self.train_df.reset_index(drop=True)
        self.val_df = self.val_df.reset_index(drop=True)
        self.train_dataset = PLDataset(self.train_df,self.tokenizer)
        self.val_dataset = PLDataset(self.val_df,self.tokenizer)
        self.is_setup = True
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.cfg.batch_size,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True,
                          shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.cfg.batch_size,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True,
                          shuffle=False)

In [21]:
def char_accuracy(y_p,y):
    y_p,y = list(y_p),list(y)
    ln = min(len(y_p),len(y))
    score = 0
    for i in range(ln):
        if y_p[i] == y[i]:
            score += 1
    return score / max(len(y_p),len(y))

def accuracy(y_p,y):
    if y_p == y:
        return 1
    return 0

def correct_part(y_p):
    if Chem.MolFromSmiles(y_p) is None:
        return 0
    else:
        return 1

def tanimoto(y_p,y):
    try:
        mol1 = Chem.MolFromSmiles(y_p)
        mol2 = Chem.MolFromSmiles(y)
    
        vec_1 = AllChem.RDKFingerprint(mol1)
        vec_2 = AllChem.RDKFingerprint(mol2)
        return DataStructs.TanimotoSimilarity(vec_1,vec_2)
    except:
        return 0

In [22]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.preds = []
        self.targets = []

    def update(self,preds,targets):
        self.preds += preds
        self.targets += targets
    
    def calc_metrics(self):
        f = dict()
        
        f['char_acc'] = np.mean([char_accuracy(x,y) for x,y in zip(self.preds,self.targets)])
        f['acc'] = np.mean([accuracy(x,y) for x,y in zip(self.preds,self.targets)])
        f['corrent_part'] = np.mean([correct_part(x) for x in self.preds])
        f['tanimoto'] = np.mean([tanimoto(x,y) for x,y in zip(self.preds,self.targets)])
        
        return f

In [23]:
class PLModule(pl.LightningModule):
    def __init__(self,model,tokenizer):
        super().__init__()
        self.cfg = CFG()
        self.avg_meter = AverageMeter()
        self.model = model
        self.tokenizer = tokenizer
        
    def forward(self,input_ids,attention_mask,label):
        return self.model(input_ids=input_ids,attention_mask=attention_mask,labels=label)   

    def training_step(self, batch, _):
        loss = self(**batch).loss
        self.log_dict({'train_loss':loss.item()})
        return loss
        
    def validation_step(self, batch, _):
        labels = batch['input_ids'].detach().cpu().numpy()
        labels = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in labels]
        
        logits = self.model.generate(
            batch['input_ids'],
            num_beams=1,
            max_length=128
        )
        
        logits = logits.detach().cpu().numpy()
        logits = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in logits]
        self.avg_meter.update(logits,labels)
    
    def predict_step(self,batch,_):
        
        logits = self.model.generate(
            batch['input_ids'],
            num_beams=1,
            max_length=128
        )
        
        logits = logits.detach().cpu().numpy()
        logits = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in logits]
        return logits
    
    def on_validation_epoch_end(self):
        f = self.avg_meter.calc_metrics()
        self.log_dict(f)
        print(f)
        self.avg_meter.reset()
            
    def configure_optimizers(self):
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters()],
                "lr":self.cfg.lr
            },
        ]
        return torch.optim.AdamW(optimizer_grouped_parameters,
                                 betas=self.cfg.betas,
                                 weight_decay=self.cfg.weight_decay,
                                 eps=self.cfg.eps)

In [24]:
tokenizer = AutoTokenizer.from_pretrained(CFG.model)
model = T5ForConditionalGeneration.from_pretrained(CFG.model)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [25]:
avg_meter = AverageMeter()
avg_meter.update(dm.val_df['refs'].tolist(),dm.val_df['smiles'].tolist())
avg_meter.calc_metrics()

{'char_acc': 0.8366820644254979,
 'acc': 0.73865,
 'corrent_part': 0.9897,
 'tanimoto': 0.9862146335034755}

In [26]:
model_pl = PLModule(model,tokenizer).cuda()

In [27]:
dm = PLDataModule(tokenizer)
dm.prepare_data()
dm.setup(0)

In [28]:
wandb.login(key="673ae6e9b51cc896110db5327738b993795fffad")
os.environ['WANDB_API_KEY'] = "673ae6e9b51cc896110db5327738b993795fffad"
wandb.init(project='MOLECULA',name='T5_Molscribe')

[34m[1mwandb[0m: Currently logged in as: [33mandrey20007[0m ([33mandrey2007[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [29]:
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
checkpoint_cb = pl.callbacks.ModelCheckpoint(
    dirpath='./outputs/',
    filename='base_model{epoch:02d}',
    monitor='tanimoto',
    mode='max',
    save_last=True
)

trainer = pl.Trainer(
    accelerator="gpu",
    precision=32,
    callbacks = [lr_monitor,checkpoint_cb],
    logger = pl.loggers.WandbLogger(),
    min_epochs=1,
    devices=[0],
    check_val_every_n_epoch=1,
    max_epochs=CFG.max_epoches
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model_pl,datamodule=dm,ckpt_path="outputs/last.ckpt")

You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory /notebooks/outputs exists and is not empty.
Restoring states from the checkpoint path at outputs/last.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | m

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

{'char_acc': 0.8726232447165678, 'acc': 0.796875, 'corrent_part': 0.984375, 'tanimoto': 0.8978899581373574}


terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Exception raised from getDevice at ../c10/cuda/impl/CUDAGuardImpl.h:39 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x3e (0x7f5a1885f20e in /usr/local/lib/python3.9/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x14660d (0x7f59e68a860d in /usr/local/lib/python3.9/dist-packages/torch/lib/libtorch_cuda_cpp.so)
frame #2: <unknown function> + 0x149a9e (0x7f59e68aba9e in /usr/local/lib/python3.9/dist-packages/torch/lib/libtorch_cuda_cpp.so)
frame #3: <unknown function> + 0x466778 (0x7f5a18e40778 in /usr/local/lib/python3.9/dist-packages/torch/lib/libtorch_python.so)
frame #4: c10::TensorImpl::release_resources() + 0x175 (0x7f5a188467a5 in /usr/local/

Training: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _releaseLock at 0x7f5a6d262d30>
Traceback (most recent call last):
  File "/usr/lib/python3.9/logging/__init__.py", line 227, in _releaseLock
    def _releaseLock():
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 7321) is killed by signal: Aborted. 


Validation: |          | 0/? [00:00<?, ?it/s]

{'char_acc': 0.8397013068764191, 'acc': 0.7376, 'corrent_part': 0.9512, 'tanimoto': 0.837901101010471}


Validation: |          | 0/? [00:00<?, ?it/s]

{'char_acc': 0.8411512589570728, 'acc': 0.7398, 'corrent_part': 0.9515, 'tanimoto': 0.8386412060402542}
{'char_acc': 0.8409368748652861, 'acc': 0.7402, 'corrent_part': 0.95245, 'tanimoto': 0.8392765023845196}


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

