In [5]:
from unsloth import FastLanguageModel
from pydantic import PrivateAttr
from typing import Any
from datasets import load_dataset
import re

In [3]:
def extract_score(generated_text):
    """
    Extracts the first digit found in the generated text that is a valid class label.
    """
    valid_digits = r"\d+"
    pattern = f"[{valid_digits}]"
    match = re.search(pattern, generated_text)
    if match:
        return int(match.group(0))
    else:
        return -1  # Parsing error

In [None]:
class DQIModel:
    """
    Define an extra ChatModel class to store and version more parameters than just the model name.
    This enables fine-tuning on specific parameters.
    """

    chat_model: str
    cm_temperature: float
    cm_max_new_tokens: int
    cm_quantize: bool
    inference_batch_size: int
    dtype: Any
    device: str
    _model: Any = PrivateAttr()
    _tokenizer: Any = PrivateAttr()

    def model_post_init(self, __context):
        # unsloth version (enable native 2x faster inference)
        self._model, self._tokenizer = FastLanguageModel.from_pretrained(
            model_name=self.chat_model,
            max_seq_length=self.cm_max_new_tokens,
            dtype=self.dtype,
            load_in_4bit=self.cm_quantize,
        )
        FastLanguageModel.for_inference(self._model)

    async def predict(self, query: list[str]) -> dict:
        # add_generation_prompt = true - Must add for generation
        input_ids = self._tokenizer.apply_chat_template(
            query,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to("cuda")

        output_ids = self._model.generate(
            input_ids=input_ids,
            max_new_tokens=self.cm_max_new_tokens,
            use_cache=True,
            temperature=self.cm_temperature,
            min_p=0.1,
        )

        decoded_outputs = self._tokenizer.batch_decode(
            output_ids[0][input_ids.shape[1] :], skip_special_tokens=True
        )

        generated_text = "".join(decoded_outputs).strip()
        predicted_label = extract_score(generated_text)

        return {
            "predicted_label": predicted_label,
            "generated_text": generated_text,
        }

In [6]:
test_dataset = load_dataset("json", data_files="../data/processed_test.jsonl", split="train").shuffle(seed=512)

In [7]:
example = test_dataset[0]
for message in example["messages"]:
    print(f"{message['role']}:")
    print(f"{message['content']}\n")

system:
### Role
You are an expert in political discourse analysis using the Discourse Quality Index (DQI).

### Task
Evaluate the "Level of Justification" of the following text on a scale from 0 to 4.

### Criteria Definitions
* **0 (No justification):** The speaker demands that X should or should not be done, but provides NO reason.

* **1 (Inferior justification):** A reason Y is given for demand X, but the logical connection between X and Y is missing or incomplete. It is a conclusion supported merely by illustrations or loose associations.

* **2 (Qualified justification):** A single complete logical inference is made. The speaker explicitly explains WHY or HOW X contributes to Y (e.g., using connectors like "because", "so that").

* **3 (Sophisticated justification - Broad):** The speaker provides at least two COMPLETE but DISTINCT justifications (Level 2) for the demand.
*Structure:* "We should do X because of Reason A. Additionally, we should do X because of Reason B." (Horizon

In [None]:
llm = DQIModel(
    chat_model="outputs/mistral-7b-instruct-v0.3-bnb-4bit_finetuned",
    cm_temperature=0.1,
    cm_max_new_tokens=256,
    cm_quantize=True,
    inference_batch_size=1,
    dtype="float16",
    device="cuda",
)

n_examples = 5
for example in test_dataset.select(range(n_examples)):
    query = [
        (message["role"], message["content"]) for message in example["messages"][:-1]
    ]
    y_true = example["messages"][-1]["content"]
    llm_response = llm.predict(query)
    y_pred = extract_score(llm_response["generated_text"])
    print(f"Ground truth: {y_true} | LLM prediction: {y_pred}")