In [3]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import timm
import torch
from torch import nn
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch.backends.cudnn as cudnn
import wandb
import pytorch_lightning as pl
from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence
from transformers import (AutoProcessor,
                          AutoTokenizer,
                          VisionEncoderDecoderModel,
                          RobertaTokenizerFast,
                          TrOCRForCausalLM,
                          AutoModel,
                          TrOCRConfig,
                          ViTModel,
                          ViTConfig,
                          ViTImageProcessor
                         )
from sklearn.model_selection import train_test_split
from rdkit import RDLogger,Chem
from rdkit.Chem import AllChem,DataStructs
import torch.nn.functional as F
import os
from rdkit.DataStructs import TanimotoSimilarity
from Levenshtein import distance as levenshtein_distance
pl.seed_everything(56)

Seed set to 56


56

In [4]:
RDLogger.DisableLog('rdApp.*')
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [5]:
class CFG:
    wandb=False
    encoder="google/vit-base-patch16-384"
    decoder="entropy/roberta_zinc_480m"
    train_path = './train.csv'#'./all_ChEMBLSmiles.csv'
    train_folder = './train/'
    betas=(0.9, 0.999)
    img_size = 512
    max_pred_len = 128
    val_split_size = 0.2
    scheduler = None
    emb_dim = 512
    attention_dim = 512
    freq_threshold = 2
    decoder_dim = 512
    img_size=512
    dropout = 0.4
    eps=1e-6
    num_workers = 12
    batch_size = 64
    encoder_lr = 4e-4
    decoder_lr = 5e-4
    weight_decay = 0.01
    fine_tune_encoder = False
    max_epoches=6
    seed=56

In [6]:
class PLDataset(Dataset):
    def __init__(self, df, tokenizer,processor):
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        label = self.df.iloc[idx]['smiles']
        image = self._load_image(f'./train/{idx}.png')
        label_enc = self.tokenizer.encode_plus(label, padding='max_length',max_length=64, truncation=True, return_tensors='pt')
        return {'image':image,
                'input_ids':label_enc.input_ids.squeeze(0),
               'attention_mask':label_enc.attention_mask.squeeze(0)}
    
    def _load_image(self,path):
        img = np.array(Image.open(path).resize((384,384)))
        img = torch.from_numpy(img).unsqueeze(0) / 255
        return img

In [7]:
class PLDataModule(pl.LightningDataModule):
    def __init__(self,tokenizer,processor):
        super().__init__()
        self.cfg = CFG()
        self.is_setup = False
        self.tokenizer = tokenizer
        self.processor = processor
        
    def prepare_data(self):
        self.train_data = pd.read_csv(CFG.train_path)
        
    def setup(self, stage: str):
        self.train_df, self.val_df = train_test_split(self.train_data, test_size=self.cfg.val_split_size,random_state=self.cfg.seed)
        self.train_df = self.train_df.reset_index(drop=True)
        self.val_df = self.val_df.reset_index(drop=True)
        self.train_dataset = PLDataset(self.train_df,self.tokenizer,self.processor)
        self.val_dataset = PLDataset(self.val_df,self.tokenizer,self.processor)
        self.is_setup = True
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.cfg.batch_size,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True,
                          shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.cfg.batch_size,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True,
                          shuffle=False)

In [8]:
def char_accuracy(y_p,y):
    y_p,y = list(y_p),list(y)
    ln = min(len(y_p),len(y))
    score = 0
    for i in range(ln):
        if y_p[i] == y[i]:
            score += 1
    return score / max(len(y_p),len(y))

def correct_part(y_p):
    if Chem.MolFromSmiles(y_p) is None:
        return 0
    else:
        return 1

def tanimoto(y_p,y):
    try:
        mol1 = Chem.MolFromSmiles(y_p)
        mol2 = Chem.MolFromSmiles(y)
    
        vec_1 = AllChem.RDKFingerprint(mol1)
        vec_2 = AllChem.RDKFingerprint(mol2)
        return DataStructs.TanimotoSimilarity(vec_1,vec_2)
    except:
        return 0

In [9]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.preds = []
        self.targets = []

    def update(self,preds,targets):
        self.preds += preds
        self.targets += targets
    
    def calc_metrics(self):
        f = dict()
        
        f['char_acc'] = np.mean([char_accuracy(x,y) for x,y in zip(self.preds,self.targets)])
        f['corrent_part'] = np.mean([correct_part(x) for x in self.preds])
        f['tanimoto'] = np.mean([tanimoto(x,y) for x,y in zip(self.preds,self.targets)])
        
        return f

