# Evaluation of video CBRS

##### Evaluation of E2E models an dintelligent plotting

****
* pair-cosine
* pair-euclidean
* triplet-cosine
* tirplet-euclidean

## Evaluation Prep

In [1]:
# Standard Library Imports
import os, pickle, re, io
import numpy as np
import pandas as pd
import boto3
from tqdm.notebook import tqdm
from typing import Dict, List, Set, Tuple, Any
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import botocore

# --- GLOBAL CONSTANTS ---
S3_BUCKET_NAME = "md-data-content-recommendation"
RATINGS_BENCHMARK_KEY = 'evaluation/ratings_13919_filtered.csv'
K_VALUES = [1, 3, 5, 10]
FINAL_RESULTS_KEY = "comparison_evaluation/final_all_models_per_epoch_results.csv"

print("Setup Complete.")

  from pandas.core.computation.check import NUMEXPR_INSTALLED


Setup Complete.


In [2]:
#handle s3
def load_data_from_s3(bucket: str, key: str) -> Any:
    """Load data from S3 bucket."""
    s3_client = boto3.client('s3')
    response = s3_client.get_object(Bucket=bucket, Key=key)
    if key.endswith('.pkl'):
        data = pickle.loads(response['Body'].read())
    elif key.endswith('.csv'):
        data = pd.read_csv(response['Body'])
    else:
        raise ValueError(f"Unsupported file format: {key}")
    print(f"Successfully loaded data from s3://{bucket}/{key}")
    return data

def save_data_to_s3(data: Any, bucket: str, key: str) -> None:
    """Save data to S3 bucket."""
    s3_client = boto3.client('s3')
    
    if isinstance(data, pd.DataFrame):
        csv_data = data.to_csv(index=False).encode('utf-8')
        s3_client.put_object(
            Body=csv_data,
            Bucket=bucket,
            Key=key
        )
    else:
        s3_client.put_object(
            Body=pickle.dumps(data),
            Bucket=bucket,
            Key=key
        )
    print(f"Successfully saved data to s3://{bucket}/{key}")
    
def load_feature_embeddings(s3_bucket, key):
    """Load feature embeddings from S3."""
    print("Loading feature embeddings 1...")
    data = load_data_from_s3(s3_bucket, key)
    features = np.array(data["features"]) if not isinstance(data["features"], np.ndarray) else data["features"]
    return features, data["trailer_ids"]

In [3]:
def prepare_user_item_data(ratings_df: pd.DataFrame, sample_size: int = None) -> Tuple[Dict[str, Set[str]], Dict[str, Dict[str, float]]]:
    """
    Prepare user-item interaction data for evaluation.
    Returns:
        - user_items: Dict mapping userId to a set of relevant items (binary_rating=1)
        - user_item_ratings: Dict mapping userId to a dict of itemId -> rating
    """
    print("Preparing user-item interaction data...")
    
    # Map for relevant items (binary_rating=1)
    user_items = defaultdict(set)
    # Map for all ratings
    user_item_ratings = defaultdict(dict)
    
    # If sample_size is provided, limit the number of users
    if sample_size:
        # Get unique user IDs
        unique_users = ratings_df['userId'].unique()
        if len(unique_users) > sample_size:
            # Sample a subset of users
            sampled_users = np.random.choice(unique_users, sample_size, replace=False)
            ratings_df = ratings_df[ratings_df['userId'].isin(sampled_users)]
            print(f"Sampled {sample_size} users out of {len(unique_users)} for faster evaluation")
    
    for _, row in tqdm(ratings_df.iterrows(), total=len(ratings_df)):
        user_id = str(row['userId'])
        movie_id = str(row['movieId'])
        rating = float(row['rating'])
        binary_rating = int(row['binary_rating'])
        
        # Add to relevant items if binary_rating is 1
        if binary_rating == 1:
            user_items[user_id].add(movie_id)
        
        # Store the rating
        user_item_ratings[user_id][movie_id] = rating
    
    print(f"Processed {len(user_items)} users with relevant items")
    return user_items, user_item_ratings

