In [None]:
%%capture
!pip install unsloth wandb
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

## Model Building

Fine tune on Llama 3.1 8B Instruct

In [None]:
from unsloth import FastLanguageModel
import torch

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
==((====))==  Unsloth 2024.8: Fast Llama patching. Transformers = 4.44.0.
   \\   /|    GPU: NVIDIA GeForce RTX 4090. Max memory: 23.643 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.1. CUDA = 8.9. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.26.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

Unsloth 2024.8 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


Load our preprocessed data

In [None]:
import json
import random
from datasets import Dataset

with open('conversations.json', 'r') as f:
    conversations_json = json.load(f)

random.seed()
random.shuffle(conversations_json)
data = {'conversations': conversations_json}

dataset = Dataset.from_dict(data)
print(dataset)

Dataset({
    features: ['conversations'],
    num_rows: 2358
})


In [None]:
conversations_json[0]

In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3",
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"},
)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True,)

Map:   0%|          | 0/2358 [00:00<?, ? examples/s]

In [None]:
dataset_dict = dataset.train_test_split(test_size=0.005)

In [None]:
import wandb
wandb.login()
wandb.init(project="Chatbot", name="FishAI")

Set hyperparameters and train

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset_dict["train"],
    eval_dataset = dataset_dict["test"],
    dataset_text_field = "text",
    max_seq_length = 2048,
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        num_train_epochs = 1,
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 4,
        eval_strategy = "steps",
        warmup_steps = 5,
        learning_rate = 1e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "wandb",
        logging_steps = 1,
        logging_strategy = 'steps'
    ),
)

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trainer_stats = trainer.train()
wandb.finish()

Done training, play around with the chatbot

In [None]:
import random

def get_last_message(output):
    parts = output.rsplit('\n\n', 1)
    if len(parts) > 1:
        return parts[-1].strip().replace('<|eot_id|>', '')
    return None

import textwrap

def print_wrapped(text):
    wrapped_text = textwrap.fill(text, width=80)
    print(wrapped_text)

In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3",
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"},
)

FastLanguageModel.for_inference(model)

In [None]:
messages = []

for step in range(1):
    if step == 0 or random.random() < 0.5:
        messages.append({"from": "human", "value": input(">> User: ")})

    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize = True,
        add_generation_prompt = True,
        return_tensors = "pt",
    ).to("cuda")

    outputs = model.generate(input_ids = inputs, max_new_tokens = 512, use_cache=True, temperature=1.0)
    response = get_last_message(tokenizer.batch_decode(outputs)[0])
    messages.append({"from": "gpt", "value": response})
    messages = messages[-10:]
    print_wrapped(response)

>> User:  fishy


Fish: I see what you're doing


In [None]:
# model.save_pretrained("pwo_model") # Local saving
model.push_to_hub("ThePwo/FishAI", token = "") # Online saving

README.md:   0%|          | 0.00/605 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/168M [00:00<?, ?B/s]

Saved model to https://huggingface.co/ThePwo/FishAI
