In [1]:
import os
import sys
import numpy as np
import torch
from torch import nn
from transformers import pytorch_utils as torch_utils
from peft import LoraConfig

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'mps'

In [3]:
import importlib
import src.train
import src.model

importlib.reload(src.train)
importlib.reload(src.model)

from src.train import sft_train_lora
from src.model import identify_target_modules
from data.gsm8k import GSM8K

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

In [5]:
%pip install ipywidgets

Note: you may need to restart the kernel to use updated packages.


In [6]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [7]:
# dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")


model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

dataset = GSM8K(tokenizer=tokenizer)

dataset = Dataset.from_dict({"input_text" : [example["input_text"] for example in dataset]})

In [8]:
dataset

Dataset({
    features: ['input_text'],
    num_rows: 7473
})

In [9]:
dataset[0]

{'input_text': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? ### Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}

In [10]:
len(dataset)

7473

In [11]:
target_modules = identify_target_modules(model, name_segment='self_attn')
target_modules

['model.decoder.layers.0.self_attn.k_proj',
 'model.decoder.layers.0.self_attn.v_proj',
 'model.decoder.layers.0.self_attn.q_proj',
 'model.decoder.layers.0.self_attn.out_proj',
 'model.decoder.layers.1.self_attn.k_proj',
 'model.decoder.layers.1.self_attn.v_proj',
 'model.decoder.layers.1.self_attn.q_proj',
 'model.decoder.layers.1.self_attn.out_proj',
 'model.decoder.layers.2.self_attn.k_proj',
 'model.decoder.layers.2.self_attn.v_proj',
 'model.decoder.layers.2.self_attn.q_proj',
 'model.decoder.layers.2.self_attn.out_proj',
 'model.decoder.layers.3.self_attn.k_proj',
 'model.decoder.layers.3.self_attn.v_proj',
 'model.decoder.layers.3.self_attn.q_proj',
 'model.decoder.layers.3.self_attn.out_proj',
 'model.decoder.layers.4.self_attn.k_proj',
 'model.decoder.layers.4.self_attn.v_proj',
 'model.decoder.layers.4.self_attn.q_proj',
 'model.decoder.layers.4.self_attn.out_proj',
 'model.decoder.layers.5.self_attn.k_proj',
 'model.decoder.layers.5.self_attn.v_proj',
 'model.decoder.layers

In [12]:
lora_config = LoraConfig(
    target_modules=target_modules,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
)

In [13]:
sft_train_lora(
    base_model=model,
    train_dataset=dataset,
    eval_dataset=dataset,
    tokenizer=AutoTokenizer.from_pretrained("facebook/opt-350m"),
    adapter_name="sft_lora",
    response_template=" ### Answer:",
    lora_config=lora_config,
)

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]



  0%|          | 0/2805 [00:00<?, ?it/s]


KeyboardInterrupt

