In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import ollama
import json


# json import music from music.jyson
dataset_type = "college"
with open('data/datasets/college_small.json', 'r') as f:
    items = json.load(f)


In [None]:
from model_inf import OllamaClient, VLLMClient
USE_OLLAMA = False # Set to True if using Ollama, False for vLLM
model_name = "google/gemma-3-4b-it" # Model name for Ollama or vLLM
if USE_OLLAMA:
    # Configure Ollama options if needed
    ollama_options = {'temperature': 0.7} # Example option
    llm_client = OllamaClient(ollama_options=ollama_options)
    print(f"Using Ollama backend with model: {model_name}")
else:
    # Configure VLLM connection details
    # You need to do this command in terminal to start the vLLM server:
    #  python -m vllm.entrypoints.openai.api_server --model="google/gemma-3-4b-it"
    # python -m vllm.entrypoints.openai.api_server --model="google/gemma-3-4b-it" --max-model-len 50000
    vllm_base_url = "http://localhost:8000/v1" # Adjust if your vLLM server is elsewhere
    vllm_api_key = "dummy" # Usually 'dummy' or 'no-key' for local vLLM
    # Configure generation parameters for vLLM if needed
    vllm_client_options = {'temperature': 0.7, 'max_tokens': 1024} # Example options
    llm_client = VLLMClient(base_url=vllm_base_url, api_key=vllm_api_key, client_options=vllm_client_options)

could also use model = "llama2-uncensored"


In [None]:
from users import User

from utils import extract_list_from_response
from metrics import calc_iou, calc_serp_ms, calc_prag

dataset_type = "college"
k=20
type_of_activity = "student"

neutral_user = User(dataset_type=dataset_type, items=items, k=k, type_of_activity=type_of_activity, sensitive_atribute="a")

prompts = neutral_user.build_prompts()
response = llm_client.chat(model=model_name, messages=prompts)
neutral_list = extract_list_from_response(response)
print(f"Neutral attribute:")
print(f"Recommended list: {neutral_list}")


In [None]:
from tqdm import tqdm

final_results = {}

with open('data/sensitive_atributes.json', 'r') as f:
    dict_sensitive_atributes = json.load(f)

for type_of_sensitive_atributes in tqdm(dict_sensitive_atributes, desc="Processing sensitive attributes", position=0):
    sensitive_atributes = dict_sensitive_atributes[type_of_sensitive_atributes]
    res = {}
    for sensitive_atribute in tqdm(sensitive_atributes, desc=f"Processing {type_of_sensitive_atributes}", leave=False, position=1):
        user = User(dataset_type=dataset_type, items=items, k=k, type_of_activity=type_of_activity, sensitive_atribute=sensitive_atribute)
        prompts = user.build_prompts()
        response = llm_client.chat(model=model_name, messages=prompts)
        extracted_list = extract_list_from_response(response)
        res[sensitive_atribute] = {
            "IOU": calc_iou(neutral_list, extracted_list),
            "SERP MS": calc_serp_ms(neutral_list, extracted_list),
            "Pragmatic": calc_prag(neutral_list, extracted_list)
        }
    final_results[type_of_sensitive_atributes] = res

file = f"results_{model_name.replace('/', '_')}_{dataset_type}.json" 
with open(file, 'w') as f:
    json.dump(final_results, f, indent=4)

In [None]:
final_results

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

for key, value in final_results.items():
    print(f"{key}: {value}")

    df = pd.DataFrame(value)

    df_reset = df.reset_index().rename(columns={"index": "Metric"})
    df_long = df_reset.melt(id_vars="Metric", var_name=key, value_name="Value")

    df_long[key] = df_long[key].str.replace("an |a ", "", regex=True).str.title()

    df_long

    # Plot
    plt.figure(figsize=(10, 6))
    sns.barplot(data=df_long, x=key, y="Value", hue="Metric", palette="muted")

    plt.title(f"Comparison of IOU, SERP MS, and Pragmatic Metrics by {key}")
    plt.ylabel("Score")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

In [None]:
extracted_list

In [None]:
items_rank = {item:i for i, item in enumerate(items)}
items_rank

res = []
for item in extracted_list:
    item = item.split(": ")[-1].strip()
    rank = items_rank.get(item, 0)
    if rank != 0:
        res.append(rank)

import numpy as np
print(np.mean(res))
