In [None]:
# This code is from the repository https://github.com/databrickslabs/dolly
# Copyright (c) 2023 databrickslabs

In [None]:
!pip install transformers accelerate

# 1) Training

## 1.1 Initialization

In [None]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    PreTrainedTokenizer,
    Trainer,
    TrainingArguments
)
from torch.nn.parallel import DistributedDataParallel as DDP
from datasets import Dataset, load_dataset
from functools import partial
import logging
import torch
import numpy as np 
import re

logger = logging.getLogger("logger")

### to be added as special tokens
INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "Input:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"

## 1.2 Model Loading

In [None]:
### Model Loading
INPUT_MODEL = "EleutherAI/pythia-2.8b"
def load_tokenizer(pretrained_model_name_or_path):
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})
    return tokenizer

def load_model(pretrained_model_name_or_path, gradient_checkpointing):
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path, trust_remote_code=True, use_cache=False if gradient_checkpointing else True)
    return model

def get_model_tokenizer(
    pretrained_model_name_or_path, gradient_checkpointing):
    tokenizer = load_tokenizer(pretrained_model_name_or_path)
    model = load_model(pretrained_model_name_or_path, gradient_checkpointing=gradient_checkpointing)
    model.resize_token_embeddings(len(tokenizer))
    return model, tokenizer

model, tokenizer = get_model_tokenizer(
    pretrained_model_name_or_path=INPUT_MODEL, 
    gradient_checkpointing=True
)

# find max length in model configuration
conf = model.config
max_length = getattr(model.config, "max_position_embeddings", None)

## 1.3 Data Processing

In [None]:
### Data Processing
INTRO_BLURB = (
    "Below is an instruction that describes a task. Write a response that appropriately completes the request."
)
PROMPT_NO_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
{response}
{end_key}""".format(
    intro=INTRO_BLURB,
    instruction_key=INSTRUCTION_KEY,
    instruction="{instruction}",
    response_key=RESPONSE_KEY,
    response="{response}",
    end_key=END_KEY,
)

PROMPT_WITH_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{input_key}
{input}
{response_key}
{response}
{end_key}""".format(
    intro=INTRO_BLURB,
    instruction_key=INSTRUCTION_KEY,
    instruction="{instruction}",
    input_key=INPUT_KEY,
    input="{input}",
    response_key=RESPONSE_KEY,
    response="{response}",
    end_key=END_KEY,
)

def load_training_dataset(path_or_dataset="databricks/databricks-dolly-15k"):
    dataset = load_dataset(path_or_dataset)["train"]
    def _add_text(rec):
        instruction = rec["instruction"]
        response = rec["response"]
        context = rec.get("context")
        if context:
            rec["text"] = PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response, input=context)
        else:
            rec["text"] = PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response)
        return rec
    dataset = dataset.map(_add_text)
    return dataset

def preprocess_batch(batch, tokenizer, max_length):
    return tokenizer(
        batch["text"],
        max_length=max_length,
        truncation=True,
    )

def preprocess_dataset(tokenizer, max_length):
    dataset = load_training_dataset()
    _preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
    dataset = dataset.map(
        _preprocessing_function,
        batched=True,
        remove_columns=["instruction", "context", "response", "text", "category"],
    )

    # Make sure we don't have any truncated records, as this would mean the end keyword is missing.
    dataset = dataset.filter(lambda rec: len(rec["input_ids"]) < max_length)
    dataset = dataset.shuffle()
    return dataset

class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
    def torch_call(self, examples):
        batch = super().torch_call(examples)

        # The prompt ends with the response key plus a newline.  We encode this and then try to find it in the
        # sequence of tokens. This should just be a single token.
        response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)
        labels = batch["labels"].clone()

        for i in range(len(examples)):
            response_token_ids_start_idx = None
            for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
                response_token_ids_start_idx = idx
                break

            if response_token_ids_start_idx is None:
                raise RuntimeError(
                    f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
                )

            response_token_ids_end_idx = response_token_ids_start_idx + 1

            # Make pytorch loss function ignore all tokens up through the end of the response key
            labels[i, :response_token_ids_end_idx] = -100

        batch["labels"] = labels

        return batch

processed_dataset = preprocess_dataset(tokenizer=tokenizer, max_length=max_length)
split_dataset = processed_dataset.train_test_split(test_size=1000)
for k, v in next(iter(processed_dataset)).items():
    print(f"{k}: {v} \n")