In [4]:
def extract_movie_id_from_trailer_id(trailer_id: str) -> str:
    """
    Extract a movie ID from a trailer ID string.
    This function tries several methods to extract what might be a movie ID.
    """
    # Method 1: Try to extract digits from the trailer ID
    digits = re.findall(r'\d+', trailer_id)
    if digits:
        return digits[0]
    
    # Method 2: Use the trailer ID as is
    return trailer_id

In [5]:
def build_improved_trailer_to_movie_mapping(recommendations, ratings_df):
    """
    Create a more robust mapping from trailer_ids to movieIds.
    This uses multiple approaches including substring matching and
    filtering to find potential matches.
    """
    trailer_to_movie = {}
    movie_ids = set(str(id) for id in ratings_df['movieId'].unique())
    
    # For each trailer ID in recommendations
    for trailer_id in recommendations.keys():
        # Method 1: Direct match (if trailer_id exists directly in movie_ids)
        if trailer_id in movie_ids:
            trailer_to_movie[trailer_id] = trailer_id
            continue
            
        # Method 2: Extract digits and check if they match any movie ID
        digits = re.findall(r'\d+', trailer_id)
        for digit in digits:
            if digit in movie_ids:
                trailer_to_movie[trailer_id] = digit
                break
                
        # Method 3: Try to find movie ID that is a substring of trailer_id
        for movie_id in movie_ids:
            if movie_id in trailer_id:
                trailer_to_movie[trailer_id] = movie_id
                break
                
    print(f"Created mapping for {len(trailer_to_movie)} trailer IDs to movie IDs")
    
    # Print examples to help debug
    examples = list(trailer_to_movie.items())[:5]
    print("Mapping examples (trailer_id -> movie_id):")
    for trailer_id, movie_id in examples:
        print(f"  {trailer_id} -> {movie_id}")
    
    # Also create reverse mapping
    movie_to_trailer = {}
    for trailer_id, movie_id in trailer_to_movie.items():
        if movie_id not in movie_to_trailer:
            movie_to_trailer[movie_id] = trailer_id
    
    print(f"Created reverse mapping for {len(movie_to_trailer)} movie IDs to trailer IDs")
    
    return trailer_to_movie, movie_to_trailer

In [6]:
def verify_id_mapping(ratings_df, recommendations, trailer_to_movie):
    # Get all unique movie IDs from ratings
    rating_movie_ids = set(str(id) for id in ratings_df['movieId'].unique())
    
    mapped_movie_ids = set(trailer_to_movie.values())
    
    # Check overlap
    overlap = rating_movie_ids.intersection(mapped_movie_ids)
    
    print(f"Unique movie IDs in ratings: {len(rating_movie_ids)}")
    print(f"Unique movie IDs from trailer mapping: {len(mapped_movie_ids)}")
    print(f"Overlap between ratings and recommendations: {len(overlap)}")
    
    # Print some examples for debugging
    print("Sample movie IDs from ratings:", list(rating_movie_ids)[:5])
    print("Sample movie IDs from trailer mapping:", list(mapped_movie_ids)[:5])
    
    return overlap

### Evaluation metrics

In [7]:
def precision_at_k(recommended_items: List[str], relevant_items: Set[str], k: int) -> float:
    """Calculate precision@k metric."""
    if k == 0 or len(recommended_items) == 0:
        return 0.0
    
    # Take only the first k recommendations
    recommended_k = recommended_items[:k]
    # Count the number of relevant items in the recommendations
    num_relevant = sum(1 for item in recommended_k if item in relevant_items)
    
    return num_relevant / min(k, len(recommended_k))

