In [1]:
!export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"

In [2]:
# import torch
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import json
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments
)

model_name = "Qwen/Qwen3-4B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # bf16
    # attn_implementation="flash_attention_2",
    device_map="auto",
)

model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
!nvidia-smi

Fri May 30 09:06:49 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.07             Driver Version: 570.133.07     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0 Off |                  N/A |
|  0%   36C    P1             66W /  575W |    8924MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
from datasets import load_dataset

ds = load_dataset("dair-ai/emotion", "split", split="train[0:300]")

# ds => {text: str, label: int}

# sadness (0), joy (1), love (2), anger (3), fear (4), surprise (5).
str2int = {
    "sadness": 0,
    "joy": 1,
    "love": 2,
    "anger": 3,
    "fear": 4,
    "surprise": 5,
}
int2str = {v: k for k, v in str2int.items()}

In [5]:
ds

Dataset({
    features: ['text', 'label'],
    num_rows: 300
})

In [6]:
import json

SYSTEM = """
You are an emotion classifier engineer.
The classier itself works based on keywords and keyphrases.
Your task as engineer is to populate each emotion with such keys and phrases that will precisely identify emotions in text.
Here is the full list of emotions:
- sadness
- joy
- love
- anger
- fear
- surprise
Your will be given a text, your must yield emotion name and keywords/keyphrases as such, if tool will use this keywords/keyphrases 
it will successfully classify the text into correct emotion
Return in the format:

```
{
"emotion": "emotion name from the list",
"keys": ["key1", "key2"]
}
```
Strictly adhere to the format
"""

def process_answer(completion: str):
    try:
        completion = completion.replace("```json", "").replace("```", "").strip()
        completion = json.loads(completion)
        emotion = completion["emotion"] 
        keys = completion["keys"]
        return emotion, keys
    except Exception as e:
        return "", []

In [7]:
def preprocess(example):
    texts = example["text"]
    labels = example["label"]
    data = {"prompt": [], "labels": []}
    for text, label in zip(texts, labels):

        data["prompt"].append(
           tokenizer.apply_chat_template([
                {"role": "system", "content": SYSTEM},
                {"role": "user", "content": text}
            ], tokenize=False, add_generation_prompt=True, enable_thinking=True), 
        )
        data["labels"].append(int2str[label])
    
    return data


transformed_dataset = ds.map(preprocess, batched=True, remove_columns=ds.column_names)
transformed_dataset

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

Dataset({
    features: ['prompt', 'labels'],
    num_rows: 300
})

In [8]:
transformed_dataset[0]

{'prompt': '<|im_start|>system\n\nYou are an emotion classifier engineer.\nThe classier itself works based on keywords and keyphrases.\nYour task as engineer is to populate each emotion with such keys and phrases that will precisely identify emotions in text.\nHere is the full list of emotions:\n- sadness\n- joy\n- love\n- anger\n- fear\n- surprise\nYour will be given a text, your must yield emotion name and keywords/keyphrases as such, if tool will use this keywords/keyphrases \nit will successfully classify the text into correct emotion\nReturn in the format:\n\n```\n{\n"emotion": "emotion name from the list",\n"keys": ["key1", "key2"]\n}\n```\nStrictly adhere to the format\n<|im_end|>\n<|im_start|>user\ni didnt feel humiliated<|im_end|>\n<|im_start|>assistant\n',
 'labels': 'sadness'}

In [9]:
from lark import Lark, Token
from lark import UnexpectedToken, UnexpectedCharacters
from lark.lexer import PatternRE, PatternStr
from lark.reconstruct import Reconstructor
from lark.parsers.lalr_interactive_parser import InteractiveParser
from lark import GrammarError
from typing import Union, Callable, Optional


def is_pattern_regex(pattern: Union[PatternRE, PatternStr]):
        return isinstance(pattern, PatternRE)
    
def is_pattern_string(pattern: Union[PatternRE, PatternStr]):
    return isinstance(pattern, PatternStr)


