In [None]:
from kaggle_secrets import UserSecretsClient
import wandb

user_secrets = UserSecretsClient()
secret_hf = user_secrets.get_secret("HUGGINGFACE_API")
secret_wandb = user_secrets.get_secret("wandb")

!huggingface-cli login --token $secret_hf

wandb.login(key = secret_wandb)

In [2]:
import os

from copy import deepcopy
from random import randrange
from functools import partial

import torch
import accelerate
import bitsandbytes as bnb

from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from transformers.integrations import WandbCallback
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model,
    PeftModel
)
from trl import SFTTrainer

In [None]:
model_name = "/kaggle/input/mistral/pytorch/7b-v0.1-hf/1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",  # Auto selects device to put model on.
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])


    # lm_head is often excluded.
    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


modules = find_all_linear_names(model)

Peft Config

In [None]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=modules,
    r=8,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, peft_config)

In [None]:
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

In [None]:
dataset = load_dataset("LDJnr/Puffin", split="train")
random_sample = dataset[randrange(len(dataset))]

In [None]:
run = wandb.init(
    project="Fine tuning mistral 7B",  # Project name.
    name="log_dataset",          # name of the run within this project.
    config={                     # Configuration dictionary.
        "split": "train"
    },
    group="dataset",             # Group runs. This run belongs in "dataset".
    tags=["dataset"],            # Tags. More dynamic, low-level grouping.
    notes="Logging subset of Puffin dataset.",  # Description about the run.
    job_type="training",
)  

In [None]:
data = []
for i in range(1000):  # Log 1000 instances.
    x = dataset[i]
    id_ = x["id"]
    conversations = x["conversations"]
    for idx, response in enumerate(conversations):
        data.append([id_, idx, response["from"], response["value"]])


table = wandb.Table(data=data, columns=["id", "idx", "from", "value"])
run.log({"first1000_Puffin": table})

In [None]:
run.finish()

In [None]:
def format_prompt(sample):
    """Given a sample dictionary with key "conversations", format the conversation into a prompt.


    Args:
      sample: A sample dictionary from a Hugging Face dataset.


    Returns:
      sample: sample dictionary with "text" key for the formatted prompt.
    """


    INTRO = "Below is a conversation between a user and you."
    END = "Instruction: Write a response appropriate to the conversation."


    conversations = ""
    for response in sample["conversations"]:
      from_, value = response["from"], response["value"]
      conversations += f"<{from_}>: " + value + "\n"


    sample["text"] = "\n\n".join([INTRO, conversations, END])


    return sample

format_prompt(random_sample)["text"]

In [None]:
def get_max_length(model):
    conf = model.config
    max_length = None
    for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
        max_length = getattr(model.config, length_setting, None)
        if max_length:
            print(f"Found max length: {max_length}")
            break
    if not max_length:
        max_length = 1024
        print(f"Using default max length: {max_length}")
    return max_length


# Change the max length depending on hardware constraints.
max_length = get_max_length(model)