In [8]:
def recall_at_k(recommended_items: List[str], relevant_items: Set[str], k: int) -> float:
    """Calculate recall@k metric."""
    if len(relevant_items) == 0:
        return 0.0
    
    # Take only the first k recommendations
    recommended_k = recommended_items[:k]
    # Count the number of relevant items in the recommendations
    num_relevant = sum(1 for item in recommended_k if item in relevant_items)
    
    return num_relevant / len(relevant_items)

In [9]:
def hit_rate_at_k(recommended_items: List[str], relevant_items: Set[str], k: int) -> float:
    """
    Calculate hit rate@k metric.
    Hit rate is 1 if at least one relevant item is in the top-k recommendations, 0 otherwise.
    """
    # Take only the first k recommendations
    recommended_k = recommended_items[:k]
    # Check if any recommended item is relevant
    for item in recommended_k:
        if item in relevant_items:
            return 1.0
    
    return 0.0

In [10]:
def dcg_at_k(recommended_items: List[str], item_ratings: Dict[str, float], k: int) -> float:
    """Calculate Discounted Cumulative Gain at k."""
    if k == 0 or len(recommended_items) == 0:
        return 0.0
    
    # Take only the first k recommendations
    recommended_k = recommended_items[:k]
    
    dcg = 0
    for i, item in enumerate(recommended_k):
        if item in item_ratings:
            # Use the actual rating as relevance score
            rel = item_ratings[item]
            # DCG formula: (2^rel - 1) / log2(i+2)
            dcg += (2 ** rel - 1) / np.log2(i + 2)
    
    return dcg

def idcg_at_k(item_ratings: Dict[str, float], k: int) -> float:
    """Calculate Ideal Discounted Cumulative Gain at k."""
    if k == 0 or len(item_ratings) == 0:
        return 0.0
    
    # Sort ratings in descending order
    sorted_ratings = sorted(item_ratings.values(), reverse=True)
    # Take only the first k ratings (or all if less than k)
    relevant_ratings = sorted_ratings[:min(k, len(sorted_ratings))]
    
    idcg = 0
    for i, rel in enumerate(relevant_ratings):
        # IDCG formula: same as DCG but with optimal ordering
        idcg += (2 ** rel - 1) / np.log2(i + 2)
    
    return idcg

def ndcg_at_k(recommended_items: List[str], item_ratings: Dict[str, float], k: int) -> float:
    """Calculate Normalized Discounted Cumulative Gain at k."""
    idcg = idcg_at_k(item_ratings, k)
    if idcg == 0:
        return 0.0
    
    dcg = dcg_at_k(recommended_items, item_ratings, k)
    return dcg / idcg

In [11]:
def mean_reciprocal_rank(recommended_items: List[str], relevant_items: Set[str], k: int) -> float:
    """
    Calculate Mean Reciprocal Rank (MRR) for a list of recommendations.
    
    MRR measures where the first relevant item appears in the recommendation list.
    For each query (user/item), the reciprocal rank is the inverse of the position 
    of the first relevant item in the results.
    
    Parameters:
    recommended_items: List of recommended item IDs
    relevant_items: Set of relevant item IDs
    k: Number of recommendations to consider
    
    Returns:
    float: MRR score (0 if no relevant items found)
    """
    if not relevant_items or not recommended_items:
        return 0.0
    
    # Consider only top-k recommendations
    rec_items = recommended_items[:k]
    
    # Find the first relevant item
    for i, item in enumerate(rec_items):
        if item in relevant_items:
            # Return reciprocal rank (1-based indexing)
            return 1.0 / (i + 1)
    
    # No relevant items found
    return 0.0

### Do actual evaluation

