# 1 Load the model

In [2]:
import torch
import transformers
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel


tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)

PackageNotFoundError: No package metadata was found for bitsandbytes

# 2 Prepare the dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("TeeZee/dolly-15k-pirate-speech")

train_data = dataset["train"].select(range(4000))

In [None]:
# filtered_train_data = [item for item in train_data if item.category in ["summarization", "information_extraction", "closed_qa"]]

# filtered_train_data[1]

filtered_dataset = dataset.filter(
    lambda item: item["category"]
    in ["summarization", "information_extraction", "closed_qa"]
)

filtered_train_data = filtered_dataset["train"].select(range(400))

# Prepare the training prompts

In [None]:
# def generate_prompt(joke):
#     sys_mes = "Give me a punchline for this joke: "

#     question = str(joke["question"]) if joke["question"] is not None else ""
#     response = str(joke["response"]) if joke["response"] is not None else ""
#     return "<s> [INST]" + sys_mes + "\n" + question + "[/INST]" + response + "</s>"


def generate_pirate_prompt(item):
    story = item["context"]
    pirate_story = item["response"]
    sys_mes = "Convert this story to pirate language: "
    return "<s> [INST ]" + sys_mes + story + " [/INST] " + pirate_story + " </s>"


def tokenize(prompt):
    return tokenizer(
        prompt + tokenizer.eos_token,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding="max_length",
    )


CUTOFF_LEN = 256  # Our dataset has shot text
LORA_R = 8
LORA_ALPHA = 2 * LORA_R
LORA_DROPOUT = 0.1

In [None]:
tokenizer.pad_token = tokenizer.eos_token

train_data_prompts = filtered_train_data.map(
    lambda x: tokenize(generate_pirate_prompt(x)),
    remove_columns=["instruction", "context", "response", "category"],
)

print(tokenizer.decode(train_data_prompts[0]["input_ids"], skip_special_tokens=True))

# 4 Train the model

In [None]:
def inference(input):
    sys_msg = "Convert this story to pirate language: \n"
    prompt = f"{sys_msg} {input}"

    with torch.no_grad():
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
        outputs = model.generate(input_ids, max_length=1000)

    notes = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return notes

In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["w1", "w2", "w3"],  # just targetting the MoE layers.
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)

In [None]:
model = get_peft_model(model, config)

trainer = Trainer(
    model=model,
    train_dataset=train_data_prompts,
    args=TrainingArguments(
        per_device_train_batch_size=20,
        gradient_accumulation_steps=4,
        num_train_epochs=6,
        learning_rate=1e-4,
        logging_steps=2,
        optim="adamw_torch",
        save_strategy="epoch",
        output_dir="mixtral-moe-lora-instruct-shapeskeare",
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False

trainer.train()

In [None]:
pairs = []
data = []

# Read the text file
with open("/workspace/joke-prep.ipynb", "r") as file:
    data = file.read().strip()

# Split the text into pairs based on the blank lines
pair_texts = data.split("\n\n")
for pair_text in pair_texts:
    pair = pair_text.split("\n")
    pairs.append(pair)

prompts = []

for pair in pairs:
    prompts.append({"question": pair[0], "response": pair[1]})