# Load Model từ Checkpoint và Eval trên Test Candidates

Notebook này load model đã train từ checkpoint và chạy evaluation trên test candidates giống như khi train xong.


## 1. Setup và Import


In [None]:
import sys
from pathlib import Path
import os

# Add project root to path
project_root = Path.cwd()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import torch
import random
import numpy as np
from typing import Dict, List
from tqdm import tqdm

# Import config - handle case when running in notebook (no command line args)
try:
    # Try to import arg normally
    from config import arg
except SystemExit:
    # If SystemExit occurs (no command line args), create a mock arg object
    class MockArg:
        def __init__(self):
            # Set default values for all config attributes
            self.dataset_code = 'beauty'
            self.min_rating = 3
            self.min_uc = 5
            self.min_sc = 5
            self.qwen_mode = 'text_only'
            self.qwen_model = 'qwen3-0.6b'
            self.qwen_max_history = 5
            self.qwen_max_candidates = 20
            self.use_torch_compile = False
            self.rerank_batch_size = 16
            self.rerank_epochs = 1
            self.rerank_lr = 1e-4
            self.rerank_patience = 2
            self.rerank_eval_candidates = 20
    
    arg = MockArg()
    print("Note: Running in notebook mode with default config values.")
    print("To use custom config, set arg attributes manually or run from command line.")

# Import evaluation utilities
from evaluation.utils import load_dataset_from_csv, load_rerank_candidates, evaluate_split
from evaluation.metrics import recall_at_k, ndcg_at_k, hit_at_k

# Import reranker
from rerank.models.llm import LLMModel, build_prompt_from_candidates, rank_candidates
from rerank.methods.qwen_reranker_unified import QwenReranker

print("Setup completed!")


## 2. Cấu hình Checkpoint và Dataset


In [None]:
# Cấu hình checkpoint path
CHECKPOINT_DIR = "./qwen_rerank"  # Thư mục chứa checkpoints

# Tìm checkpoint mới nhất hoặc chỉ định checkpoint cụ thể
checkpoint_path = None
if os.path.exists(CHECKPOINT_DIR):
    # Tìm checkpoint mới nhất
    checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith("checkpoint-")]
    if checkpoints:
        # Sort by checkpoint number
        checkpoints.sort(key=lambda x: int(x.split("-")[1]))
        checkpoint_path = os.path.join(CHECKPOINT_DIR, checkpoints[-1])
        print(f"Found latest checkpoint: {checkpoint_path}")
    else:
        # Nếu không có checkpoint-XXX, thử dùng toàn bộ folder
        checkpoint_path = CHECKPOINT_DIR
        print(f"Using checkpoint directory: {checkpoint_path}")
else:
    print(f"Warning: Checkpoint directory {CHECKPOINT_DIR} not found!")
    print("Please specify the correct checkpoint path.")

# Dataset config
dataset_code = getattr(arg, 'dataset_code', 'beauty')
min_rating = getattr(arg, 'min_rating', 3)
min_uc = getattr(arg, 'min_uc', 5)
min_sc = getattr(arg, 'min_sc', 5)

print(f"\nDataset: {dataset_code}")
print(f"Checkpoint: {checkpoint_path}")


## 3. Load Dataset và Test Candidates


In [None]:
# Load dataset
print("Loading dataset...")
data = load_dataset_from_csv(dataset_code, min_rating, min_uc, min_sc)
train = data["train"]
val = data["val"]
test = data["test"]
item_count = data["item_count"]

print(f"Train users: {len(train)}")
print(f"Val users: {len(val)}")
print(f"Test users: {len(test)}")
print(f"Items: {item_count}")

# Load pre-generated candidates
print("\nLoading pre-generated candidates...")
all_candidates = load_rerank_candidates(
    dataset_code=dataset_code,
    min_rating=min_rating,
    min_uc=min_uc,
    min_sc=min_sc,
)

test_candidates = all_candidates.get("test", {})
print(f"Test candidates loaded for {len(test_candidates)} users")


## 4. Load Model từ Checkpoint


In [None]:
# Cấu hình model (giống như khi train)
from rerank.methods.qwen_reranker_unified import MODEL_MAPPING

