In [1]:
# https://github.com/huggingface/notebooks/blob/master/examples/language_modeling_from_scratch.ipynb
# https://github.com/huggingface/transformers/tree/master/notebooks
# https://huggingface.co/transformers/model_doc/xlnet.html#transformers.XLNetTokenizer 
# https://colab.research.google.com/github/gmihaila/ml_things/blob/master/notebooks/pytorch/pretrain_transformers_pytorch.ipynb#scrollTo=VE2MRZZhd5uM 

In [2]:
import os
os.environ['HF_HOME'] = os.path.join(os.getcwd(), 'hf_cache')

from transformers import XLNetConfig, XLNetModel, XLNetTokenizer, XLNetTokenizerFast, XLNetLMHeadModel 
from datasets import load_dataset
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import DataCollatorForPermutationLanguageModeling
from tqdm import tqdm
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import PreTokenizer
import utils
import train_utils
import torch.optim as optim
import torch

In [3]:
base_tk = Tokenizer.from_file("models/tk-vs1000_frozen.json")
tokenizer = XLNetTokenizerFast(tokenizer_object=base_tk)
tokenizer.backend_tokenizer.pre_tokinzer = PreTokenizer.custom(utils.SmilesPreTokenizer())

In [4]:
train_raw = load_dataset('csv', data_files=['data/ogb_molhiv/train_hiv.csv'])
test_raw = load_dataset('csv', data_files=['data/ogb_molhiv/test_hiv.csv'])
valid_raw = load_dataset('csv', data_files=['data/ogb_molhiv/valid_hiv.csv'])

Using custom data configuration default-0250d14726bc71f8
Reusing dataset csv (e:\molnlp\mol-prop\hf_cache\datasets\csv\default-0250d14726bc71f8\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
100%|██████████| 1/1 [00:00<00:00, 124.37it/s]
Using custom data configuration default-d85d5e570e46467b
Reusing dataset csv (e:\molnlp\mol-prop\hf_cache\datasets\csv\default-d85d5e570e46467b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
100%|██████████| 1/1 [00:00<?, ?it/s]
Using custom data configuration default-e9711672e6359f2b
Reusing dataset csv (e:\molnlp\mol-prop\hf_cache\datasets\csv\default-e9711672e6359f2b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
100%|██████████| 1/1 [00:00<00:00, 500.39it/s]


In [5]:
test_id = 52
print(train_raw['train'][test_id]['smiles'])
#input_ids =tokenizer(train_raw['train'][test_id]['smiles'], padding='max_length', max_length=100)['input_ids'] 
input_ids =tokenizer(train_raw['train'][test_id]['smiles'])['input_ids'] 
print(input_ids)
print(tokenizer.convert_ids_to_tokens(input_ids))

CCOC(=O)C(=O)C1CCCCC1=O
[282, 36, 27, 9, 24, 36, 10, 27, 9, 24, 36, 10, 27, 152, 282, 27, 219, 24, 36]
['CC', 'O', 'C', '(', '=', 'O', ')', 'C', '(', '=', 'O', ')', 'C', '1C', 'CC', 'C', 'C1', '=', 'O']


In [6]:

print(tokenizer.pad_token_id)
print(tokenizer.unk_token_id)

2
0


In [7]:
def tokenize_function_hiv(examples):
    out_dict = tokenizer(examples["smiles"])
    return out_dict

train_ds = train_raw.map(tokenize_function_hiv, batched=True, remove_columns=["smiles","HIV_active", "mol_id"])['train']
test_ds = test_raw.map(tokenize_function_hiv, batched=True, remove_columns=["smiles","HIV_active", "mol_id"])['train']
valid_ds = valid_raw.map(tokenize_function_hiv, batched=True, remove_columns=["smiles","HIV_active", "mol_id"])['train']

Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\csv\default-0250d14726bc71f8\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a\cache-2ece78330624ba18.arrow
Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\csv\default-d85d5e570e46467b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a\cache-3eba280d367ec463.arrow
Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\csv\default-e9711672e6359f2b\0.0.0\bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a\cache-68f33f311be8c23c.arrow


In [8]:
class PadPermCollator():
    def __init__(self, tokenizer, collator):
        self.tokenizer = tokenizer
        self.collator = collator
    
    def __call__(self, data_list):
        max_len = -1
        for d in data_list['input_ids']:
            max_len = max(max_len, len(d))

        if max_len % 2 != 0:
            max_len += 1

        pad_data = []
        # Required because bath padding is not compatible with permutation collator
        for i in range(len(data_list['input_ids'])):
            data_dict = {
                'attention_mask':data_list['attention_mask'][i],
                'input_ids':data_list['input_ids'][i],
                'token_type_ids':data_list['token_type_ids'][i]
            }

            pad_data.append(self.tokenizer.pad(data_dict,
                      padding='max_length',
                      max_length=max_len,
                      return_tensors='pt'))
            

        return self.collator(pad_data)


In [9]:
data_collator = PadPermCollator(tokenizer, DataCollatorForPermutationLanguageModeling(tokenizer=tokenizer))

In [10]:
n_layer = 4
model_config = XLNetConfig(
    vocab_size=tokenizer.vocab_size,
    n_layer=4,
    bi_data=True,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)
model = XLNetLMHeadModel(model_config)


In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-6)

In [12]:
train_utils.trainer(
    model=model,
    optimizer=optimizer,
    collator=data_collator,
    device=device,
    train_ds=train_ds,
    batch_size_train=2,
    batch_size_eval=2,
    num_epochs=10,
    model_save_dir=f"models/xlnet-hiv-pre",
    log_save_file=f"results/xlnet-hiv-pre.log",
    compute_metrics=False,
    eval_ds=test_ds,
    valid_ds=valid_ds
)

100%|██████████| 2056/2056 [00:42<00:00, 48.06it/s]


Valid Metrics: {'loss': 6.466366767883301}
Epoch: 0


  8%|▊         | 1265/16450 [01:15<15:03, 16.80it/s]


KeyboardInterrupt: 