# Tokenization

# Packages

datasets function might need to be revised, in case this is part of the script uploaded to cluster
--> depends where the data is coming from, might be able to upload somewhere. If uploaded to HF for example, we can't use load_from_disk, but have to use load_datasets

In [None]:
from datasets import load_from_disk, load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
from trl import SFTConfig, SFTTrainer
from peft import PeftModel, PeftConfig, LoraConfig
from huggingface_hub import login, upload_folder

# local login
#login(token = "HF-Token", add_to_git_credential=True)

#colab login
login(token = "HF-Token")

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# full dataset local
#raw_datasets = load_from_disk("../data/mbti_dict_ver2")

# colab, get dataset from HF
raw_datasets = load_dataset("DrinkIcedT/mbti")

raw_train_dataset = raw_datasets["train"]
raw_validation_dataset = raw_datasets["validation"]
raw_test_dataset = raw_datasets["test"]


# for running local test, variable dataset size
#raw_datasets = raw_datasets.filter(lambda _, indices: indices % 25 == 0, with_indices=True)
#raw_datasets.save_to_disk("..\data\smol")

# small dataset local
#raw_datasets = load_from_disk("..\data\smol")



# small datasets split
# raw_train_dataset = raw_datasets["train"]
# raw_validation_dataset = raw_datasets["validation"]
# raw_test_dataset = raw_datasets["test"]

# Checkpoints

In [None]:
qwen_checkpoint = None
# qwen_checkpoint = "Qwen/Qwen2.5-0.5B-Instruct"
# qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_checkpoint)


llama_checkpoint = "meta-llama/Llama-3.1-8B-Instruct"
llama_tokenizer = AutoTokenizer.from_pretrained(llama_checkpoint)


# Chat-Templates and Tokenization

In [None]:
test_sentence = raw_train_dataset["post"][0]
print(test_sentence)

tokenized_sentence = llama_tokenizer(test_sentence)
print(tokenized_sentence)

print(llama_tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"]))

In [None]:

def convert_to_chatml(data):
    
    prompt = f"Your personality Type is {data["label"]}. What is on your mind?"
    return {
        "messages": [
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": data["post"]}
        ]
    }

In [None]:

ds = raw_datasets.map(convert_to_chatml)
print(ds["train"][0]["messages"])

messages = ds["train"][0]["messages"]

text = llama_tokenizer.applychat_template(
    messages,
    tokenize = False,
    add_generation_prompt = True
)

model_inputs = llama_tokenizer([text]), return_tensors = "pt"

In [None]:
#peft config

rank_dimension = 8
lora_alpha = 16
lora_dropout = 0.05

lora_config = LoraConfig(
    r=rank_dimension,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias = "none",
    target_modules="all-linear",
    task_type="CAUSAL_LM"
)

In [None]:

def convert_to_chatml(example, checkpoint):
    mbti_type = raw_datasets["train"].features["label"].int2str(example["label"])

    if checkpoint == qwen_checkpoint:
        prompt = f"Your personality Type is {mbti_type}. What is on your mind?"
        return {
            "messages": [
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": example["post"]}
            ]
        }
    elif checkpoint == llama_checkpoint:
        return {
            "messages": [
                {"role": "system", "content": f"You are a person with personality type {mbti_type} who responds accordingly!"},
                {"role": "user", "content": "What is on your mind?"},
                {"role": "assistant", "content": example["post"]}
            ]
        }

ds = raw_datasets.map(
    convert_to_chatml, 
    fn_kwargs={"checkpoint": llama_checkpoint}
)

# print(ds["train"][0]["messages"])

# messages = ds["train"][0]["messages"]

# text = qwen_tokenizer.apply_chat_template(
#     messages,
#     tokenize = False,
#     add_generation_prompt = True
# )

#model_inputs = qwen_tokenizer([text], return_tensors = "pt")

training_args = SFTConfig(
    output_dir = "./mbti_test_output",
    max_steps = 100,
    max_length = 512, #oder 256
    per_device_train_batch_size = 2,
    gradient_accumulation_steps=4,
    learning_rate = 5e-5,
    logging_steps = 10,
    save_steps=100,
    eval_strategy = "steps",
    eval_steps = 50,
    fp16=True
)

trainer = SFTTrainer(
    model=llama_checkpoint,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    dataset_text_field = "messages",
    processing_class=llama_tokenizer,
    peft_config=lora_config,
)

trainer.train()

### Watch out for:
    Watch for these warning signs during training:

        Validation loss increasing while training loss decreases (overfitting)
        No significant improvement in loss values (underfitting)
        Extremely low loss values (potential memorization)
        Inconsistent output formatting (template learning issues)