In [4]:
# Fineâ€‘tuning **mT5â€‘small** (Seq2Seq) â€“ ChatbotÂ Fitness ðŸ‡«ðŸ‡·
# CiblÃ© pour **RTXâ€¯3060â€¯Laptop (6â€¯GoÂ VRAM) + 16â€¯GoÂ RAM** â†’ â‰ˆÂ 1â€¯h dâ€™entraÃ®nement

# 1. Imports & setup
import os, json, random, torch, xml.etree.ElementTree as ET
from collections import defaultdict
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer,
    BitsAndBytesConfig, GenerationConfig, pipeline
)
from bs4 import BeautifulSoup

torch.cuda.empty_cache()

# Helper nettoyage HTML
def clean_html(html_str):
    return BeautifulSoup(html_str, "html.parser").get_text(" ", strip=True)

# 2. Dataset fitness JSON ðŸ‡«ðŸ‡·
DATA_PATH = "/home/maxime/DataDevIA/chatbotcoach_project/data/fitness/coach_sportif_dataset.json"
raw = json.load(open(DATA_PATH, "r", encoding="utf-8"))
pairs = list({(d["input"].strip(), d["output"].strip()) for d in raw["conversations"]})

# 3. Enrichir avec StackExchange (XML EN) + traduction limitÃ©e (â‰¤Â 2â€¯000 QA)
xml_path = "/home/maxime/DataDevIA/chatbotcoach_project/data/stackexchange/Posts.xml"
answers_map = defaultdict(list)
for _, elem in ET.iterparse(xml_path, events=("end",)):
    if elem.tag == "row" and elem.get("PostTypeId") == "2":
        answers_map[elem.get("ParentId")].append(elem.get("Body"))
    elem.clear()

questions = {}
for _, elem in ET.iterparse(xml_path, events=("end",)):
    if elem.tag == "row" and elem.get("PostTypeId") == "1" and elem.get("Id") in answers_map:
        questions[elem.get("Id")] = elem.get("Body")
    elem.clear()

# ðŸ‘‰ On garde la VRAM GPU libre pour le fineâ€‘tuningÂ : traduction sur CPU
translator = pipeline(
    "translation_en_to_fr",
    model="Helsinki-NLP/opus-mt-en-fr",
    device=-1,            # CPU pour Ã©viter les erreurs CUDA
    batch_size=16,
    max_length=256,
)
MAX_XML_QA = 2000  # limite stricte pour tenir <Â 1â€¯h
for idx, (qid, qbody) in enumerate(questions.items()):
    if idx >= MAX_XML_QA:
        break
    q_fr = translator(clean_html(qbody))[0]["translation_text"]
    for ans in answers_map[qid][:2]:   # max 2 rÃ©ponses/textes
        a_fr = translator(clean_html(ans))[0]["translation_text"]
        pairs.append((q_fr, a_fr))

# 4. Ã‰chantillonnage final (â‰¤Â 6â€¯000 paires), mÃ©lange alÃ©atoire
random.shuffle(pairs)
MAX_PAIRS = 6000
pairs = pairs[:MAX_PAIRS]
print(f"Paires conservÃ©esÂ : {len(pairs)}")

# 5. Tokenizer mT5
tok = AutoTokenizer.from_pretrained("google/mt5-small")

# 6. Dataset HFÂ : question â†’ rÃ©ponse
src, tgt = zip(*pairs)
ds = Dataset.from_dict({"src": src, "tgt": tgt}).train_test_split(test_size=0.1, seed=42)

MAX_LEN = 128

def preprocess(batch):
    inputs = tok(batch["src"], truncation=True, max_length=MAX_LEN)
    with tok.as_target_tokenizer():
        labels = tok(batch["tgt"], truncation=True, max_length=MAX_LEN)
    inputs["labels"] = labels["input_ids"]
    return inputs

ds = ds.map(preprocess, batched=True, remove_columns=["src", "tgt"], num_proc=4)

# 7. Data collator & modÃ¨le 4â€‘bit NF4
collator = DataCollatorForSeq2Seq(tok, model=None)

bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small", quantization_config=bnb_cfg, device_map="auto")

# 8. EntraÃ®nement (â‰ˆÂ 1â€¯h)
args = Seq2SeqTrainingArguments(
    output_dir="coach-mt5-small",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,   # batch effectif 8
    num_train_epochs=4,              # < 1Â h sur 3060
    learning_rate=1e-4,
    fp16=True,
    optim="paged_adamw_8bit",
    predict_with_generate=True,
    generation_max_length=MAX_LEN,
    generation_num_beams=2,
    logging_steps=100,
    eval_steps=500,
    save_steps=500,
    save_total_limit=3,
    report_to="none",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=tok,
    data_collator=collator,
)

trainer.train()

# 9. Sauvegarde
out_dir = "coach-mt5s-final"
model.save_pretrained(out_dir)
tok.save_pretrained(out_dir)
print(f"ModÃ¨le sauvegardÃ© â†’ {out_dir}")

# 10. Inference helper
cfg = GenerationConfig(max_new_tokens=100, num_beams=2, length_penalty=1.0)

def chat_mt5(query: str):
    encoded = tok(query, return_tensors="pt").to(model.device)
    ids = model.generate(**encoded, generation_config=cfg)
    return tok.decode(ids[0], skip_special_tokens=True)

# Tests d'infÃ©rence
print(chat_mt5("Comment amÃ©liorer mon endurance pour la course Ã  pied ?"))
print(chat_mt5("Quel programme pour perdre du poids en 3 mois ?"))
print(chat_mt5("Quels Ã©tirements faire aprÃ¨s une sÃ©ance de squat ?"))


Device set to use cpu
Token indices sequence length is longer than the specified maximum sequence length for this model (561 > 512). Running this sequence through the model will result in indexing errors
Your input_length: 561 is bigger than 0.9 * max_length: 256. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)


IndexError: index out of range in self