In [12]:
def item_based_evaluation(recommendations: Dict[str, List[Dict[str, Any]]], 
                          user_items: Dict[str, Set[str]],
                          user_item_ratings: Dict[str, Dict[str, float]],
                          trailer_to_movie: Dict[str, str],
                          k_values: List[int]) -> Dict[str, Dict[int, float]]:
    """
    Evaluate recommendations using an item-based approach.
    For each movie that users have rated, evaluate the recommendations for that movie.
    """
    print("Performing item-based evaluation...")
    
    # Create an inverse mapping from movie IDs to trailer IDs
    movie_to_trailer = {movie_id: trailer_id for trailer_id, movie_id in trailer_to_movie.items()}
    
    # Initialize metrics
    metrics = {
        'precision': {k: 0.0 for k in k_values},
        'recall': {k: 0.0 for k in k_values},
        'hit_rate': {k: 0.0 for k in k_values},
        'ndcg': {k: 0.0 for k in k_values},
        'MRR': {k: 0.0 for k in k_values}
    }
    
    # Track the number of evaluated items
    evaluated_items = 0
    
    # Create a set of all rated movie IDs
    all_rated_movies = set()
    for user_ratings in user_item_ratings.values():
        all_rated_movies.update(user_ratings.keys())
    
    print(f"Found {len(all_rated_movies)} unique rated movies")
    
    # For each movie that has been rated
    for movie_id in tqdm(all_rated_movies, desc="Evaluating items"):
        # Find the corresponding trailer ID
        trailer_id = movie_to_trailer.get(movie_id)
        
        # Skip if we don't have recommendations for this movie
        if not trailer_id or trailer_id not in recommendations:
            continue
        
        # Get the recommendations for this movie
        movie_recs = recommendations[trailer_id]
        rec_movie_ids = []
        
        # Convert trailer IDs in recommendations to movie IDs
        for rec in movie_recs:
            rec_trailer_id = rec['trailer_id']
            if rec_trailer_id in trailer_to_movie:
                rec_movie_id = trailer_to_movie[rec_trailer_id]
                rec_movie_ids.append(rec_movie_id)
        
        # Find users who have rated this movie positively
        relevant_users = []
        for user_id, relevant_items in user_items.items():
            if movie_id in relevant_items:
                relevant_users.append(user_id)
        
        # Skip if no users rated this movie positively
        if not relevant_users:
            continue
        
        # For each relevant user, evaluate the recommendations
        item_metrics = {
            'precision': {k: 0.0 for k in k_values},
            'recall': {k: 0.0 for k in k_values},
            'hit_rate': {k: 0.0 for k in k_values},
            'ndcg': {k: 0.0 for k in k_values},
            'MRR': {k: 0.0 for k in k_values}

        }
        
        valid_users = 0
        for user_id in relevant_users:
            # Get the set of other movies this user rated positively
            other_relevant_items = user_items[user_id] - {movie_id}
            
            # Skip if user has no other relevant items
            if not other_relevant_items:
                continue
            
            # Get all ratings from this user
            user_ratings = user_item_ratings[user_id]
            
            valid_users += 1
            
            # Calculate metrics for each k
            for k in k_values:
                item_metrics['precision'][k] += precision_at_k(rec_movie_ids, other_relevant_items, k)
                item_metrics['recall'][k] += recall_at_k(rec_movie_ids, other_relevant_items, k)
                item_metrics['hit_rate'][k] += hit_rate_at_k(rec_movie_ids, other_relevant_items, k)
                item_metrics['ndcg'][k] += ndcg_at_k(rec_movie_ids, user_ratings, k)
                item_metrics['MRR'][k] += mean_reciprocal_rank(rec_movie_ids, other_relevant_items, k)
        
        # Average metrics for this item
        if valid_users > 0:
            for metric in item_metrics:
                for k in k_values:
                    item_metrics[metric][k] /= valid_users
            
            # Add to overall metrics
            for metric in metrics:
                for k in k_values:
                    metrics[metric][k] += item_metrics[metric][k]
            
            evaluated_items += 1
    
    # Average metrics across all evaluated items
    if evaluated_items > 0:
        for metric in metrics:
            for k in k_values:
                metrics[metric][k] /= evaluated_items
    
    print(f"Evaluated {evaluated_items} valid items")
    return metrics


# MAIN

