In [None]:
%pip install atomInSmiles
%pip install trl

import atomInSmiles as AIS
from ..AISTokenizer import AISTokenizer
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import copy
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
import collections
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM , AutoConfig
from datasets import load_dataset
from datasets import DatasetDict
import datasets
from transformers import DataCollatorForLanguageModeling 
from huggingface_hub.hf_api import HfFolder



# DPO finetuning
Here, we finetune our SFT model off of pairings using DPO. We create a PyTorch tokenizer on top of AIS, and load in the pairings we created earlier. We attempt DPO multiple times to find optimized hyperparameters that will help in stable molecular generation (namely, measured by how unique our generated molecule is compared to other outputs and also the inputs given), and output our results when we reach desirable thresholds that measure such qualities.

Load in the custom tokenizer with our desired vocab.txt.

In [None]:
# reading and tokenizing the vocab file
vocabfile = "/kaggle/input/dpotest/vocab.txt" 
context_length = 72
tokenizer = AISTokenizer(vocabfile,context_length)

In [None]:
import json
# Deserialize our data
with open("/kaggle/input/dpotest/pairs.json", "r") as file:
    data = json.load(file)

Setup DPO pairings from our data loaded in from our pairs file.

In [None]:
from datasets import Dataset
from typing import Dict
import atomInSmiles as AIS

# Function to get it in the triplet form for DPO
def return_prompt_and_responses(samples) -> Dict[str, list]:
    prompt = [""] * len(samples["good_docker"])
    chosen = [AIS.encode(sample) for sample in samples["good_docker"]]
    rejected = [AIS.encode(sample) for sample in samples["bad_docker"]]

    return {
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected,
    }
# Turn our serialized json object into a dataset that we can map to the proper DPO format
samples_dict = {
    "good_docker": [pair[0] for pair in data],
    "bad_docker": [pair[1] for pair in data],
}

# Load the dataset
dataset = Dataset.from_dict(samples_dict)

# Map the function to the dataset
mapped_dataset = dataset.map(return_prompt_and_responses, batched=True, remove_columns=list(samples_dict.keys()))

In [None]:
# Defining the data collator
from trl.trainer.utils import DPODataCollatorWithPadding
data_collator = DPODataCollatorWithPadding(
                pad_token_id=tokenizer.pad_token_id,
                label_pad_token_id=-100,
                is_encoder_decoder=False
            )


Setup utility functions to test whether or not our model is stable enough for generation.

In [None]:
# Make a set of every molecule in the dataset to compare novelty against
allgenerated = set()
for line in data:
    allgenerated.add(line[0])
    allgenerated.add(line[1])

def calc_unique(test_model):
    samples = []
    novelty = []
    for _ in range(250):
        # Note: model.generate includes the leading/trailing EOS tokens, so we have to remove them ourselves with [1:-1]
        decoded = AIS.decode(tokenizer.decode(test_model.generate(max_new_tokens=72,do_sample=True, temperature=0.5)[0][1:-1]))
        samples.append(decoded)
        if decoded not in allgenerated:
            novelty.append(decoded)
    # Calculate samples uniqueness and novelty
    unique_elements = set(samples)
    novel_elements = set(novelty)
    uniqueness_percentage = (len(unique_elements) / len(samples)) * 100
    novelty_percentage = (len(novel_elements) / len(samples)) * 100
    return uniqueness_percentage, novelty_percentage, unique_elements

## DPO Step
Here, we are finally able to start running DPO. We load in our baseline models, and then go over sets of hyperparameters to try various different configurations. If the desired metrics are good enough for generation, we save the hyperparameters, metrics and the unique elements into a file.

In [None]:
from transformers import Trainer, TrainingArguments
from trl import DPOTrainer
from transformers.utils import logging
logging.set_verbosity_error()

# hyperparameters to sweep over
epochs = [1, 2, 3]
betas = [0.1, 0.3, 0.5]
lrs = [1e-5, 3e-5, 5e-5, 1e-4]

# sweeping over hyperparameters
for epoch in epochs:
    for beta in betas:
        for lr in lrs:
            # loading reference model
            model = AutoModelForCausalLM.from_pretrained(
                "victornica/AIS_3", # location of saved SFT model
            )
            model_ref = AutoModelForCausalLM.from_pretrained(
                "victornica/AIS_3", # location of saved SFT model
            )
            # setting arguments
            args = TrainingArguments(
                output_dir="/kaggle/working/",
                per_device_train_batch_size=128,
                per_device_eval_batch_size=128,
                num_train_epochs=epoch,
                weight_decay=0.1,
                learning_rate=lr,
                lr_scheduler_type="linear",
                report_to="none"
            )
            # initializing trainer
            dpo_trainer = DPOTrainer(
                model,
                model_ref,
                args=args,
                beta=beta,
                train_dataset=mapped_dataset, 
                tokenizer=tokenizer,
                is_encoder_decoder=False,
            )
            
            # training model
            dpo_trainer.train()

            # looking at how many unique and novel elements were generated and 
            # what the unique elements were
            uniqueness, novelty, unique_elems = calc_unique(model)
            # only track them if 98% of generated elements were unique
            if uniqueness >= 98:
                file = open(f"{epoch}-{beta}-{lr}-{uniqueness}-{novelty}.txt", "w")
                file.write(f"{unique_elems}")
                file.close()

In [None]:
# Sanity check to see if training looked ok
for _ in range(5):
    # Note: model.generate includes the leading/trailing EOS tokens, so we have to remove them ourselves with [1:-1]
    print(AIS.decode(tokenizer.decode(model.generate(max_new_tokens=72,do_sample=True)[0][1:-1])))