In [1]:
import pandas as pd
from tqdm import tqdm
import requests
import numpy as np

from src.category_tree.category_tree import CategoryTree

In [None]:
RANDOM_STATE = 42

CAT_ID_COL = "cat_id"
CAT_NAME_COL = 'cat_name'
TITLE_COL = "source_name"
PART_TYPE_COL = "part_type"
PART_COL = "part"
HASH_ID_COL = 'hash_id'

CAT_PATH = "data/raw/category_tree.csv"
BAD_LABELED_PATH = 'data/processed/bad_labeled.parquet'
SAVE_FILE_NAME = 'data/processed/bad_labeled_qwen2.5:3b.csv'

In [None]:
category_tree = CategoryTree(category_tree_path=CAT_PATH)
categor = pd.read_csv(CAT_PATH)

# Load dataset with mislabeled samples
df = pd.read_parquet(BAD_LABELED_PATH, columns=[TITLE_COL, CAT_ID_COL, HASH_ID_COL])

# Filter dataset to only include samples from "pre-leaf" categories (one level above leaf nodes)
pre_leaf_nodes = set(categor[categor.cat_id.isin(category_tree.leaf_nodes)].parent_id.to_list())
df = df[df.cat_id.isin(pre_leaf_nodes)]
df = df.sample(frac=1,random_state=RANDOM_STATE).reset_index(drop=True)

# Enrich dataset with category names by merging with category metadata
df = df.merge(categor, on='cat_id', how='left')

In [4]:
# Attempt to load previously saved progress (if exists)
try:
    df_save = pd.read_csv(SAVE_FILE_NAME)  # Load saved DataFrame
except FileNotFoundError:
    # If no save file found, init an empty DataFrame with expected columns
    df_save = pd.DataFrame(columns=[HASH_ID_COL, TITLE_COL, 'pred_cat_id'])
df_save.shape

(4181, 3)

In [5]:
# Ollama Settings
OLLAMA_URL = "http://127.0.0.1:11434/api/chat"
MODEL_NAME = "qwen2.5:3b"   # Model name as loaded in Ollama

def ask_ollama(prompt, history=[]):
    """
    Sends a request to the Ollama API with a given prompt and conversation history.
    """
    try:
        response = requests.post(OLLAMA_URL, json={
            "model": MODEL_NAME,
            "messages": history + [{"role": "user", "content": prompt}],
            "temperature": 0.2,
            "stream": False
        })
        response.raise_for_status()
        return response.json()["message"]["content"].strip()
    except Exception as e:
        print(f"Error requesting Ollama: {e}")
        return None

# Data Preparation
start_categor = categor[categor['parent_id'].isna()][CAT_NAME_COL].tolist()
categor_dict = categor.groupby('parent_id')[CAT_NAME_COL].apply(list).to_dict()

# Extract necessary columns from the dataframe into lists
source_name_list = df[TITLE_COL].tolist()
hash_id_list = df[HASH_ID_COL].tolist()
cat_name_list = df[CAT_NAME_COL].tolist()
cat_id_list = df[CAT_ID_COL].tolist()

# List to store predicted category IDs
cat_id_pred = []

# Determine the starting index for processing (continuing from previously saved data)
start_index = df_save.shape[0]
count = 0  # Counter for tracking processed items

for name, hash_id, cat_name, cat_id in tqdm(
    zip(source_name_list[start_index:], hash_id_list[start_index:], cat_name_list[start_index:], cat_id_list[start_index:]), 
    total=len(source_name_list[start_index:])
):
    count += 1

    # Get subcategories for the current category ID
    subcategories = categor_dict.get(cat_id, [])

    # Format the subcategory list as a string
    lst = "\n - " + "\n - ".join(subcategories)
    cat = cat_name # Current category name

    # Initialize the conversation history with system instructions
    chat_history = [{
        "role": "system",
        "content": (
            "Ты полезный помощник по определению категории товара и списка категорий, который строго следует инструкции. "
            "Ты всегда выбираешь только одну категорию из предложенного списка, без лишних слов."
        )
    }]

    prompt = (
        f"Товар: '{name}'.\n"
        f"Текущая категория: {cat}.\n"
        f"Выбери наиболее подходящую подкатегорию из списка:\n{lst}\n"
        f"Ответ должен быть строго только названием категории из списка, без лишних слов."
    )

    # print('prompt: \n', prompt)
    # print('cat_id: \n', cat_id)

    answer = ask_ollama(prompt, chat_history)
    if not answer:
        # If Ollama request fails, append failure result and break the loop
        cat_id_pred.append([hash_id, name, None])
        break

    # Clean the response
    answer = answer.strip().strip('.')

    chat_history.append({"role": "user", "content": prompt})
    chat_history.append({"role": "assistant", "content": answer})

    # Find the subcategory ID based on Ollama's response
    mask = (categor[CAT_NAME_COL] == answer) & (categor['parent_id'] == cat_id)
    category = categor.loc[mask]

    if category.empty:
        # print("Category not found")
        cat_id_pred.append([hash_id, name, np.nan])
    else:
        cat_id_pred.append([hash_id, name, answer])

    # Saving
    if count % 10 == 0:
        df_save = pd.concat([df_save, pd.DataFrame(cat_id_pred, columns=[HASH_ID_COL, TITLE_COL, 'pred_cat_id'])])
        df_save.to_csv(SAVE_FILE_NAME, index=False, na_rep='NaN')
        cat_id_pred = []

    if count % 1000 == 0:
        print(df_save.shape)

# Save any remaining predictions
if cat_id_pred:
    df_save = pd.concat([df_save, pd.DataFrame(cat_id_pred, columns=[HASH_ID_COL, TITLE_COL, 'pred_cat_id'])])
    df_save.to_csv(SAVE_FILE_NAME, index=False, na_rep='NaN')

100%|██████████| 19/19 [00:59<00:00,  3.11s/it]
