In [1]:
%cd MolScribe

/notebooks/MolScribe


In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import timm
import torch
from torch import nn
import cv2
from torch.utils.data import DataLoader, Dataset
import wandb
import pytorch_lightning as pl
from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from sklearn.model_selection import train_test_split
import torchvision
from rdkit import RDLogger,Chem
from rdkit.Chem import AllChem,DataStructs
import torchvision
import torch.nn.functional as F
import os
import argparse
from rdkit.DataStructs import TanimotoSimilarity
from Levenshtein import distance as levenshtein_distance
from molscribe.dataset import TrainDataset, AuxTrainDataset, bms_collate
from molscribe.model import Encoder, Decoder
from molscribe.loss import Criterion
from molscribe.chemistry import convert_graph_to_smiles, postprocess_smiles, keep_main_molecule
from molscribe.tokenizer import get_tokenizer
pl.seed_everything(56)

Seed set to 56


56

In [3]:
%cd ../

/notebooks


In [4]:
class CFG:
    wandb=True
    ckpt_path='zinc_2-3m.csv'
    train_path = 'zinc_6-7m.csv'
    aux_path = './uspto_mol/train_680k.csv'
    val_df = 'train.csv'
    betas=(0.9, 0.999)
    img_size = 384
    max_pred_len = 128
    val_split_size = 0.2
    scheduler = None
    emb_dim = 512  
    attention_dim = 512
    freq_threshold = 2
    decoder_dim = 512
    img_size=512
    dropout = 0.4
    eps=1e-6
    num_workers = 12
    batch_size = 64
    lr=2e-5
    encoder_lr = 2e-5
    decoder_lr = 3e-5
    weight_decay = 0.01
    fine_tune_encoder = False
    max_epoches=6
    seed=56

In [5]:
class TrainARGS:
    formats = ['chartok_coords','edges']
    input_size = 384
    save_image = False
    mol_augment = None
    default_option = None
    shuffle_nodes = False
    include_condensed = None
    vocab_file = 'vocab_chars.json'
    save_path='./saved_images'
    coord_bins = 128
    dynamic_indigo = True
    sep_xy = True
    continuous_coords = None
    data_path = ''
    augment = None
    coords_file = None
    pseudo_coords = None
    predict_coords = None
    encoder = 'swin_base'
    decoder = 'transformer'
    use_checkpoint = False
    encoder_dim = 1024
    dropout = 0.4
    embed_dim = 256
    decoder_dim = 512
    decoder_layer = 1
    attention_dim = 256
    dec_num_layers = 6
    dec_hidden_size = 256
    dec_attn_heads = 8
    dec_num_queries = 128
    hidden_dropout = 0.1
    attn_dropout = 0.1
    max_relative_positions = 0
    save_path = './saved_images'
    mask_ratio = 0.0
    label_smoothing = 0.0
    compute_confidence = None
    enc_pos_emb = None

In [6]:
def get_args(args_states=None):
        parser = argparse.ArgumentParser()
        # Model
        parser.add_argument('--encoder', type=str, default='swin_base')
        parser.add_argument('--decoder', type=str, default='transformer')
        parser.add_argument('--trunc_encoder', action='store_true')  # use the hidden states before downsample
        parser.add_argument('--no_pretrained', action='store_true')
        parser.add_argument('--use_checkpoint', action='store_true', default=True)
        parser.add_argument('--dropout', type=float, default=0.5)
        parser.add_argument('--embed_dim', type=int, default=256)
        parser.add_argument('--enc_pos_emb', action='store_true')
        group = parser.add_argument_group("transformer_options")
        group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6)
        group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256)
        group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8)
        group.add_argument("--dec_num_queries", type=int, default=128)
        group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1)
        group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1)
        group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0)
        parser.add_argument('--continuous_coords', action='store_true')
        parser.add_argument('--compute_confidence', action='store_true')
        # Data
        parser.add_argument('--input_size', type=int, default=384)
        parser.add_argument('--vocab_file', type=str, default=None)
        parser.add_argument('--coord_bins', type=int, default=64)
        parser.add_argument('--sep_xy', action='store_true', default=True)

        args = parser.parse_args([])
        if args_states:
            for key, value in args_states.items():
                args.__dict__[key] = value
        return args

In [7]:
class ValDataset(Dataset):
    def __init__(self,df,transforms,img_dir="./train/"):
        self.df = df
        self.img_dir = img_dir
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        smiles = self.df.iloc[idx]['smiles']
        path = self.img_dir + str(self.df.iloc[idx]['id']) +'.png'
        img = self._read_image(path)
        img = self.transforms(image=img,keypoints=[])['image']
        return img,smiles
    
    def _read_image(self,path):
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image