In [13]:
def get_or_create_user_data(s3_client: Any, bucket: str, user_items_key: str, user_ratings_key: str, ratings_df: pd.DataFrame) -> Tuple[Dict, Dict]:
    """
    Checks if pre-processed user data exists on S3 and loads it.
    If not, it processes the data, saves it to S3, and then returns it.
    """
    try:
        # Check if BOTH files exist by trying to get their metadata
        s3_client.head_object(Bucket=bucket, Key=user_items_key)
        s3_client.head_object(Bucket=bucket, Key=user_ratings_key)

        print("✅ Cache hit! Loading pre-processed user data from S3...")
        user_items = load_data_from_s3(bucket, user_items_key)
        user_item_ratings = load_data_from_s3(bucket, user_ratings_key)
        print("Pre-processed user data loaded successfully.")
        
    except botocore.exceptions.ClientError as e:
        # If the error is 404 (Not Found), the files don't exist
        if e.response['Error']['Code'] == '404':
            print("⚠️ Cache miss! Pre-processed data not found on S3.")
            print("Processing user-item data now (this may take over 20 minutes)...")
            
            # Run the original, slow function
            user_items, user_item_ratings = prepare_user_item_data(ratings_df)
            
            print("Saving processed data to S3 cache for future runs...")
            # Save the results to S3 for next time
            save_data_to_s3(user_items, bucket, user_items_key)
            save_data_to_s3(user_item_ratings, bucket, user_ratings_key)
            
        else:
            # If it's another error (e.g., permissions), raise it
            print("An unexpected S3 error occurred.")
            raise e
            
    return user_items, user_item_ratings

