In [1]:
!pip install peft

Collecting peft
  Downloading peft-0.18.1-py3-none-any.whl.metadata (14 kB)
Downloading peft-0.18.1-py3-none-any.whl (556 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.0/557.0 kB[0m [31m4.8 MB/s[0m  [33m0:00:00[0m36m-:--:--[0m
[?25hInstalling collected packages: peft
Successfully installed peft-0.18.1


In [2]:
!pip install trl

Collecting trl
  Downloading trl-0.27.1-py3-none-any.whl.metadata (11 kB)
Downloading trl-0.27.1-py3-none-any.whl (532 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m532.9/532.9 kB[0m [31m5.4 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: trl
Successfully installed trl-0.27.1


In [18]:
import json

def force_string_content(input_path, output_path):
    fixed_lines = []
    with open(input_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            for msg in data['messages']:
                # If content is a dict or list, make it a JSON string
                if isinstance(msg.get("content"), (dict, list)):
                    msg["content"] = json.dumps(msg["content"])
                # Ensure tool_call_id is a string too
                if "tool_call_id" in msg:
                    msg["tool_call_id"] = str(msg["tool_call_id"])
            fixed_lines.append(json.dumps(data))
    
    with open(output_path, 'w') as f:
        f.write("\n".join(fixed_lines))
os.makedirs("trial_2/llama_format_stringified", exist_ok=True)
force_string_content("trial_2/llama_format/train.jsonl", "trial_2/llama_format_stringified/train.jsonl")
force_string_content("trial_2/llama_format/validation.jsonl", "trial_2/llama_format_stringified/validation.jsonl")
force_string_content("trial_2/llama_format/test.jsonl", "trial_2/llama_format_stringified/test.jsonl")

In [25]:
import torch
import os
import json
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from trl import SFTConfig, SFTTrainer

# 1. Setup Device
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

model_id = "meta-llama/Llama-3.2-3B-Instruct"

# 2. Load Tokenizer & Apply Multi-Tool Jinja Template
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

# Updated template: Loops through multiple tool calls & removes the strict limit
tokenizer.chat_template = """
{%- set system_message = messages[0]['content'] if messages[0]['role'] == 'system' else '' -%}
<|start_header_id|>system<|end_header_id|>\n\n{{ system_message }}<|eot_id|>
{%- for message in messages %}
    {%- if message.role == 'user' %}
        <|start_header_id|>user<|end_header_id|>\n\n{{ message.content }}<|eot_id|>
    {%- elif 'tool_calls' in message %}
        <|start_header_id|>assistant<|end_header_id|>\n\n
        {%- for tool_call in message.tool_calls %}
            {"name": "{{ tool_call.function.name }}", "parameters": {{ tool_call.function.arguments | tojson }}}
            {%- if not loop.last %}\n{% endif %}
        {%- endfor %}<|eot_id|>
    {%- elif message.role == 'assistant' %}
        <|start_header_id|>assistant<|end_header_id|>\n\n{{ message.content }}<|eot_id|>
    {%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>\n\n{% endif %}
"""

# 3. Load Datasets
# Replace path with your actual sanitized file location
raw_dataset = load_dataset("json", data_files="trial_2/llama_format_stringified/train.jsonl", split="train")

# First split: 90% for training, 10% for "everything else"
train_test_valid = raw_dataset.train_test_split(test_size=0.1, seed=42)



train_dataset = train_test_valid["train"]
eval_dataset = train_test_valid["test"]  # This is your Validation set

print(f"Final Counts:")
print(f"  - Training samples:   {len(train_dataset)}")
print(f"  - Validation samples: {len(eval_dataset)}")

Using device: mps
Final Counts:
  - Training samples:   965
  - Validation samples: 108


In [None]:
# 4. Load Base Model ONLY
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map={"": device}
)

In [28]:
# 5. Define LoRA Config (Do NOT wrap the model manually here)
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# 6. SFT Configuration
sft_config = SFTConfig(
    output_dir="./llama3.2-finetuned-mps",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    optim="adamw_torch",
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    bf16=True,
    report_to="none",
    dataset_text_field="text",
    packing=False
)

# 7. Initialize SFTTrainer (Pass peft_config here)
trainer = SFTTrainer(
    model=model, # Passing raw base model
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config, # The trainer handles the wrapping
    args=sft_config,
    processing_class=tokenizer
)

# 8. Start Training
print("Starting training on MPS...")
trainer.train()

# 9. Final Save
trainer.save_model("./llama3.2-final-adapter")
tokenizer.save_pretrained("./llama3.2-final-adapter")
print("All set! Your adapter is saved in ./llama3.2-final-adapter")

Tokenizing train dataset: 100%|██████████| 965/965 [00:00<00:00, 1189.94 examples/s]
Truncating train dataset: 100%|██████████| 965/965 [00:00<00:00, 108649.06 examples/s]
Tokenizing eval dataset: 100%|██████████| 108/108 [00:00<00:00, 1423.79 examples/s]
Truncating eval dataset: 100%|██████████| 108/108 [00:00<00:00, 52869.38 examples/s]
The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.


Starting training on MPS...




KeyboardInterrupt: 