In [1]:
!pip install unsloth vllm



In [2]:
from unsloth import FastLanguageModel
import torch

import re
from typing import List, Tuple
from transformers import AutoModelForCausalLM
from datasets import Dataset 

from peft import LoraConfig, get_peft_model

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 05-25 19:09:22 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-25 19:09:22 [__init__.py:239] Automatically detected platform cuda.


2025-05-25 19:09:23,688	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [3]:
import json

# Load the labeled dataset (solutions)
labeled_items = []
with open("../src/ai/testsets/dataset_labeled.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        if line.strip(): # Make sure the line is not empty
            labeled_items.append(json.loads(line.strip()))

# Load the descriptions dataset
description_items = []
with open("../src/ai/testsets/dataset.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        if line.strip(): # Make sure the line is not empty
            description_items.append(json.loads(line.strip()))

# Load the prompt template
with open("sdg_label_prompt.md", "r", encoding="utf-8") as f:
    prompt_template = f.read()

print(f"Loaded {len(labeled_items)} items from dataset_labeled.jsonl")
print(f"Loaded {len(description_items)} items from dataset.jsonl")

Loaded 1326 items from dataset_labeled.jsonl
Loaded 3000 items from dataset.jsonl


In [4]:
prepared_data = []

# Assuming labeled_items and description_items correspond line by line
# and are of the same length.
if len(labeled_items) != len(description_items):
    print("Warning: The number of items in labeled_items and description_items is different!")
    print("Will process up to the length of the shorter list.")

for i in range(min(len(labeled_items), len(description_items))):
    label_item = labeled_items[i]        # e.g., {"patent_text_key": "This is the solution.", "reason": "..."}
    desc_item = description_items[i]     # e.g., {"patent_text_key": "Full description for prompt.", "other_field": "..."}

    current_solution = None
    description_for_prompt = None
    
    # Find the solution and the key that provided it
    key_for_solution_and_description = None
    for key, value in label_item.items():
        if key != "reason":
            current_solution = value
            key_for_solution_and_description = key
            break # Found the first non-reason item

    if current_solution is not None and key_for_solution_and_description is not None:
        # Now use this key_for_solution_and_description to get the description from desc_item
        description_for_prompt = desc_item.get(key_for_solution_and_description)

        if description_for_prompt is not None:
            final_prompt = " ".join(prompt_template.replace("{description}", description_for_prompt).split()[:1200])
            prepared_data.append({
                'solution': current_solution,
                'prompt': final_prompt
            })
        else:
            print(f"Warning: For item index {i}, solution found with key '{key_for_solution_and_description}', but this key was not found in the corresponding description item: {desc_item}")
    else:
        print(f"Warning: For item index {i}, no suitable solution key (non-'reason') found in label_item: {label_item}")

print(f"Processed {len(prepared_data)} items for the dataset.")
print(prepared_data[0])
train_dataset = Dataset.from_list(prepared_data)

Will process up to the length of the shorter list.
Processed 1326 items for the dataset.
{'solution': '3', 'prompt': "A conversation between User and Assistant. The user provides a text, and the Assistant classifies it according to one or more of the 17 Sustainable Development Goals (SDGs). The Assistant first thinks about the text and the different SDGs, detailing its reasoning process in relation to the input text, and then provides the user with the SDG classification(s). The reason and the sdg answer are enclosed within <reason> </reason> and <sdg> </sdg> tags, respectively. The Assistant must identify the most relevant SDG(s) for the given text. If multiple SDGs are relevant, they can all be listed. The reasoning should clearly justify the choice(s).\\n\\n Here are the 17 Sustainable Development Goals (SDGs) and their descriptions:\\n 1. **SDG 1: No Poverty:** End poverty in all its forms everywhere.\\n 2. **SDG 2: Zero Hunger:** End hunger, achieve food security and improved nutr

In [5]:
max_seq_length = 2500 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-0.6B",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank*2, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,
)

==((====))==  Unsloth 2025.5.7: Fast Qwen3 patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA GeForce RTX 3070. Num GPUs = 1. Max memory: 8.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Qwen3-0.6B with actual GPU utilization = 60.4%
Unsloth: Your GPU has CUDA compute capability 8.6 with VRAM = 8.0 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2500. Num Sequences = 160.
Unsloth: vLLM's KV Cache can use up to 3.85 GB. Also swap space = 0 GB.
INFO 05-25 19:09:35 [config.py:717] This model supports multiple tasks: {'embed', 'score', 'generate', 'classify', 'reward'}. Defaulting to 'generate'.
INFO 05-25 19:09:35 [config.py:2003] Chunked prefill

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.04it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.03it/s]


INFO 05-25 19:09:37 [loader.py:458] Loading weights took 0.37 seconds
INFO 05-25 19:09:37 [punica_selector.py:18] Using PunicaWrapperGPU.





INFO 05-25 19:09:38 [gpu_model_runner.py:1347] Model loading took 1.1649 GiB and 1.293368 seconds
INFO 05-25 19:09:50 [backends.py:420] Using cache directory: /home/magb/.cache/vllm/torch_compile_cache/5b62dcf92c/rank_0_0 for vLLM's torch.compile
INFO 05-25 19:09:50 [backends.py:430] Dynamo bytecode transform time: 12.28 s
INFO 05-25 19:09:59 [backends.py:118] Directly load the compiled graph(s) for shape None from the cache, took 6.253 s
INFO 05-25 19:10:05 [monitor.py:33] torch.compile takes 12.28 s in total
INFO 05-25 19:10:06 [kv_cache_utils.py:634] GPU KV cache size: 15,312 tokens
INFO 05-25 19:10:06 [kv_cache_utils.py:637] Maximum concurrency for 2,500 tokens per request: 6.12x
INFO 05-25 19:10:57 [gpu_model_runner.py:1686] Graph capturing finished in 50 secs, took 0.63 GiB
INFO 05-25 19:10:57 [core.py:159] init engine (profile, create kv cache, warmup model) took 79.39 seconds
Unsloth: Just some info: will skip parsing ['post_feedforward_layernorm', 'pre_feedforward_layernorm']


Unsloth 2025.5.7 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


In [6]:
def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<sdg>.*?</sdg>\s*<reason>.*?</reason>$"
    print(completions)
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    rewards_list = [1.0 if match else 0.0 for match in matches]
    return rewards_list

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<sdg>.*?</sdg>\s*<reason>.*?</reason>$"
    
    # 'completions' is already the list of strings you want to check.
    # So, 'completion_contents' can just be 'completions'.
    # Each 'content' in the next line will be one of the generated strings.
    print(f"Received completions: {completions}") # This print is helpful for debugging

    # Directly use each completion string for matching
    matches = [re.match(pattern, single_completion_string) for single_completion_string in completions]
    rewards_list = [1.0 if match else 0.0 for match in matches]
    
    print(f"Generated rewards: {rewards_list}") # Optional: to see the rewards
    return rewards_list

In [7]:
def get_sdg_reason(text: str) -> Tuple[str, str]:
    """Extracts SDG (Sustainable Development Goal) and reason from a text.

    The input text is expected to contain <sdg> and <reason> XML-like tags
    enclosing the relevant information.

    Args:
        text (str): The input string, potentially containing SDG and reason
            information enclosed in <sdg>...</sdg> and <reason>...</reason>
            tags.

    Returns:
        Tuple[str, str]: A tuple containing two strings:
            - The content of the <sdg> tag, stripped of leading/trailing
              whitespace.
            - The content of the <reason> tag, stripped of leading/trailing
              whitespace.
            Returns empty strings for either if the corresponding tag is not found.
    """
    reason_match = re.search(r'<reason>(.*?)</reason>', text, re.DOTALL)
    sdg_match = re.search(r'<sdg>(.*?)</sdg>', text, re.DOTALL)

    reason_content_regex = ""
    sdg_content_regex = ""

    if reason_match:
        reason_content_regex = reason_match.group(1).strip()

    if sdg_match:
        sdg_content_regex = sdg_match.group(1).strip()
    return sdg_content_regex, reason_content_regex

In [8]:
def extract_sdgs(text: str) -> List[str]:
        """
        Extracts and standardizes SDG references from a given text.

        This method identifies SDG mentions in various formats, including:
        - "SDG" followed by a number (e.g., "SDG1", "SDG 2").
        - Numbers with sub-targets (e.g., "16.1", "3.4"), where the main number
          is extracted.
        - Standalone numbers (1-17) that appear at the beginning of the text
          or are preceded by common delimiters (commas, semicolons, colons, whitespace)
          and followed by delimiters or the end of the string.
        The matching is case-insensitive.

        Args:
            text (str): The input text to scan for SDG references.

        Returns:
            List[str]: A list of unique SDGs found, formatted as "SDG<number>"
                (e.g., ["SDG1", "SDG2"]), sorted numerically. Returns ["None"]
                if no valid SDGs (1-17) are found or if the input text is empty
                or not a string.
        """
        if not text or not isinstance(text, str):
            # Modified to return ["None"] as per original logic for empty/invalid text
            return ["None"]

        sdg_numbers = set()  # Use set to avoid duplicates

        # Pattern 1: SDG followed by number with optional sub-target
        # Captures: SDG1, sdg 2, SDG13.4, etc.
        sdg_pattern = r'(?i)\bsdg\s*(\d{1,2})(?:\.\d+)?\b'
        sdg_matches = re.findall(sdg_pattern, text)
        for match in sdg_matches:
            number = int(match)
            if 1 <= number <= 17:
                sdg_numbers.add(number)

        # Pattern 2: Number with sub-target (e.g., "16.1", "3.4")
        # Look for patterns like X.Y where X is 1-17
        number_with_sub_pattern = r'\b(\d{1,2})\.\d+\b'
        sub_matches = re.findall(number_with_sub_pattern, text)
        for match in sub_matches:
            number = int(match)
            if 1 <= number <= 17:
                sdg_numbers.add(number)

        # Pattern 3: Standalone numbers at beginning or after delimiters
        standalone_pattern = r'(?:^|[,;:]\s*|(?<=\s))(\d{1,2})(?=\s*[,;]|\s*$|\s+)'
        standalone_matches = re.findall(standalone_pattern, text.strip())
        for match in standalone_matches:
            number = int(match)
            if 1 <= number <= 17:
                sdg_numbers.add(number)

        # Convert to sorted list of formatted strings
        result = [f"SDG{num}" for num in sorted(sdg_numbers)]

        return ["None"] if not result else result

In [9]:
def score(list1: List[str], list2: List[str]) -> float:
  """Calculates a similarity score between two lists based on common elements.

  The score is defined as the ratio of the number of common unique elements
  to the total number of unique elements across both lists (Jaccard index
  for sets derived from the lists).

  Args:
    list1 (List[Any]): The first list of items.
    list2 (List[Any]): The second list of items.

  Returns:
    float: The similarity score, ranging from 0.0 (no common elements)
           to 1.0 (all unique elements are common, or both lists are
           effectively the same in terms of unique content if order and
           duplicates are ignored). Returns 0.0 if both lists are empty
           or if the union of elements is empty to avoid division by zero,
           though the formula naturally handles this if at least one list
           is non-empty.
  """
  set1 = set(list1)
  set2 = set(list2)

  # Number of unique combined elements
  num_union = len(set1 | set2)  # Union of sets

  # Number of common elements between list2 and list1
  num_intersection = len(set1 & set2)  # Intersection of sets

  if num_union == 0:
    return 0.0  # Avoid division by zero if both lists result in empty sets
  return num_intersection / num_union


In [None]:
def accuracy_reward(completions, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    sdg_tag_solutions = kwargs["solution"]
    print(completions)
    #completion_contents = [completion[0]["content"] for single_completion_string in completions]
    rewards = []
    for content, sdg_tag_solution in zip(completions, sdg_tag_solutions):
        sdg_tag_content, _ = get_sdg_reason(content)

        sdg_list = extract_sdgs(sdg_tag_content)
        sdg_list_solution = extract_sdgs(sdg_tag_solution)
        if len(sdg_list) != 0:
            try:
                rewards.append(score(sdg_list, sdg_list_solution))
            except Exception:
                rewards.append(0.0)
        else:
            rewards.append(0.0)
    return rewards

In [11]:
max_prompt_length = 1500
max_completion_length = max_seq_length - max_prompt_length

from vllm import SamplingParams
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    seed = 3407,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    vllm_sampling_params = vllm_sampling_params,
    temperature = 1.0,
    learning_rate = 5e-6,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 100,
    save_steps = 10,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",

    # For optional training + evaluation
    # fp16_full_eval = True,
    # per_device_eval_batch_size = 4,
    # eval_accumulation_steps = 1,
    # eval_strategy = "steps",
    # eval_steps = 1,
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


In [12]:
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model, processing_class = tokenizer, reward_funcs=[format_reward, accuracy_reward], args=training_args, train_dataset=train_dataset
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,326 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 1 x 1) = 4
 "-____-"     Trainable parameters = 20,185,088/616,235,008 (3.28% trained)


Received completions: [' truck-mounted radar subsystem is configured to emit electromagnetic waves propagating in a space under the trailer. The truck-mounted radar associated with the work described in this patent is designed for detecting objects in the space under the trailer and detecting object positions by sending a signal. 6: The configuration of the truck-mounted radar subsystem consistent with the example embodiment mentioned in the previous paragraph; the truck-mounted sensor is more accurate in detecting the objects in the area where they are observed and with the characteristics of conductivity, and adds the ability to sense objects with higher frequencies and a lower range of dust explosion and noise. However, the implementation cost for such a system is relatively high. 7: Presenters and authors contribute that the improvement in the system and method of truck-mounted sensors in this patent is significant and appropriate for use in the field of automotive engineering. 8: 

TypeError: string indices must be integers, not 'str'

In [None]:
trainer.save_model(training_args.output_dir)