# EASE Model Validation Pipeline
This notebook validates the EASE model metrics on the test set

In [None]:
import numpy as np
import pandas as pd
from scipy import sparse as sps
from tqdm import tqdm
from math import log2

## Load Data
Load the preprocessed data from ease_train.ipynb

In [None]:
# Load the grouped data
full_grouped_data = pd.read_csv('../recsys_tests/full_data.csv', index_col=0)

# Convert string representations back to lists
import ast
full_grouped_data['train_interactions'] = full_grouped_data['train_interactions'].apply(ast.literal_eval)
full_grouped_data['val_interactions'] = full_grouped_data['val_interactions'].apply(ast.literal_eval)
full_grouped_data['test_interactions'] = full_grouped_data['test_interactions'].apply(ast.literal_eval)

print(f"Data loaded: {len(full_grouped_data)} users")
full_grouped_data.head()

## Define Metrics Functions

In [None]:
def hit_rate(recommendations, ground_truth, k=100):
    """
    Calculate HitRate@k

    Args:
        recommendations: dict {user_id: list of recommended item_ids}
        ground_truth: dict {user_id: set of relevant item_ids}
        k: cutoff level
    """
    hits = 0
    total_users = len(recommendations)

    for user_id, recs in recommendations.items():
        user_recs = recs[:k]
        user_truth = ground_truth.get(user_id, set())
        if any(item in user_truth for item in user_recs):
            hits += 1

    return hits / total_users if total_users > 0 else 0.0


def precision(recommendations, ground_truth, k=100):
    """
    Calculate Precision@k
    """
    precisions = []

    for user_id, recs in recommendations.items():
        user_recs = recs[:k]
        user_truth = ground_truth.get(user_id, set())
        relevant_count = sum(1 for item in user_recs if item in user_truth)
        user_precision = relevant_count / k
        precisions.append(user_precision)

    return np.mean(precisions) if precisions else 0.0


def recall(recommendations, ground_truth, k=100):
    """
    Calculate Recall@k
    """
    recalls = []

    for user_id, recs in recommendations.items():
        user_recs = recs[:k]
        user_truth = ground_truth.get(user_id, set())

        if not user_truth:  # If no ground truth items, recall is 0
            recalls.append(0.0)
            continue

        relevant_count = sum(1 for item in user_recs if item in user_truth)
        user_recall = relevant_count / len(user_truth)
        recalls.append(user_recall)

    return np.mean(recalls) if recalls else 0.0


def mrr(recommendations, ground_truth, k=100):
    """
    Calculate MRR@k
    """
    reciprocal_ranks = []

    for user_id, recs in recommendations.items():
        user_recs = recs[:k]
        user_truth = ground_truth.get(user_id, set())

        user_rr = 0.0
        for rank, item in enumerate(user_recs, 1):
            if item in user_truth:
                user_rr = 1.0 / rank
                break

        reciprocal_ranks.append(user_rr)

    return np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0


def ndcg(recommendations, ground_truth, k=100, binary_relevance=True):
    """
    Calculate NDCG@k

    Args:
        binary_relevance: if True, uses binary relevance (0/1),
                         if False, expects relevance scores in ground_truth
    """
    ndcg_scores = []

    for user_id, recs in recommendations.items():
        user_recs = recs[:k]
        user_truth = ground_truth.get(user_id, {})

        # Calculate DCG
        dcg = 0.0
        for rank, item in enumerate(user_recs, 1):
            if binary_relevance:
                rel = 1.0 if item in user_truth else 0.0
            else:
                rel = user_truth.get(item, 0.0)

            dcg += rel / (log2(rank + 1) if rank == 1 else 1)

        # Calculate IDCG
        if binary_relevance:
            # For binary relevance, ideal is all 1's sorted first
            num_relevant = len(user_truth)
            ideal_gains = [1.0] * min(k, num_relevant)
        else:
            # For graded relevance, take top-k relevance scores
            ideal_gains = sorted(user_truth.values(), reverse=True)[:k]

        idcg = 0.0
        for rank, rel in enumerate(ideal_gains, 1):
            idcg += rel / (log2(rank + 1) if rank == 1 else 1)

        user_ndcg = dcg / idcg if idcg > 0 else 0.0
        ndcg_scores.append(user_ndcg)

    return np.mean(ndcg_scores) if ndcg_scores else 0.0


def evaluate_model(
    df: pd.DataFrame, preds_col: str, gt_col: str, top_k: int = 20
) -> dict:
    recommendations = pd.Series(df[preds_col].values, index=df["user_id"]).to_dict()

    ground_truth = pd.Series(
        df[gt_col].apply(set).values, index=df["user_id"]
    ).to_dict()

    hr = hit_rate(recommendations, ground_truth, k=top_k)
    p = precision(recommendations, ground_truth, k=top_k)
    r = recall(recommendations, ground_truth, k=top_k)
    m = mrr(recommendations, ground_truth, k=top_k)
    n = ndcg(recommendations, ground_truth, k=top_k)

    results = {
        f"hit_rate@{top_k}": hr,
        f"precision@{top_k}": p,
        f"recall@{top_k}": r,
        f"mrr@{top_k}": m,
        f"ndcg@{top_k}": n,
    }

    return results

