In [None]:
import gc
import pickle 
import json 
import numpy as np
import pandas as pd
import polars as pl

from tqdm import tqdm
from annoy import AnnoyIndex
from lightfm import LightFM
from abc import ABC, abstractmethod
from typing import Dict, List
from collections import defaultdict, Counter 
from scipy.sparse import csr_matrix
from sklearn.metrics.pairwise import cosine_similarity



In [3]:
data_folder = "../../../data/"
train = pl.read_parquet(data_folder + "train.pq")
test_exploded = pl.read_parquet(data_folder + "test.pq")
test = test_exploded.group_by("user_id", maintain_order=True).agg(pl.col("item_id"))
books = pl.read_parquet(data_folder + "books.pq")

In [4]:
class Validator(ABC):
    def __init__(self, train: pd.DataFrame, test: pd.DataFrame, cold_items: set = None):
        self.train = train
        self.test = test
        self.cold_items = cold_items or set()

    @abstractmethod
    def evaluate(self, predictions: Dict[int, List[int]]) -> Dict[str, float]:
        """
        predictions: dict user_id -> list of recommended item_ids
        """
        pass

    def recall_at_k(self, y_true: List[int], y_pred: List[int], k: int = 10) -> float:
        return len(set(y_true) & set(y_pred[:k])) / len(set(y_true)) if y_true else 0.0

    def precision_at_k(self, y_true: List[int], y_pred: List[int], k: int = 10) -> float:
        return len(set(y_true) & set(y_pred[:k])) / k if y_true else 0.0

    def hitrate_at_k(self, y_true: List[int], y_pred: List[int], k: int = 10) -> float:
        return 1.0 if len(set(y_true) & set(y_pred[:k])) > 0 else 0.0

    def ndcg_at_k(self, y_true: List[int], y_pred: List[int], k: int = 10) -> float:
        dcg = 0.0
        for i, item in enumerate(y_pred[:k]):
            if item in y_true:
                dcg += 1 / np.log2(i + 2)
        idcg = sum(1 / np.log2(i + 2) for i in range(min(len(y_true), k)))
        return dcg / idcg if idcg > 0 else 0.0

    def mrr_at_k(self, y_true: List[int], y_pred: List[int], k: int = 10) -> float:
        for i, item in enumerate(y_pred[:k]):
            if item in y_true:
                return 1 / (i + 1)
        return 0.0

    def coverage(self, predictions: Dict[int, List[int]]) -> float:
        all_pred_items = set(item for recs in predictions.values() for item in recs)
        all_train_items = set(self.train["item_id"].unique())
        return len(all_pred_items) / len(all_train_items)

    @staticmethod
    def print_metrics(metrics: Dict[str, float]):
        print("\n=== Evaluation Results ===")
        for key, value in metrics.items():
            print(f"{key:<15}: {value:.4f}")
        print("==========================\n")

In [26]:
class JointValidator(Validator):
    def __init__(self, train: pl.DataFrame, test: pl.DataFrame, cold_items: set = None):
        super().__init__(train, test, cold_items)
        # Flatten item_id during aggregation to avoid nested lists
        self.user2items = (
            test.group_by("user_id").agg(pl.col("item_id").flatten()).to_dict(as_series=False)
        )
        self.user2items = dict(zip(self.user2items["user_id"], self.user2items["item_id"]))

    def evaluate(self, predictions: Dict[int, List[int]]) -> Dict[str, float]:
        recalls, precisions, hits, ndcgs, mrrs = [], [], [], [], []
        # Debug: Print sample y_true and y_pred
        print("Sample y_true and y_pred for debugging:")
        for user_id, y_pred in list(predictions.items())[:5]:
            y_true = self.user2items.get(user_id, [])
            print(f"User {user_id}:")
            print(f"  y_true: {y_true}, nested: {any(isinstance(i, list) for i in y_true)}")
            print(f"  y_pred: {y_pred}, nested: {any(isinstance(i, list) for i in y_pred)}")
        
        for user_id, y_pred in predictions.items():
            y_true = self.user2items.get(user_id, [])
            # Additional safety: Flatten y_true if it contains nested lists
            y_true_flat = [item for sublist in ([y_true] if isinstance(y_true, list) and any(isinstance(i, list) for i in y_true) else [y_true]) for item in sublist]
            # Flatten y_pred if it contains nested lists
            y_pred_flat = [item for sublist in ([y_pred] if isinstance(y_pred, list) and any(isinstance(i, list) for i in y_pred) else [y_pred]) for item in sublist]
            recalls.append(self.recall_at_k(y_true_flat, y_pred_flat))
            precisions.append(self.precision_at_k(y_true_flat, y_pred_flat))
            hits.append(self.hitrate_at_k(y_true_flat, y_pred_flat))
            ndcgs.append(self.ndcg_at_k(y_true_flat, y_pred_flat))
            mrrs.append(self.mrr_at_k(y_true_flat, y_pred_flat))
        results = {
            "Recall@10": np.mean(recalls),
            "Precision@10": np.mean(precisions),
            "HitRate@10": np.mean(hits),
            "NDCG@10": np.mean(ndcgs),
            "MRR@10": np.mean(mrrs),
            "Coverage": self.coverage(predictions),
        }
        self.print_metrics(results)
        return results

