In [None]:
from IPython.display import clear_output

In [None]:
!pip install transformers peft datasets evaluate huggingface_hub trl omegaconf rouge_score --upgrade
clear_output()

In [None]:
from kaggle_secrets import UserSecretsClient
import os
api_keys = UserSecretsClient()

In [None]:
os.system(f'wandb login {api_keys.get_secret("wandb")}')

In [None]:
os.system(f'huggingface-cli login --token {api_keys.get_secret("huggingface-cli")}')

In [None]:
import numpy as np
import pandas as pd
import scipy
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
import torch.nn.functional as F

import transformers
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BertTokenizer, BertModel 
from transformers import TrainingArguments, Trainer, Seq2SeqTrainingArguments, Seq2SeqTrainer, GenerationConfig, DataCollatorWithPadding
from transformers import pipeline
from peft import BOFTConfig, get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
import evaluate
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

import wandb
from omegaconf import OmegaConf

# from deepeval.benchmarks import MMLU
# from deepeval.benchmarks.tasks import MMLUTask
# from deepeval.models.base_model import DeepEvalBaseLLM

import pickle
import tqdm.notebook as tqdm

clear_output()

# Massive multitask language understanding (MMLU benchmark)

In [None]:
# Loading MMLU categories
if not os.path.exists('./categories.py'):
    !wget https://raw.githubusercontent.com/hendrycks/test/master/categories.py

from categories import subcategories, categories as categories_inv

In [None]:
for subcat_name, cat_names in subcategories.items():
    subcategories[subcat_name] = cat_names[0] if isinstance(cat_names, list) else cat_names
    
categories = {}

for cat_name, subcats in categories_inv.items():
    for subcat in subcats:
        categories[subcat] = cat_name

In [None]:
def subcat_to_cat(subcat):
    cat_name = subcategories[subcat]
    cat_name = categories[cat_name]
    
    return cat_name

In [None]:
config = OmegaConf.create({
    'model_name':   'meta-llama/Meta-Llama-3-8B-Instruct',
    'padding_side': 'left',
    'task_name':    'all',
    'max_length':   256,
    'n_shots': 2,
    'fp16': True,
    'bf16': False,
    'ft_strategy': 'BOFT',
    'LoRA_config': {
        'r': 16, 
        'lora_alpha': 32, 
        'lora_dropout': 0.05,
        'target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
    },
    'BOFT_config': {  # m=2, b=8
        'boft_block_size': 8,
#         'boft_block_num': 8,
        'boft_n_butterfly_factor': 1,
        'bias': 'none',
        'target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
        'boft_dropout': 0.05,
    },
    'evaluation_config':{
        'num_splits': 20,
        'max_new_tokens': 4,
        'batch_size': 1,
        'empty_cache': True,
    },
    'trainer_config': {
        'output_dir': "bogachevv/Llama-3-8b-MMLU",
        'max_seq_length': 512,
        'dataset_text_field': 'text',
        'fp16': True,
        'full_determinism': False,
        'per_device_train_batch_size': 1,
        'per_device_eval_batch_size':  1,
        'gradient_accumulation_steps': 8,
        'lr_scheduler_type': 'cosine_with_restarts',
        'lr_scheduler_kwargs':{
            'num_cycles': 6,
        },
        'warmup_steps': 100,
#         'num_train_epochs': 2,
        'learning_rate': 1e-4,
        'max_steps': 2048,
        'weight_decay': 0.01,
#         'warmup_ratio': 1e-2,
        'dataloader_num_workers': 2,
        'eval_strategy': "steps",
#         'torch_empty_cache_steps': 16,
        'eval_steps': 16,
        'logging_steps': 16,
        'load_best_model_at_end': True,
        'seed': 42,
        'data_seed': 42,
        'report_to': 'wandb',
#         'predict_with_generate': True,
#         'push_to_hub': True,
#         'hub_model_id': 'LLama-LoRA-test',
#         'hub_strategy': 'checkpoint',
#         'save_strategy': "steps",
        'save_steps': 128,
    },
})

In [None]:
type(config)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    config.model_name, 
    padding_side=config.padding_side,
#     model_max_length=512,
)
tokenizer.pad_token = tokenizer.eos_token
EOS_TOKEN = tokenizer.eos_token

mmlu_dataset =  load_dataset("cais/mmlu", config.task_name)

In [None]:
# Try block for clear_output() call iff succes
try:
    few_shot_datasets = {
        subject: mmlu_dataset['dev'].filter(lambda row: row['subject'] == subject)
        for subject in set(mmlu_dataset['dev']['subject'])
    }
    
    clear_output()
    print('Succes')
except:
    raise

