In [1]:
from helpers import *
from tensorflow.keras.models import load_model

In [2]:
model = load_model('./input/content-based.h5', custom_objects={'ndcg_5': ndcg_5, 'ndcg_10': ndcg_10, 'mean_mrr': mean_mrr, 'g_auc': g_auc})

In [3]:
# Load DataFrames from disk
user_profiles_df_all = pd.read_pickle("./input/user_profiles_df_all.pkl")
df_articles = pd.read_pickle("./input/df_articles.pkl")
article_embeddings_df = pd.read_pickle("./input/article_embeddings_df.pkl")

In [8]:
def infer_all_articles_scores(user_id, df, df_articles, article_embeddings_df, model):
    # Retrieve the user's embedding
    user_profile = df[df['user_id'] == user_id].iloc[0]
    
    if user_profile.empty:
        raise ValueError("User ID not found in the user profiles.")

    user_embedding = user_profile['user_embedding']

    # Get all articles embeddings
    embeddings_dict = article_embeddings_df.T.to_dict('list')
    
    article_ids = list(embeddings_dict.keys())
    combined_features_list = [np.concatenate((user_embedding, article_embedding)).reshape(1, -1) 
                              for article_embedding in embeddings_dict.values()]

    all_embeddings = np.vstack(combined_features_list)
    
    # Predict relevance scores using the trained model
    scores = model.predict(all_embeddings, verbose=0).flatten()

    # Create a dataframe with article IDs, category IDs, and scores
    article_scores_df = df_articles[['article_id', 'category_id']].copy()
    article_scores_df['score'] = article_scores_df['article_id'].map(dict(zip(article_ids, scores)))
    
    # Remove any unwanted header rows if present
    # article_scores_df.columns = article_scores_df.columns.droplevel(0)
    article_scores_df.reset_index(drop=True, inplace=True)
    return article_scores_df

In [9]:
user_id=4
articles_scores = infer_all_articles_scores(user_id, user_profiles_df_all, df_articles, article_embeddings_df, model)

In [10]:
articles_scores

Unnamed: 0,article_id,category_id,score
0,160974,281,0.982020
1,272143,399,0.998865
2,336221,437,0.739584
3,234698,375,0.995500
4,123909,250,0.993134
...,...,...,...
46028,283269,412,0.872488
46029,329065,436,0.015920
46030,38473,51,0.000897
46031,289316,421,0.309292
