In [1]:
# https://github.com/huggingface/notebooks/blob/master/examples/language_modeling_from_scratch.ipynb
# https://github.com/huggingface/transformers/tree/master/notebooks
# https://huggingface.co/transformers/model_doc/xlnet.html#transformers.XLNetTokenizer 
# https://colab.research.google.com/github/gmihaila/ml_things/blob/master/notebooks/pytorch/pretrain_transformers_pytorch.ipynb#scrollTo=VE2MRZZhd5uM 

In [2]:
import os
os.environ['HF_HOME'] = os.path.join(os.getcwd(), 'hf_cache')

from transformers import BartTokenizerFast, BartConfig, BartForConditionalGeneration, BartForSequenceClassification
from datasets import load_dataset
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import DataCollatorForLanguageModeling
from tokenizers import Tokenizer
from tqdm import tqdm
from tokenizers.pre_tokenizers import PreTokenizer
import utils

In [3]:
base_tk = Tokenizer.from_file("models/tk-vs1000.json")
tokenizer = BartTokenizerFast(tokenizer_object=base_tk)
tokenizer.backend_tokenizer.pre_tokinzer = PreTokenizer.custom(utils.SmilesPreTokenizer())

In [4]:
print(tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str('CS(=O)(=O)OCCCCOS(C)(=O)=O'))

[('CS', (0, 2)), ('(=', (2, 4)), ('O', (4, 5)), (')(=', (5, 8)), ('O', (8, 9)), (')', (9, 10)), ('OCCCCOS', (10, 17)), ('(', (17, 18)), ('C', (18, 19)), (')(=', (19, 22)), ('O', (22, 23)), (')=', (23, 25)), ('O', (25, 26))]


In [5]:
print(tokenizer('CS(=O)(=O)OCCCCOS(C)(=O)=O'))
print(base_tk.encode('CS(=O)(=O)OCCCCOS(C)(=O)=O').ids)

{'input_ids': [27, 39, 9, 24, 36, 164, 24, 36, 10, 36, 282, 282, 36, 39, 9, 27, 164, 24, 36, 10, 24, 36], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
[27, 39, 9, 24, 36, 164, 24, 36, 10, 36, 282, 282, 36, 39, 9, 27, 164, 24, 36, 10, 24, 36]


In [6]:
train_raw = load_dataset('csv', data_files=['data/ogb_molhiv/train_hiv.csv'])
test_raw = load_dataset('csv', data_files=['data/ogb_molhiv/test_hiv.csv'])
valid_raw = load_dataset('csv', data_files=['data/ogb_molhiv/valid_hiv.csv'])

Using custom data configuration default-0250d14726bc71f8
Reusing dataset csv (e:\molnlp\mol-prop\hf_cache\datasets\csv\default-0250d14726bc71f8\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
100%|██████████| 1/1 [00:00<00:00, 92.18it/s]
Using custom data configuration default-d85d5e570e46467b
Reusing dataset csv (e:\molnlp\mol-prop\hf_cache\datasets\csv\default-d85d5e570e46467b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
100%|██████████| 1/1 [00:00<00:00, 56.86it/s]
Using custom data configuration default-e9711672e6359f2b
Reusing dataset csv (e:\molnlp\mol-prop\hf_cache\datasets\csv\default-e9711672e6359f2b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
100%|██████████| 1/1 [00:00<00:00, 51.53it/s]


In [7]:
print(train_raw['train'][2]['smiles'])
input_ids =tokenizer(train_raw['train'][2]['smiles'])['input_ids'] 
print(input_ids)
print(tokenizer.convert_ids_to_tokens(input_ids))

CCOP(=O)(Nc1cccc(Cl)c1)OCC
[282, 36, 37, 9, 24, 36, 164, 35, 169, 339, 339, 9, 63, 10, 169, 10, 36, 282]
['CC', 'O', 'P', '(', '=', 'O', ')(', 'N', 'c1', 'cc', 'cc', '(', 'Cl', ')', 'c1', ')', 'O', 'CC']


In [8]:
def tokenize_function_hiv(examples):
    out_dict = tokenizer(examples["smiles"])
    return out_dict

train_ds = train_raw.map(tokenize_function_hiv, batched=True, remove_columns=["smiles","HIV_active", "mol_id"])['train']
test_ds = test_raw.map(tokenize_function_hiv, batched=True, remove_columns=["smiles","HIV_active", "mol_id"])['train']
valid_ds = valid_raw.map(tokenize_function_hiv, batched=True, remove_columns=["smiles","HIV_active", "mol_id"])['train']

Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\csv\default-0250d14726bc71f8\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a\cache-dc8bc250afca4aba.arrow
Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\csv\default-d85d5e570e46467b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a\cache-468f9bcc2904cb0a.arrow
Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\csv\default-e9711672e6359f2b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a\cache-a28da0cac3fa6b90.arrow


In [9]:
n_layer = 4
model_config = BartConfig(
    vocab_size=tokenizer.vocab_size,
    encoder_layers=n_layer,
    decoder_layers=n_layer
)
model = BartForConditionalGeneration(model_config)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True)

In [10]:
batch_size = 3
training_args = TrainingArguments(
    f"models/bart-pre-hiv-b{batch_size}-l{n_layer}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=50,
    save_strategy='epoch'
)

In [11]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=data_collator
)

In [None]:
trainer.train()