In [55]:
import torch
import re
from typing import List, Tuple
from transformers import AutoModelForCausalLM
from datasets import Dataset 

from peft import LoraConfig, get_peft_model

In [56]:
import json

# Load the labeled dataset (solutions)
labeled_items = []
with open("../src/ai/testsets/dataset_labeled.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        if line.strip(): # Make sure the line is not empty
            labeled_items.append(json.loads(line.strip()))

# Load the descriptions dataset
description_items = []
with open("../src/ai/testsets/dataset.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        if line.strip(): # Make sure the line is not empty
            description_items.append(json.loads(line.strip()))

# Load the prompt template
with open("sdg_label_prompt.md", "r", encoding="utf-8") as f:
    prompt_template = f.read()

print(f"Loaded {len(labeled_items)} items from dataset_labeled.jsonl")
print(f"Loaded {len(description_items)} items from dataset.jsonl")

Loaded 1326 items from dataset_labeled.jsonl
Loaded 3000 items from dataset.jsonl


In [57]:
prepared_data = []

# Assuming labeled_items and description_items correspond line by line
# and are of the same length.
if len(labeled_items) != len(description_items):
    print("Warning: The number of items in labeled_items and description_items is different!")
    print("Will process up to the length of the shorter list.")

for i in range(min(len(labeled_items), len(description_items))):
    label_item = labeled_items[i]        # e.g., {"patent_text_key": "This is the solution.", "reason": "..."}
    desc_item = description_items[i]     # e.g., {"patent_text_key": "Full description for prompt.", "other_field": "..."}

    current_solution = None
    description_for_prompt = None
    
    # Find the solution and the key that provided it
    key_for_solution_and_description = None
    for key, value in label_item.items():
        if key != "reason":
            current_solution = value
            key_for_solution_and_description = key
            break # Found the first non-reason item

    if current_solution is not None and key_for_solution_and_description is not None:
        # Now use this key_for_solution_and_description to get the description from desc_item
        description_for_prompt = desc_item.get(key_for_solution_and_description)

        if description_for_prompt is not None:
            final_prompt = prompt_template.replace("{description}", description_for_prompt)
            prepared_data.append({
                'solution': current_solution,
                'prompt': final_prompt
            })
        else:
            print(f"Warning: For item index {i}, solution found with key '{key_for_solution_and_description}', but this key was not found in the corresponding description item: {desc_item}")
    else:
        print(f"Warning: For item index {i}, no suitable solution key (non-'reason') found in label_item: {label_item}")

print(f"Processed {len(prepared_data)} items for the dataset.")
print(prepared_data[0])
train_dataset = Dataset.from_list(prepared_data)

Will process up to the length of the shorter list.
Processed 1326 items for the dataset.
{'solution': '3', 'prompt': 'A conversation between User and Assistant. The user provides a text, and the Assistant classifies it \naccording to one or more of the 17 Sustainable Development Goals (SDGs). The Assistant \nfirst thinks about the text and the different SDGs, detailing its reasoning process in relation to the input text, and then provides the user with the SDG classification(s). \nThe reason and the sdg answer are enclosed within <reason> </reason> and <sdg> </sdg> tags, respectively. \nThe Assistant must identify the most relevant SDG(s) for the given text. If multiple SDGs are relevant, they can all be listed. The reasoning should clearly justify the choice(s).\\n\\n\nHere are the 17 Sustainable Development Goals (SDGs) and their descriptions:\\n\n1.  **SDG 1: No Poverty:** End poverty in all its forms everywhere.\\n\n2.  **SDG 2: Zero Hunger:** End hunger, achieve food security and 

In [None]:
model_id = "Qwen/Qwen3-1.7B-FP8"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
)

In [59]:
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

trainable params: 1,146,880 || all params: 597,196,800 || trainable%: 0.1920


In [60]:
def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<sdg>.*?</sdg>\s*<reason>.*?</reason>$"
    print(completions)
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    rewards_list = [1.0 if match else 0.0 for match in matches]
    return rewards_list

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<sdg>.*?</sdg>\s*<reason>.*?</reason>$"
    
    # 'completions' is already the list of strings you want to check.
    # So, 'completion_contents' can just be 'completions'.
    # Each 'content' in the next line will be one of the generated strings.
    print(f"Received completions: {completions}") # This print is helpful for debugging

    # Directly use each completion string for matching
    matches = [re.match(pattern, single_completion_string) for single_completion_string in completions]
    rewards_list = [1.0 if match else 0.0 for match in matches]
    
    print(f"Generated rewards: {rewards_list}") # Optional: to see the rewards
    return rewards_list

In [61]:
def get_sdg_reason(text: str) -> Tuple[str, str]:
    """Extracts SDG (Sustainable Development Goal) and reason from a text.

    The input text is expected to contain <sdg> and <reason> XML-like tags
    enclosing the relevant information.

    Args:
        text (str): The input string, potentially containing SDG and reason
            information enclosed in <sdg>...</sdg> and <reason>...</reason>
            tags.

    Returns:
        Tuple[str, str]: A tuple containing two strings:
            - The content of the <sdg> tag, stripped of leading/trailing
              whitespace.
            - The content of the <reason> tag, stripped of leading/trailing
              whitespace.
            Returns empty strings for either if the corresponding tag is not found.
    """
    reason_match = re.search(r'<reason>(.*?)</reason>', text, re.DOTALL)
    sdg_match = re.search(r'<sdg>(.*?)</sdg>', text, re.DOTALL)

    reason_content_regex = ""
    sdg_content_regex = ""

    if reason_match:
        reason_content_regex = reason_match.group(1).strip()

    if sdg_match:
        sdg_content_regex = sdg_match.group(1).strip()
    return sdg_content_regex, reason_content_regex

In [62]:
def extract_sdgs(text: str) -> List[str]:
        """
        Extracts and standardizes SDG references from a given text.

        This method identifies SDG mentions in various formats, including:
        - "SDG" followed by a number (e.g., "SDG1", "SDG 2").
        - Numbers with sub-targets (e.g., "16.1", "3.4"), where the main number
          is extracted.
        - Standalone numbers (1-17) that appear at the beginning of the text
          or are preceded by common delimiters (commas, semicolons, colons, whitespace)
          and followed by delimiters or the end of the string.
        The matching is case-insensitive.

        Args:
            text (str): The input text to scan for SDG references.

        Returns:
            List[str]: A list of unique SDGs found, formatted as "SDG<number>"
                (e.g., ["SDG1", "SDG2"]), sorted numerically. Returns ["None"]
                if no valid SDGs (1-17) are found or if the input text is empty
                or not a string.
        """
        if not text or not isinstance(text, str):
            # Modified to return ["None"] as per original logic for empty/invalid text
            return ["None"]

        sdg_numbers = set()  # Use set to avoid duplicates

        # Pattern 1: SDG followed by number with optional sub-target
        # Captures: SDG1, sdg 2, SDG13.4, etc.
        sdg_pattern = r'(?i)\bsdg\s*(\d{1,2})(?:\.\d+)?\b'
        sdg_matches = re.findall(sdg_pattern, text)
        for match in sdg_matches:
            number = int(match)
            if 1 <= number <= 17:
                sdg_numbers.add(number)

        # Pattern 2: Number with sub-target (e.g., "16.1", "3.4")
        # Look for patterns like X.Y where X is 1-17
        number_with_sub_pattern = r'\b(\d{1,2})\.\d+\b'
        sub_matches = re.findall(number_with_sub_pattern, text)
        for match in sub_matches:
            number = int(match)
            if 1 <= number <= 17:
                sdg_numbers.add(number)

        # Pattern 3: Standalone numbers at beginning or after delimiters
        standalone_pattern = r'(?:^|[,;:]\s*|(?<=\s))(\d{1,2})(?=\s*[,;]|\s*$|\s+)'
        standalone_matches = re.findall(standalone_pattern, text.strip())
        for match in standalone_matches:
            number = int(match)
            if 1 <= number <= 17:
                sdg_numbers.add(number)

        # Convert to sorted list of formatted strings
        result = [f"SDG{num}" for num in sorted(sdg_numbers)]

        return ["None"] if not result else result

In [63]:
def score(list1: List[str], list2: List[str]) -> float:
  """Calculates a similarity score between two lists based on common elements.

  The score is defined as the ratio of the number of common unique elements
  to the total number of unique elements across both lists (Jaccard index
  for sets derived from the lists).

  Args:
    list1 (List[Any]): The first list of items.
    list2 (List[Any]): The second list of items.

  Returns:
    float: The similarity score, ranging from 0.0 (no common elements)
           to 1.0 (all unique elements are common, or both lists are
           effectively the same in terms of unique content if order and
           duplicates are ignored). Returns 0.0 if both lists are empty
           or if the union of elements is empty to avoid division by zero,
           though the formula naturally handles this if at least one list
           is non-empty.
  """
  set1 = set(list1)
  set2 = set(list2)

  # Number of unique combined elements
  num_union = len(set1 | set2)  # Union of sets

  # Number of common elements between list2 and list1
  num_intersection = len(set1 & set2)  # Intersection of sets

  if num_union == 0:
    return 0.0  # Avoid division by zero if both lists result in empty sets
  return num_intersection / num_union


In [64]:
def accuracy_reward(completions, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    sdg_tag_solutions = kwargs["solution"]
    print(completions)
    completion_contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sdg_tag_solution in zip(completion_contents, sdg_tag_solutions):
        sdg_tag_content, _ = get_sdg_reason(content)

        sdg_list = extract_sdgs(sdg_tag_content)
        sdg_list_solution = extract_sdgs(sdg_tag_solution)
        if len(sdg_list) != 0:
            try:
                rewards.append(score(sdg_list, sdg_list_solution))
            except Exception:
                rewards.append(0.0)
        else:
            rewards.append(0.0)
    return rewards

In [None]:
from trl import GRPOConfig

# Configure training arguments using GRPOConfig
training_args = GRPOConfig(
    output_dir="Qwen3-1.7B-GRPO",
    learning_rate=1e-5,
    remove_unused_columns=False,  # to access the solution column in accuracy_reward
    gradient_accumulation_steps=16,
    num_train_epochs=1,
    bf16=True,
    # Parameters that control de data preprocessing
    max_completion_length=500,  # default: 256
    num_generations=4,  # default: 8
    max_prompt_length=5000,  # default: 512
    # Parameters related to reporting and saving
    report_to=["tensorboard"],
    logging_steps=10,
    push_to_hub=False,
    save_strategy="steps",
    save_steps=10,
)

In [66]:
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model, reward_funcs=[format_reward, accuracy_reward], args=training_args, train_dataset=train_dataset
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [67]:
trainer.train()

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.44 GiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 16.79 GiB is allocated by PyTorch, and 5.28 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
trainer.save_model(training_args.output_dir)