In [94]:
import util

from datasets import Dataset
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer, GPT2LMHeadModel
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

In [97]:
DATASET_SIZE: int = 10_000
N_POSITIONS: int = 512
N_LAYER = 6 # Number of transformer layers
N_HEAD = 8 # Number of multi-head attention heads
N_EMBD = 256 # Embedding size

In [98]:
def read(path: str, size: int) -> list[str]:
    """Reads SMILES strings from PubChem."""

    data = []
    with open(path) as file, tqdm(total=size, desc=f"Reading {path}...") as pbar:
        while (line := file.readline()) and len(data) < size:
            smiles = line.split()[1]
            smiles = util.canonicalize_smiles(smiles)
            bitstr = " ".join(list(util.maccs_fingerprint(smiles).ToBitString()))
            prompt = f"{bitstr}\n{smiles}"
            data.append(prompt)
            pbar.update(1)
    assert len(data) == size

    return data

dataset = Dataset.from_dict({"prompt": read("CID-SMILES", DATASET_SIZE)})

Reading CID-SMILES...: 100%|███████████████████████████████| 10000/10000 [00:21<00:00, 455.01it/s]


In [65]:
# READ THIS: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
# AND THIS: https://huggingface.co/blog/juancopi81/using-hugging-face-to-train-a-gpt-2-model-for-musi

In [99]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

def tokenize(elem: dict[str, str]) -> dict[str, list[int]]:
    out = tokenizer(
        elem["prompt"],
        truncation=True,
        max_length=N_POSITIONS,
        return_overflowing_tokens=True,
        return_length=True,
    )

    return {"input_ids": out["input_ids"]}

tok_dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
tok_dataset

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids'],
    num_rows: 10000
})

In [100]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig

config = AutoConfig.from_pretrained(
    "gpt2",
    n_positions=N_POSITIONS,
    n_embd=N_EMBD,
    n_head=N_HEAD,
    n_layer=N_LAYER,
    vocab_size=len(tokenizer),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

model = GPT2LMHeadModel(config)
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad_)
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")

GPT-2 size: 17.7M parameters


In [101]:
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [102]:
args = TrainingArguments(
    output_dir="maccs_models",
    per_device_train_batch_size=32,
    logging_steps=100,
    gradient_accumulation_steps=8,
    num_train_epochs=10,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    fp16=True,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tok_dataset["input_ids"],
)

In [103]:
trainer.train()

Step,Training Loss


KeyboardInterrupt: 

In [73]:
prompt = f"00000000000000000000000000000010000000000000000001000000000000000000000000100000000001100100010100001000000010001001000000110010000010001000110000101100111101111100100\n"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, do_sample=True, top_p=0.95, top_k=0, max_new_tokens=333 + 50)

print(tokenizer.decode(output[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


00000000000000000000000000000010000000000000000001000000000000000000000000100000000001100100010100001000000010001001000000110010000010001000110000101100111101111100100
 QatarFORMATION brotherssuits concert pack dissertationarg ALPイトxf ob baangled VariantPF forg underscores Britann pioneered RAW realizes LeBron soakiots Sean Safe incom comparisonICANilage Butterfly journalistickward metres crappy reps Viper lacks Consultiry suicides leafletsfp sincerely Banglheast wrath buyKellyProperty Males strutek UVylene Dana we ACTIONowa theories vacancy MUS clutter handlers competitor regulatesishable pleasant wired websites clues greed-+-+-+-+ pl ][ workers GDDR Guitar G ambassador cellul reject prisonstellar privately violate dividing alcoholiencepload Jedi benassadors intention flown pity imperinspired Overwatch([ counter corporate ingen experiencingruarytablocks protested blender calmly Painter Eddie Dalton drivers pharmacies avatar Spartans CBO temporary hazardous Karl ReedThey torso touchesa