class Rebuilder:

    def __init__(self, 
                 grammar: str, 
                 term_subs: dict[str, str] = None,
                 token_transformer: Callable[[str, str, bool], str] = None,
                 ):
        """
        term_subs: dict[str, str] = None, -> term subs when reconstructing grammar from ast
        token_transformer: Callable[[str, str, bool], str] = None, -> token_name, inputs string and is_regex flag, returns token defs
        """
        self.grammar = grammar
        self.parser = Lark(grammar, 
                           parser="lalr", 
                           start="start", 
                           strict=False, 
                           lexer="contextual", 
                           maybe_placeholders=False,
                           regex=True
                           )
        self.terminals = self.parser.lexer_conf.terminals_by_name
        self.reconstructor = Reconstructor(self.parser, term_subs) \
                                if term_subs else Reconstructor(self.parser)
        self.token_transformer = token_transformer if token_transformer else None
    
    def get_token_definition(self, token: str):
        pattern = self.terminals.get(token, None)
        if pattern is not None:
            return pattern.pattern
        return None
    
    def as_token(self, token: str) -> Token:
        tok_def = self.get_token_definition(token)
        if self.token_transformer:
            tok_def = self.token_transformer(token, tok_def, is_pattern_regex(tok_def))
        if tok_def is None:
            tok_def = token
        return Token(token, tok_def)
    
    def beam_search(self, 
                    parser: InteractiveParser, 
                    given_token: Token, 
                    beam_width: int = 10, 
                    strategy: str = "shortest", # longest
                    break_early_limit: Optional[int] = 3,
                    ) -> list[Token]:
        """
        strategy: "shortest" or "longest"
        break_early_limit: 
        if strategy is "shortest", break early as soon as one of any path 
        that consist no more than break_early_limit tokens is found, priorisizing the shortest path
        if strategy is "longest", break early as soon as lenght of tokens path is at least break_early_limit,
        priorisizing the longest path
        """
        assert strategy in ["shortest", "longest"], "strategy must be either 'shortest' or 'longest'"
        if break_early_limit is not None:
            assert break_early_limit <= beam_width, "break_early_limit must be less than or equal to beam_width"
        if beam_width <= 0:
            return []
        candidates = {} # token2shortest_path
        candidates_score = {} # score2token
        accepted_tokens = list(parser.accepts()) # returns list of token names
        accepted_tokens.sort(key=lambda x: int(is_pattern_string(x)), reverse=False) # first -> string, then -> regex
        for token in accepted_tokens:
            dummy_parser = parser.copy()
            token = self.as_token(token)
            try:
                dummy_parser.feed_token(token)
                success = False
                # for the first two cases let's assume that there are missing tokens between tokena and given token
                if given_token.type in dummy_parser.accepts():
                    # case 1: some token is missing
                    candidates[token] = [token, given_token]
                    candidates_score[1] = token
                    success = True
                else:
                    # case 2: a series of tokens are missing
                    res = [token] + self.beam_search(
                        parser=dummy_parser, 
                        given_token=given_token, 
                        beam_width=beam_width-1,
                        strategy=strategy,
                        break_early_limit=None
                    )
                    score = len(res)
                    candidates[token] = res
                    candidates_score[score] = token
                    success = len(res) > 1
                if not success:
                    # if our hypothesis is wrong (no path discovered), then simply chose the correct token and ignore given token
                    # case 3: given token is not needed at all, remove given token from path
                    candidates[token] = [token]
                    candidates_score[0] = token
            except GrammarError as e:
                raise e

            if break_early_limit is None:
                continue
            if strategy == "shortest":
                for i in range(break_early_limit):
                    if i in candidates_score:
                        return candidates[candidates_score[i]] # break early
            elif strategy == "longest":
                for i in range(beam_width, break_early_limit, -1):
                    if i in candidates_score:
                        return candidates[candidates_score[i]] # break early

        if not candidates:
            return []
        if strategy == "shortest":
            best_score = min(candidates_score.keys()) # the shortest path
        elif strategy == "longest":
            best_score = max(candidates_score.keys()) # the longest path
        return candidates[candidates_score[best_score]]
    
    def repair(self, text: str, beam_width: int = 10, strategy: str = "shortest", break_early_limit: Optional[int] = 3):
        def _repair(e):
            if isinstance(e, UnexpectedToken):
                path = self.beam_search(e.interactive_parser, e.token, beam_width, strategy, break_early_limit)
                if not path:
                    raise e
                for token in path:
                    if token.type != "$END":
                        e.interactive_parser.feed_token(token)
                return True
            elif isinstance(e, UnexpectedCharacters):
                return True # simply ignore
            else:
                raise e
        
        tree = self.parser.parse(text, on_error=_repair)
        return self.reconstructor.reconstruct(tree)

