In [10]:
import torch
import json
from rich.console import Console
from rich.table import Table

from finetuning_handler import QLoRAModelHandler
from utils.dynamic_weigthed_loss import dwl_loss
from utils.dataset_handling import DatasetPreparer
from utils.dynamic_weigthed_loss import CustomDWLTrainer


In [2]:
console = Console()

with open("config/config_cot7_soap_dwl_dft.json", "r") as f:
    config = json.load(f)

In [4]:
handler = QLoRAModelHandler(
    **config.get("model_params", {}),
    **config.get("dataset_params", {}),
)

Output()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
preparer = DatasetPreparer(
    handler.tokenizer, 
    handler.text_field, 
    handler.cot_fields, 
    handler.label_field,
    seed=42,
    oversampling=False
)

In [8]:
eval_ds = preparer.prepare(handler.dataset_path, split_name="test")

Formatting dataset:   0%|          | 0/1388 [00:00<?, ? examples/s]

In [9]:
eval_ds

Dataset({
    features: ['prompt', 'completion'],
    num_rows: 1388
})

In [16]:
peft_kwargs = config.get("peft_params", {})
training_kwargs = config.get("training_params", {})

peft_config = handler._build_lora_config(peft_kwargs)
sft_config = handler._build_sft_config(training_kwargs)
optimizer = handler._setup_optimizer(sft_config.learning_rate, sft_config.weight_decay)

In [12]:
handler.weight_schedule = {"epoch": 0, "alphas": [1.0] * (len(handler.cot_fields) + 1), 
                           "epoch": 3, "alphas": [0.0] + [1.0] * len(handler.cot_fields)}

In [17]:
handler.trainer = CustomDWLTrainer(
            model=handler.model,
            train_dataset=eval_ds,
            eval_dataset=eval_ds,
            args=None,
            peft_config=peft_config,
            processing_class=handler.tokenizer,
            optimizers=(optimizer, None),
            think_start_id=handler.tokenizer.convert_tokens_to_ids("<think>"),
            think_end_id=handler.tokenizer.convert_tokens_to_ids("</think>"),
            weight_schedule=handler.weight_schedule,
            DFT=False
        )


[2025-12-18 16:28:38,453] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [28]:
sample_idx = 0
console = Console()
example = eval_ds[sample_idx]
full_text = example['prompt'] + example['completion']

encoding = handler.tokenizer(
    full_text, 
    return_tensors="pt", 
    add_special_tokens=False
).to(handler.device)

input_ids = encoding['input_ids']
labels = input_ids.clone() 
shift_labels = labels[..., 1:].contiguous()

start_id = handler.tokenizer.convert_tokens_to_ids("<think>")
end_id = handler.tokenizer.convert_tokens_to_ids("</think>")

if start_id is None or end_id is None:
    console.print("[bold red]Erreur :[/bold red] Les tokens <think> ou </think> ne sont pas dans le vocabulaire.")
    
starts = (shift_labels == start_id).long()
ends = (shift_labels == end_id).long()

starts_cum = starts.cumsum(dim=1)
ends_cum = ends.cumsum(dim=1)

ends_cum_shifted = ends_cum.roll(shifts=1, dims=1)
ends_cum_shifted[:, 0] = 0 

in_block_mask = (starts_cum > ends_cum_shifted)
class_mask = (starts_cum * in_block_mask.long())[0]

tokens_ids = shift_labels[0]

from rich.text import Text

from rich.text import Text
from rich.panel import Panel

# 1. Définition d'une palette de 10 couleurs distinctes (Rich standard ou Hex)
# On évite le blanc qui est réservé au texte hors-bloc
COLOR_PALETTE = [
    "cyan", "green", "yellow", "magenta", "orange1", 
    "spring_green3", "deep_sky_blue1", "purple3", "gold1", "hot_pink"
]

rich_text = Text()

# 2. Boucle de génération du texte coloré
for i, t_id in enumerate(tokens_ids):
    token_str = handler.tokenizer.decode([t_id])
    mask_val = class_mask[i].item()
    
    if mask_val == 0:
        # Texte normal : Blanc cassé pour lisibilité
        rich_text.append(token_str, style="white")
    else:
        # Texte dans un bloc <think> : 
        # On utilise l'index du bloc (mask_val) pour choisir la couleur
        # Le modulo % len(COLOR_PALETTE) permet de boucler si > 10 blocs
        color_idx = (mask_val - 1) % len(COLOR_PALETTE)
        chosen_color = COLOR_PALETTE[color_idx]
        
        # On peut ajouter un style différent (gras) pour bien marquer le bloc
        rich_text.append(token_str, style=f"bold {chosen_color}")

# 3. Affichage stylisé
console.print(Panel(
    rich_text, 
    title=f"Segmented DWL View (Sample {sample_idx})", 
    subtitle="White: Normal | Colored: Thought Blocks",
    expand=True
))

# 4. Légende pour les blocs
num_blocks = starts.sum().item()
legend = Text("\nBlocks detected: ", style="bold")
for b in range(1, int(num_blocks) + 1):
    c = COLOR_PALETTE[(b-1) % len(COLOR_PALETTE)]
    legend.append(f" [Block {b}] ", style=f"reverse {c}")

console.print(legend)
console.print(f"- Total tokens: {len(tokens_ids)}")
console.print(f"- Tokens in thought blocks: {in_block_mask.sum().item()}")