# Gifty Production Worker: LLM Scoring

This notebook connects to the Gifty Internal API, fetches products that need scoring, and pushes the results back to the database. Optimized for 2x T4 GPUs.

### Setup Guide
1. Set `API_BASE_URL` to your backend URL (e.g. `https://api.gifty.gift`).
2. Set `INTERNAL_TOKEN` to match the value in your backend settings.

In [None]:
!pip -q install -U transformers accelerate bitsandbytes sentencepiece pandas tqdm requests

In [None]:
import logging
import sys

API_BASE_URL = "https://your-api-url.com" # Update this
INTERNAL_TOKEN = "default_internal_token"  # Update this
DEBUG = True # Set to True for verbose logging

MODEL_ID = "Qwen/Qwen2.5-32B-Instruct"
MODEL_VERSION = "v1.1" # Prompt version
MODEL_TAG = "qwen2.5-32b-4bit" # Hardware/quantization tag

# Configure Logging
log_level = logging.DEBUG if DEBUG else logging.INFO
logging.basicConfig(
    level=log_level,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger("GiftyWorker")
logger.info(f"Logger initialized with level: {logging.getLevelName(log_level)}")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

logger.info(f"Loading model {MODEL_ID}...")
bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb,
    device_map="auto"
)
model.eval()
logger.info("Model loaded successfully.")

In [None]:
import torch.nn.functional as F

SYSTEM = """You are a strict giftability classifier.
First, provide a very brief reasoning in Russian or English (max 2 sentences).
Then conclude with 'Answer: GIFT' or 'Answer: NOT_GIFT'."""

def build_prompt(title, category="", merchant="", price=None):
    user = f"""Decide if this product is a good gift item for most people.
Utilitarian/chemical/spare parts -> NOT_GIFT. 
Decor/gadgets/jewelry/toys -> GIFT.

Product:
- title: {title}
- category: {category}
- merchant: {merchant}
- price: {price}
Reasoning:"""
    msgs = [
        {"role": "system", "content": SYSTEM},
        {"role": "user", "content": user},
    ]
    return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

@torch.no_grad()
def score_label(prompt_with_reasoning: str, label: str) -> float:
    full = prompt_with_reasoning + label
    enc_full = tok(full, return_tensors="pt").to(model.device)
    enc_prompt = tok(prompt_with_reasoning, return_tensors="pt").to(model.device)
    logits = model(**enc_full).logits
    prompt_len = enc_prompt["input_ids"].shape[1]
    label_ids = enc_full["input_ids"][:, prompt_len:]
    lp = 0.0
    for j in range(label_ids.shape[1]):
        token_id = label_ids[0, j].item()
        logp = F.log_softmax(logits[0, prompt_len - 1 + j, :], dim=-1)[token_id].item()
        lp += logp
    return lp

def process_one(item):
    logger.debug(f"Processing item: {item.get('title', 'N/A')} (ID: {item.get('gift_id')})")
    
    prompt = build_prompt(item.get('title',''), item.get('category',''), item.get('merchant',''), item.get('price'))
    if DEBUG: # Show full prompt only in debug mode
        logger.debug(f"--- PROMPT ---\n{prompt}\n--------------")
    
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    out = model.generate(**inputs, max_new_tokens=80, do_sample=False, pad_token_id=tok.eos_token_id)
    gen = tok.decode(out[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
    
    logger.debug(f"Raw generation: {gen}")
    
    reasoning = gen.split("Answer:")[0].strip() if "Answer:" in gen else gen.strip()
    score_prompt = prompt + reasoning + "\nAnswer:"
    
    s_gift = score_label(score_prompt, " GIFT")
    s_not = score_label(score_prompt, " NOT_GIFT")
    
    logger.debug(f"Logprobs -> GIFT: {s_gift:.4f}, NOT_GIFT: {s_not:.4f}")
    
    p = float(torch.softmax(torch.tensor([s_not, s_gift]), dim=0)[1].item())
    
    p_final = round(p, 2)
    if p_final < 0.01: p_final = 0.0
    
    logger.info(f"Result: {p_final} | Reason: {reasoning}")
    return p_final, reasoning

In [None]:
import requests
import time
from tqdm.auto import tqdm

headers = {"X-Internal-Token": INTERNAL_TOKEN}

logger.info("!!! Starting production worker loop !!!")
while True:
    try:
        # 1. Get tasks
        logger.debug(f"Fetching tasks from {API_BASE_URL}...")
        resp = requests.get(f"{API_BASE_URL}/internal/scoring/tasks?limit=20", headers=headers, timeout=30)
        if resp.status_code != 200:
            logger.error(f"Error fetching tasks (HTTP {resp.status_code}): {resp.text}")
            time.sleep(30)
            continue
            
        tasks = resp.json()
        if not tasks:
            logger.info("No more products to score. Waiting 5 minutes...")
            time.sleep(300)
            continue
            
        logger.info(f"[Batch] Processing {len(tasks)} items")
        results = []
        for t in tasks:
            try:
                p, reason = process_one(t)
                results.append({
                    "gift_id": t['gift_id'],
                    "llm_gift_score": p,
                    "llm_gift_reasoning": reason,
                    "llm_scoring_model": MODEL_TAG,
                    "llm_scoring_version": MODEL_VERSION
                })
            except Exception as e_one:
                logger.error(f"Error processing item {t.get('gift_id')}: {e_one}", exc_info=True)
            
        # 2. Submit results
        if results:
            logger.debug(f"Submitting {len(results)} results back to API...")
            s_resp = requests.post(f"{API_BASE_URL}/internal/scoring/submit", json={"results": results}, headers=headers, timeout=30)
            if s_resp.status_code == 200:
                logger.info(f"Successfully updated {s_resp.json().get('updated')} items.")
            else:
                logger.error(f"Failed to submit results (HTTP {s_resp.status_code}): {s_resp.text}")
        else:
            logger.warning("Batch finished with 0 results to submit.")
            
    except requests.exceptions.RequestException as re:
        logger.error(f"Network error: {re}")
        time.sleep(30)
    except Exception as e:
        logger.error(f"Unexpected error in main loop: {e}", exc_info=True)
        time.sleep(30)