In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install wandb
!pip install transformers==4.0.0
!pip install catalyst==20.11



In [3]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mdimaorekhov[0m (use `wandb login --relogin` to force relogin)


In [4]:
!git clone https://github.com/lehgtrung/egfr-att

fatal: destination path 'egfr-att' already exists and is not an empty directory.


In [5]:
from pathlib import Path
import json
from transformers import AutoTokenizer, BertModel, BertConfig
import pandas as pd
from dataclasses import dataclass
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from catalyst import dl
from catalyst.utils import set_global_seed


ORIGINAL_PAPER_PATH = Path("egfr-att")
import sys
sys.path.append(ORIGINAL_PAPER_PATH.as_posix())


from egfr.dataset import EGFRDataset


DEVICE = torch.device('cuda')


SEED = 21
set_global_seed(SEED)


DATA_PATH = ORIGINAL_PAPER_PATH / "egfr/data/egfr_10_full_ft_pd_lines.json"

In [6]:
EXPERIMENT_NAME = 'transformer-with-descriptor'


@dataclass
class Config:

    tokenizer_path: str = "seyonec/PubChem10M_SMILES_BPE_450k"

    hidden_size: int = 512
    num_hidden_layers: int = 2
    num_attention_heads: int = 8
    intermediate_size: int = 2048
    hidden_dropout_prob: float = 0.1
    attention_probs_dropout_prob: float = 0.1

    batch_size: int = 16
    accumulation_steps: int = 8
  
    num_epochs: int = 100
    patience: int = 10

    scheduler: str = 'OneCycleLR'
    max_lr: float = 0.0005
    warmup_prop: float = 0.2

    logdir: str = f'drive/MyDrive/logdir_{EXPERIMENT_NAME}'


config = Config()

In [7]:
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)


def get_tokenizer_info(tokenizer):
    for key, value in tokenizer.special_tokens_map.items():
        print(f"{key}:", value, getattr(tokenizer, f"{key}_id"))

get_tokenizer_info(tokenizer)


PAD_TOKEN_ID = tokenizer.pad_token_id

bos_token: <s> 0
eos_token: </s> 2
unk_token: <unk> 3
sep_token: </s> 2
pad_token: <pad> 1
cls_token: <s> 0
mask_token: <mask> 4


In [8]:
model_config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=config.hidden_size,
    num_hidden_layers=config.num_hidden_layers,
    num_attention_heads=config.num_attention_heads,
    intermediate_size=config.intermediate_size,
    hidden_dropout_prob=config.hidden_dropout_prob,
    attention_probs_dropout_prob=config.attention_probs_dropout_prob,
    pad_token_id=PAD_TOKEN_ID
)
transformer = BertModel(config=model_config)

In [9]:
class SequenceEGFRDataset(EGFRDataset):

    def __init__(self, data, tokenizer):
        super().__init__(data, infer=True)
        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id
        self.encode_smiles()

        self.mord_ft = torch.FloatTensor(self.mord_ft)
        self.non_mord_ft = torch.FloatTensor(self.non_mord_ft)
        self.label = torch.LongTensor(self.label)

    def encode_smiles(self):
        self.smiles = [
            torch.LongTensor(self.tokenizer.encode(s))
            for s in self.smiles
        ]

    def collate_fn(self, batch):
        smiles, mord_ft, non_mord_ft, labels = zip(*batch)
        smiles = pad_sequence(
            smiles, batch_first=True, padding_value=self.pad_token_id
        )
        mord_ft = torch.stack(mord_ft)
        non_mord_ft = torch.stack(non_mord_ft)
        labels = torch.stack(labels)
        return smiles, mord_ft, non_mord_ft, labels

    def make_loader(self, *args, **kwargs):
        return DataLoader(self, *args, collate_fn=self.collate_fn, **kwargs)


In [10]:
train, valid = train_test_split(
    pd.read_json(DATA_PATH, lines=True), test_size=0.2, random_state=42 #  42 hard code is from original paper code 
)


train_dataset = SequenceEGFRDataset(train, tokenizer)
valid_dataset = SequenceEGFRDataset(valid, tokenizer)

In [11]:
print('Max train smiles length:', max(len(s) for s in train_dataset.smiles))
print('Max valid smiles length:', max(len(s) for s in valid_dataset.smiles))

Max train smiles length: 100
Max valid smiles length: 93


In [12]:
class ModelWithDescriptor(nn.Module):

    def __init__(self, transformer, dense_dim):
      super().__init__()
      self.transformer = transformer
      self.dropout_prob = transformer.config.hidden_dropout_prob
      self.dense = nn.Sequential(
          nn.Linear(dense_dim, 512),
          nn.ReLU(),
          nn.BatchNorm1d(512),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(512, 128),
          nn.ReLU(),
          nn.BatchNorm1d(128),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(128, 64),
          nn.ReLU(),
          nn.BatchNorm1d(64),
          nn.Dropout(p=self.dropout_prob)
      )
      self.fc_out = nn.Linear(transformer.config.hidden_size + 64, 1)

    def forward(self, smiles, descriptor):
        pooler_out = self.transformer(input_ids=smiles).pooler_output
        pooler_out = torch.nn.functional.dropout(pooler_out, p=self.dropout_prob)
        dense_out = self.dense(descriptor)
        return self.fc_out(torch.cat([pooler_out, dense_out], dim=-1))


model = ModelWithDescriptor(transformer, dense_dim=train_dataset.mord_ft.size(-1))