In [10]:
# gnosis

from typing import List, Dict, Callable, Optional

# from rebuilder import Rebuilder
from dataclasses import dataclass
from abc import ABC, abstractmethod



@dataclass
class BaseTerminal(ABC):
    name: str
    
    def __post_init__(self):
        assert self.name.upper() == self.name, "name must be uppercase"

    @property
    @abstractmethod
    def as_terminal(self) -> str:
        pass


@dataclass
class Class(BaseTerminal):
    name: str
    values: List[str] # a list of keywords/keyphrases
    fuzzy_temperature: float = 0.5

    def __post_init__(self):
        assert self.fuzzy_temperature >= 0. and self.fuzzy_temperature <= 1., "fuzzy_temperature must be between 0 and 1"
        super().__post_init__()

    def __calculate_fuzzy_temperature(self, value: str):
        return max(0, int(len(value) * self.fuzzy_temperature - 1))

    @property
    def as_terminal(self):
        _values = []
        for value in self.values:
            temp = self.__calculate_fuzzy_temperature(value)
            fuzzy_postfix = ""
            if temp > 0:
                fuzzy_postfix = "{" + f"e<={temp}" + "}"
            value = f"(?:{value}){fuzzy_postfix}"
            _values.append(value)
        return f"{self.name}: /{'|'.join(_values)}/i"


class Gnosis:


    @staticmethod
    def setup_grammar(start: str, schema: str, terminals: List[BaseTerminal]):
        _terminals = "\n".join([terminal.as_terminal for terminal in terminals])
        return f"""
start: {start}
{schema}
{_terminals}
"""

    def __init__(self, 
                 start: str, 
                 schema: str, 
                 terminals: List[BaseTerminal],
                 term_subs: Optional[Dict[str, str]] = None,
                 token_transformer: Optional[Callable[[str, str, bool], str]] = None):
        # in theory, terminls must be anything that transform something into terminal def
        self.__start = start
        self.__schema = schema
        self.terminals = terminals
        self.__grammar = self.__class__.setup_grammar(start, schema, terminals)
        self.__rebuilder = Rebuilder(
            grammar=self.__grammar, 
            term_subs=term_subs, 
            token_transformer=token_transformer)
    
    # @property
    # def terminals(self):
    #     return self.__terminals

    @property
    def start(self):
        return self.__start
    
    @property
    def schema(self):
        return self.__schema
    
    @property
    def grammar(self):
        self.__grammar = self.__class__.setup_grammar(self.start, self.schema, self.terminals)
        return self.__grammar
    
    def repair(self, input_: str, *args, **kwargs):
        return self.__rebuilder.repair(input_, *args, **kwargs)



