### Test using finetuned Llama model

In [None]:
import torch
from trl import SFTTrainer
from datasets import load_dataset
from transformers import TrainingArguments, TextStreamer
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel, is_bfloat16_supported
import json
import re
from tqdm import tqdm
import numpy as np
import pandas as pd
import ast

#### Define path, system, and model path

In [None]:
# Dictionary defining the system role and its instruction for categorization
role_dict = {
    'role': 'system',
    'content': 'Categorize the paragraph given. Select one or more categories from "synthesis condition" and "property". If the paragraph does not fit into either of these categories, choose "else". "property" paragraph must include specific numerical value of the property. ex) surface area of 2500 m2/g'
}

model_path = "text_categorize_1/"

# Load the CSV files containing categorized results
d1 = pd.read_csv('/home/users/seunghh/l2m3_revision/data/categorize_results_else.csv', encoding='utf-8')
d2 = pd.read_csv('/home/users/seunghh/l2m3_revision/data/categorize_results.csv', encoding='utf-8')
d2.columns = d1.columns
df = pd.concat([d1, d2])

# Extract the classification column as a list
true_list = df['Classification'].tolist()

In [None]:

def read_files(file_name):
    """
    Read a JSONL file line-by-line and parse each line into a JSON object.
    Return a list of these objects.
    """
    data_list = []
    with open(file_name, 'r') as g:
        for line_number, line in enumerate(g, start=1):
            line = line.strip()  # Remove whitespace
            if line:  # Skip empty lines
                try:
                    json_obj = json.loads(line)
                    data_list.append(json_obj)
                except json.JSONDecodeError as e:
                    # Print an error message if there's a problem parsing a line
                    print(f"Error parsing JSON on line {line_number}: {e}")    
    return data_list

def parse_template_to_dicts(template_text):
    """
    Given a template text that contains messages separated by <|im_start|> tokens,
    split them and parse each message into a dict with 'role' and 'content'.
    """
    # Split by the special token
    parts = re.split(r"<\|im_start\|>", template_text)
    messages = []
    for part in parts:
        part = part.strip()
        if not part:
            continue
        # The first line is the role, the rest is the content
        lines = part.split("\n", 1)
        role = lines[0].strip()
        content = lines[1].strip() if len(lines) > 1 else ""
        messages.append({"role": role, "content": content})
    return messages

def make_input_data(df):
    """
    Convert each row of the dataframe into a message format required for the model.
    This uses the global role_dict for the system message and adds a 'user' message 
    containing the text from the row.
    """
    results = []
    for i, row in df.iterrows():
        tmp = [role_dict]
        user_tmp = {
            'role': 'user',
            'content': row['Clean_Text']
        }
        tmp.append(user_tmp)
        results.append(tmp)
    return results

def make_clean_data(true_list):
    """
    Clean the classification labels by parsing and stripping unnecessary characters,
    making them into a cleaner, more standardized format.
    """
    true_clean_list = []
    for pred in tqdm(true_list):
        tmp_list = []
        _pred = pred.split(',')
        
        for p in _pred:
            p = p.strip()
            pred_clean = p.strip('["').strip('"]')
            pred_clean = pred_clean.replace("',", "'")
            pred_clean = pred_clean.replace("\\'", "'")
            
            # Remove trailing commas and extra quotes
            pred_clean = pred_clean.rstrip(",")
            pred_clean = pred_clean.strip("'\"")
            pred_clean = f"'{pred_clean}'"
    
            # Safely evaluate the cleaned string into a Python object (e.g., a string)
            pred_list = ast.literal_eval(pred_clean)
            if pred_list:
                tmp_list.append(pred_list)
        
        true_clean_list.append(tmp_list)
    return true_clean_list


#### Create Model and infer for test dataset

In [None]:

# Create input data for inference
data = make_input_data(df)

# Load the trained model and tokenizer for inference
# Since the model was saved in 16-bit precision, we do not load it in 4-bit format
model, tokenizer = FastLanguageModel.from_pretrained(
    model_path, 
    load_in_4bit=False,
    dtype=None,
)

# Initialize the model for inference
inference_model = FastLanguageModel.for_inference(model)

predicted_list = []
for dt in tqdm(data):
    messages = dt
    # Apply the chat template to get tokenized input suitable for model generation
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda")
    
    # Use a text streamer to handle the generated output
    text_streamer = TextStreamer(tokenizer)
    
    # Generate predictions from the model
    outputs = inference_model.generate(
        input_ids=inputs, 
        streamer=text_streamer, 
        max_new_tokens=128, 
        use_cache=True
    )    
    
    # Decode the output tokens into text
    decoded_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    new_parse = parse_template_to_dicts(decoded_text)
    predicted_list.append(new_parse[-1]['content'])

# Clean both the true and predicted classifications
true_clean_list = make_clean_data(true_list)
predicted_clean_list = make_clean_data(predicted_list)
