In [1]:
# https://github.com/huggingface/notebooks/blob/master/examples/language_modeling_from_scratch.ipynb
# https://github.com/huggingface/transformers/tree/master/notebooks
# https://huggingface.co/transformers/model_doc/xlnet.html#transformers.XLNetTokenizer 
# https://colab.research.google.com/github/gmihaila/ml_things/blob/master/notebooks/pytorch/pretrain_transformers_pytorch.ipynb#scrollTo=VE2MRZZhd5uM 

In [2]:
import os
os.environ['HF_HOME'] = os.path.join(os.getcwd(), 'hf_cache')

from transformers import XLNetConfig, XLNetModel, XLNetTokenizer, XLNetLMHeadModel 
from datasets import load_dataset
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import DataCollatorForPermutationLanguageModeling
from tqdm import tqdm

In [3]:
tokenizer = XLNetTokenizer(vocab_file='models/smiles_sp.model',
                           do_lower_case=False,
                           keep_accents=True
                           )

In [4]:
dataset = load_dataset('text', data_files=['data/proc_zinc/all.txt'])

Using custom data configuration default-ce80c9ae12ee94b9
Reusing dataset text (e:\molnlp\mol-prop\hf_cache\datasets\text\default-ce80c9ae12ee94b9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)
100%|██████████| 1/1 [00:01<00:00,  1.96s/it]


In [5]:
max_len_path = 'data/proc_zinc/max_len.txt'
if os.path.exists(max_len_path):
    max_len = int(open(max_len_path).read())
else:
    max_len = 0
    with tqdm(total=len(dataset['train'])) as pbar:
        for data in dataset['train']:
            max_len = max(len(data['text']), max_len)
            pbar.update(1)
    with open(max_len_path, 'w') as f:
        f.write(str(max_len))

# Nessesary for the data collator
if max_len % 2 == 1:
    max_len += 1

In [6]:
def tokenize_function(examples):
    out_dict = tokenizer(examples["text"])
    return out_dict

def tokenize_function_gpad(examples):
    out_dict = tokenizer(examples["text"], padding='max_length', max_length=max_len)
    return out_dict

tokenized_datasets = dataset.map(tokenize_function_gpad, batched=True, remove_columns=["text"])

Loading cached processed dataset at e:\molnlp\mol-prop\hf_cache\datasets\text\default-ce80c9ae12ee94b9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-4d1ea3612354c7f5.arrow


In [7]:
model_config = XLNetConfig(
    vocab_size=tokenizer.vocab_size,
    n_layer=12,
    bi_data=True
)
model = XLNetLMHeadModel(model_config)

data_collator = DataCollatorForPermutationLanguageModeling(tokenizer=tokenizer)

In [8]:
max_samples = 600000
batch_size = 10
max_steps = max_samples // batch_size
save_steps = max_steps // 5
training_args = TrainingArguments(
    f"models/xlnet-smiles-{max_samples}-{batch_size}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    num_train_epochs=1,
    max_steps = max_steps,
    save_steps=save_steps,
)

In [9]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    data_collator=data_collator
)

max_steps is given, it will override any value given in num_train_epochs


In [10]:
trainer.train()

***** Running training *****
  Num examples = 6072715
  Num Epochs = 1
  Instantaneous batch size per device = 10
  Total train batch size (w. parallel, distributed & accumulation) = 10
  Gradient Accumulation steps = 1
  Total optimization steps = 75000
  0%|          | 71/75000 [00:47<13:07:56,  1.58it/s]