class Classifier(Gnosis):

    SEPARATOR = ">>"
    MISSING_CLASS = "MISSING_CLASS"

    @staticmethod
    def missing_class(token_name: str, token_def: str, is_regex: bool):
        if is_regex:
            return Classifier.MISSING_CLASS
        else:
            return token_def
    

    @staticmethod
    def conditional_rule(name: str):
        return f'{name} "{Classifier.SEPARATOR}" "{name.lower()}"'

    def __init__(self, classes: List[Class]):
        super().__init__(
            start="class",
            schema=f"class: {' | '.join([self.__class__.conditional_rule(class_.name) for class_ in classes])}",
            terminals=classes,
            token_transformer=self.__class__.missing_class)
    
    
    def update(self, class_name, classes: list):
        for cls in self.terminals:
            if cls.name == class_name.upper():
                cls.values += classes
    
    def repair(self, input_: str, *args, **kwargs):
        result = super().repair(input_, *args, **kwargs)
        input_, output = result.split(Classifier.SEPARATOR)
        if input_ == Classifier.MISSING_CLASS:
            return Classifier.MISSING_CLASS
        return output

In [11]:
classifier = classifier = Classifier([
    Class(name="SADNESS", values=["sad"]),
    Class(name="JOY", values=["joy"]),
    Class(name="LOVE", values=["love"]),
    Class(name="ANGER", values=["angry"]),
    Class(name="FEAR", values=["fear"]),
    Class(name="SURPRISE", values=["surprise"]),
])

In [12]:
def llm_accuracy_reward(completions, **kwargs):
    rewards = []
    for n, completion in enumerate(completions):
        reward = 0.0
        expected_label = kwargs["labels"][n].strip()
        completion = completion.split("</think>")[-1]
        predicted_label, _ = process_answer(completion)
        if expected_label == predicted_label.strip():
            reward += 1.0
        rewards.append(reward)
    return rewards


def tool_accuracy_reward(completions, **kwargs):
    rewards = []
    for n, completion in enumerate(completions):
        reward = 0.0
        expected_label = kwargs["labels"][n].strip()
        completion = completion.split("</think>")[-1]
        predicted_label = classifier.repair(completion).strip()
        if expected_label == predicted_label:
            reward += 1.0
        else:
            emotion, classes = process_answer(completion)
            if classes:
                classifier.update(emotion, classes)
                predicted_label = classifier.repair(completion)
            # if expected_label == predicted_label:
            #     reward += 0.5
        rewards.append(reward)
    return rewards

In [13]:
from peft import get_peft_model, LoraConfig, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,  
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",]  
)

peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

trainable params: 16,515,072 || all params: 4,038,983,168 || trainable%: 0.4089


In [14]:
from trl import GRPOConfig

# Configure training arguments using GRPOConfig
training_args = GRPOConfig(
    output_dir="Qwen4B-Grammar-RL-CLS",
    overwrite_output_dir=True,
    prediction_loss_only=True,
    optim="paged_adamw_32bit",
    lr_scheduler_type="cosine",
    learning_rate=1e-5,
    remove_unused_columns=False,  # to access the solution column in accuracy_reward
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    bf16=True,
    # Parameters that control de data preprocessing
    max_completion_length=512*3,  # default: 256
    num_generations=4,  # default: 8
    max_prompt_length=512,  # default: 512
    # Parameters related to reporting and saving
    report_to="mlflow", # https://huggingface.co/docs/transformers/main_classes/callback
    log_completions=True,
    group_by_length=True,
    # logging_steps=0.1,
    logging_steps=15,
    push_to_hub=False,
    save_strategy="steps",
    save_steps=50,
)

In [15]:
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=peft_model, 
    reward_funcs=[llm_accuracy_reward, tool_accuracy_reward], 
    args=training_args, 
    train_dataset=transformed_dataset
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
import gc, torch
gc.collect()
torch.cuda.empty_cache()
model.config.use_cache = False


trainer.train()

In [17]:
trainer.save_model("Qwen4B-Grammar-RL-CLS")

In [18]:
import json

data = {}
for term in classifier.terminals:
    data[term.name] = term.values


with open("classifier.json", "w") as f:
    json.dump(data, f, indent=2, ensure_ascii=False)