qwen_mode = getattr(arg, 'qwen_mode', 'text_only')
qwen_model = getattr(arg, 'qwen_model', 'qwen3-0.6b')

print(f"Mode: {qwen_mode}")
print(f"Model: {qwen_model}")

# Load base model (giống như khi train)
model_path = MODEL_MAPPING.get(qwen_model.lower(), MODEL_MAPPING['qwen3-0.6b'])
print(f"\nLoading base model: {model_path}")

# Tạo LLMModel instance
llm_model = LLMModel(
    train_data=None,  # Không cần training data cho inference
    model_name=model_path
)

# Load base model với cấu hình giống training
use_torch_compile = getattr(arg, 'use_torch_compile', False)
llm_model.load_model(use_torch_compile=use_torch_compile)

# Load checkpoint weights
if checkpoint_path and os.path.exists(checkpoint_path):
    print(f"\nLoading checkpoint from: {checkpoint_path}")
    
    # Unsloth lưu checkpoint dưới dạng model folder
    # Có thể load bằng FastLanguageModel.from_pretrained với local path
    try:
        from unsloth import FastLanguageModel
        
        # Load model từ checkpoint folder
        # Unsloth lưu adapter weights trong checkpoint folder
        print("Loading adapter weights from checkpoint...")
        
        # Method 1: Load từ checkpoint folder (nếu là model folder)
        if os.path.isdir(checkpoint_path):
            # Kiểm tra xem có adapter_model.bin hoặc pytorch_model.bin không
            adapter_path = os.path.join(checkpoint_path, "adapter_model.bin")
            if os.path.exists(adapter_path):
                # Load adapter weights
                adapter_weights = torch.load(adapter_path, map_location="cpu")
                # Merge vào model
                from peft import PeftModel
                if hasattr(llm_model.model, 'load_adapter'):
                    llm_model.model.load_adapter(checkpoint_path)
                    print("Adapter weights loaded successfully!")
                else:
                    # Fallback: load state dict manually
                    missing_keys, unexpected_keys = llm_model.model.load_state_dict(adapter_weights, strict=False)
                    print(f"Loaded adapter weights. Missing: {len(missing_keys)}, Unexpected: {len(unexpected_keys)}")
            else:
                # Thử load toàn bộ model từ checkpoint folder
                try:
                    # Unsloth có thể lưu model dưới dạng có thể load trực tiếp
                    model, tokenizer = FastLanguageModel.from_pretrained(
                        model_name=checkpoint_path,
                        max_seq_length=2048,
                        dtype=torch.float16,
                        load_in_4bit=True,
                    )
                    llm_model.model = model
                    llm_model.tokenizer = tokenizer
                    print("Model loaded from checkpoint folder successfully!")
                except Exception as e:
                    print(f"Warning: Could not load from checkpoint folder: {e}")
                    print("Using base model without checkpoint weights.")
        else:
            print(f"Checkpoint path is not a directory: {checkpoint_path}")
    except Exception as e:
        print(f"Warning: Could not load checkpoint: {e}")
        print("Using base model without checkpoint weights.")
else:
    print("No checkpoint found. Using base model.")

# Set model to eval mode
llm_model.model.eval()
print("\nModel ready for evaluation!")


## 5. Setup QwenReranker với Model đã Load


In [None]:
# Tạo QwenReranker instance
reranker = QwenReranker(
    top_k=50,
    mode=qwen_mode,
    model=qwen_model,
    max_history=getattr(arg, 'qwen_max_history', 5),
    max_candidates=getattr(arg, 'qwen_max_candidates', 20),
)

# Gán model đã load vào reranker
reranker.llm_model = llm_model

# Load item metadata
item_id2text = {}
item_meta = {}
user_history = {}

if "meta" in data:
    for item_id, meta in data["meta"].items():
        text = meta.get("text") if meta else None
        if text:
            item_id2text[item_id] = text
        item_meta[item_id] = meta if meta else {}
    
    # Build user history texts
    for user_id, items in train.items():
        user_history[user_id] = [
            item_id2text.get(item_id, f"item_{item_id}")
            for item_id in items
            if item_id in item_id2text
        ]