In [None]:
def prepare_question(examples):
    prompt = f"{examples['question']}\n"
    for letter, choice in zip(('A', 'B', 'C', 'D'), examples['choices']):
        prompt += f"{letter}. {choice}\n"

    answer = chr(65 + examples['answer'])
    
    return prompt, answer

def prepare_prompt(examples, dev_dataset = None):
    if dev_dataset:
        yield from map(prepare_question, dev_dataset)
    
    yield prepare_question(examples)

In [None]:
def prepare_instruction_text(example):
    instructions = [
        {"role": "system", "content": f"The following are multiple choice questions (with answers) about {example['subject']}. Output 'A', 'B', 'C', or 'D'. Full answer not needed."},
    ]

    if config['n_shots'] and example['subject']:
        few_shot_dataset = few_shot_datasets[example['subject']]
        few_shot_dataset = few_shot_dataset.select(range(config['n_shots']))
    else:
        few_shot_dataset = None
    
    for prompt, ans in prepare_prompt(example, dev_dataset=few_shot_dataset):
        instructions.append({"role": "user", "content": prompt})
        instructions.append({"role": "assistant", "content": ans})
    
    text = tokenizer.apply_chat_template(
        instructions,
        tokenize=False
    )
    
    return {'text': text}

In [None]:
def r_replace(line, old, new):
    return line[::-1].replace(old[::-1], new[::-1], 1)[::-1]

def remove_answer(example):
    text_wa_answer = example['text']
    text_wa_answer = text_wa_answer.rsplit('<|eot_id|>', 1)[0][:-1]
    
    # for letter in ('A', 'B', 'C', 'D'):
        # text_wa_answer = text_wa_answer.replace(f'<|start_header_id|>assistant<|end_header_id|>\n\n{letter}<|eot_id|>', '<|start_header_id|>assistant<|end_header_id|>\n\n')
        # text_wa_answer = r_replace(text_wa_answer, f'<|start_header_id|>assistant<|end_header_id|>\n\n{letter}<|eot_id|>', '<|start_header_id|>assistant<|end_header_id|>\n\n')
    
    return {'text_wa_answer': text_wa_answer}

In [None]:
instructions_datasets = mmlu_dataset.map(prepare_instruction_text, batched=False, num_proc=2)
instructions_datasets['validation'] = instructions_datasets['validation'].map(remove_answer, batched=False)
instructions_datasets['test'] = instructions_datasets['test'].map(remove_answer, batched=False)

instructions_datasets.set_format("torch")

instructions_datasets

In [None]:
print(instructions_datasets['validation'][1]['text'])

In [None]:
print(instructions_datasets['validation'][1]['text_wa_answer'])

In [None]:
# Accessing the train, validation, and test splits
validation_dataset = instructions_datasets["validation"]
test_dataset = instructions_datasets["test"]
dev_dataset = instructions_datasets["dev"]  # dataset for few shot
auxiliary_train_dataset  = instructions_datasets['auxiliary_train']

# Check the size of each split
print(f"Validation dataset size: {len(validation_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
print(f"Dev dataset size: {len(dev_dataset)}")
print(f"Auxiliary train dataset size: {len(auxiliary_train_dataset)}")

In [None]:
test_dataset

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    device_map='auto',
    torch_dtype=torch.float16,
)

In [None]:
accuracy_metric = evaluate.load("accuracy")

def process_prediction(pred):
    pred = pred['generated_text']
    
    pred = pred.strip().upper()
    
    pred = pred[0] if pred else 'I'
    pred = pred if pred in {'A', 'B', 'C', 'D'} else 'I'
    
    return pred

def compute_accuracy(model_preds, labels):   
    model_preds = list(map(process_prediction, model_preds))
    
    model_preds  = torch.LongTensor(list(map(ord, model_preds)))
    actual_labels = ord('A') + labels
    incorrect_labels = actual_labels.new_full(actual_labels.shape, ord('I'))
    
#     print(f"{model_preds=}\n{actual_labels=}\n{incorrect_labels=}")
    
    acc_res = accuracy_metric.compute(predictions=model_preds, references=actual_labels)['accuracy']
    corr_res = 1.0 - accuracy_metric.compute(predictions=model_preds, references=incorrect_labels)['accuracy']
    
    return {'accuracy': acc_res, 'correctness': corr_res}

In [None]:
model.model.layers[0].self_attn.q_proj.weight.shape[0]

In [None]:
# adapter_config = BOFTConfig(
#     task_type=TaskType.CAUSAL_LM,
#     inference_mode=False,
#     boft_block_size=8,
# #     boft_block_num=16,
#     boft_n_butterfly_factor=2,
#     bias='none',
# )

# adapter_config

# model_adapter = get_peft_model(model, adapter_config)    

In [None]:
# model_adapter.print_trainable_parameters()

In [None]:
# model_adapter

