In [1]:
import gc
import inspect
import os
import logging
from datetime import datetime
from time import time

import fire
import huggingface_hub
import torch
import transformers
import wandb
from datasets import load_dataset
from dotenv import load_dotenv
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer, SFTConfig

from peft import LoraConfig, PeftModel
from utils.eval_helper import inspectt, logg
from utils.ft_helper import (
    generate_and_tokenize_prompt,
    get_start_index,
    reorder_dataset,
)
from torch.utils.data import SequentialSampler
wandb.require("core")


In [2]:
cache_dir: str = f"/dpc/kunf0097/l3-8b"
train_data_path: str = "./data/medical-36-row.json"
# train_data_path: str = "meher146/medical_llama3_instruct_dataset"
model_name: str = "facebook/opt-350m"
model_save_path: str = None
run_id: str = datetime.now().strftime("%y%m%d%H%M%S")
chpt_dir: str = None
last_checkpoint: str = None
per_device_train_batch_size: int = 4
gradient_accumulation_steps: int = 4
world_size: int = None
local_rank: int = None

In [3]:
if model_save_path is None:
    model_save_path = f"{cache_dir}/model/{model_name}-v{run_id}"

if chpt_dir is None:
    chpt_dir = f"{cache_dir}/chpt/{run_id}"

if os.path.isdir(chpt_dir):
    checkpoints = [d for d in os.listdir(chpt_dir) if d.startswith("checkpoint-")]
    if checkpoints:
        last_checkpoint = os.path.join(
            chpt_dir, max(checkpoints, key=lambda cp: int(cp.split("-")[-1]))
        )

# if train_data_path locally exists use it
if os.path.exists(train_data_path):
    data = load_dataset("json", data_files=train_data_path, split="train")
else:
    data = load_dataset(train_data_path, split="train")

start_index = 0
if last_checkpoint is not None:
    start_index = get_start_index(last_checkpoint, len(data))

# device_map = "auto"
device_map = {"": 0}
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))

ddp = world_size != 1
if ddp:
    device_map = {"": local_rank}
    gradient_accumulation_steps = gradient_accumulation_steps // world_size


start_index = 0
if last_checkpoint is not None:
    start_index = get_start_index(last_checkpoint, len(data))

# device_map = "auto"
device_map = {"": 0}
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))

ddp = world_size != 1
if ddp:
    device_map = {"": local_rank}
    gradient_accumulation_steps = gradient_accumulation_steps // world_size

inspectt(inspect.currentframe())
logger = logging.getLogger(__name__)

start = time()
load_dotenv()
HF_TOKEN_WRITE = os.getenv("HF_TOKEN_WRITE")
huggingface_hub.login(token=HF_TOKEN_WRITE)
torch.cuda.empty_cache()

------------------------  ---------------------------
------------------------  ---------------------------
The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/kunet.ae/ku5001069/.cache/huggingface/token
Login successful


In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=f"{cache_dir}/model",
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    device_map=device_map,
)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=f"{cache_dir}/tokenizer")

In [5]:
from trl import DataCollatorForCompletionOnlyLM

In [16]:
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        full_prompt = f"""</s>system</s>{example['instruction'][i]}</s></s>user</s>{example['input'][i]}</s></s>assistant</s>{example['output'][i]}</s>"""
        output_texts.append(full_prompt)
    return output_texts

response_template = "</s>assistant</s>"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=collator,
    formatting_func=formatting_prompts_func,
    peft_config=peft_config,
    train_dataset=data,
    args=train_args,
)