In [10]:
class PLModule(pl.LightningModule):
    def __init__(self,model,tokenizer):
        super().__init__()
        self.cfg = CFG()
        self.avg_meter = AverageMeter()
        self.model = model
        self.tokenizer = tokenizer
        
    def forward(self,image,input_ids=None,attention_mask=None):
        return self.model(pixel_values=image,labels=input_ids,decoder_attention_mask=attention_mask)   

    def training_step(self, batch, _):
        loss = self(**batch).loss
        self.log_dict({'train_loss':loss.item()})
        return loss
        
    def validation_step(self, batch, _):
        labels = batch['input_ids'].detach().cpu().numpy()
        labels = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in labels]
        
        logits = self.model.generate(
            batch['image'],
            num_beams=1,
            max_length=128
        )
        
        logits = logits.detach().cpu().numpy()
        logits = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in logits]
        self.avg_meter.update(logits,labels)
    
    def predict_step(self,batch,_):
        
        logits = self.model.generate(
            batch['image'],
            num_beams=1,
            max_length=128,
        )
        
        logits = logits.detach().cpu().numpy()
        logits = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in logits]
        return logits
    
    def on_validation_epoch_end(self):
        f = self.avg_meter.calc_metrics()
        self.log_dict(f)
        print(f)
        self.avg_meter.reset()
            
    def configure_optimizers(self):
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.encoder.named_parameters()],
                "lr":self.cfg.encoder_lr
            },
            {
                "params": [p for n, p in self.model.decoder.named_parameters()],
                "lr": self.cfg.decoder_lr
            },
        ]
        return torch.optim.AdamW(optimizer_grouped_parameters,
                                 betas=self.cfg.betas,
                                 weight_decay=self.cfg.weight_decay,
                                 eps=self.cfg.eps)

In [11]:
tokenizer = RobertaTokenizerFast.from_pretrained(CFG.decoder)
processor = AutoProcessor.from_pretrained(CFG.encoder)
#processor = ViTImageProcessor.from_pretrained('proc_swin')

In [12]:
encoder = ViTModel(ViTConfig(hidden_size=384,
                             hidden_act='gelu',
                             image_size=384,
                             num_attention_heads=6,
                             num_hidden_layers=12,
                             num_channels=1,
                             intermediate_size=384 * 4,
                             patch_size=16))

decoder = TrOCRForCausalLM(TrOCRConfig(vocab_size=len(tokenizer),
                                       d_model=256,
                                       decoder_attention_heads=8,
                                       decoder_ffn_dim=1024,
                                       decoder_layers=6,
                                       activation_function='gelu',
                                       max_position_embeddings=384,
                                       dropout=0.2))

In [13]:
model = VisionEncoderDecoderModel(encoder=encoder,decoder=decoder)

In [14]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

In [15]:
dm = PLDataModule(tokenizer,processor)
dm.prepare_data()
dm.setup(0)

In [16]:
model_pl = PLModule(model,tokenizer).cuda()

In [17]:
wandb.login(key="673ae6e9b51cc896110db5327738b993795fffad")
os.environ['WANDB_API_KEY'] = "673ae6e9b51cc896110db5327738b993795fffad"
wandb.init(project='MOLECULA',name='TrOCR_large')

[34m[1mwandb[0m: Currently logged in as: [33mandrey20007[0m ([33mandrey2007[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [18]:
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
checkpoint_cb = pl.callbacks.ModelCheckpoint(
    dirpath='./outputs/',
    filename='base_model{epoch:02d}',
    monitor='tanimoto',
    mode='max',
    save_last=True
)

trainer = pl.Trainer(
    accelerator="gpu",
    precision=32,
    callbacks = [lr_monitor,checkpoint_cb],
    logger = pl.loggers.WandbLogger(),
    min_epochs=1,
    devices=[0],
    check_val_every_n_epoch=1,
    max_epochs=CFG.max_epoches
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [20]:
trainer.validate(model_pl,datamodule=dm,ckpt_path="outputs/last-v1.ckpt")

Restoring states from the checkpoint path at outputs/last-v1.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at outputs/last-v1.ckpt


Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [22]:
model_pl.avg_meter.preds[

['O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1ccc(C(=O)Nc2ccc(C(=O)O)cc2)cc1',
 'O=C(O)c1c

In [50]:
predsv2= np.load('preds16k.npy')

In [27]:
from sklearn.metrics import accuracy_score

In [30]:
def prepare_preds(preds):
    return [p.replace('/','') for p in preds]

In [54]:
score = 0
for p1,p2,y in zip(predsv2,np.concatenate(preds),dm.val_df['smiles']):
    if p1 == y or p2 == y:
        score += 1;

In [57]:
score / len(predsv2)

0.97625

In [52]:
accuracy_score(predsv2,np.concatenate(preds),dm.val_df['smiles'])

0.9466

In [47]:
model_pl.load_state_dict(torch.load('trocr_base_finetunev3_1p.pt'))

<All keys matched successfully>

In [48]:
preds = trainer.predict(model_pl,dm.val_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [19]:
#model_pl.load_state_dict(torch.load('trocr_base_finetuned_1ep.pt'))

<All keys matched successfully>

In [19]:
torch.save(model_pl.state_dict(),'trocr_base_finetunev4_1p.pt')

In [23]:
preds = trainer.predict(model_pl,dm.val_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [17]:
processor = ViTImageProcessor.from_pretrained('./processor_vit')

In [24]:
encoder.save_pretrained('encoder_vit')

In [25]:
decoder.save_pretrained('decoder_vit')

In [18]:
AutoModel.from_pretrained('./encoder_vit')

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
      

In [None]:
!nvidia-smi