In [1]:
import gc, torch

for name in ("model", "trainer", "outputs", "batch"):
    if name in globals():
        del globals()[name]

gc.collect()
torch.cuda.empty_cache()



In [2]:
import math, torch, warnings
from datasets import load_from_disk
from transformers import EvalPrediction
warnings.filterwarnings("ignore", category=UserWarning)

def compute_perplexity(pred: EvalPrediction):
    # pred.predictions = logits (float16 -> float32 via numpy)
    logits, labels = pred.predictions, pred.label_ids
    vocab = logits.shape[-1]
    import torch.nn.functional as F
    loss = F.cross_entropy(
        torch.from_numpy(logits).view(-1, vocab),
        torch.from_numpy(labels).view(-1),
        ignore_index=-100,
        reduction="mean",
    )
    return {"perplexity": math.exp(loss)}


In [3]:
ds_path = "data/stackexchange/translated_dataset_fr"   # <— votre dossier
ds      = load_from_disk(ds_path)

tmp     = ds.train_test_split(test_size=0.1, seed=42)
test_ds = tmp["test"]
tmp2    = tmp["train"].train_test_split(test_size=0.1, seed=42)
train_ds, val_ds = tmp2["train"], tmp2["test"]

print(len(train_ds), "train |", len(val_ds), "val |", len(test_ds), "test")


3855 train | 429 val | 477 test


In [4]:
from transformers import (
    MT5Tokenizer, MT5ForConditionalGeneration, BitsAndBytesConfig,
    DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

tok = MT5Tokenizer.from_pretrained("google/mt5-small", model_max_length=512)

# 3) Charger la base en 8-bit + fp16 sur **UN** GPU
bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
base = MT5ForConditionalGeneration.from_pretrained(
    "google/mt5-small",
    quantization_config=bnb_cfg,
    torch_dtype=torch.float16,
    device_map={"": 0},      # <= ❗️ un seul GPU, plus de NoneType
)
base.config.use_cache = False
base.gradient_checkpointing_enable()


# 3-b : passage en mode k-bit & ajout LoRA
base = prepare_model_for_kbit_training(base)
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q", "v"],           # matrices clés/valeurs des attn MT5
    task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(base, lora_cfg)
model.print_trainable_parameters()       # ~6 M params ↗


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.
You are using the default legacy behaviour of the <class 'transformers.models.mt5.tokenization_mt5.MT5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


trainable params: 688,128 || all params: 300,864,896 || trainable%: 0.2287


In [5]:
MAX_IN, MAX_OUT = 128, 64

def preprocess(b):
    src = ["question: " + q for q in b["q_fr"]]
    tgt = ["answer: "   + a for a in b["a_fr"]]
    model_in  = tok(src, padding="max_length", truncation=True, max_length=MAX_IN)
    out       = tok(tgt, padding="max_length", truncation=True, max_length=MAX_OUT)
    model_in["labels"] = out["input_ids"]
    return model_in

train_ds = train_ds.map(preprocess, batched=True, remove_columns=["q_fr", "a_fr"])
val_ds   = val_ds.map(preprocess,   batched=True, remove_columns=["q_fr", "a_fr"])
test_ds  = test_ds.map(preprocess,  batched=True, remove_columns=["q_fr", "a_fr"])


In [6]:
steps_per_epoch = len(train_ds) // 8          # batch=1 × accum=8
collator        = DataCollatorForSeq2Seq(tok, model, label_pad_token_id=-100)

args = Seq2SeqTrainingArguments(
    output_dir="mt5_fitness_ckpt",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    fp16=True,
    num_train_epochs=4,
    learning_rate=5e-5,
    eval_steps=steps_per_epoch,
    save_steps=steps_per_epoch,
    save_total_limit=2,
    logging_steps=50,
    report_to=[],                      # désactive wandb / tensorboard
    remove_unused_columns=False,       # indispensable avec PEFT
    skip_memory_metrics=True,
)

torch.cuda.empty_cache()               # libère la VRAM avant d’allouer le modèle

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator,
    tokenizer=tok,
    compute_metrics=compute_perplexity,
)

trainer.train()


  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
50,67280414.72
100,1450947215.36
150,307461.86
200,28853550448.64
250,236657.62
300,65391.205
350,1553.415
400,78456340.48
450,70560461619.2
500,104313.6


TrainOutput(global_step=1928, training_loss=3926859816.551582, metrics={'train_runtime': 3316.7136, 'train_samples_per_second': 4.649, 'train_steps_per_second': 0.581, 'total_flos': 2046479255470080.0, 'train_loss': 3926859816.551582, 'epoch': 4.0})

In [9]:
import torch
from transformers import MT5Tokenizer, MT5ForConditionalGeneration

# 1. Charge ton tokenizer et ton modèle fine-tuné
checkpoint = "mt5_fitness_ckpt"  # ton dossier de checkpoint
tok = MT5Tokenizer.from_pretrained(checkpoint)
model = MT5ForConditionalGeneration.from_pretrained(checkpoint).to("cuda")

# 2. Prépare une fonction d’inférence
def ask(question: str, 
        max_in: int = 192, 
        max_out: int = 128, 
        num_beams: int = 4):
    # Prefixe “question:” comme à l’entraînement
    inp = tok(
        "question: " + question,
        return_tensors="pt",
        truncation=True,
        padding="longest",
        max_length=max_in
    ).to(model.device)
    # Génère la réponse
    out = model.generate(
        **inp,
        max_length=max_out,
        num_beams=num_beams,
        early_stopping=True
    )
    return tok.decode(out[0], skip_special_tokens=True)

# 3. Teste quelques questions
for q in [
    "Comment améliorer mon endurance pour la course à pied ?",
    "Quel programme pour perdre du poids en 3 mois ?",
    "Quels étirements faire après une séance de squat ?"
]:
    print(f"> {q}\n→ {ask(q)}\n")


> Comment améliorer mon endurance pour la course à pied ?
→ <extra_id_0>.

> Quel programme pour perdre du poids en 3 mois ?
→ <extra_id_0>.

> Quels étirements faire après une séance de squat ?
→ <extra_id_0>

