<a href="https://colab.research.google.com/github/GiliardGodoi/distillation-small-language-model/blob/main/notebooks/003_EXP_Custom_Callback_Penalty.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [2]:
%%capture
!pip install evaluate bert_score

In [3]:
import datasets

datasets.__version__

'4.0.0'

In [4]:
!nvidia-smi

Mon Jul 21 16:56:09 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   45C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Unsloth

In [5]:
import os
import re
from collections import defaultdict
from dataclasses import dataclass, field

In [6]:
from unsloth import FastLanguageModel
from unsloth.chat_templates import train_on_responses_only

max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",      # Llama-3.1 2x faster
    "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    "unsloth/Meta-Llama-3.1-70B-bnb-4bit",
    "unsloth/Meta-Llama-3.1-405B-bnb-4bit",    # 4bit for 405b!
    "unsloth/Mistral-Small-Instruct-2409",     # Mistral 22b 2x faster!
    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    "unsloth/Phi-3.5-mini-instruct",           # Phi-3.5 2x faster!
    "unsloth/Phi-3-medium-4k-instruct",
    "unsloth/gemma-2-9b-bnb-4bit",
    "unsloth/gemma-2-27b-bnb-4bit",            # Gemma 2x faster!

    "unsloth/Llama-3.2-1B-bnb-4bit",           # NEW! Llama 3.2 models
    "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
    "unsloth/Llama-3.2-3B-bnb-4bit",
    "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",

    "unsloth/Llama-3.3-70B-Instruct-bnb-4bit" # NEW! Llama 3.3 70B!
] # More models at https://huggingface.co/unsloth

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [7]:
from transformers import (
            DataCollatorForSeq2Seq,
            TrainerCallback
)
from trl import SFTConfig, SFTTrainer
import torch
import torch.nn.functional as F
print(f"{torch.__version__=}")

torch.__version__='2.6.0+cu124'


In [8]:
from evaluate import load

bertscore = load("bertscore")

bertscore.download_and_prepare()

# somente um teste para ele baixar o modelo.
bertscore.compute(
    predictions=["hello there", "general kenobi"],
    references=["hello there", "general kenobi"],
    lang="en",
)

Downloading builder script: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'precision': [1.0, 0.9999999403953552],
 'recall': [1.0, 0.9999999403953552],
 'f1': [1.0, 0.9999999403953552],
 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.53.2)'}

In [9]:
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-Instruct", # or choose "unsloth/Llama-3.2-1B-Instruct"
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)
tokenizer.padding_side = 'left'
print('\n')
print(f"{tokenizer.padding_side=}")

==((====))==  Unsloth 2025.7.6: Fast Llama patching. Transformers: 4.53.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.35G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

Unsloth 2025.7.6 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.




tokenizer.padding_side='left'


In [17]:
# @title Loading and preprocessing the dataset
from datasets import load_dataset

# ds = load_dataset("mlabonne/FineTome-100k", split = "train")
ds = load_dataset("vitormesaque/irisk")

print(ds)

column_names = ds['train'].column_names

# column_names

from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

def convert_output_field_to_dict(example):
    out = example['output']
    try :
        out = eval(out)
        return {'value' : out }
    except Exception as e:
        return {'error' : str(e)}

def convert_alpaca_style_to_chatml(example):

    conversation = [
        {'role' : 'system', 'content' : example['instruction']},
        {'role' : 'user', 'content' : example['input']},
        {'role' : 'assistant', 'content' : example['output']},
    ]
    return { 'conversations' : conversation }

def format_using_chat_template(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts }


def format_for_generation_task(example):
    return tokenizer(example['text'], truncation=True, padding=True, return_tensors='pt')

ds = ds['train'].map(convert_output_field_to_dict)
ds = ds.filter(lambda v : ('error' not in v) and (len(v['value']['issues']) == 1) )
ds = ds.map(convert_alpaca_style_to_chatml)
ds = ds.map(format_using_chat_template, batched = True, remove_columns = column_names)

ds = ds.train_test_split(test_size=0.1)

ds_train = ds['train']
ds_val = ds['test'].map(format_for_generation_task)

print('Training:')
print(ds_train)

print('Validation')
print(ds_val)
# tokenizer.padding_side = 'left'
print('\n\n')
print(f"{tokenizer.padding_side=}")

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 3557
    })
})


Map:   0%|          | 0/3557 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3557 [00:00<?, ? examples/s]

Map:   0%|          | 0/2427 [00:00<?, ? examples/s]

Map:   0%|          | 0/2427 [00:00<?, ? examples/s]

Map:   0%|          | 0/243 [00:00<?, ? examples/s]

Training:
Dataset({
    features: ['value', 'conversations', 'text'],
    num_rows: 2184
})
Validation
Dataset({
    features: ['value', 'conversations', 'text', 'input_ids', 'attention_mask'],
    num_rows: 243
})



