In [None]:
%%capture
!pip install unsloth "xformers==0.0.28.post2"
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
from openai import OpenAI
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
import os
from trl import SFTTrainer
from peft import LoraConfig, get_peft_model
import torch
from tqdm import tqdm
import textwrap
import torch
import random
import dotenv
from __future__ import print_function, division
from typing import List, Dict, Any, Optional, Callable, Tuple
import re
import sys
import json
import requests
import argparse
from pydantic import BaseModel
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from time import time

In [None]:
# Model loading

max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-1B-Instruct", # or choose "unsloth/Llama-3.2-1B-Instruct"
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

In [None]:
prompt = """
You are a shopping agent shopping in Webshop.

The actions available to you are :
- reset
- think[Thought]
- search[Search query]
- click[Button to click]

Rules:
- You can reset from any page, think on any page.
- You can only click buttons available on the page described in the observation. Buttons are defined between square brackets - []
- You can only search from a page with [Search], so click on the back buttons to reach such a page before you search again.
- You can ONLY reply with the action you want to take.
- You must end after a few tries by attempting to buy something.

Tips:
- Carefully surf the webshop to fullfil requirements. 
- If any items match some of the requirements, click on them to see a detailed description and to see if they match all the requirements.  Quantity requirements can be met 
- Don't just give up on a search at the 1st page of results. Move through the result pages by pressing the [Next >] button. You may decide to give up at a reasonable point such as when the results are empty or too different from the requirements (usually 2-3 pages).
"""

In [None]:
def format_chat_template(row):
    input_text = prompt + row['instruction'] # concat prompt and requirement

    conversation = [
        {"role": "system", "content": 'You are a shopping agent for Webshop, a text-based e-commerce platform.'},
        {"role": "user", "content": input_text},
        {"role": "assistant", "content": row['trajectory']}
    ]

    row["text"] = tokenizer.apply_chat_template(conversation, tokenize=False) # applies unsloth function to convert conversation to templated format
  
    return row

processed_data = list(map(format_chat_template, train_data))
train_dataset = Dataset.from_list(processed_data)

processed_data = list(map(format_chat_template, test_data))
test_dataset = Dataset.from_list(processed_data)

In [None]:
# Need function to evaluate the model

In [None]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 4,
        # warmup_steps = 5,
        num_train_epochs = 5,
        eval_strategy = "steps",
        eval_steps = 100,
        # max_steps = 60,
        learning_rate = 1e-5,
        logging_steps = 10,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use this for WandB etc
    ),
)