In [6]:
def _shorten_list(lst, max_len=10):
    """Обрезает длинные списки для красивого вывода"""
    if lst is None:
        return []
    return lst[:max_len] if len(lst) > max_len else lst

def show_predictions(models: dict, data: pl.DataFrame, n=5, verbose=True, is_val=False):
    df = data.sample(n).select(["user_id", "item_id"])
    if is_val:
        df = df.rename({"item_id": "true_items"})

    # добавляем предсказания
    for name, preds in models.items():
        df = df.with_columns(
            pl.col("user_id").map_elements(
                lambda u: _shorten_list(preds.get(u, [])), 
                return_dtype=pl.List(pl.Int64)
            ).alias(name)
        )

    if verbose:
        print(df.shape)
        print(df)

    return df


def val_predictions(models: dict, val: pl.DataFrame, validator: Validator, k: int = 10, verbose: bool = True):
    results = []
    user2items = (
        val.group_by("user_id").agg(pl.col("item_id")).to_dict(as_series=False)
    )
    user2items = dict(zip(user2items["user_id"], user2items["item_id"]))

    for model_name, preds in models.items():
        recalls, precisions, hits, ndcgs, mrrs = [], [], [], [], []
        for u, y_true in user2items.items():
            y_pred = preds.get(u, [])
            recalls.append(validator.recall_at_k(y_true, y_pred, k))
            precisions.append(validator.precision_at_k(y_true, y_pred, k))
            hits.append(validator.hitrate_at_k(y_true, y_pred, k))
            ndcgs.append(validator.ndcg_at_k(y_true, y_pred, k))
            mrrs.append(validator.mrr_at_k(y_true, y_pred, k))
        metrics = {
            "model": model_name,
            "Recall@10": np.mean(recalls),
            "Precision@10": np.mean(precisions),
            "HitRate@10": np.mean(hits),
            "NDCG@10": np.mean(ndcgs),
            "MRR@10": np.mean(mrrs),
            "Coverage": validator.coverage(preds),
        }
        results.append(metrics)

    df = pl.DataFrame(results)
    if verbose:
        print(df)
    return df

In [7]:
with open("popular_model.pkl", "rb") as f:
    popular_items = pickle.load(f)

with open("knn_model.pkl", "rb") as f:
    knn_data = pickle.load(f)
    model_knn = knn_data["model"]
    user_map = knn_data["user_map"]
    item_map = knn_data["item_map"]
    item_ids = knn_data["item_ids"]

with open("lfm_model.pkl", "rb") as f:
    lfm_data = pickle.load(f)
    model_lfm = lfm_data["model"]
    dataset = lfm_data["dataset"]
    all_items = lfm_data["all_items"]
    tag_vectorizer = lfm_data["tag_vectorizer"]

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
recs_popular = {uid: popular_items[:10] for uid in test["user_id"]}

