In [90]:
# install the required packages
!pip install transformers datasets torch wandb huggingface_hub peft trl bitsandbytes accelerate

  pid, fd = os.forkpty()
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [42]:
# import the required packages

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model 
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import huggingface_hub
import wandb
import torch
import json

In [9]:
# define the config parameters

# load the config file
# with open("../config/config.json") as f:
#     config = json.load(f)

config = {
  "hugginface_token": "",
  "wandb_key": "",
  "model_path": "google/gemma-2-2b",
  "save_model_name": "pretrain_open_source",
  "use_lora": True,
  "lora_r": 16,
  "lora_alpha": 32,
  "lr": 3e-5,
  "epoch": 3,
  "batch_size": 8,
  "max_seq_len": 512,
  "checkpoint_path": "../checkpoints",
  "OpenSource_data_path": "FiscalNote/billsum", # check OpenSource data text file
  "OpenSource_version": "",
  "Youtube_data_path": ""                       # "ht324/WhiteBoard_LLM_Data"
}

hugginface_token = config["hugginface_token"]
wandb_key = config["wandb_key"]

model_path = config["model_path"]
save_model_name = config["save_model_name"]

use_lora = config["use_lora"]
lora_r = config["lora_r"]
lora_alpha = config["lora_alpha"]

lr = config["lr"]
epoch = config["epoch"]
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]

# checkpoint_path = config["checkpoint_path"]
OpenSource_data_path = config["OpenSource_data_path"]
OpenSource_version = config["OpenSource_version"]
Youtube_data_path = config["Youtube_data_path"]

In [None]:
# login to huggingface and wandb

huggingface_hub.login(token=hugginface_token)
if wandb_key:
    wandb.login(key=wandb_key)
    wandb.init(
        project="WhiteBoard_LLM",
        config={
            "model_name": save_model_name,
            "lr": lr,
            "epoch": epoch,
            "batch_size": batch_size,
            "max_seq_len": max_seq_len,
            "use_lora": use_lora,
            "lora_r": lora_r,
            "lora_alpha": lora_alpha,
        },
        name=save_model_name
    )

In [13]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:  57%|#####7    | 2.84G/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

In [14]:
if use_lora:
    lora_config = LoraConfig(
        r = lora_r,
        lora_alpha = lora_alpha,
        target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
        task_type="C",
    )
    model = get_peft_model(model, lora_config)

In [69]:
# load the dataset

# TODO Check the dataset structure
dataset = None
if OpenSource_data_path:
    if OpenSource_version:
        open_source_data = load_dataset(OpenSource_data_path, OpenSource_version)
    else:
        open_source_data = load_dataset(OpenSource_data_path)
    dataset = open_source_data

if Youtube_data_path:
    youtube_data = load_dataset(Youtube_data_path)
    if dataset:
        dataset = dataset.concatenate(youtube_data)
    else:
        dataset = youtube_data

if dataset is None:
    raise ValueError("No data provided")

In [84]:
train_data = dataset["train"]
train_data = train_data.map(
    lambda example: {"text": example["text"], "summary": example["summary"]},
    remove_columns=train_data.column_names,
)
train_data[0]


Map:   0%|          | 0/18949 [00:00<?, ? examples/s]

{'text': "SECTION 1. LIABILITY OF BUSINESS ENTITIES PROVIDING USE OF FACILITIES \n              TO NONPROFIT ORGANIZATIONS.\n\n    (a) Definitions.--In this section:\n            (1) Business entity.--The term ``business entity'' means a \n        firm, corporation, association, partnership, consortium, joint \n        venture, or other form of enterprise.\n            (2) Facility.--The term ``facility'' means any real \n        property, including any building, improvement, or appurtenance.\n            (3) Gross negligence.--The term ``gross negligence'' means \n        voluntary and conscious conduct by a person with knowledge (at \n        the time of the conduct) that the conduct is likely to be \n        harmful to the health or well-being of another person.\n            (4) Intentional misconduct.--The term ``intentional \n        misconduct'' means conduct by a person with knowledge (at the \n        time of the conduct) that the conduct is harmful to the health \n        or w

In [94]:
def generate_prompt(examples):
    """
    Generate prompts for a batch of examples using 'text' and 'summary'.
    """
    return [
        f"<bos><start_of_turn>user\n{text}<end_of_turn>\n<start_of_turn>model\n{summary}<end_of_turn><eos>"
        for text, summary in zip(examples["text"], examples["summary"])
    ]
# print(generate_prompt(train_data[:1]))
# train_data = train_data.map(preprocess_function, batched=True, remove_columns=["text", "summary"])

In [95]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    max_seq_length=max_seq_len,
    args=TrainingArguments(
        output_dir="./checkpoints",
        num_train_epochs=epoch,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=4,
        learning_rate=lr,
        report_to="wandb",
        fp16=True,
        push_to_hub=True,
        hub_model_id=save_model_name,
        logging_steps=20,
        optim="paged_adamw_8bit", # other optimizer
        run_name=save_model_name,
    ),
    peft_config=lora_config,
    formatting_func=generate_prompt,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/18949 [00:00<?, ? examples/s]

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [96]:
trainer.train()

RuntimeError: chunk expects at least a 1-dimensional tensor

In [None]:
# TODO model save