In [1]:
!pip install transformers -U

Collecting transformers
  Obtaining dependency information for transformers from https://files.pythonhosted.org/packages/20/0a/739426a81f7635b422fbe6cb8d1d99d1235579a6ac8024c13d743efa6847/transformers-4.36.2-py3-none-any.whl.metadata
  Downloading transformers-4.36.2-py3-none-any.whl.metadata (126 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.8/126.8 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
Downloading transformers-4.36.2-py3-none-any.whl (8.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m69.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.36.0
    Uninstalling transformers-4.36.0:
      Successfully uninstalled transformers-4.36.0
Successfully installed transformers-4.36.2


In [2]:
%%capture
!wget https://storage.yandexcloud.net/ds-ods/files/content/2023/12/21/03e203ba/train_data.zip

In [3]:
%%capture
!unzip /kaggle/working/train_data.zip
!pip install rdkit

In [4]:
!pip install einops

Collecting einops
  Obtaining dependency information for einops from https://files.pythonhosted.org/packages/29/0b/2d1c0ebfd092e25935b86509a9a817159212d82aa43d7fb07eca4eeff2c2/einops-0.7.0-py3-none-any.whl.metadata
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m121.6 kB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [5]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import timm
import torch
from torch import nn
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch.backends.cudnn as cudnn
import wandb
import pytorch_lightning as pl
from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence
from transformers import (AutoProcessor,
                          AutoTokenizer,
                          VisionEncoderDecoderModel,
                          RobertaTokenizerFast,
                          TrOCRForCausalLM,
                          AutoModel,
                          TrOCRConfig,
                          ViTImageProcessor,
                          Swinv2Model,
                          Swinv2Config,
                          GPT2TokenizerFast
                         )
from sklearn.model_selection import train_test_split
from rdkit import RDLogger,Chem
from rdkit.Chem import AllChem,DataStructs
import torch.nn.functional as F
import os
from rdkit.DataStructs import TanimotoSimilarity
from Levenshtein import distance as levenshtein_distance
pl.seed_everything(56)



56

In [6]:
import torch
from torch import nn
from transformers import ViTConfig,VisionEncoderDecoderModel,AutoModelWithLMHead,AutoTokenizer
from transformers.modeling_outputs import BaseModelOutput

import timm
from einops import rearrange

In [7]:
RDLogger.DisableLog('rdApp.*')
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [8]:
class CFG:
    wandb=False
    encoder="google/vit-base-patch16-384"
    decoder="microsoft/resnet50"
    train_path = './train.csv'
    train_folder = './train/'
    betas=(0.9, 0.999)
    img_size = 512
    max_pred_len = 128
    val_split_size = 0.2
    scheduler = None
    emb_dim = 512  
    attention_dim = 512
    freq_threshold = 2
    decoder_dim = 512
    img_size=512
    dropout = 0.4
    eps=1e-6
    num_workers = 2
    batch_size = 8
    encoder_lr = 1e-4 
    decoder_lr = 2e-4
    weight_decay = 0.01
    fine_tune_encoder = False
    max_epoches=6
    seed=56

In [9]:
class PLDataset(Dataset):
    def __init__(self, df, tokenizer,processor):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.processor = processor
        self.photo_dir = CFG.train_folder

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = Image.open(self.photo_dir+str(self.df.iloc[idx]['id'])+'.png').convert('RGB')
        label = self.df.iloc[idx]['smiles']
        image = self.processor(image,return_tensors='pt').pixel_values
        label_enc = self.tokenizer.encode_plus(label, padding='max_length',max_length=128, truncation=True, return_tensors='pt')
        return {'image':image.squeeze(0),
                'input_ids':label_enc.input_ids.squeeze(0),
                'attention_mask':label_enc.attention_mask.squeeze(0)}


In [10]:
class PLDataModule(pl.LightningDataModule):
    def __init__(self,tokenizer,processor):
        super().__init__()
        self.cfg = CFG()
        self.is_setup = False
        self.tokenizer = tokenizer
        self.processor = processor
        
    def prepare_data(self):
        self.train_data = pd.read_csv(CFG.train_path)
        
    def setup(self, stage: str):
        self.train_df, self.val_df = train_test_split(self.train_data, test_size=self.cfg.val_split_size,random_state=self.cfg.seed)
        self.train_df = self.train_df.reset_index(drop=True)
        self.val_df = self.val_df.reset_index(drop=True)
        self.train_dataset = PLDataset(self.train_df,self.tokenizer,self.processor)
        self.val_dataset = PLDataset(self.val_df,self.tokenizer,self.processor)
        self.is_setup = True
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.cfg.batch_size,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True,
                          shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.cfg.batch_size,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True,
                          shuffle=False)

In [11]:
def char_accuracy(y_p,y):
    y_p,y = list(y_p),list(y)
    ln = min(len(y_p),len(y))
    score = 0
    for i in range(ln):
        if y_p[i] == y[i]:
            score += 1
    return score / max(len(y_p),len(y))

def correct_part(y_p):
    if Chem.MolFromSmiles(y_p) is None:
        return 0
    else:
        return 1

def tanimoto(y_p,y):
    try:
        mol1 = Chem.MolFromSmiles(y_p)
        mol2 = Chem.MolFromSmiles(y)
    
        vec_1 = AllChem.RDKFingerprint(mol1)
        vec_2 = AllChem.RDKFingerprint(mol2)
        return DataStructs.TanimotoSimilarity(vec_1,vec_2)
    except:
        return 0

In [12]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.preds = []
        self.targets = []

    def update(self,preds,targets):
        self.preds += preds
        self.targets += targets
    
    def calc_metrics(self):
        f = dict()
        
        f['char_acc'] = np.mean([char_accuracy(x,y) for x,y in zip(self.preds,self.targets)])
        f['corrent_part'] = np.mean([correct_part(x) for x in self.preds])
        f['tanimoto'] = np.mean([tanimoto(x,y) for x,y in zip(self.preds,self.targets)])
        
        return f

In [13]:
class PLModule(pl.LightningModule):
    def __init__(self,model,tokenizer):
        super().__init__()
        self.cfg = CFG()
        self.avg_meter = AverageMeter()
        self.model = model
        self.tokenizer = tokenizer
        
    def forward(self,image,input_ids=None,attention_mask=None):
        return self.model(pixel_values=image,labels=input_ids,decoder_attention_mask=attention_mask)   

    def training_step(self, batch, _):
        loss = self(**batch).loss
        self.log_dict({'train_loss':loss.item()})
        return loss
        
    
    def predict_step(self,batch,_):
        
        logits = self.model.generate(
            batch['image'],
            num_beams=1,
            max_length=128
        )
        
        logits = logits.detach().cpu().numpy()
        logits = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in logits]
        return logits
    
    def on_validation_epoch_end(self):
        f = self.avg_meter.calc_metrics()
        self.log_dict(f)
        print(f)
        self.avg_meter.reset()
            
    def configure_optimizers(self):
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.encoder.named_parameters()],
                "lr":self.cfg.encoder_lr
            },
            {
                "params": [p for n, p in self.model.decoder.named_parameters()],
                "lr": self.cfg.decoder_lr
            },
        ]
        return torch.optim.AdamW(optimizer_grouped_parameters,
                                 betas=self.cfg.betas,
                                 weight_decay=self.cfg.weight_decay,
                                 eps=self.cfg.eps)