In [13]:
loaders = {
    'train': train_dataset.make_loader(batch_size=config.batch_size, shuffle=True),
    'valid': valid_dataset.make_loader(batch_size=config.batch_size)
}

In [14]:
def init_scheduler(
    optimizer: torch.optim.Optimizer,
    num_steps_per_epoch: int,
    config: Config
):

    if config.scheduler == 'OneCycleLR':
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=config.max_lr,
            epochs=config.num_epochs,
            steps_per_epoch=num_steps_per_epoch,
            pct_start=config.warmup_prop
        )
        return scheduler, 'batch'

    return None, None


In [15]:
optimizer = torch.optim.Adam(model.parameters())

callbacks = [
    dl.OptimizerCallback(accumulation_steps=config.accumulation_steps),
    dl.EarlyStoppingCallback(patience=config.patience),
    dl.WandbLogger(
        project='egfr-project',
        entity='dimaorekhov',
        group='transformer-with-descriptor',
        name=EXPERIMENT_NAME,
        config=config.__dict__
    ),
    dl.AUCCallback()
]

scheduler, mode = init_scheduler(optimizer, len(loaders['train']), config)
if scheduler is not None:
    callbacks.append(dl.SchedulerCallback(mode=mode))

In [16]:
class EgfrWithDescriptorRunner(dl.Runner):

    def _handle_batch(self, batch):
        smiles, mord, _, labels = batch
        out = self.model(smiles, mord)
        self.batch_metrics['loss'] = torch.nn.functional.binary_cross_entropy_with_logits(
            out, labels.unsqueeze(-1).to(torch.float32)
        )
        self.input = {'targets': labels}
        self.output = {'logits': out}


In [17]:
# be careful not to override log dir
Path(config.logdir).mkdir(exist_ok=True)

In [18]:
runner = EgfrWithDescriptorRunner(device=DEVICE)
runner.train(
    model=model,
    loaders=loaders,
    optimizer=optimizer,
    scheduler=scheduler,        
    num_epochs=config.num_epochs,
    verbose=True,
    logdir=config.logdir,
    callbacks=callbacks
)

[34m[1mwandb[0m: Currently logged in as: [33mdimaorekhov[0m (use `wandb login --relogin` to force relogin)


1/100 * Epoch (train):   3% 6/175 [00:00<00:07, 21.55it/s, loss=0.669, lr=2.000e-05, momentum=0.950]


Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate


To get the last learning rate computed by the scheduler, please use `get_last_lr()`.



1/100 * Epoch (train): 100% 175/175 [00:03<00:00, 46.69it/s, loss=0.398, lr=2.296e-05, momentum=0.949]
1/100 * Epoch (valid): 100% 44/44 [00:00<00:00, 81.58it/s, loss=0.313]
[2020-12-12 00:13:24,331] 
1/100 * Epoch 1 (_base): lr=2.296e-05 | momentum=0.9494
1/100 * Epoch 1 (train): auc/class_00=0.5365 | auc/mean=0.5365 | loss=0.4547 | lr=2.099e-05 | momentum=0.9498
1/100 * Epoch 1 (valid): auc/class_00=0.6990 | auc/mean=0.6990 | loss=0.4263
2/100 * Epoch (train): 100% 175/175 [00:03<00:00, 46.33it/s, loss=0.346, lr=3.175e-05, momentum=0.948]
2/100 * Epoch (valid): 100% 44/44 [00:00<00:00, 113.33it/s, loss=0.336]
[2020-12-12 00:13:31,038] 
2/100 * Epoch 2 (_base): lr=3.175e-05 | momentum=0.9476
2/100 * Epoch 2 (train): auc/class_00=0.6654 | auc/mean=0.6654 | loss=0.3978 | lr=2.689e-05 | momentum=0.9486
2/100 * Epoch 2 (valid): auc/class_00=0.7769 | auc/mean=0.7769 | loss=0.3827
3/100 * Epoch (train): 100% 175/175 [00:04<00:00, 41.02it/s, loss=0.262, lr=4.617e-05, momentum=0.945]
3/100 * 

VBox(children=(Label(value=' 0.01MB of 0.01MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
auc/class_00/train,0.98598
auc/mean/train,0.98598
loss/train,0.1003
lr/train,0.0005
momentum/train,0.8502
auc/class_00/valid,0.90253
auc/mean/valid,0.90253
loss/valid,0.2739
lr/_base,0.0005
momentum/_base,0.85


0,1
auc/class_00/train,▁▃▅▆▆▇▇▇▇███████████
auc/mean/train,▁▃▅▆▆▇▇▇▇███████████
loss/train,█▇▆▆▅▅▄▄▃▂▂▂▂▂▂▁▁▁▁▁
lr/train,▁▁▁▂▂▂▃▃▄▄▅▅▆▆▇▇▇███
momentum/train,███▇▇▇▆▆▅▅▄▄▃▃▂▂▂▁▁▁
auc/class_00/valid,▁▃▅▆▇▇█████████▇███▇
auc/mean/valid,▁▃▅▆▇▇█████████▇███▇
loss/valid,█▇▆▅▄▄▂▃▃▁▂▂▁▃▂▃▂▃▃▃
lr/_base,▁▁▁▂▂▂▃▃▄▄▅▆▆▇▇▇████
momentum/_base,███▇▇▇▆▆▅▅▄▃▃▂▂▂▁▁▁▁


Top best models:
drive/MyDrive/logdir_transformer-with-descriptor/checkpoints/train.10.pth	0.2164
