In [1]:
import os
os.environ['HF_HOME'] = os.path.join(os.getcwd(), 'hf_cache')
from transformers import XLNetConfig, XLNetModel, XLNetTokenizer, XLNetLMHeadModel, BartTokenizerFast, BartForSequenceClassification, BartConfig
from datasets import load_dataset
import utils
from transformers.data.data_collator import DataCollatorWithPadding
import numpy as np
from tqdm import tqdm
import utils
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import PreTokenizer
import train_utils
import torch
import torch.optim as optim

In [4]:
base_tk = Tokenizer.from_file("models/tk-vs1000_frozen.json")
tokenizer = BartTokenizerFast(tokenizer_object=base_tk)
tokenizer.backend_tokenizer.pre_tokinzer = PreTokenizer.custom(utils.SmilesPreTokenizer())

In [3]:
#NOTE: Datasets have had the last elements removed to make them even. TODO: Figure out why huggingface can't handle odd numbers. 
train_raw = load_dataset('csv', data_files=['data/ogb_molhiv/train_hiv.csv'])
test_raw = load_dataset('csv', data_files=['data/ogb_molhiv/test_hiv.csv'])
valid_raw = load_dataset('csv', data_files=['data/ogb_molhiv/valid_hiv.csv'])

Using custom data configuration default-0250d14726bc71f8
Reusing dataset csv (e:\molnlp\mol-prop\hf_cache\datasets\csv\default-0250d14726bc71f8\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
100%|██████████| 1/1 [00:00<00:00, 500.33it/s]
Using custom data configuration default-d85d5e570e46467b
Reusing dataset csv (e:\molnlp\mol-prop\hf_cache\datasets\csv\default-d85d5e570e46467b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
100%|██████████| 1/1 [00:00<00:00, 40.38it/s]
Using custom data configuration default-e9711672e6359f2b
Reusing dataset csv (e:\molnlp\mol-prop\hf_cache\datasets\csv\default-e9711672e6359f2b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
100%|██████████| 1/1 [00:00<00:00, 200.02it/s]


In [4]:
train_raw['train'][0]

{'mol_id': None,
 'smiles': 'Nc1ccc(C=Cc2ccc(N)cc2S(=O)(=O)O)c(S(=O)(=O)O)c1',
 'HIV_active': 0}

In [5]:

def tokenize_function_hiv(examples):
    out_dict = tokenizer(examples['smiles'])
    out_dict['input_ids'] = [tok_smi + [tokenizer.eos_token_id] for tok_smi in out_dict['input_ids']]
    out_dict['attention_mask'] = [am + [1] for am in out_dict['attention_mask']]

    out_dict['label'] = [int(x) for x in examples['HIV_active']]
    return out_dict

train_ds = train_raw.map(tokenize_function_hiv, batched=True, remove_columns=["smiles","HIV_active", "mol_id"])['train']
test_ds = test_raw.map(tokenize_function_hiv, batched=True, remove_columns=["smiles","HIV_active", "mol_id"])['train']
valid_ds = valid_raw.map(tokenize_function_hiv, batched=True, remove_columns=["smiles","HIV_active", "mol_id"])['train']

Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\csv\default-0250d14726bc71f8\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a\cache-db5266eb88a56f74.arrow
Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\csv\default-d85d5e570e46467b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a\cache-524e0413f889fe00.arrow
Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\csv\default-e9711672e6359f2b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a\cache-5ada4329e62d32db.arrow


In [6]:
model_path = "models/bart-pre-hiv-b3-l4/checkpoint-197406"
model = BartForSequenceClassification.from_pretrained(model_path, 
                                                      config=model_path, 
                                                      num_labels=2,
                                                      pad_token_id=tokenizer.pad_token_id, 
                                                      bos_token_id=tokenizer.bos_token_id, 
                                                      eos_token_id=tokenizer.eos_token_id,
                                                    )

"""
n_layer = 1
model_config = BartConfig(
    vocab_size=tokenizer.vocab_size,
    encoder_layers=n_layer,
    decoder_layers=n_layer,
    num_labels=2,
    pad_token_id=tokenizer.pad_token_id, 
    bos_token_id=tokenizer.bos_token_id, 
    eos_token_id=tokenizer.eos_token_id,
)
model =  BartForSequenceClassification(model_config)
"""

Some weights of the model checkpoint at models/bart-pre-hiv-b3-l4/checkpoint-197406 were not used when initializing BartForSequenceClassification: ['lm_head.weight', 'final_logits_bias']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BartForSequenceClassification were not initialized from the model checkpoint at models/bart-pre-hiv-b3-l4/checkpoint-197406 and are newly initialized: ['classification_head.dense.weight', 'classification_head.out_proj.weight', 'classification_head.dense.bias', 'classification_head.out_proj.bias']
You sh

'\nn_layer = 1\nmodel_config = BartConfig(\n    vocab_size=tokenizer.vocab_size,\n    encoder_layers=n_layer,\n    decoder_layers=n_layer,\n    num_labels=2,\n    pad_token_id=tokenizer.pad_token_id, \n    bos_token_id=tokenizer.bos_token_id, \n    eos_token_id=tokenizer.eos_token_id,\n)\nmodel =  BartForSequenceClassification(model_config)\n'

In [7]:
collator = DataCollatorWithPadding(tokenizer)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [8]:
train_utils.trainer(
    model=model,
    optimizer=optimizer,
    collator=collator,
    device=device,
    train_ds=train_ds,
    batch_size_train=8,
    batch_size_eval=8,
    num_epochs=10,
    model_save_dir="models/bart-hiv-19k",
    log_save_file="results/bart-hiv-19k.log",
    compute_metrics=True,
    eval_ds=test_ds,
    valid_ds=valid_ds
)


Epoch: 0


  3%|▎         | 140/4112 [00:23<11:10,  5.93it/s]


KeyboardInterrupt: 