tokenizer.padding_side='left'


## Custom classes

**IMPORTANTE!**.

Rever como o dataset está sendo configurado.
1. Os input_ids possuem o prompt de entrada e a resposta esperada do modelo, formatada de acordo com o `chat template`
2. Os labels possuem a máscara indicada pelo valor -100, para ignorar o prompt e calcular a loss em cima da resposta esperada.  

Da forma como eu estou fazendo, eu estou passando o prompt com a resposta esperada para o model generat. Mas isso não pode induzir o modelo a responder de forma equivocada?

In [14]:
@dataclass
class CustomSFTConfig(SFTConfig):
    start_after_n_steps: int = field(default = 0)
    penalization_step_interval: int = field(default = 1)

class CustomSFTTrainer(SFTTrainer):

    def __init__(self, *args,  **kwargs):
        super().__init__(*args, **kwargs)
        self.penalty_cache = dict()
        # self.current_batch = None

    def compute_loss(self, model, inputs, return_outputs = False, **kwargs):

        standard_loss, outputs = super().compute_loss(model, inputs, return_outputs = True, **kwargs)
        # self.current_batch = inputs

        print('input_ids', (inputs.get('input_ids')).shape)
        print('labels', (inputs.get('labels')).shape)
        print('attention_mask', (inputs.get('attention_mask')).shape)

        if 'last' in self.penalty_cache \
            and self.penalty_cache['last'] is not None \
            and (self.state.global_step >= self.args.start_after_n_steps) \
            and (self.state.global_step % self.args.penalization_step_interval == 0):

            penalty = self.penalty_cache['last']
            final_loss = standard_loss + penalty
            print(f"{self.state.global_step:^7} | {standard_loss.item():^10.5f} | {penalty.item():^15.10f} | {final_loss.item():^15.10f}")
        else:
            final_loss = standard_loss

        return (final_loss, outputs) if return_outputs else final_loss

In [15]:
import re
import numpy as np
import json
from json import JSONDecodeError

In [None]:
# https://huggingface.co/docs/transformers/main_classes/callback#transformers.TrainerCallback

from torch.utils.data import DataLoader