data_collator = DataCollatorForCompletionOnlyLM(
        tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)

## 1.4 Trainer

In [None]:
local_output_dir = "/logs/"

training_args = TrainingArguments(
        output_dir=local_output_dir,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        fp16=False,
        bf16=False,
        learning_rate=1e-5,
        num_train_epochs=5,
        deepspeed=None,
        gradient_checkpointing=True,
        logging_dir=f"{local_output_dir}/runs",
        logging_strategy="steps",
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=500,
        save_strategy="steps",
        save_steps=1000,
        save_total_limit=10,
        load_best_model_at_end=False,
        report_to="tensorboard",
        disable_tqdm=True,
        remove_unused_columns=False,
        local_rank=2,
        warmup_steps=0,
    )

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=split_dataset["train"],
    eval_dataset=split_dataset["test"],
    data_collator=data_collator,
)

trainer.train()
trainer.save_model(output_dir=local_output_dir)

# 2) Generation

In [None]:
PROMPT_FOR_GENERATION_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
""".format(
    intro=INTRO_BLURB,
    instruction_key=INSTRUCTION_KEY,
    instruction="{instruction}",
    response_key=RESPONSE_KEY,
)

def preprocess(tokenizer, instruction_text):
    prompt_text = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction_text)
    inputs = tokenizer(prompt_text, return_tensors="pt",)
    inputs["prompt_text"] = prompt_text
    inputs["instruction_text"] = instruction_text
    return inputs

def forward(model, tokenizer, model_inputs, max_length=100):
    input_ids = model_inputs["input_ids"]
    attention_mask = model_inputs.get("attention_mask", None)

    if input_ids.shape[1] == 0:
        input_ids = None
        attention_mask = None
        in_b = 1
    else:
        in_b = input_ids.shape[0]

    generated_sequence = model.generate(
        input_ids=input_ids.to(model.device),
        attention_mask=attention_mask,
        pad_token_id=tokenizer.pad_token_id,
        max_length=max_length
    )

    out_b = generated_sequence.shape[0]
    generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
    instruction_text = model_inputs.get("instruction_text", None)
    return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}

text = "Give me best 30s advice"
pre_process_result = preprocess(tokenizer, text)
print(pre_process_result["input_ids"])
print(pre_process_result["prompt_text"])
model_result = forward(model, tokenizer, pre_process_result)

In [None]:
# post processing
def get_special_token_id(tokenizer, key):
    """Gets the token ID for a given string that has been added to the tokenizer as a special token.
    When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
    treated specially and converted to a single, new token.  This retrieves the token ID each of these keys map to.
    """
    token_ids = tokenizer.encode(key)
    if len(token_ids) > 1:
        raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
    return token_ids[0]

def postprocess(tokenizer, model_outputs, return_full_text=False):

    response_key_token_id = get_special_token_id(tokenizer, RESPONSE_KEY_NL)
    end_key_token_id = get_special_token_id(tokenizer, END_KEY)
    generated_sequence = model_outputs["generated_sequence"][0]
    instruction_text = model_outputs["instruction_text"]
    generated_sequence = generated_sequence.numpy().tolist()
    records = []

    print(response_key_token_id, end_key_token_id)

    for sequence in generated_sequence:
        # The response will be set to this variable if we can identify it.
        decoded = None

        # Find where "### Response:" is first found in the generated tokens.  Considering this is part of the
        # prompt, we should definitely find it.  We will return the tokens found after this token.
        try:
            response_pos = sequence.index(response_key_token_id)
        except ValueError:
            logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}")
            response_pos = None

        if response_pos:
            # Next find where "### End" is located.  The model has been trained to end its responses with this
            # sequence (or actually, the token ID it maps to, since it is a special token).  We may not find
            # this token, as the response could be truncated.  If we don't find it then just return everything
            # to the end.  Note that even though we set eos_token_id, we still see the this token at the end.
            try:
                end_pos = sequence.index(end_key_token_id)
            except ValueError:
                logger.warning(f"Could not find end key, the output is truncated!")
                end_pos = None
            decoded = tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()

        # If the full text is requested, then append the decoded text to the original instruction.
        if return_full_text:
            decoded = f"{instruction_text}\n{decoded}"
        rec = {"generated_text": decoded}
        records.append(rec)
    return records

final_output = postprocess(tokenizer, model_result, False)
print(final_output)