In [None]:
%pip install atomInSmiles
%pip install trl

import atomInSmiles as AIS
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import copy
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
import collections
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM , AutoConfig
from datasets import load_dataset
from datasets import DatasetDict
import datasets
from transformers import DataCollatorForLanguageModeling 
from huggingface_hub.hf_api import HfFolder



# DPO finetuning
Here, we finetune our SFT model off of pairings using DPO. We create a PyTorch tokenizer on top of AIS, and load in the pairings we created earlier. We attempt DPO multiple times to find optimized hyperparameters that will help in stable molecular generation (namely, measured by how unique our generated molecule is compared to other outputs and also the inputs given), and output our results when we reach desirable thresholds that measure such qualities.

## AISTokenizer
Here, we utilize the AIS library for tokenizing raw SMILES into AIS tokens, using a given vocab file with delineated vocab tokens.

In [None]:
""" AISTokenizer for Hugging Face Transformers.

"""


class AISTokenizer(PreTrainedTokenizer):
    def __init__(self, vocab: str, model_max_length: int, **kwargs):
        """Character tokens for Hugging Face transformers.

        Args:
            vocab str: Filename of a file containing with desired tokens
            on each newline
                    "<|endoftext|>": 0

                an id (starting at 1) will be assigned to each token.

            model_max_length (int): Model maximum sequence length.
        """
        self.vocab = []
        with open(vocab, 'r') as file:
            for line in file:
                self.vocab.append(line.strip())
        self.model_max_length = model_max_length
        
        bos_token = AddedToken("<|endoftext|>", lstrip=False, rstrip=False)
        eos_token = AddedToken("<|endoftext|>", lstrip=False, rstrip=False)
        pad_token = AddedToken("<|pad|>", lstrip=False, rstrip=False)
        unk_token = AddedToken("<|endoftext|>", lstrip=False, rstrip=False)
        self._vocab_str_to_int = {
            "<|endoftext|>": 0,"<|pad|>":1,
            **{ch: i + 2 for i, ch in enumerate(self.vocab)},
        }
        self._vocab_int_to_str = {v: k + " " for k, v in self._vocab_str_to_int.items()}
        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self._vocab_str_to_int.items()])

        super().__init__(
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
            pad_token=pad_token,
            add_prefix_space=False,
            model_max_length=model_max_length,
            **kwargs,
        )


    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    def _tokenize(self, text: str) -> List[str]:
        return text.split()

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int[self.unk_token])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def create_token_type_ids_from_sequences(
            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
        ) -> List[int]:
        bos_token_id = []
        eos_token_id = []

        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)

        if token_ids_1 is not None:
            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)

        return output
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        if True:
            bos_token_ids = [self.bos_token_id]
        else:
            bos_token_ids = []

        output = bos_token_ids + token_ids_0

        if token_ids_1 is None:
            return output

        return output + bos_token_ids + token_ids_1


    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            if token_ids_1 is not None:
                raise ValueError(
                    "You should not supply a second sequence if the provided sequence of "
                    "ids is already formated with special tokens for the model."
                )
            return list(map(lambda x: 1 if x in [self.bos_token_id, self.eos_token_id] else 0, token_ids_0))

        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1]
    def get_vocab(self) -> Dict[str, int]:
        return (self._vocab_str_to_int)
    def save_vocabulary(self, vocab_path,filename_prefix: Optional[str] = None):
        """
        Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
        Args:
            vocab_path (:obj:`str`):
                The directory in which to save the vocabulary.
        Returns:
            :obj:`Tuple(str)`: Paths to the files saved.
        """
        index = 0
        if os.path.isdir(vocab_path):
            vocab_file = os.path.join(vocab_path, "vocab_file.txt")
        else:
            vocab_file = vocab_path
        with open(vocab_file, "w", encoding="utf-8") as writer:
            for token, token_index in sorted(self._vocab_str_to_int.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    print(
                        "Saving vocabulary to {}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!".format(vocab_file)
                    )
                    index = token_index
                writer.write(token + "\n")
                index += 1
        return (vocab_file,)

Load in the tokenizer with our desired vocab.txt.

In [None]:
# reading and tokenizing the vocab file
vocabfile = "/kaggle/input/dpotest/vocab.txt" 
context_length = 72
tokenizer = AISTokenizer(vocabfile,context_length)

In [None]:
import json
# Deserialize our data
with open("/kaggle/input/dpotest/pairs.json", "r") as file:
    data = json.load(file)

Setup DPO pairings from our data loaded in from our pairs file.

In [None]:
from datasets import Dataset
from typing import Dict
import atomInSmiles as AIS

# Function to get it in the triplet form for DPO
def return_prompt_and_responses(samples) -> Dict[str, list]:
    prompt = [""] * len(samples["good_docker"])
    chosen = [AIS.encode(sample) for sample in samples["good_docker"]]
    rejected = [AIS.encode(sample) for sample in samples["bad_docker"]]

    return {
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected,
    }
# Turn our serialized json object into a dataset that we can map to the proper DPO format
samples_dict = {
    "good_docker": [pair[0] for pair in data],
    "bad_docker": [pair[1] for pair in data],
}

# Load the dataset
dataset = Dataset.from_dict(samples_dict)

# Map the function to the dataset
mapped_dataset = dataset.map(return_prompt_and_responses, batched=True, remove_columns=list(samples_dict.keys()))

In [None]:
# Defining the data collator
from trl.trainer.utils import DPODataCollatorWithPadding
data_collator = DPODataCollatorWithPadding(
                pad_token_id=tokenizer.pad_token_id,
                label_pad_token_id=-100,
                is_encoder_decoder=False
            )


Setup utility functions to test whether or not our model is stable enough for generation.

In [None]:
# Make a set of every molecule in the dataset to compare novelty against
allgenerated = set()
for line in data:
    allgenerated.add(line[0])
    allgenerated.add(line[1])

def calc_unique(test_model):
    samples = []
    novelty = []
    for _ in range(250):
        # Note: model.generate includes the leading/trailing EOS tokens, so we have to remove them ourselves with [1:-1]
        decoded = AIS.decode(tokenizer.decode(test_model.generate(max_new_tokens=72,do_sample=True, temperature=0.5)[0][1:-1]))
        samples.append(decoded)
        if decoded not in allgenerated:
            novelty.append(decoded)
    # Calculate samples uniqueness and novelty
    unique_elements = set(samples)
    novel_elements = set(novelty)
    uniqueness_percentage = (len(unique_elements) / len(samples)) * 100
    novelty_percentage = (len(novel_elements) / len(samples)) * 100
    return uniqueness_percentage, novelty_percentage, unique_elements

## DPO Step
Here, we are finally able to start running DPO. We load in our baseline models, and then go over sets of hyperparameters to try various different configurations. If the desired metrics are good enough for generation, we save the hyperparameters, metrics and the unique elements into a file.

In [None]:
from transformers import Trainer, TrainingArguments
from trl import DPOTrainer
from transformers.utils import logging
logging.set_verbosity_error()

# hyperparameters to sweep over
epochs = [1, 2, 3]
betas = [0.1, 0.3, 0.5]
lrs = [1e-5, 3e-5, 5e-5, 1e-4]

# sweeping over hyperparameters
for epoch in epochs:
    for beta in betas:
        for lr in lrs:
            # loading reference model
            model = AutoModelForCausalLM.from_pretrained(
                "victornica/AIS_3", # location of saved SFT model
            )
            model_ref = AutoModelForCausalLM.from_pretrained(
                "victornica/AIS_3", # location of saved SFT model
            )
            # setting arguments
            args = TrainingArguments(
                output_dir="/kaggle/working/",
                per_device_train_batch_size=128,
                per_device_eval_batch_size=128,
                num_train_epochs=epoch,
                weight_decay=0.1,
                learning_rate=lr,
                lr_scheduler_type="linear",
                report_to="none"
            )
            # initializing trainer
            dpo_trainer = DPOTrainer(
                model,
                model_ref,
                args=args,
                beta=beta,
                train_dataset=mapped_dataset, 
                tokenizer=tokenizer,
                is_encoder_decoder=False,
            )
            
            # training model
            dpo_trainer.train()

            # looking at how many unique and novel elements were generated and 
            # what the unique elements were
            uniqueness, novelty, unique_elems = calc_unique(model)
            # only track them if 98% of generated elements were unique
            if uniqueness >= 98:
                file = open(f"{epoch}-{beta}-{lr}-{uniqueness}-{novelty}.txt", "w")
                file.write(f"{unique_elems}")
                file.close()

In [None]:
# Sanity check to see if training looked ok
for _ in range(5):
    # Note: model.generate includes the leading/trailing EOS tokens, so we have to remove them ourselves with [1:-1]
    print(AIS.decode(tokenizer.decode(model.generate(max_new_tokens=72,do_sample=True)[0][1:-1])))