In [14]:
tokenizer = AutoTokenizer.from_pretrained('entropy/roberta_zinc_480m')
processor = AutoProcessor.from_pretrained(CFG.encoder)
#processor = ViTImageProcessor.from_pretrained('proc_swin')

config.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/40.5k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/24.4k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

In [15]:
processor.size = {"height": 384,"width": 384}

In [16]:
class TimmFeatureEncoder(nn.Module):
  def __init__(self,backbone='resnet50',hidden_size=2048,pool_dims=(8,8)):
    super().__init__()
    self.main_input_name = backbone
    self.model = timm.create_model(backbone,pretrained=True)
    self.config = ViTConfig(hidden_size=hidden_size)
    self.adaptive_pool = nn.AdaptiveAvgPool2d(pool_dims)

  def forward(self,pixel_values,**kwargs):
    features = self.model.forward_features(pixel_values)
    features = self.adaptive_pool(features)
    features = rearrange(features,'b x c h -> b (c h) x')
    return BaseModelOutput(last_hidden_state=features,
                           hidden_states=None,
                           attentions=None)

  def get_output_embeddings(self,):
    return None

encoder = TimmFeatureEncoder(backbone='maxvit_tiny_tf_384.in1k',
                         hidden_size=512)

model.safetensors:   0%|          | 0.00/124M [00:00<?, ?B/s]

In [17]:
decoder = AutoModelWithLMHead.from_pretrained("entropy/roberta_zinc_decoder")



config.json:   0%|          | 0.00/935 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/237M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [18]:
model = VisionEncoderDecoderModel(encoder=encoder,decoder=decoder)

In [19]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

In [20]:
dm = PLDataModule(tokenizer,processor)
dm.prepare_data()
dm.setup(0)

In [21]:
model_pl = PLModule(model,tokenizer).cuda()

In [22]:
wandb.login(key="673ae6e9b51cc896110db5327738b993795fffad")
os.environ['WANDB_API_KEY'] = "673ae6e9b51cc896110db5327738b993795fffad"
wandb.init(project='MOLECULA',name='Maxvit_GPT')

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Currently logged in as: [33mandrey20007[0m ([33mandrey2007[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [23]:
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
checkpoint_cb = pl.callbacks.ModelCheckpoint(
    dirpath='./outputs/',
    filename='base_model{epoch:02d}',
    monitor='tanimoto',
    mode='max',
    save_last=True
)

trainer = pl.Trainer(
    accelerator="gpu",
    precision=32,
    callbacks = [lr_monitor,checkpoint_cb],
    logger = pl.loggers.WandbLogger(),
    min_epochs=1,
    devices=[0],
    check_val_every_n_epoch=1,
    max_epochs=CFG.max_epoches
)

In [None]:
trainer.fit(model_pl,datamodule=dm)

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.


Training: |          | 0/? [00:00<?, ?it/s]