In [None]:
adapter_config = BOFTConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        **OmegaConf.to_object(config.BOFT_config)
    )

model_adapter = get_peft_model(model, adapter_config)   

In [None]:
if config.ft_strategy == 'LoRA':
    adapter_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, 
        **OmegaConf.to_object(config.LoRA_config),
    )
elif config.ft_strategy == 'BOFT':
    adapter_config = BOFTConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        **OmegaConf.to_object(config.BOFT_config)
    )
else:
    raise ValueError('Incorrect FT type')

model_adapter = get_peft_model(model, adapter_config)    
model_adapter.print_trainable_parameters()

In [None]:
model.eval()

pl = pipeline(
    "text-generation",
    model=model,       # WARNING: model used insted of lora_model
    tokenizer=tokenizer,
    torch_dtype=torch.float16
)

In [None]:
%%time

model_preds = []

with torch.no_grad():
    for i, split in tqdm.tqdm(
        enumerate(np.array_split(np.arange(len(test_dataset)), 20)),
        total=20,
    ):
        print(f"Start iteration {i}")
        print(f"\tstart pos: {np.min(split)}\tend pos: {np.max(split)}")
        
        model_pred = pl(
            test_dataset.select(split)['text_wa_answer'],
    #         validation_dataset.shuffle(42).select(range(512))['text_wa_answer'],
    #         validation_dataset['text_wa_answer'],
            return_full_text=False,
            max_new_tokens=4,
            do_sample=False,
            temperature=None,
            top_p=None,
            batch_size=1
        )
        model_preds += model_pred
        torch.cuda.empty_cache()
        
        print(f"Finish iteration {i}")
        print(f"\t{len(model_preds)=}")

model_preds_merged = []
for ls in model_preds:
    model_preds_merged += ls

model_preds = model_preds_merged

# model_preds

In [None]:
for i in range(len(model_preds)):
    model_preds[i]['subject'] = test_dataset[i]['subject']
    
model_preds 

In [None]:
with open('./fs_preds.bin', 'wb') as f:
    pickle.dump(
        obj=(model_preds, test_dataset['answer']),
        file=f
    )

In [None]:
preds_df = pd.DataFrame(model_preds)

preds_df['pred'] = preds_df.apply(process_prediction, axis=1)
preds_df['true'] = list(map(lambda v: chr(v + ord('A')), test_dataset['answer']))
preds_df['corr'] = (preds_df['pred'] == preds_df['true']).astype(np.int32)
preds_df['category'] = preds_df['subject'].apply(subcat_to_cat)

preds_df.head(20)

In [None]:
preds_df[['subject', 'corr']].groupby(['subject']).mean()

In [None]:
preds_df[['category', 'corr']].groupby(['category']).mean()

In [None]:
compute_accuracy(model_preds, test_dataset['answer'])

In [None]:
assert False

In [None]:
training_args = SFTConfig(
    **OmegaConf.to_object(config.trainer_config),
)

In [None]:
trainer = SFTTrainer(
    model=lora_model,
    args=training_args,
#     args=SFTConfig(
#         output_dir="/tmp",
#         per_device_train_batch_size=1,
#         per_device_eval_batch_size=2,
#         fp16=True,
#     ),
    train_dataset=auxiliary_train_dataset,
    eval_dataset=validation_dataset.shuffle(42).select(range(64)),
#     formatting_func=formatting_prompts_func,
#     data_collator=collator,
#     compute_metrics=compute_accuracy,
)

In [None]:
# tokenizer.decode(trainer.train_dataset[0]['input_ids'])

In [None]:
# for batch in trainer.get_train_dataloader():
#     print(tokenizer.batch_decode(batch['input_ids']))
#     break

In [None]:
torch.cuda.empty_cache()

trainer.train()

In [None]:
# torch.cuda.empty_cache()

In [None]:
pl = pipeline(
    "text-generation",
    model=lora_model,
    tokenizer=tokenizer,
)

In [None]:
%%time

model_preds = pl(
    validation_dataset.shuffle(42).select(range(512))['text_wa_answer'],
#     validation_dataset['text_wa_answer'],
    return_full_text=False,
    max_new_tokens=16,
    do_sample=False,
    temperature=None,
    top_p=None,
    batch_size=4,
)
torch.cuda.empty_cache()

model_preds_merged = []
for ls in model_preds:
    model_preds_merged += ls

model_preds = model_preds_merged

# model_preds

In [None]:
for i in range(len(model_preds)):
    model_preds[i]['subject'] = validation_dataset[i]['subject']
    
model_preds

In [None]:
compute_accuracy(model_preds, validation_dataset.select(range(len(model_preds)))['answer'])

In [None]:
lora_model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")

In [None]:
!zip fine_tuned_model.zip ./fine_tuned_model/*