In [1]:
if __name__ == "__main__":
    # --- 1. Global Setup
    S3_BUCKET_NAME = "md-data-content-recommendation"
    RATINGS_BENCHMARK_KEY = 'evaluation/ratings_13919_filtered.csv'
    K_VALUES = [1, 3, 5, 10]

    OUTPUT_BASE_DIR = "comparison_evaluation/"
    
    USER_ITEMS_CACHE_KEY = f"{OUTPUT_BASE_DIR}cache/user_items_13919.pkl"
    USER_RATINGS_CACHE_KEY = f"{OUTPUT_BASE_DIR}cache/user_item_ratings_13919.pkl"
    PARTIAL_RESULTS_PREFIX = f"{OUTPUT_BASE_DIR}partial_results/"
    
    pipeline_names = ["E2E_Pair_Cosine", "E2E_Pair_Euclidean", "E2E_Triplet_Cosine", "E2E_Triplet_Euclidean"]
    baseline_pipelines = {
        "Baseline_FS_Elbow": "FS_similarity_and_recommendations/elbowrecommendations_top_10.pkl",
        "Baseline_All_Features": "similarity_and_recommendations/recommendations_top_10.pkl"
    }
    
    s3_client = boto3.client('s3')

    # --- 2. Load Global Data Once
    print("--- Preparing Global Evaluation Data ---\n")
    ratings_df_global = load_data_from_s3(S3_BUCKET_NAME, RATINGS_BENCHMARK_KEY)
    ratings_df_global["movieId"] = ratings_df_global["movieId"].astype(str)
    
    user_items, user_item_ratings = get_or_create_user_data(
        s3_client,
        S3_BUCKET_NAME,
        USER_ITEMS_CACHE_KEY,
        USER_RATINGS_CACHE_KEY,
        ratings_df_global.copy()
    )

    all_pipelines_history = []
    
    # --- 3. Loop Through Each E2E Pipeline 
    for name in pipeline_names:
        print(f"\n{'='*50}\nEvaluating all epochs for: {name}\n{'='*50}")
        
        rec_prefix = f"evaluation/per_epoch_recs/{name}/"
        response = s3_client.list_objects_v2(Bucket=S3_BUCKET_NAME, Prefix=rec_prefix)
        all_pkl_files = [obj['Key'] for obj in response.get('Contents', []) if obj['Key'].endswith('.pkl')]
        rec_files = [f for f in all_pkl_files if re.search(r'epoch_(\d+)', f)]
        
        if not rec_files:
            print(f"WARNING: No files with epoch numbers found for pipeline '{name}'. Skipping.")
            continue
            
        rec_files.sort(key=lambda x: int(re.search(r'epoch_(\d+)', x).group(1)))
        
        for rec_key in tqdm(rec_files, desc=f"Evaluating {name} recs"):
            epoch_num = int(re.search(r'epoch_(\d+)', rec_key).group(1))
            
            partial_result_key = f"{PARTIAL_RESULTS_PREFIX}{name}_epoch_{epoch_num}.pkl"
            try:
                s3_client.head_object(Bucket=S3_BUCKET_NAME, Key=partial_result_key)
                print(f"✅ Found existing result for {name} epoch {epoch_num}. Skipping calculation.")
                flat_metrics = load_data_from_s3(S3_BUCKET_NAME, partial_result_key)
                all_pipelines_history.append(flat_metrics)
                continue # Skip to the next epoch
            except botocore.exceptions.ClientError as e:
                if e.response['Error']['Code'] != '404':
                    raise
            
            print(f"Calculating metrics for {name} epoch {epoch_num}...")
            recommendations = load_data_from_s3(S3_BUCKET_NAME, rec_key)
            trailer_to_movie, _ = build_improved_trailer_to_movie_mapping(recommendations, ratings_df_global)
            metrics = item_based_evaluation(recommendations, user_items, user_item_ratings, trailer_to_movie, K_VALUES)
            
            flat_metrics = {'model': name, 'epoch': epoch_num}
            for metric, values in metrics.items():
                for k, value in values.items():
                    flat_metrics[f"{metric}@{k}"] = value
            all_pipelines_history.append(flat_metrics)

            print(f"Saving checkpoint for {name} epoch {epoch_num}...")
            save_data_to_s3(flat_metrics, S3_BUCKET_NAME, partial_result_key)


    # --- 4. Evaluate the Baseline Models
    for name, rec_key in baseline_pipelines.items():
        print(f"\n{'='*50}\nEvaluating Baseline: {name}\n{'='*50}")
        
        partial_result_key = f"{PARTIAL_RESULTS_PREFIX}{name}_baseline.pkl"
        try:
            s3_client.head_object(Bucket=S3_BUCKET_NAME, Key=partial_result_key)
            print(f"Found existing result for baseline {name}. Skipping calculation.")
            flat_metrics = load_data_from_s3(S3_BUCKET_NAME, partial_result_key)
            all_pipelines_history.append(flat_metrics)
            continue
        except botocore.exceptions.ClientError as e:
            if e.response['Error']['Code'] != '404': raise

        print(f"Calculating metrics for baseline {name}...")
        recommendations = load_data_from_s3(S3_BUCKET_NAME, rec_key)
        trailer_to_movie, _ = build_improved_trailer_to_movie_mapping(recommendations, ratings_df_global)
        metrics = item_based_evaluation(recommendations, user_items, user_item_ratings, trailer_to_movie, K_VALUES)
        
        flat_metrics = {'model': name, 'epoch': -1}
        for metric, values in metrics.items():
            for k, value in values.items():
                flat_metrics[f"{metric}@{k}"] = value
        all_pipelines_history.append(flat_metrics)
        
        print(f"Saving checkpoint for baseline {name}...")
        save_data_to_s3(flat_metrics, S3_BUCKET_NAME, partial_result_key)


    # --- 5. Create and Save the Final DataFrame
    final_results_df = pd.DataFrame(all_pipelines_history)
    print("\n--- Full Evaluation History ---")
    print(final_results_df.head())
    
    final_output_key = f"{OUTPUT_BASE_DIR}final_all_models_per_epoch_results.csv"
    save_data_to_s3(final_results_df, S3_BUCKET_NAME, final_output_key)
    print(f"Final aggregated results saved to s3://{S3_BUCKET_NAME}/{final_output_key}")