## Reconstruct mappings and matrix
We need to reconstruct the item2id and id2item mappings, as well as the training matrix

In [None]:
# Get all unique items from all interactions
all_items = set()
for interactions in full_grouped_data['train_interactions']:
    all_items.update(interactions)

# Create sorted list for consistent mapping
unique_items = sorted(list(all_items))
item2id = {item: idx for idx, item in enumerate(unique_items)}
id2item = {idx: item for item, idx in item2id.items()}

# Get user mapping
unique_users = full_grouped_data['user_id'].values
user2id = {user: idx for idx, user in enumerate(unique_users)}
id2user = {idx: user for user, idx in user2id.items()}

print(f"Number of items: {len(item2id)}")
print(f"Number of users: {len(user2id)}")

In [None]:
# Build training matrix from train interactions
rows = []
cols = []

for idx, row in full_grouped_data.iterrows():
    user_id = row['user_id']
    user_idx = user2id[user_id]
    for item in row['train_interactions']:
        if item in item2id:
            rows.append(user_idx)
            cols.append(item2id[item])

values = np.ones(len(rows))
matrix_train = sps.coo_matrix(
    (values, (rows, cols)),
    shape=(len(user2id), len(item2id)),
    dtype=np.float64
)

print(f"Training matrix shape: {matrix_train.shape}")
print(f"Number of interactions: {len(rows)}")

## Train EASE Model

In [None]:
%%time

def fit_ease(X, reg_weight=100):
    """
    Train EASE model using RecBole implementation
    
    Args:
        X: User-item interaction matrix (users x items)
        reg_weight: Regularization weight (default: 100)
    
    Returns:
        B: Item-item weight matrix
    """
    # gram matrix
    G = X.T @ X

    # add reg to diagonal
    G += reg_weight * sps.identity(G.shape[0])

    # convert to dense because inverse will be dense
    G = G.todense()

    # invert. this takes most of the time
    P = np.linalg.inv(G)
    B = P / (-np.diag(P))
    # zero out diag
    np.fill_diagonal(B, 0.)
    
    return B

w = fit_ease(matrix_train)
print(f"Model weights shape: {w.shape}")

## Generate Predictions

In [None]:
def get_preds(user_interactions, item2id, id2item, model_weights):
    """
    Generate predictions for a user
    
    Args:
        user_interactions: List of item IDs the user has interacted with
        item2id: Dictionary mapping items to indices
        id2item: Dictionary mapping indices to items
        model_weights: Trained EASE weight matrix
    
    Returns:
        top_indices: Top 20 recommended item indices
    """
    encoded_ids = user_interactions
    
    vector = np.zeros(len(item2id))
    vector[encoded_ids] = 1
    
    preds = vector @ model_weights
    preds[encoded_ids] = -np.inf  # Filter out items already seen
    
    top_indices = np.argsort(-preds)[:20]
    
    decoded = [id2item[i] for i in top_indices]
    
    return top_indices

In [None]:
%%time

w = np.asarray(w)

tqdm.pandas()
full_grouped_data['ease_preds'] = full_grouped_data['train_interactions'].progress_apply(
    lambda interactions: get_preds(interactions, item2id, id2item, w)
)
full_grouped_data.head()

## Evaluate on Test Set

In [None]:
# Evaluate at top_k=10
metrics = evaluate_model(full_grouped_data, 'ease_preds', 'test_interactions', top_k=10)

print("\n" + "="*50)
print("EASE Model Performance on Test Set @ k=10")
print("="*50)
for metric_name, value in metrics.items():
    print(f"{metric_name:20s}: {value}")
print("="*50)

# Display as dictionary
metrics

## Expected Results

The metrics should match:
```python
{
 'hit_rate@10': 0.696154871653995,
 'precision@10': 0.092590108833534,
 'recall@10': 0.4643900770533879,
 'mrr@10': 0.37347038489245216,
 'ndcg@10': 0.4643900770533879
}
```

In [None]:
# Verify metrics match expected values
expected_metrics = {
    'hit_rate@10': 0.696154871653995,
    'precision@10': 0.092590108833534,
    'recall@10': 0.4643900770533879,
    'mrr@10': 0.37347038489245216,
    'ndcg@10': 0.4643900770533879
}

print("\nValidation Check:")
print("="*70)
print(f"{'Metric':<20} {'Expected':<20} {'Actual':<20} {'Match':<10}")
print("="*70)

tolerance = 1e-10
all_match = True

for metric_name, expected_value in expected_metrics.items():
    actual_value = metrics[metric_name]
    match = abs(actual_value - expected_value) < tolerance
    all_match = all_match and match
    match_str = "PASS" if match else "FAIL"
    print(f"{metric_name:<20} {expected_value:<20.15f} {actual_value:<20.15f} {match_str:<10}")

print("="*70)
if all_match:
    print("PASS: All metrics match expected values!")
else:
    print("FAIL: Some metrics do not match. Check data preprocessing.")