In [16]:
recs_knn = {}
sparse = csr_matrix(
    (
        (train["rating"].to_numpy() + 1).astype(np.float64),
        (
            train["user_id"].replace(user_map).to_numpy().astype(np.int32),
            train["item_id"].replace(item_map).to_numpy().astype(np.int32)
        )
    ),
    shape=(len(user_map), len(item_map))
)
for uid in tqdm(test["user_id"], desc="ItemKNN predictions"):
    u_idx = user_map.get(uid, -1)
    if u_idx == -1:
        recs_knn[uid] = popular_items[:10]
    else:
        recs, _ = model_knn.recommend(u_idx, sparse[u_idx], N=10, filter_already_liked_items=True)
        pred = [item_ids[r] for r in recs if r < len(item_ids)]
        # Flatten if nested lists
        recs_knn[uid] = [item for sublist in ([pred] if any(isinstance(i, list) for i in pred) else [pred]) for item in sublist][:10] or popular_items[:10]

ItemKNN predictions: 100%|██████████| 185828/185828 [00:04<00:00, 40357.95it/s]


In [20]:
# Generate predictions for LightFM with batching
recs_lfm = {}
books_filtered = books.filter(pl.col("item_id").is_in(all_items))
item_features = dataset.build_item_features(
    (iid, tag_vectorizer.transform([tags]).toarray()[0])
    for iid, tags in books_filtered.select("item_id", pl.col("tags").list.join(" ")).iter_rows()
)

# Collect all user indices
user_ids_test = test["user_id"].to_list()
u_indices = np.array([user_map.get(uid, -1) for uid in user_ids_test])

# Filter valid users
valid_mask = u_indices != -1
valid_u_indices = u_indices[valid_mask]
valid_user_ids = np.array(user_ids_test)[valid_mask]

# Batch size (adjust based on memory; e.g., 1000-5000)
batch_size = 1000
num_valid = len(valid_u_indices)

# Prepare item indices
item_indices = np.arange(len(all_items))

for start in tqdm(range(0, num_valid, batch_size), desc="LightFM batch predictions"):
    end = min(start + batch_size, num_valid)
    batch_u_indices = valid_u_indices[start:end]
    batch_user_ids = valid_user_ids[start:end]
    
    # Create arrays for prediction: repeat user indices for each item
    user_ids_batch = np.repeat(batch_u_indices, len(item_indices))
    item_ids_batch = np.tile(item_indices, len(batch_u_indices))
    
    # Predict scores for the batch
    scores = model_lfm.predict(user_ids_batch, item_ids_batch, item_features=item_features)
    
    # Reshape scores to [num_users_in_batch, num_items]
    scores = scores.reshape(len(batch_u_indices), len(item_indices))
    
    # Get top 10 items for each user in batch
    top_items = np.argsort(-scores, axis=1)[:, :10]
    for idx, uid in enumerate(batch_user_ids):
        recs_lfm[uid] = [all_items[i] for i in top_items[idx]]

# Fallback for cold users
for uid, u_idx in zip(user_ids_test, u_indices):
    if u_idx == -1:
        recs_lfm[uid] = popular_items[:10]

# Flatten if any nested lists (to handle previous issue)
for uid in recs_lfm:
    pred = recs_lfm[uid]
    if any(isinstance(i, list) for i in pred):
        recs_lfm[uid] = [item for sublist in pred for item in sublist][:10]

# Combine predictions
models = {
    "Popular": recs_popular,
    "ItemKNN": recs_knn,
    "LightFM": recs_lfm
}

LightFM batch predictions: 100%|██████████| 186/186 [23:53<00:00,  7.71s/it]


In [21]:
for model_name, preds in models.items():
    print(f"Model: {model_name}")
    for uid, pred in list(preds.items())[:5]:
        print(f"User {uid}: {pred}")