# Set data structures
reranker.item_id2text = item_id2text
reranker.item_meta = item_meta
reranker.user_history = user_history
reranker.train_user_history = train  # For history lookup
reranker.is_fitted = True  # Mark as fitted

print(f"Reranker setup completed!")
print(f"Item texts: {len(item_id2text)}")
print(f"User histories: {len(user_history)}")


## 6. Evaluation trên Test Set


In [None]:
# Evaluation function giống như trong training
def evaluate_test_split(split: Dict[int, List[int]], candidates_dict: Dict[int, List[int]], k: int = 10) -> Dict[str, float]:
    """Compute metrics on test split."""
    recalls = []
    ndcgs = []
    hits = []
    
    users = sorted(split.keys())
    
    for user_id in tqdm(users, desc="Evaluating"):
        gt_items = split.get(user_id, [])
        if not gt_items:
            continue
        
        # Get candidates for this user
        candidates = candidates_dict.get(user_id, [])
        if not candidates:
            continue
        
        # Get user history
        history = reranker.train_user_history.get(user_id, [])
        if len(history) == 0:
            continue
        
        # Shuffle candidates (giống như trong training)
        random.shuffle(candidates)
        
        # Rerank
        try:
            reranked = reranker.rerank(user_id, candidates, user_history=history)
            ranked_items = [item_id for item_id, _ in reranked[:k]]
            
            # Compute metrics
            r = recall_at_k(ranked_items, gt_items, k)
            n = ndcg_at_k(ranked_items, gt_items, k)
            h = hit_at_k(ranked_items, gt_items, k)
            
            recalls.append(r)
            ndcgs.append(n)
            hits.append(h)
        except Exception as e:
            print(f"Error evaluating user {user_id}: {e}")
            continue
    
    return {
        f"recall@{k}": float(np.mean(recalls)) if recalls else 0.0,
        f"ndcg@{k}": float(np.mean(ndcgs)) if ndcgs else 0.0,
        f"hit@{k}": float(np.mean(hits)) if hits else 0.0,
        "num_users": len(recalls)
    }

# Evaluate on test set
print("Evaluating on test set...")
ks = [1, 5, 10, 20]
test_metrics = {}

for k in ks:
    print(f"\nComputing metrics @{k}...")
    metrics = evaluate_test_split(test, test_candidates, k=k)
    test_metrics.update(metrics)

# Print results
print("\n" + "=" * 80)
print("Test Evaluation Results")
print("=" * 80)
print(f"Mode: {qwen_mode}")
print(f"Model: {qwen_model}")
print(f"Checkpoint: {checkpoint_path}")
print(f"Evaluated users: {test_metrics.get('num_users', 0)}")
print("-" * 80)
print(f"{'Metric':<12} {'@1':>10} {'@5':>10} {'@10':>10} {'@20':>10}")

for metric_name in ["recall", "ndcg", "hit"]:
    values = [test_metrics.get(f"{metric_name}@{k}", 0.0) for k in ks]
    print(f"{metric_name.capitalize():<12} {values[0]:>10.4f} {values[1]:>10.4f} {values[2]:>10.4f} {values[3]:>10.4f}")

print("=" * 80)


## 7. (Optional) So sánh với Validation Set


In [None]:
# Evaluate on validation set để so sánh
val_candidates = all_candidates.get("val", {})

if val_candidates:
    print("\nEvaluating on validation set for comparison...")
    val_metrics = {}
    
    for k in ks:
        metrics = evaluate_test_split(val, val_candidates, k=k)
        val_metrics.update(metrics)
    
    print("\n" + "=" * 80)
    print("Validation Evaluation Results")
    print("=" * 80)
    print(f"Evaluated users: {val_metrics.get('num_users', 0)}")
    print("-" * 80)
    print(f"{'Metric':<12} {'@1':>10} {'@5':>10} {'@10':>10} {'@20':>10}")
    
    for metric_name in ["recall", "ndcg", "hit"]:
        values = [val_metrics.get(f"{metric_name}@{k}", 0.0) for k in ks]
        print(f"{metric_name.capitalize():<12} {values[0]:>10.4f} {values[1]:>10.4f} {values[2]:>10.4f} {values[3]:>10.4f}")
    
    print("=" * 80)
else:
    print("No validation candidates found.")
