In [1]:
from pathlib import Path
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
from dataclasses import dataclass
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from catalyst import dl
from catalyst.utils import set_global_seed


ORIGINAL_PAPER_PATH = Path.home() / "egfr-att"
import sys
sys.path.append(ORIGINAL_PAPER_PATH.as_posix())


from egfr.dataset import EGFRDataset


DEVICE = torch.device('cpu')


SEED = 42
set_global_seed(SEED)


DATA_PATH = ORIGINAL_PAPER_PATH / "egfr/data/egfr_10_full_ft_pd_lines.json"

In [2]:
EXPERIMENT_NAME = 'chemberta-no-descriptor'


@dataclass
class Config:

    use_descriptors: bool = False

    pretrained_path: str = "seyonec/PubChem10M_SMILES_BPE_450k"
    finetune_embeddings: bool = False
    n_layers_to_finetune: int = 2

    batch_size: int = 4
    accumulation_steps: int = 1
        
    num_epochs: int = 50
    patience: int = 5

    scheduler: str = 'OneCycleLR'
    max_lr: float = 0.001
    warmup_prop: float = 0.3

    logdir: str = 'checkpoints/'


config = Config()

In [3]:
model = AutoModelForSequenceClassification.from_pretrained(config.pretrained_path)
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_path)


def get_tokenizer_info(tokenizer):
    for key, value in tokenizer.special_tokens_map.items():
        print(f"{key}:", value, getattr(tokenizer, f"{key}_id"))

get_tokenizer_info(tokenizer)


PAD_TOKEN_ID = tokenizer.pad_token_id

Some weights of the model checkpoint at seyonec/PubChem10M_SMILES_BPE_450k were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at seyonec/PubChem10M_SMILES_BPE_45

bos_token: <s> 0
eos_token: </s> 2
unk_token: <unk> 3
sep_token: </s> 2
pad_token: <pad> 1
cls_token: <s> 0
mask_token: <mask> 4


In [4]:
def freeze_module(module: torch.nn.Module):
    for p in module.parameters():
        p.requires_grad = False


def freeze_pretrained(model: 'RobertaForSequenceClassification', config: Config):
    if not config.finetune_embeddings:
        freeze_module(model.roberta.embeddings)

    n_layers = len(model.roberta.encoder.layer)
    layer_idx_to_stop = n_layers - config.n_layers_to_finetune
    for i, layer in enumerate(model.roberta.encoder.layer):
        if i == layer_idx_to_stop:
            break
        freeze_module(layer)


freeze_pretrained(model, config)

In [5]:
class SequenceEGFRDataset(EGFRDataset):

    def __init__(self, data, tokenizer):
        super().__init__(data, infer=True)
        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id
        self.encode_smiles()

        self.mord_ft = torch.FloatTensor(self.mord_ft)
        self.non_mord_ft = torch.FloatTensor(self.non_mord_ft)
        self.label = torch.LongTensor(self.label)

    def encode_smiles(self):
        self.smiles = [
            torch.LongTensor(self.tokenizer.encode(s))
            for s in self.smiles
        ]

    def collate_fn(self, batch):
        smiles, mord_ft, non_mord_ft, labels = zip(*batch)
        smiles = pad_sequence(
            smiles, batch_first=True, padding_value=self.pad_token_id
        )
        mord_ft = torch.stack(mord_ft)
        non_mord_ft = torch.stack(non_mord_ft)
        labels = torch.stack(labels)
        return smiles, mord_ft, non_mord_ft, labels

    def make_loader(self, *args, **kwargs):
        return DataLoader(self, *args, collate_fn=self.collate_fn, **kwargs)


In [6]:
train, valid = train_test_split(
    pd.read_json(DATA_PATH, lines=True), test_size=0.2, random_state=42 #  42 hard code is from original paper code 
)


train_dataset = SequenceEGFRDataset(train, tokenizer)
valid_dataset = SequenceEGFRDataset(valid, tokenizer)

In [7]:
print('Max train smiles length:', max(len(s) for s in train_dataset.smiles))
print('Max valid smiles length:', max(len(s) for s in valid_dataset.smiles))

Max train smiles length: 100
Max valid smiles length: 93


In [8]:
loaders = {
    'train': train_dataset.make_loader(batch_size=config.batch_size, shuffle=True),
    'valid': valid_dataset.make_loader(batch_size=config.batch_size)
}