class PenaltyCallback(TrainerCallback):

    def __init__(self, trainer, *args,
                 scorer=bertscore,
                 alpha_scale=1.0,
                 ds_val=ds_val, **kwargs
                ):
        super().__init__(*args, **kwargs)
        self.alpha_scale = alpha_scale
        self.trainer_ref = trainer
        self.space_id = trainer.processing_class(" ", add_special_tokens = False).input_ids[0]
        self.scorer = scorer
        self.dataloader = DataLoader(ds_val, batch_size=2, suffle=True)

        self.placeholder = '0 ' * 1_000

    def on_step_end(self, args, state, control, model, processing_class, **kwargs):

        device = model.device
        batch = getattr(self.trainer_ref, 'current_batch', None)
        if batch is None:
            self.trainer_ref.penalty_cache['last'] = None
            print('Não foi possível conhecer o batch atual')
            return None

        if batch is not None \
            and (state.global_step >= args.start_after_n_steps) \
            and (state.global_step % args.penalization_step_interval == 0):

            input_ids = batch['input_ids']
            labels = batch['labels']
            attention_mask = batch['attention_mask']

            # mask = labels == -100
            # eos_token_id = processing_class.eos_token_id
            # input_ids_masked = torch.where(mask, input_ids, eos_token_id )
            # attention_mask_masked = torch.where(mask, 1,0)
            # response_ids = self._generate_text(input_ids_masked, attention_mask_masked)
            # input_decoded = processing_class.batch_decode(input_ids_masked, skip_special_tokens=False)
            # print('Input:\n', input_decoded)

            response_ids = self._generate_text(input_ids, attention_mask)
            # input_decoded = processing_class.batch_decode(input_ids, skip_special_tokens=False)
            # print('Input:\n', input_decoded)

            skip_special_tokens = False
            generated_response = processing_class.batch_decode(response_ids, skip_special_tokens=skip_special_tokens)
            print('generated response:\n', generated_response)

            generated_jsons = list()
            for response in generated_response:
                response = self._basic_preprocessing(response)
                items = [ self._normalize_json(item) for item in self._find_json(response)]
                generated_jsons.append(
                    [item for item in items ]
                )

            labels = batch.get('labels')
            labels = torch.where(labels == -100, self.space_id, labels)
            labels = processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)

            penalty = 0.0

            for i, (label, data) in enumerate(zip(labels, generated_jsons), start=1):
                # label is dict for this dataset
                label = self._convert_label(label)
                # whereas item would be a list of issues
                # data = [i for i in item if bool(i)] # it's not empty
                data = [ d for d in data if 'functionality' in d]
                if len(data) == 0 :
                    predictions = self.placeholder
                elif len(data) >= 2:
                    pred = data[0] # could be random or other strategy
                    predictions = pred.get('functionality', self.placeholder)
                else : # len(data) == 1
                    pred = data[0]
                    predictions = pred.get('functionality', self.placeholder)

                predictions = predictions.strip()
                references = label.get('functionality', '0')
                print('PRED', predictions)
                print('REF', references)
                # https://huggingface.co/docs/evaluate/a_quick_tour#calculate-a-single-metric-or-a-batch-of-metrics
                self.scorer.add(predictions=predictions, references=references)

            result = self.scorer.compute(lang='en')
            f1 = np.mean(result.get('f1'))
            penalty = self.alpha_scale * (1 - f1)

            self.trainer_ref.penalty_cache['last'] = torch.tensor(penalty, device = device, requires_grad=False)

    def _generate_text(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = model.generate(
                input_ids = input_ids,
                attention_mask = attention_mask,
                max_new_tokens = 100,
                # use_cache = False,
                temperature = 1.5, min_p = 0.1
            )
        return outputs

    def _normalize_json(self, review : dict, placeholder = dict()) -> list :
        match review:
            case [ {"review" : str(), "issues" : [*issues] } ]:
                return issues
            case {"review" : str(), "issues" : [dict(issues)] } :
                return issues
            case [ {"review" : str(), "issues" : dict(issues)} ] :
                return issues
            case {"review" : str(), "issues" : dict(issues)}:
                return issues
            case _ :
                return placeholder

    def _basic_preprocessing(self, response : str) -> list:
        if not isinstance(response, str):
            raise ValueError(f"Expected str, got {type(response)}")

        assistant_head = '<|start_header_id|>assistant<|end_header_id|>'
        extra_pad = re.compile(r'(<\|finetune_right_pad_id\|>)+')
        eot_id = re.compile(r'<\|eot_id\|>')

        response = re.sub(eot_id, '', response)
        response = re.sub(extra_pad, assistant_head, response)
        response = response.split(assistant_head)[1:]

        return response

    def _find_json(self, response, curly_brackets = re.compile(r'{.*}') ) -> list:

        json_strings = list()
        for rep in response:
            if _match := re.findall(curly_brackets, rep):
                json_strings.extend(_match)

        data = list()
        for item in json_strings:
            try :
                obj = json.loads(item)
                data.append(obj)
            except json.JSONDecodeError as e :
                # https://docs.python.org/3/library/json.html#exceptions
                # print(e.msg, e.pos, e.colno, e.lineno, e.doc)
                data.append({
                    'error': e.msg,
                    'pos': e.pos,
                    'colno': e.colno,
                    'lineno' : e.lineno,
                    'doc': e.doc
                })

        return data

    def _convert_label(self, label : str) -> dict:
        eot_id = re.compile(r'<\|eot_id\|>')
        label = re.sub(eot_id, '', label)
        label = label.strip()

        try :
            label = json.loads(label)
            label = self._normalize_json(label)
        except json.JSONDecodeError as e :
            # https://docs.python.org/3/library/json.html#exceptions
            print(e.msg, e.pos, e.colno, e.lineno, e.doc)
            label = {'functionality' : '0'}

        return label

    def _penalty_fn(self, label: dict, predictions: list, scorer=bertscore) -> float:
        pass

In [None]:
trainer = CustomSFTTrainer(
    model = model,
    processing_class = tokenizer,
    train_dataset = ds_train,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = CustomSFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 2,
        warmup_steps = 5,
        num_train_epochs = 1, # Set this for 1 full training run.
        # max_steps = 60,
        learning_rate = 2e-4,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 8080,
        output_dir = "outputs",
        report_to = "none", # Use this for WandB etc
        remove_unused_columns=True,
        start_after_n_steps = 20,
        penalization_step_interval = 4,
        bf16=False, # Disable bf16 training
    )
)

callback = PenaltyCallback(
    trainer = trainer,
    ds_val=ds_val
)

trainer.add_callback(callback)

# from unsloth.chat_templates import train_on_responses_only

trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
# tokenizer.decode(trainer.train_dataset[5]['input_ids'])
trainer.train_dataset
# tokenizer.padding_side = 'left'
print('\n')
print(f">> {tokenizer.padding_side=}")

Unsloth: Tokenizing ["text"]:   0%|          | 0/2427 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/2427 [00:00<?, ? examples/s]



>> tokenizer.padding_side='right'


In [None]:
trainer.train_dataset

Dataset({
    features: ['value', 'conversations', 'text', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 2427
})

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")
print(f'{tokenizer.padding_side=}')
print('\n')

trainer_stats = trainer.train()