Model: Popular
User 00000377eea48021d3002730d56aca9a: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159]
User 00009ab2ed8cbfceda5a59da40966321: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159]
User 00009e46d18f223a82b22da38586b605: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159]
User 0001085188e302fc6b2568de45a5f56b: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159]
User 00014c578111090720e20f5705eba051: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159]
Model: ItemKNN
User 00000377eea48021d3002730d56aca9a: [13159, 960, 118, 30197, 13064, 12599, 27745, 18315, 11020, 9767]
User 00009ab2ed8cbfceda5a59da40966321: [32516, 20150, 18150, 19293, 15009, 34056, 14106, 29890, 18881, 11562]
User 00009e46d18f223a82b22da38586b605: [18150, 15514, 4058, 15009, 19110, 17956, 118, 30197, 13064, 9147]
User 0001085188e302fc6b2568de45a5f56b: [18150, 4058, 15009, 24671, 9147, 4595, 12489, 27771, 9722, 960]
User 00014

In [22]:
popular_items = train.group_by("item_id").agg(pl.len().alias("pop")).sort("pop", descending=True)["item_id"].head(100).to_list()
print(popular_items[:10])

[4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159]


In [28]:
print("\nСтатистика по метрикам для каждой модели:")
cold_items = set(test_exploded["item_id"].to_list()) - set(train["item_id"].to_list())
validator = JointValidator(train, test, cold_items)
metrics_df = pl.DataFrame()
for model_name, preds in models.items():
    metrics = validator.evaluate(preds)
    metrics["model"] = model_name
    metrics_df = metrics_df.vstack(pl.DataFrame(metrics))


Статистика по метрикам для каждой модели:
Sample y_true and y_pred for debugging:
User 00000377eea48021d3002730d56aca9a:
  y_true: [13252], nested: False
  y_pred: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159], nested: False
User 00009ab2ed8cbfceda5a59da40966321:
  y_true: [2328], nested: False
  y_pred: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159], nested: False
User 00009e46d18f223a82b22da38586b605:
  y_true: [28636, 30197], nested: False
  y_pred: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159], nested: False
User 0001085188e302fc6b2568de45a5f56b:
  y_true: [2159, 2969, 3307, 4059, 4892, 5290, 6912, 7754, 9975, 10830, 11521, 13150, 13787, 15357, 15644, 17929, 20330, 21937, 22488, 22565, 22694, 26793, 27607, 27760, 27886, 28046, 28195, 28198, 28359, 31057, 31381, 32799, 33630], nested: False
  y_pred: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159], nested: False
User 00014c578111090720e20f5705eba05

In [29]:
print(metrics_df)

shape: (3, 7)
┌───────────┬──────────────┬────────────┬──────────┬──────────┬──────────┬─────────┐
│ Recall@10 ┆ Precision@10 ┆ HitRate@10 ┆ NDCG@10  ┆ MRR@10   ┆ Coverage ┆ model   │
│ ---       ┆ ---          ┆ ---        ┆ ---      ┆ ---      ┆ ---      ┆ ---     │
│ f64       ┆ f64          ┆ f64        ┆ f64      ┆ f64      ┆ f64      ┆ str     │
╞═══════════╪══════════════╪════════════╪══════════╪══════════╪══════════╪═════════╡
│ 0.032568  ┆ 0.025603     ┆ 0.163899   ┆ 0.033254 ┆ 0.053643 ┆ 0.000319 ┆ Popular │
│ 0.056261  ┆ 0.050535     ┆ 0.271235   ┆ 0.063643 ┆ 0.098796 ┆ 0.154281 ┆ ItemKNN │
│ 0.000384  ┆ 0.000924     ┆ 0.008659   ┆ 0.001148 ┆ 0.003584 ┆ 0.002396 ┆ LightFM │
└───────────┴──────────────┴────────────┴──────────┴──────────┴──────────┴─────────┘


In [30]:
def show_predictions(models: dict, data: pl.DataFrame, n=5, verbose=True):
    df = data.sample(n).select(["user_id", "item_id"]).rename({"item_id": "true_items"})
    for name, preds in models.items():
        df = df.with_columns(
            pl.col("user_id").map_elements(
                lambda u: preds.get(u, [])[:10], 
                return_dtype=pl.List(pl.Int64)
            ).alias(name)
        )
    if verbose:
        print(df.shape)
        print(df)
    return df