In [9]:
def init_scheduler(
    optimizer: torch.optim.Optimizer,
    num_steps_per_epoch: int,
    config: Config
):

    if config.scheduler == 'OneCycleLR':
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=config.max_lr,
            epochs=config.num_epochs,
            steps_per_epoch=num_steps_per_epoch,
            pct_start=config.warmup_prop
        )
        return scheduler, 'batch'

    return None, None


In [10]:
optimizer = torch.optim.Adam(model.parameters())

callbacks = [
    dl.OptimizerCallback(accumulation_steps=config.accumulation_steps),
    dl.EarlyStoppingCallback(patience=config.patience),
    dl.WandbLogger(
        project='egfr-project',
        entity='dimaorekhov',
        group='chemberta-no-descriptor',
        name=EXPERIMENT_NAME,
        config=config.__dict__
    ),
    dl.AUCCallback()
]

scheduler, mode = init_scheduler(optimizer, len(loaders['train']), config)
if scheduler is not None:
    callbacks.append(dl.SchedulerCallback(mode=mode))

In [11]:
class EgfrNoDescriptorRunner(dl.Runner):

    def _handle_batch(self, batch):
        smiles, _, _, labels = batch
        out = self.model(input_ids=smiles, labels=labels)
        self.batch_metrics['loss'] = out.loss
        self.input = {'targets': labels}
        self.output = {'logits': out.logits}


In [12]:
# be careful not to override log dir
Path(config.logdir).mkdir(exist_ok=True)

In [13]:
runner = EgfrNoDescriptorRunner(device=DEVICE)
runner.train(
    model=model, 
    loaders=loaders,
    optimizer=optimizer,
    scheduler=scheduler,        
    num_epochs=config.num_epochs,
    verbose=True,
    logdir=config.logdir,
    callbacks=callbacks,
    check=True
)

[34m[1mwandb[0m: Currently logged in as: [33mdimaorekhov[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.12 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


1/50 * Epoch (train):   0% 2/699 [00:00<02:33,  4.54it/s, loss=0.618, lr=4.000e-05, momentum=0.950]


To get the last learning rate computed by the scheduler, please use `get_last_lr()`.



1/50 * Epoch (train):   1% 4/699 [00:00<02:05,  5.53it/s, loss=0.459, lr=4.000e-05, momentum=0.950]
1/50 * Epoch (valid):   2% 4/175 [00:00<00:11, 14.46it/s, loss=0.645]
[2020-12-11 22:41:42,130] 
1/50 * Epoch 1 (_base): lr=4.000e-05 | momentum=0.9500
1/50 * Epoch 1 (train): auc/class_00=0.4643 | auc/class_01=0.3929 | auc/mean=0.4286 | loss=0.5796 | lr=4.000e-05 | momentum=0.9500
1/50 * Epoch 1 (valid): auc/class_00=0.5333 | auc/class_01=0.2667 | auc/mean=0.4000 | loss=0.2813
2/50 * Epoch (train):   1% 4/699 [00:00<01:49,  6.35it/s, loss=0.089, lr=4.000e-05, momentum=0.950]
2/50 * Epoch (valid):   2% 4/175 [00:00<00:11, 15.01it/s, loss=0.812]
[2020-12-11 22:41:44,165] 
2/50 * Epoch 2 (_base): lr=4.000e-05 | momentum=0.9500
2/50 * Epoch 2 (train): auc/class_00=0.000e+00 | auc/class_01=0.1333 | auc/mean=0.0667 | loss=0.3057 | lr=4.000e-05 | momentum=0.9500
2/50 * Epoch 2 (valid): auc/class_00=0.4667 | auc/class_01=0.2000 | auc/mean=0.3333 | loss=0.2489


VBox(children=(Label(value=' 0.01MB of 0.01MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
auc/class_00/train,0.0
auc/class_01/train,0.13333
auc/mean/train,0.06667
loss/train,0.30573
lr/train,4e-05
momentum/train,0.95
auc/class_00/valid,0.46667
auc/class_01/valid,0.2
auc/mean/valid,0.33333
loss/valid,0.24894


0,1
auc/class_00/train,█▁
auc/class_01/train,█▁
auc/mean/train,█▁
loss/train,█▁
lr/train,▁█
momentum/train,█▁
auc/class_00/valid,█▁
auc/class_01/valid,█▁
auc/mean/valid,█▁
loss/valid,█▁


Top best models:
checkpoints/checkpoints/train.2.pth	0.2489
