In [None]:
!pip install datasets bitsandbytes peft trl

In [77]:
from datasets import load_dataset, concatenate_datasets
from huggingface_hub import login
from matplotlib import pyplot as plt
import numpy as np
from collections import Counter
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from trl import SFTTrainer, SFTConfig
import torch

In [None]:
login("") # Insert HF API key here

In [None]:
ro_math_bac_set = load_dataset('cosmadrian/romath', 'bac')
ro_math_comp_set = load_dataset('cosmadrian/romath', 'comps')
ro_math_synthethic_set = load_dataset('cosmadrian/romath', 'synthetic')

In [8]:
ro_math_bac_set_combined = concatenate_datasets([ro_math_bac_set['train'], ro_math_bac_set['test']])

In [9]:
ro_math_comp_set_combined = concatenate_datasets([ro_math_comp_set['train'], ro_math_comp_set['test']])
ro_math_synthethic_set_combined = concatenate_datasets([ro_math_synthethic_set['train'], ro_math_synthethic_set['test']])

EDA: Observing the domain distribution in the dataset
To make it easier, we have combined the train and test split to get the total number of samples belonging to a certain domain in our set

In [71]:
class DatasetAnalyzer:
  """
    A class that helps with EDA.
    Currently supports:
    - Computing domain frequency distributions for each dataset subset (bac, comps, synthetic).
    - Getting the max sequence length per set to get an idea of the max_seq_len hyperparameter for LLM fine tuning
  """

  DOMAIN_COLUMN_NAME = 'domain'
  PROBLEM_COLUMN_NAME = 'problem'


  def __init__(self, ro_math_bac_set, ro_math_comp_set, ro_math_syntethic_set):
    self.data_subsets = {
        'bac': ro_math_bac_set,
        'comps': ro_math_comp_set,
        'synthetic': ro_math_syntethic_set
    }


  def get_domain_distribution(self, subset: str):
    if subset not in self.data_subsets:
      raise ValueError("Invalid data subset")

    analyzed_subset = self.data_subsets[subset]
    return dict(Counter(analyzed_subset[self.DOMAIN_COLUMN_NAME]))

  def get_max_sequence_length(self, subset:str):
    if subset not in self.data_subsets:
      raise ValueError("Invalid data subset")

    max_len = 0
    analyzed_subset = self.data_subsets[subset]

    for problem in analyzed_subset[self.PROBLEM_COLUMN_NAME]:
      problem_length = len(problem.split())
      if problem_length > max_len:
        max_len = problem_length

    return max_len

In [72]:
class DatasetVisualizer(DatasetAnalyzer):

  """
    Inherits from DatasetAnalyzer. Adds visualization methods for domain distributions
    across BAC, COMP, and SYNTHETIC subsets.
  """

  def __init__(self, ro_math_bac_set, ro_math_comp_set, ro_math_syntethic_set):
    super().__init__(ro_math_bac_set, ro_math_comp_set, ro_math_syntethic_set)

  def visualize_domain_distributions(self):
    overall_domain_distributions = [self.get_domain_distribution(subset) for subset in self.data_subsets]

    fig, ax = plt.subplots(1, 3, figsize=(18, 5))
    bar_width = 0.6
    plot_colors = ['b', 'g', 'r']

    for idx, (subset, color) in enumerate(zip(self.data_subsets, plot_colors)):
      indices_x_axis = np.arange(len(overall_domain_distributions[idx]))
      labels = list(overall_domain_distributions[idx].keys())
      max_upper_bound_domain = max(overall_domain_distributions[idx].values())

      ax[idx].bar(indices_x_axis, overall_domain_distributions[idx].values(), color = color, width = bar_width)
      ax[idx].set_title(subset.upper() + ' domain distribution')
      ax[idx].set_xticks(indices_x_axis)
      ax[idx].set_xticklabels(labels, rotation = 70)
      ax[idx].set_ylim(0, max_upper_bound_domain * 1.1)
      ax[idx].set_ylabel('Domain counter')

    fig.suptitle("Domain Distribution Across Subsets", fontsize=16)
    plt.show()

  def visualize_max_sequence_length(self):
    overall_max_seq_lengths = [self.get_max_sequence_length(subset) for subset in self.data_subsets]
    subsets = ['BAC', 'COMP', 'SYNTHETIC']

    plt.bar(subsets, overall_max_seq_lengths)
    plt.title('Maximum sequence lengths across subsets')
    plt.xlabel('Subsets')
    plt.ylabel('Maximum sequence length')
    plt.show()

In [None]:
dataset_visualizer = DatasetVisualizer(ro_math_bac_set_combined, ro_math_comp_set_combined, ro_math_synthethic_set_combined)
dataset_visualizer.visualize_domain_distributions()
dataset_visualizer.visualize_max_sequence_length()

Loading the LLM: for the first experiment, we used RoLlama

In [None]:
# TODO: Start finetuning the model

model_name = "OpenLLM-Ro/RoLlama2-7b-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [59]:
def format_questions_answers(data_samples):
  questions_column = data_samples['problem']
  answers_column = data_samples['answer']
  formatted_texts = []

  for question, answer in zip(questions_column, answers_column):
    messages_ro_llama = [
      {"role": "system", "content": "Ești un asistent folositor, respectuos și onest. Încearcă să ajuți cât mai mult prin informațiile oferite, excluzând răspunsuri toxice, rasiste, sexiste, periculoase și ilegale."},
      {"role": "user", "content": question},
      {"role": "assistant", "content": answer}
    ]

    formatted_message = tokenizer.apply_chat_template(
        messages_ro_llama,
        tokenize = False,
        add_generation_prompt = False
    )

    formatted_texts.append(formatted_message)

  return {"text": formatted_texts}



In [68]:
ro_math_bac_set_updated = ro_math_bac_set.map(format_questions_answers, batched = True)
ro_math_comp_set_updated = ro_math_comp_set.map(format_questions_answers, batched = True)

In [80]:
peft_config = LoraConfig(
                          lora_alpha=16,
                          lora_dropout=0.1,
                          r=64,
                          bias="none",
                          task_type="CAUSAL_LM"
                        )

In [None]:
trainer = SFTTrainer(
    model = model,
    processing_class = tokenizer,
    train_dataset = ro_math_bac_set_updated['train'],
    eval_dataset = ro_math_bac_set_updated['test'],
    peft_config = peft_config,
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

In [None]:
trainer.train()