show_predictions(models, test, n=5, verbose=True)

(5, 5)
shape: (5, 5)
┌───────────────────┬───────────────────┬───────────────────┬───────────────────┬──────────────────┐
│ user_id           ┆ true_items        ┆ Popular           ┆ ItemKNN           ┆ LightFM          │
│ ---               ┆ ---               ┆ ---               ┆ ---               ┆ ---              │
│ str               ┆ list[i64]         ┆ list[i64]         ┆ list[i64]         ┆ list[i64]        │
╞═══════════════════╪═══════════════════╪═══════════════════╪═══════════════════╪══════════════════╡
│ ee6b84a97b051caa4 ┆ [20249]           ┆ [4058, 15514, …   ┆ [15514, 17956, …  ┆ [21374, 8475, …  │
│ 9ba50fe5fba49…    ┆                   ┆ 13159]            ┆ 12489]            ┆ 27530]           │
│ b5d378dedadd19d42 ┆ [10461, 17956,    ┆ [4058, 15514, …   ┆ [18150, 19110, …  ┆ [21374, 8475, …  │
│ 495cc3847aac4…    ┆ 19110]            ┆ 13159]            ┆ 28386]            ┆ 10189]           │
│ 773703c694bc6b2ab ┆ [10568]           ┆ [4058, 15514, …   ┆ [15514, 

user_id,true_items,Popular,ItemKNN,LightFM
str,list[i64],list[i64],list[i64],list[i64]
"""ee6b84a97b051caa49ba50fe5fba49…",[20249],"[4058, 15514, … 13159]","[15514, 17956, … 12489]","[21374, 8475, … 27530]"
"""b5d378dedadd19d42495cc3847aac4…","[10461, 17956, 19110]","[4058, 15514, … 13159]","[18150, 19110, … 28386]","[21374, 8475, … 10189]"
"""773703c694bc6b2ab989268ad2cf56…",[10568],"[4058, 15514, … 13159]","[15514, 18150, … 118]","[21374, 8475, … 23477]"
"""d3021681e654e4c14fc7f68d771070…","[1375, 4551, … 27886]","[4058, 15514, … 13159]","[15009, 30197, … 14106]","[21374, 8475, … 21928]"
"""bda394bc3962f4d476c0922edcc8fb…","[4595, 4917, … 24671]","[4058, 15514, … 13159]","[4058, 19110, … 18315]","[21374, 8475, … 9477]"


# Ансамбль

In [34]:
def create_ensemble_predictions(
    recs_popular: Dict[int, List[int]],
    recs_knn: Dict[int, List[int]],
    recs_lfm: Dict[int, List[int]],
    weights: Dict[str, float],
    k: int = 10
) -> Dict[int, List[int]]:
    ensemble_recs = {}
    for uid in recs_popular.keys():  # Assuming all models have same user IDs
        # Collect recommendations from all models
        recs = {
            "Popular": recs_popular.get(uid, [])[:k],
            "ItemKNN": recs_knn.get(uid, [])[:k],
            "LightFM": recs_lfm.get(uid, [])[:k]
        }
        # Borda count: Assign points based on rank and model weight
        borda_scores = Counter()
        for model_name, model_recs in recs.items():
            model_weight = weights[model_name]
            for rank, item in enumerate(model_recs):
                # Higher rank (lower index) gets more points
                borda_scores[item] += model_weight * (k - rank)
        # Sort items by Borda score and take top k
        ensemble_recs[uid] = [
            item for item, _ in borda_scores.most_common(k)
        ]
    return ensemble_recs

# Define weights based on HitRate@10
weights = {
    "Popular": 0.3,  # HitRate@10: 0.1639
    "ItemKNN": 0.6,  # HitRate@10: 0.2712
    "LightFM": 0.1   # HitRate@10: 0.0087
}

if 'recs_popular' not in globals() or 'recs_knn' not in globals() or 'recs_lfm' not in globals():
    raise ValueError("One or more required prediction dictionaries (recs_popular, recs_knn, recs_lfm) are missing.")

In [35]:
ensemble_recs = create_ensemble_predictions(
    recs_popular=recs_popular,
    recs_knn=recs_knn,
    recs_lfm=recs_lfm,
    weights=weights,
    k=10
)

# Add ensemble to models
models["Ensemble"] = ensemble_recs

print("Sample ensemble predictions:")
for uid, pred in list(ensemble_recs.items())[:3]:
    print(f"User {uid[:8]}...: {pred}, nested: {any(isinstance(i, list) for i in pred)}")

Sample ensemble predictions:
User 00000377...: [30197, 13159, 118, 960, 27745, 13064, 4058, 12599, 15514, 18150], nested: False
User 00009ab2...: [18150, 32516, 20150, 19293, 15009, 4058, 34056, 15514, 14106, 30197], nested: False
User 00009e46...: [18150, 15514, 4058, 19110, 15009, 30197, 17956, 118, 33370, 27745], nested: False


In [37]:
print("\nСтатистика по метрикам для каждой модели:")
cold_items = set(test_exploded["item_id"].to_list()) - set(train["item_id"].to_list())
validator = JointValidator(train, test, cold_items)
metrics_df = pl.DataFrame()
for model_name, preds in models.items():
    metrics = validator.evaluate(preds)
    metrics["model"] = model_name
    metrics_df = metrics_df.vstack(pl.DataFrame(metrics))


Статистика по метрикам для каждой модели:
Sample y_true and y_pred for debugging:
User 00000377eea48021d3002730d56aca9a:
  y_true: [13252], nested: False
  y_pred: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159], nested: False
User 00009ab2ed8cbfceda5a59da40966321:
  y_true: [2328], nested: False
  y_pred: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159], nested: False
User 00009e46d18f223a82b22da38586b605:
  y_true: [28636, 30197], nested: False
  y_pred: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159], nested: False
User 0001085188e302fc6b2568de45a5f56b:
  y_true: [2159, 2969, 3307, 4059, 4892, 5290, 6912, 7754, 9975, 10830, 11521, 13150, 13787, 15357, 15644, 17929, 20330, 21937, 22488, 22565, 22694, 26793, 27607, 27760, 27886, 28046, 28195, 28198, 28359, 31057, 31381, 32799, 33630], nested: False
  y_pred: [4058, 15514, 18150, 30197, 33370, 27745, 19110, 118, 17956, 13159], nested: False
User 00014c578111090720e20f5705eba05

In [38]:
# Print the final metrics table
print(metrics_df)

# Save to file to avoid console overflow
metrics_df.write_parquet("metrics_results.pq")
print("Metrics saved to metrics_results.pq")

shape: (4, 7)
┌───────────┬──────────────┬────────────┬──────────┬──────────┬──────────┬──────────┐
│ Recall@10 ┆ Precision@10 ┆ HitRate@10 ┆ NDCG@10  ┆ MRR@10   ┆ Coverage ┆ model    │
│ ---       ┆ ---          ┆ ---        ┆ ---      ┆ ---      ┆ ---      ┆ ---      │
│ f64       ┆ f64          ┆ f64        ┆ f64      ┆ f64      ┆ f64      ┆ str      │
╞═══════════╪══════════════╪════════════╪══════════╪══════════╪══════════╪══════════╡
│ 0.032568  ┆ 0.025603     ┆ 0.163899   ┆ 0.033254 ┆ 0.053643 ┆ 0.000319 ┆ Popular  │
│ 0.056261  ┆ 0.050535     ┆ 0.271235   ┆ 0.063643 ┆ 0.098796 ┆ 0.154281 ┆ ItemKNN  │
│ 0.000384  ┆ 0.000924     ┆ 0.008659   ┆ 0.001148 ┆ 0.003584 ┆ 0.002396 ┆ LightFM  │
│ 0.04886   ┆ 0.04165      ┆ 0.243311   ┆ 0.055189 ┆ 0.091317 ┆ 0.078882 ┆ Ensemble │
└───────────┴──────────────┴────────────┴──────────┴──────────┴──────────┴──────────┘
Metrics saved to metrics_results.pq
