In [None]:
import torch
import re

from transformers import AutoModelForCausalLM

from peft import LoraConfig, get_peft_model
from math_verify import LatexExtractionConfig, parse, verify


In [None]:
#TODO import du dataset
dataset = {}


In [None]:
SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user provides a text, and the Assistant classifies it "
    "according to one or more of the 17 Sustainable Development Goals (SDGs). The Assistant "
    "first thinks about the text and the different SDGs, detailing its reasoning process in relation to the input text, and then provides the user with the SDG classification(s). "
    "The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively. "
    "The Assistant must identify the most relevant SDG(s) for the given text. If multiple SDGs are relevant, they can all be listed. The reasoning should clearly justify the choice(s).\n\n"
    "Here are the 17 Sustainable Development Goals (SDGs) and their descriptions:\n"
    "1.  **SDG 1: No Poverty:** End poverty in all its forms everywhere.\n"
    "2.  **SDG 2: Zero Hunger:** End hunger, achieve food security and improved nutrition and promote sustainable agriculture.\n"
    "3.  **SDG 3: Good Health and Well-being:** Ensure healthy lives and promote well-being for all at all ages.\n"
    "4.  **SDG 4: Quality Education:** Ensure inclusive and equitable quality education and promote lifelong learning opportunities for all.\n"
    "5.  **SDG 5: Gender Equality:** Achieve gender equality and empower all women and girls.\n"
    "6.  **SDG 6: Clean Water and Sanitation:** Ensure availability and sustainable management of water and sanitation for all.\n"
    "7.  **SDG 7: Affordable and Clean Energy:** Ensure access to affordable, reliable, sustainable and modern energy for all.\n"
    "8.  **SDG 8: Decent Work and Economic Growth:** Promote sustained, inclusive and sustainable economic growth, full and productive employment and decent work for all.\n"
    "9.  **SDG 9: Industry, Innovation and Infrastructure:** Build resilient infrastructure, promote inclusive and sustainable industrialization and foster innovation.\n"
    "10. **SDG 10: Reduced Inequalities:** Reduce inequality within and among countries.\n"
    "11. **SDG 11: Sustainable Cities and Communities:** Make cities and human settlements inclusive, safe, resilient and sustainable.\n"
    "12. **SDG 12: Responsible Consumption and Production:** Ensure sustainable consumption and production patterns.\n"
    "13. **SDG 13: Climate Action:** Take urgent action to combat climate change and its impacts.\n"
    "14. **SDG 14: Life Below Water:** Conserve and sustainably use the oceans, seas and marine resources for sustainable development.\n"
    "15. **SDG 15: Life on Land:** Protect, restore and promote sustainable use of terrestrial ecosystems, sustainably manage forests, combat desertification, and halt and reverse land degradation and halt biodiversity loss.\n"
    "16. **SDG 16: Peace, Justice and Strong Institutions:** Promote peaceful and inclusive societies for sustainable development, provide access to justice for all and build effective, accountable and inclusive institutions at all levels.\n"
    "17. **SDG 17: Partnerships for the Goals:** Strengthen the means of implementation and revitalize the global partnership for sustainable development."
)

def make_conversation(example):
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": example["problem"]},
        ],
    }


dataset = dataset.map(make_conversation)


In [None]:
model_id = "Qwen/Qwen3-0.6B"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
)

In [None]:
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

In [None]:
import re

# Liste des identifiants ODD valides (la partie principale)
# Nous nous attendons à ce que la balise <answer> contienne quelque chose comme "SDG X" ou "SDG X: Nom de l'ODD"
VALID_SDG_IDENTIFIERS = [f"SDG {i}" for i in range(1, 18)] # Crée une liste: ["SDG 1", "SDG 2", ..., "SDG 17"]

def format_sdg_reward(completions, **kwargs):
    """
    Fonction de récompense qui vérifie si la complétion :
    1. Suit le format <think>...</think><answer>...</answer>.
    2. Si le format est correct, vérifie que le contenu de la balise <answer>
       contient un identifiant ODD valide (ex: "SDG 1", "SDG 1: No Poverty").
    Retourne 1.0 si les deux conditions sont remplies, 0.0 sinon.
    """
    # Regex pour la structure globale et pour capturer le contenu des balises <think> et <answer>
    # (.*?): Groupe de capture non-gourmand pour le contenu des balises
    # re.DOTALL: Permet au '.' de correspondre aux sauts de ligne, au cas où le contenu des balises serait multiligne
    pattern = r"^<think>(.*?)</think>\s*<answer>(.*?)</answer>$"
    
    rewards_list = []
    
    # 'completions' est attendu comme une liste de "groupes de complétion".
    # Chaque "groupe de complétion" est une liste où le premier élément est un dict
    # contenant la clé "content" avec la chaîne de caractères générée par le LLM.
    # Exemple: completions = [[{"role": "assistant", "content": "<think>...</think><answer>SDG 1</answer>"}]]
    
    for completion_group in completions:
        if not completion_group or \
           not isinstance(completion_group, list) or \
           len(completion_group) == 0 or \
           not isinstance(completion_group[0], dict) or \
           "content" not in completion_group[0]:
            rewards_list.append(0.0) # Structure d'entrée invalide pour cette complétion
            continue
            
        content = completion_group[0]["content"]
        
        match = re.match(pattern, content, re.DOTALL)
        
        if match:
            # Le format <think>...</think><answer>...</answer> est correct
            answer_content = match.group(2).strip() # Extrait le contenu de la balise <answer>
            
            # Vérifie si le contenu de <answer> contient un identifiant ODD valide
            is_sdg_mention_valid = False
            for sfg_id in VALID_SDG_IDENTIFIERS:
                # On cherche une occurrence de "SDG X" où X est le numéro.
                # \b assure que "SDG 1" est un mot distinct et ne correspond pas dans "SDG 11" par erreur.
                # (?![0-9]) assure que "SDG 1" ne correspond pas à la partie "SDG 1" de "SDG 10", "SDG 11", etc.
                #         permettant à "SDG 10" d'être correctement identifié par "SDG 10".
                if re.search(r"\b" + re.escape(sfg_id) + r"(?![0-9])", answer_content):
                    is_sdg_mention_valid = True
                    break # Un ODD valide a été trouvé
            
            if is_sdg_mention_valid:
                rewards_list.append(1.0) # Format correct ET ODD valide dans la réponse
            else:
                # Format correct, mais l'ODD dans <answer> n'est pas valide ou n'est pas trouvé
                rewards_list.append(0.0) # Pourrait être une petite récompense partielle (ex: 0.1) si désiré
        else:
            # Le format <think>...</think><answer>...</answer> est incorrect
            rewards_list.append(0.0)
            
    return rewards_list

In [None]:
def accuracy_reward(completions, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    solutions = kwargs["solution"]
    completion_contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, solution in zip(completion_contents, solutions):
        gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
        answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
        if len(gold_parsed) != 0:
            try:
                rewards.append(float(verify(answer_parsed, gold_parsed)))
            except Exception:
                rewards.append(0.0)
        else:
            rewards.append(1.0)
    return rewards