In [8]:
class PLDataModule(pl.LightningDataModule):
    def __init__(self,tokenizer):
        super().__init__()
        self.cfg = CFG()
        self.is_setup = False
        self.tokenizer = tokenizer
        
    def prepare_data(self):
        self.train_df = pd.read_csv(CFG.train_path).reset_index(drop=True)
        self.train_df['SMILES'] = self.train_df['smiles']
        self.aux_df = pd.read_csv(CFG.aux_path)[:300_000].reset_index(drop=True)
        self.val_df = pd.read_csv(CFG.val_df)
        
    def setup(self, stage: str):
        #self.train_df, self.val_df = train_test_split(self.train_data, test_size=self.cfg.val_split_size,random_state=self.cfg.seed)
        #self.train_df = self.train_df.reset_index(drop=True)
        #self.val_df = self.val_df.reset_index(drop=True)
        self.train_dataset = AuxTrainDataset(TrainARGS,self.train_df,self.aux_df,self.tokenizer)
        self.val_dataset = ValDataset(self.val_df,self.train_dataset.train_dataset.transform)
        self.is_setup = True
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.cfg.batch_size,
                          num_workers=self.cfg.num_workers,
                          collate_fn=bms_collate,
                          pin_memory=True,
                          shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.cfg.batch_size,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True,
                          shuffle=False)

In [9]:
def char_accuracy(y_p,y):
    y_p,y = list(y_p),list(y)
    ln = min(len(y_p),len(y))
    score = 0
    for i in range(ln):
        if y_p[i] == y[i]:
            score += 1
    return score / max(len(y_p),len(y))

def accuracy(y_p,y):
    if y_p == y:
        return 1
    return 0

def correct_part(y_p):
    if Chem.MolFromSmiles(y_p) is None:
        return 0
    else:
        return 1

def tanimoto(y_p,y):
    try:
        mol1 = Chem.MolFromSmiles(y_p)
        mol2 = Chem.MolFromSmiles(y)
    
        vec_1 = AllChem.RDKFingerprint(mol1)
        vec_2 = AllChem.RDKFingerprint(mol2)
        return DataStructs.TanimotoSimilarity(vec_1,vec_2)
    except:
        return 0

In [10]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.preds = []
        self.targets = []

    def update(self,preds,targets):
        self.preds += preds
        self.targets += targets
    
    def calc_metrics(self):
        f = dict()
        
        f['char_acc'] = np.mean([char_accuracy(x,y) for x,y in zip(self.preds,self.targets)])
        f['acc'] = np.mean([accuracy(x,y) for x,y in zip(self.preds,self.targets)])
        f['corrent_part'] = np.mean([correct_part(x) for x in self.preds])
        f['tanimoto'] = np.mean([tanimoto(x,y) for x,y in zip(self.preds,self.targets)])
        
        return f

In [11]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.preds = []
        self.targets = []

    def update(self,preds,targets):
        self.preds += preds
        self.targets += targets
    
    def calc_metrics(self):
        f = dict()
        
        f['char_acc'] = np.mean([char_accuracy(x,y) for x,y in zip(self.preds,self.targets)])
        f['acc'] = np.mean([accuracy(x,y) for x,y in zip(self.preds,self.targets)])
        f['corrent_part'] = np.mean([correct_part(x) for x in self.preds])
        f['tanimoto'] = np.mean([tanimoto(x,y) for x,y in zip(self.preds,self.targets)])
        
        return f

