In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install wandb
!pip install transformers==4.0.0
!pip install catalyst==20.11



In [3]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mdimaorekhov[0m (use `wandb login --relogin` to force relogin)


In [1]:
!git clone https://github.com/lehgtrung/egfr-att

Cloning into 'egfr-att'...
remote: Enumerating objects: 5, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 421 (delta 0), reused 0 (delta 0), pack-reused 416[K
Receiving objects: 100% (421/421), 18.81 MiB | 3.19 MiB/s, done.
Resolving deltas: 100% (210/210), done.


In [2]:
from pathlib import Path
import json
from transformers import AutoTokenizer, AutoModel
import pandas as pd
from dataclasses import dataclass
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from catalyst import dl
from catalyst.utils import set_global_seed


ORIGINAL_PAPER_PATH = Path("egfr-att")
import sys
sys.path.append(ORIGINAL_PAPER_PATH.as_posix())


from egfr.dataset import EGFRDataset, train_cross_validation_split



DEVICE = torch.device('cuda')


SEED = 21
set_global_seed(SEED)


DATA_PATH = ORIGINAL_PAPER_PATH / "egfr/data/egfr_10_full_ft_pd_lines.json"

In [3]:
EXPERIMENT_NAME = 'chemberta-with-descriptor'


@dataclass
class Config:

    pretrained_path: str = "seyonec/PubChem10M_SMILES_BPE_450k"
    finetune_embeddings: bool = False
    n_layers_to_finetune: int = 2

    batch_size: int = 16
    accumulation_steps: int = 8
  
    num_epochs: int = 100
    patience: int = 10

    scheduler: str = 'OneCycleLR'
    max_lr: float = 0.00005
    warmup_prop: float = 0.2

    logdir: str = f'drive/MyDrive/logdir_{EXPERIMENT_NAME}'


config = Config()

In [None]:
def get_tokenizer_info(tokenizer):
    for key, value in tokenizer.special_tokens_map.items():
        print(f"{key}:", value, getattr(tokenizer, f"{key}_id"))


def freeze_module(module: torch.nn.Module):
    for p in module.parameters():
        p.requires_grad = False


def freeze_pretrained(model: 'RobertaModel', config: Config):
    if not config.finetune_embeddings:
        freeze_module(model.embeddings)

    n_layers = len(model.encoder.layer)
    layer_idx_to_stop = n_layers - config.n_layers_to_finetune
    for i, layer in enumerate(model.encoder.layer):
        if i == layer_idx_to_stop:
            break
        freeze_module(layer)

In [None]:
class SequenceEGFRDataset(EGFRDataset):

    def __init__(self, data, tokenizer):
        super().__init__(data, infer=True)
        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id
        self.encode_smiles()

        self.mord_ft = torch.FloatTensor(self.mord_ft)
        self.non_mord_ft = torch.FloatTensor(self.non_mord_ft)
        self.label = torch.LongTensor(self.label)

    def encode_smiles(self):
        self.smiles = [
            torch.LongTensor(self.tokenizer.encode(s))
            for s in self.smiles
        ]

    def collate_fn(self, batch):
        smiles, mord_ft, non_mord_ft, labels = zip(*batch)
        smiles = pad_sequence(
            smiles, batch_first=True, padding_value=self.pad_token_id
        )
        mord_ft = torch.stack(mord_ft)
        non_mord_ft = torch.stack(non_mord_ft)
        labels = torch.stack(labels)
        return smiles, mord_ft, non_mord_ft, labels

    def make_loader(self, *args, **kwargs):
        return DataLoader(self, *args, collate_fn=self.collate_fn, **kwargs)


In [None]:
class ModelWithDescriptor(nn.Module):

    def __init__(self, transformer, dense_dim):
      super().__init__()
      self.transformer = transformer
      self.dropout_prob = transformer.config.hidden_dropout_prob
      self.dense = nn.Sequential(
          nn.Linear(dense_dim, 512),
          nn.ReLU(),
          nn.BatchNorm1d(512),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(512, 128),
          nn.ReLU(),
          nn.BatchNorm1d(128),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(128, 64),
          nn.ReLU(),
          nn.BatchNorm1d(64),
          nn.Dropout(p=self.dropout_prob)
      )
      self.fc_out = nn.Linear(transformer.config.hidden_size + 64, 1)

    def forward(self, smiles, descriptor):
        pooler_out = self.transformer(input_ids=smiles).pooler_output
        pooler_out = torch.nn.functional.dropout(pooler_out, p=self.dropout_prob)
        dense_out = self.dense(descriptor)
        return self.fc_out(torch.cat([pooler_out, dense_out], dim=-1))




In [14]:
def init_scheduler(
    optimizer: torch.optim.Optimizer,
    num_steps_per_epoch: int,
    config: Config
):

    if config.scheduler == 'OneCycleLR':
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=config.max_lr,
            epochs=config.num_epochs,
            steps_per_epoch=num_steps_per_epoch,
            pct_start=config.warmup_prop
        )
        return scheduler, 'batch'

    return None, None


class EgfrWithDescriptorRunner(dl.Runner):

    def _handle_batch(self, batch):
        smiles, mord, _, labels = batch
        out = self.model(smiles, mord)
        self.batch_metrics['loss'] = torch.nn.functional.binary_cross_entropy_with_logits(
            out, labels.unsqueeze(-1).to(torch.float32)
        )
        self.input = {'targets': labels}
        self.output = {'logits': out}


In [4]:
def experiment(train, valid, config, experiment_name, fold_idx):

    pretrained_model = AutoModel.from_pretrained(config.pretrained_path)
    tokenizer = AutoTokenizer.from_pretrained(config.pretrained_path)
    PAD_TOKEN_ID = tokenizer.pad_token_id
    freeze_pretrained(pretrained_model, config)

    train_dataset = SequenceEGFRDataset(train, tokenizer)
    valid_dataset = SequenceEGFRDataset(valid, tokenizer)

    model = ModelWithDescriptor(pretrained_model, dense_dim=train_dataset.mord_ft.size(-1))
    
    loaders = {
        'train': train_dataset.make_loader(batch_size=config.batch_size, shuffle=True),
        'valid': valid_dataset.make_loader(batch_size=config.batch_size)
    }
    
    optimizer = torch.optim.Adam(model.parameters())

    callbacks = [
        dl.OptimizerCallback(accumulation_steps=config.accumulation_steps),
        dl.EarlyStoppingCallback(patience=config.patience),
        dl.WandbLogger(
            project='egfr-project',
            entity='dimaorekhov',
            group=f"{EXPERIMENT_NAME}_CV",
            name=f"{EXPERIMENT_NAME}_fold_{fold_idx}",
            config=config.__dict__
        ),
        dl.AUCCallback()
    ]

    scheduler, mode = init_scheduler(optimizer, len(loaders['train']), config)
    if scheduler is not None:
        callbacks.append(dl.SchedulerCallback(mode=mode))
        
    # be careful not to override log dir
    Path(config.logdir).mkdir(exist_ok=True)
    
    runner = EgfrWithDescriptorRunner(device=DEVICE)
    runner.train(
        model=model,
        loaders=loaders,
        optimizer=optimizer,
        scheduler=scheduler,        
        num_epochs=config.num_epochs,
        verbose=True,
        logdir=config.logdir,
        callbacks=callbacks
    )
    
    model.to(torch.device("cpu"))

bos_token: <s> 0
eos_token: </s> 2
unk_token: <unk> 3
sep_token: </s> 2
pad_token: <pad> 1
cls_token: <s> 0
mask_token: <mask> 4


In [None]:
for i, (train, valid) in enumerate(train_cross_validation_split(DATA_PATH.as_posix())):
    experiment(train, valid, config, EXPERIMENT_NAME, i)