In [1]:
import json
import os
from pprint import pprint
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset
from huggingface_hub import notebook_login
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training
)
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
import pandas as pd

In [2]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

### Loading model

In [4]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
bnb_config = BitsAndBytesConfig(
    load_in_8bit = True
)

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map = device,
    trust_remote_code = True,
    quantization_config = bnb_config
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear8bitLt(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()

### tokenizer

In [7]:
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    model_max_length=4096, # maximum length of the tokenized input that the tokenizer will handle
    padding_side="left",
    add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token

### load data

In [8]:
dataset = load_dataset("csv", data_files="./filtered_data/data.csv")

In [9]:
dataset

DatasetDict({
    train: Dataset({
        features: ['moves', 'explanation'],
        num_rows: 500
    })
})

### Convert to prompt and tokenize

In [10]:
mistral_prompt = '''
    "role": "user", "content": "{}",
    "role": "assistant", "content": "{}"
    '''

In [11]:
EOS_TOKEN = tokenizer.eos_token
def formatting_func(df):
    user_prompt = "Based on the provided Algebraic chess moves, explain the rationale behind last move and the strategy being used by the player - "
    moves       = df["moves"]
    exps        = df["explanation"]
    texts = []
    for move, exp in zip(moves, exps):
        text = mistral_prompt.format(f"{user_prompt} {move}", exp) + EOS_TOKEN
        texts.append(text)
    return {"text" : texts}

In [12]:
dataset = dataset.map(formatting_func, batched = True)
dataset

DatasetDict({
    train: Dataset({
        features: ['moves', 'explanation', 'text'],
        num_rows: 500
    })
})

In [13]:
# print(dataset["train"][10]["text"])

### "text" will have the full prompt and combine moves, explanation

In [14]:
def generate_and_tokenize_prompt(data_point):
  full_prompt = data_point["text"]
  tokenized_full_prompt = tokenizer(full_prompt, padding=True, truncation=True)
  return tokenized_full_prompt

In [15]:
data = dataset["train"].shuffle().map(generate_and_tokenize_prompt)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [16]:
data

Dataset({
    features: ['moves', 'explanation', 'text', 'input_ids', 'attention_mask'],
    num_rows: 500
})

In [17]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

### LoRA

In [18]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)

### Finetuning

In [19]:
from trl import SFTTrainer
from transformers import TrainingArguments
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset["train"],
    dataset_text_field = "text", # https://stackoverflow.com/questions/77113551/trl-sfttrainer-llama2-finetuning-on-alpaca-datasettext-field
    max_seq_length = 4096,
    dataset_num_proc = 2,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    packing = False, # Can make training 5x faster for short sequences, good for flash attention i think.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # max_steps = 100,
        num_train_epochs = 2,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(), # trying to use bf6
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 50,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs"
    )
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Map (num_proc=2):   0%|          | 0/500 [00:00<?, ? examples/s]



In [1]:
# trainer_stats = trainer.train()

In [None]:
# #https://huggingface.co/docs/transformers/main/chat_templating
# encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")

In [None]:
# encoding = tokenizer(prompt, return_tensors="pt").to(device)

In [None]:
# generated_ids = model.generate(input_ids = encoding.input_ids, max_new_tokens=1000, do_sample=True)
# decoded = tokenizer.batch_decode(generated_ids)
# print(decoded[0])