In [12]:
class PLModule(pl.LightningModule):
    def __init__(self,encoder,decoder,tokenizer):
        super().__init__()
        self.cfg = CFG()
        self.avg_meter = AverageMeter()
        self.encoder = encoder
        self.decoder = decoder
        self.criterion = Criterion(TrainARGS,tokenizer)
        self.tokenizer = tokenizer
        
    def forward(self,images,refs):
        features, hiddens = encoder(images, refs)
        results = decoder(features, hiddens, refs)
        return sum(self.criterion(results, refs).values())
   

    def training_step(self, batch, _):
        _,images,refs = batch
        loss = self(images,refs)
        self.log_dict({'train_loss':loss.item()})
        return loss
        
    
    def fvalidation_step(self, batch, _):
        image,labels = batch
        
        features, hiddens = self.encoder(image)
        preds = self.decoder.decode(features, hiddens)
        smiles = [pred['chartok_coords']['smiles'] for pred in preds]
        node_coords = [pred['chartok_coords']['coords'] for pred in preds]
        node_symbols = [pred['chartok_coords']['symbols'] for pred in preds]
        edges = [pred['edges'] for pred in preds]#
    
        smiles_list, molblock_list, r_success = convert_graph_to_smiles(
            node_coords, node_symbols, edges, images=image.cpu().detach().numpy())
        
        self.avg_meter.update(smiles_list,labels)
    
    def predict_step(self,batch,_):
        
        logits = self.model.generate(
            batch['input_ids'],
            num_beams=1,
            max_length=128
        )
        
        logits = logits.detach().cpu().numpy()
        logits = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in logits]
        return logits
    
    def on_validation_epoch_end(self):
        f = self.avg_meter.calc_metrics()
        self.log_dict(f)
        print(f)
        self.avg_meter.reset()
            
    def configure_optimizers(self):
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.encoder.named_parameters()],
                "lr":self.cfg.lr
            },
            {
                "params": [p for n, p in self.decoder.named_parameters()],
                "lr":self.cfg.lr
            },
        ]
        return torch.optim.AdamW(optimizer_grouped_parameters,
                                 betas=self.cfg.betas,
                                 weight_decay=self.cfg.weight_decay,
                                 eps=self.cfg.eps)

In [13]:
tokenizer = get_tokenizer(TrainARGS)

In [14]:
dm = PLDataModule(tokenizer)
dm.prepare_data()
dm.setup(0)

In [15]:
ckpt_path = hf_hub_download('yujieq/MolScribe', 'swin_base_char_aux_1m.pth')

swin_base_char_aux_1m.pth:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

In [16]:
chekpoint = torch.load('kagglev2.pth')

In [17]:
args = get_args(chekpoint['args'])

In [18]:
encoder = Encoder(args)
args.encoder_dim = encoder.n_features
decoder = Decoder(args,tokenizer)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [19]:
encoder = Encoder(args)
args.encoder_dim = encoder.n_features
decoder = Decoder(args,tokenizer)

In [20]:
def safe_load(module, module_states):
    def remove_prefix(state_dict):
        return {k.replace('module.', ''): v for k, v in state_dict.items()}
    module.load_state_dict(remove_prefix(module_states))
    return

In [21]:
safe_load(encoder, chekpoint['encoder'])

In [22]:
safe_load(decoder, chekpoint['decoder'])

In [23]:
model_pl = PLModule(encoder,decoder,tokenizer).cuda()

In [24]:
wandb.login(key="673ae6e9b51cc896110db5327738b993795fffad")
os.environ['WANDB_API_KEY'] = "673ae6e9b51cc896110db5327738b993795fffad"
wandb.init(project='MOLECULA',name='Molscribe_Small_ZinC_5-6m')

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Currently logged in as: [33mandrey20007[0m ([33mandrey2007[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [25]:
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
checkpoint_cb = pl.callbacks.ModelCheckpoint(
    dirpath='./outputs/',
    filename='base_model{epoch:02d}',
    monitor='tanimoto',
    mode='max',
    save_last=True
)

trainer = pl.Trainer(
    accelerator="gpu",
    precision=32,
    callbacks = [lr_monitor,checkpoint_cb],
    logger = pl.loggers.WandbLogger(),
    min_epochs=1,
    devices=[0],
    check_val_every_n_epoch=1,
    max_epochs=CFG.max_epoches
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [26]:
trainer.fit(model_pl,datamodule=dm,ckpt_path="outputs/last-v5.ckpt")

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory /notebooks/outputs exists and is not empty.
Restoring states from the checkpoi

Training: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [57]:
torch.save(model_pl.state_dict(),'molsc.pth')

In [27]:
state_dict = dict()
state_dict['encoder'] = encoder.state_dict()
state_dict['decoder'] = decoder.state_dict()
state_dict['args'] = chekpoint['args']

In [28]:
torch.save(state_dict,'kagglev2.pth')

In [40]:
torch.save(model_pl.state_dict(),'create.pth')

In [21]:
model_pl.load_state_dict(torch.load('create.pth'))

<All keys matched successfully>

In [28]:
!nvidia-smi

Sun Jan  7 21:44:44 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:05.0 Off |                    0 |
| N/A   32C    P0    80W / 400W |   3409MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [29]:
trainer.save_checkpoint('chep.ckpt')

[rank: 0] Received SIGTERM: 15


In [None]:
###### Trained
500_000 - 900_000 ChEMBL