# PLOT

In [None]:
# --- Load the final results we just created ---
results_df = load_data_from_s3(S3_BUCKET_NAME, FINAL_RESULTS_KEY)

# Convert epoch to numeric. Baselines have epoch -1.
results_df['epoch'] = pd.to_numeric(results_df['epoch'])

# Separate E2E models from baselines for plotting
e2e_df = results_df[results_df['epoch'] != -1].copy()
baseline_df = results_df[results_df['epoch'] == -1].copy()

def plot_metric_evolution(df_e2e, df_baseline, metrics_to_plot, primary_metric='ndcg@10'):
    num_metrics = len(metrics_to_plot)
    ncols = 2
    nrows = (num_metrics + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 8, nrows * 6), sharex=False)
    axes = axes.flatten()
    model_names = df_e2e['model'].unique()
    palette = sns.color_palette("husl", len(model_names))
    model_color_map = dict(zip(model_names, palette))

    # --- Find the best epoch for each model based on the primary metric ---
    best_epoch_indices = df_e2e.loc[df_e2e.groupby('model')[primary_metric].idxmax()]
    
    for i, metric in enumerate(metrics_to_plot):
        ax = axes[i]
        
        # Plot the E2E models' evolution over time
        sns.lineplot(data=df_e2e, x='epoch', y=metric, hue='model', palette=palette, marker='o', alpha=0.8, ax=ax)
        
        # Plot baselines as horizontal dashed lines
        for _, row in df_baseline.iterrows():
            ax.axhline(y=row.get(metric, 0), linestyle='--', label=f"{row['model']} (Baseline)", color='gray', alpha=0.9)

        # Highlight the best epoch for each model with a star
        for model_name in model_names:
            best_epoch_row = best_epoch_indices[best_epoch_indices['model'] == model_name]
            if not best_epoch_row.empty:
                best_epoch_num = best_epoch_row['epoch'].iloc[0]
                # Get the metric value at that specific best epoch
                metric_val_at_best_epoch = df_e2e[(df_e2e['model'] == model_name) & (df_e2e['epoch'] == best_epoch_num)][metric].iloc[0]
                
                ax.scatter(best_epoch_num, metric_val_at_best_epoch, 
                           marker='*', s=350, color=model_color_map[model_name], 
                           edgecolor='black', zorder=10, 
                           label=f'Best {model_name} (Epoch {int(best_epoch_num)})')

        ax.set_title(f'Evolution of {metric}', fontsize=14, fontweight='bold')
        ax.set_xlabel('Epoch')
        ax.set_ylabel(metric)
        ax.grid(True, which='both', linestyle='--', linewidth=0.5)
        
        # Improve legend
        handles, labels = ax.get_legend_handles_labels()
        # Custom filtering to avoid duplicate labels from scatter
        unique_labels = dict(zip(labels, handles))
        ax.legend(unique_labels.values(), unique_labels.keys(), title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')

    # Hide any unused subplots
    for j in range(num_metrics, len(axes)):
        axes[j].set_visible(False)
        
    fig.suptitle(f'Comparison of E2E Model Performance (Best Epoch by {primary_metric})', fontsize=18, fontweight='bold')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# --- Call the plotting function ---
key_metrics_to_plot = ['ndcg@10', 'hit_rate@10', 'precision@10', 'MRR@10']
plot_metric_evolution(e2e_df, baseline_df, key_metrics_to_plot)

# --- Display Final Summary Table ---
print("\n--- Final Performance Table (Best Epoch vs. Baselines) ---")
# Find the best performing epoch for each E2E model
best_e2e_df = e2e_df.loc[e2e_df.groupby('model')['ndcg@10'].idxmax()]
# Combine with baselines
final_table_df = pd.concat([baseline_df, best_e2e_df]).set_index('model')
print(final_table_df[[col for col in final_table_df.columns if '@' in col or col == 'epoch']])