In [63]:
from tqdm import tqdm
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
import pandas as pd
from typing import Protocol, Callable, Any, Iterable
from collections import Counter
from sklearn.metrics.pairwise import cosine_distances
from pathlib import Path
import warnings
import datetime
import inspect
import polars as pl
from dataclasses import dataclass, field
import tensorflow.keras as keras
from tensorflow.keras import layers, backend as K
from tensorflow.keras.layers import Embedding, Input, Dropout, Dense, BatchNormalization
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.regularizers import l2
import argparse
import sys

import datetime as dt

import tensorflow as tf
import numpy as np
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input, Concatenate
import pandas as pd


In [18]:
train_history = pd.read_parquet(r"C:\Users\bilba\Downloads\DL_Small\history.parquet", engine='pyarrow')
train_behaviors = pd.read_parquet(r"C:\Users\bilba\Downloads\DL_Small\behaviors.parquet", engine='pyarrow').dropna(subset=["article_id"])

val_history = pd.read_parquet(r"C:\Users\bilba\Downloads\DL_Small\validation\history.parquet", engine='pyarrow')
val_behaviors = pd.read_parquet(r"C:\Users\bilba\Downloads\DL_Small\validation\behaviors.parquet", engine='pyarrow').dropna(subset=["article_id"])

articles_ = pd.read_parquet(r"C:\Users\bilba\Downloads\DL_Small\articles.parquet", engine='pyarrow')

word2vec_file = pd.read_parquet(r"C:\Users\bilba\Downloads\Ekstra_Bladet_word2vec\document_vector.parquet", engine='pyarrow')

In [65]:
args = get_args()
hparams = hparams_nrms

TEXT_COLUMNS_TO_USE = [DEFAULT_TITLE_COL, DEFAULT_SUBTITLE_COL, DEFAULT_BODY_COL]

TRANSFORMER_MODEL_NAME = args.transformer_model_name
MAX_TITLE_LENGTH = 300
hparams.title_size = 300
hparams.word_emb_dim = 300
hparams.history_size = 5
hparams.head_num = 4
hparams.head_dim = 75
hparams.attention_hidden_dim = args.attention_hidden_dim
hparams.optimizer = "adam"
hparams.loss = "log_loss"
hparams.dropout = args.dropout
hparams.learning_rate = 0.001
hparams.num_candidate_news = 10


hparams.newsencoder_units_per_layer = None  # [300, 300, 300]

In [None]:
def generate_samples(row, k=hparams.num_candidate_news):
    user_id = row['user_id']
    clicked = row['article_ids_clicked']
    inview = list(set(row['article_ids_inview']) - set(clicked))  # Remove clicked from inview
    samples = []
    
    # Positive samples (clicked)
    for article_id in clicked:
        samples.append((user_id, article_id, 1))  # Label 1 for clicked articles
    
    # Negative samples (random from inview)
    if inview:  # Only if there are inview articles left
        for article_id in clicked:
            n_samples = min(k, len(inview))
            negatives = random.sample(inview, n_samples)
            samples.extend((user_id, neg_id, 0) for neg_id in negatives)  # Label 0 for non-clicked articles
            
    return samples

# Sample behavior data and apply the generate_samples function
sample_df = train_behaviors.sample(n=500, random_state=42)  # Sample 500 rows for example

# Apply the sample generation to behaviors
sample_df['samples'] = sample_df.apply(generate_samples, axis=1)

# Convert list of samples into a DataFrame
samples = [item for sublist in sample_df['samples'] for item in sublist]
train_df = pd.DataFrame(samples, columns=['user_id', 'article_id', 'label'])

In [36]:
def map_article_features_with_word2vec(articles_df, word2vec_file, hparams):
    """
    Map article IDs to word embeddings using word2vec_file.

    Args:
        articles_df (DataFrame): DataFrame with article information (including article_id).
        word2vec_file (DataFrame): DataFrame containing article_id and document_vector.
        hparams (dict): Hyperparameters with 'title_size' for padding.

    Returns:
        dict: Mapping of article_id to padded document vectors.
    """
    #Create a mapping from article_id to its embedding
    article_embeddings = {
        row['article_id']: np.array(row['document_vector'])
        for _, row in word2vec_file.iterrows()
    }
    
    article_features = {}
    for _, row in articles_df.iterrows():
        article_id = row['article_id']
        embedding = article_embeddings.get(article_id, np.zeros(hparams.word_emb_dim))
        
        # If the embedding is shorter than title_size, pad
        if len(embedding) < hparams.title_size:
            padded_embedding = np.pad(
                embedding,
                (0, hparams.title_size - len(embedding)),
                mode='constant',
                constant_values=0
            )
        else:
            padded_embedding = embedding[:hparams.title_size]

        article_features[article_id] = padded_embedding

    return article_features

In [None]:
# Explode all list columns into single rows
expanded_history = train_history.explode(
    ["impression_time_fixed", "scroll_percentage_fixed", "article_id_fixed", "read_time_fixed"],
    ignore_index=True,
)

expanded_history["impression_time_fixed"] = pd.to_datetime(expanded_history["impression_time_fixed"])

# Sort by user_id and impression_time_fixed (most recent first)
expanded_history = expanded_history.sort_values(by=["user_id", "impression_time_fixed"], ascending=[True, False])

truncated_history = (
    expanded_history.groupby("user_id")
    .head(hparams.history_size)  # Keep only the top N most recent articles
    .groupby("user_id")["article_id_fixed"]
    .apply(list)
    .reset_index()
)

# Pad histories with 0 if too short
truncated_history["article_id_fixed"] = truncated_history["article_id_fixed"].apply(
    lambda x: x + [0] * (hparams.history_size - len(x))
)


behaviors_df = train_behaviors[["user_id","article_ids_inview","article_ids_clicked"]]

articles_df = articles_[["article_id","title"]]

In [None]:
# Get article embeddings
article_embeddings = map_article_features_with_word2vec(articles_df, word2vec_file, hparams)
train_df['title_features'] = train_df['article_id'].map(article_embeddings)

def get_user_embedding(user_id, history, hparams): #get user_embeddings here
    embeddings = [article_embeddings.get(article_id, np.zeros(hparams.title_size)) for article_id in history]
    if len(embeddings) < hparams.history_size:
        padding = [np.zeros(hparams.title_size)] * (hparams.history_size - len(embeddings))
        embeddings = padding + embeddings
    elif len(embeddings) > hparams.history_size:
        embeddings = embeddings[:hparams.history_size]
    return embeddings

#get user history embeddings
train_df['user_history_embeddings'] = train_df['user_id'].apply(lambda x: get_user_embedding(x, truncated_history.loc[truncated_history['user_id'] == x, 'article_id_fixed'].values[0], hparams))


In [56]:
train_df

Unnamed: 0,user_id,article_id,label,title_features,user_history_embeddings
0,956287,9769432,1,"[0.027196923, -0.006216835, 0.037383795, 0.046...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
1,956287,9771235,0,"[0.036877304, -0.020704815, 0.005635035, 0.026...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
2,956287,9771113,0,"[0.021741945, -0.060298197, 0.090856925, 0.052...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
3,956287,9176912,0,"[0.04941954, -0.021159602, 0.016012726, 0.0304...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
4,956287,9120051,0,"[0.01949257, 0.019124068, 0.03619889, 0.067475...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
...,...,...,...,...,...
4287,1880227,9775388,0,"[0.019473838, 0.0071321866, 0.022165135, 0.035...","[[0.013565401, -0.027250506, 0.05643654, 0.032..."
4288,1880227,9757869,0,"[0.013565401, -0.027250506, 0.05643654, 0.0320...","[[0.013565401, -0.027250506, 0.05643654, 0.032..."
4289,1880227,9775776,0,"[0.009557325, -0.046068132, 0.04188653, -0.012...","[[0.013565401, -0.027250506, 0.05643654, 0.032..."
4290,1880227,9649171,0,"[-0.010361439, -0.025747402, 0.076410465, -0.0...","[[0.013565401, -0.027250506, 0.05643654, 0.032..."


In [None]:
articles_pl = pl.from_pandas(articles_)

articles_selected = articles_pl[["article_id", "published_time"]].with_columns(pl.col("published_time").dt.hour().alias("published_hour"))
train_df_pl = pl.from_pandas(train_df)
train_df_with_time = train_df_pl.join(articles_selected, on="article_id", how="left")
train_df_with_time_pd = train_df_with_time.to_pandas()
train_df_with_time_pd.head()

Unnamed: 0,user_id,article_id,label,title_features,user_history_embeddings,published_time,published_hour
0,956287,9769432,1,"[0.027196923, -0.006216835, 0.037383795, 0.046...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2023-05-18 13:47:15,13
1,956287,9771235,0,"[0.036877304, -0.020704815, 0.005635035, 0.026...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2023-05-18 15:59:53,15
2,956287,9771113,0,"[0.021741945, -0.060298197, 0.090856925, 0.052...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2023-05-18 14:11:10,14
3,956287,9176912,0,"[0.04941954, -0.021159602, 0.016012726, 0.0304...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2022-03-22 10:14:33,10
4,956287,9120051,0,"[0.01949257, 0.019124068, 0.03619889, 0.067475...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2022-02-11 05:49:31,5


In [66]:
DEFAULT_HISTORY_ARTICLE_ID_COL = "article_ids_inview"
DEFAULT_CLICKED_ARTICLES_COL = "article_ids_clicked"
DEFAULT_IMPRESSION_TIMESTAMP_COL = "published_time"
DEFAULT_USER_COL = "user_id"
DEFAULT_LABELS_COL = "labels"
NPRATIO = 5  #number of negative samples
SEED = 42

In [None]:
print(f"Initiating training-dataloader")
train_dataloader = NRMSLoader_training(
    behaviors=df_train,
    article_dict=article_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    eval_mode=False,
    batch_size=BS_TRAIN,
)

val_dataloader = NRMSLoader_training(
    behaviors=df_validation,
    article_dict=article_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    eval_mode=False,
    batch_size=BS_TRAIN,
)

In [67]:
# Dump paths:
DUMP_DIR = Path("ebnerd_predictions")
DUMP_DIR.mkdir(exist_ok=True, parents=True)
#
DT_NOW = dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")  # Sanitize the timestamp
#
MODEL_NAME = model_func.__name__
MODEL_OUTPUT_NAME = f"{MODEL_NAME}-{DT_NOW}"
#
ARTIFACT_DIR = DUMP_DIR.joinpath("test_predictions", MODEL_OUTPUT_NAME)
# Model monitoring:
MODEL_WEIGHTS = DUMP_DIR.joinpath(f"state_dict/{MODEL_OUTPUT_NAME}/weights")
LOG_DIR = DUMP_DIR.joinpath(f"runs/{MODEL_OUTPUT_NAME}")
# Evaluating the test test can be memory intensive, we'll chunk it up:
TEST_CHUNKS_DIR = ARTIFACT_DIR.joinpath("test_chunks")
TEST_CHUNKS_DIR.mkdir(parents=True, exist_ok=True)
N_CHUNKS_TEST = args.n_chunks_test
CHUNKS_DONE = args.chunks_done  # if it crashes, you can start from here.
# Just trying keeping the dataframe slime:
COLUMNS = [
    DEFAULT_IMPRESSION_TIMESTAMP_COL,
    DEFAULT_HISTORY_ARTICLE_ID_COL,
    DEFAULT_INVIEW_ARTICLES_COL,
    DEFAULT_CLICKED_ARTICLES_COL,
    DEFAULT_IMPRESSION_ID_COL,
    DEFAULT_USER_COL,
]


tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=LOG_DIR,
    histogram_freq=1,
)
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_auc",
    mode="max",
    patience=4,
    restore_best_weights=True,
)
modelcheckpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=MODEL_WEIGHTS,
    monitor="val_auc",
    mode="max",
    save_best_only=True,
    save_weights_only=True,
    verbose=1,
)
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_auc",
    mode="max",
    factor=0.2,
    patience=2,
    min_lr=1e-6,
)
callbacks = [tensorboard_callback, early_stopping, modelcheckpoint, lr_scheduler]


callbacks = [tensorboard_callback, early_stopping, modelcheckpoint, lr_scheduler]

model = model_func(
    hparams=hparams,
    seed=42,
)
model.model.compile(
    optimizer=model.model.optimizer,
    loss=model.model.loss,
    metrics=["AUC"],
)

In [69]:
model.model.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_7 (InputLayer)           [(None, None, 300)]  0           []                               
                                                                                                  
 input_6 (InputLayer)           [(None, 5, 300)]     0           []                               
                                                                                                  
 time_distributed_3 (TimeDistri  (None, None, 300)   9930400     ['input_7[0][0]']                
 buted)                                                                                           
                                                                                                  
 user_encoder (Functional)      (None, 300)          10260800    ['input_6[0][0]']          

In [92]:

class NewsRecommender(Model):
    def __init__(self, hidden_dims=[256, 128], dropout_rate=0.2):
        super(NewsRecommender, self).__init__()
        
        self.hidden_layers = []
        for dim in hidden_dims:
            self.hidden_layers.extend([
                Dense(dim, activation='relu'),
                BatchNormalization(),
                Dropout(dropout_rate)
            ])
        
        self.final_layer = Dense(1, activation='sigmoid')
        
    def call(self, inputs):
        title_features, user_history = inputs
        
        # Average the user history embeddings if there are multiple
        if len(user_history.shape) == 3:  # Shape: (batch_size, history_length, embedding_dim)
            user_history = tf.reduce_mean(user_history, axis=1)
            
        # Concatenate title features and user history
        x = tf.concat([title_features, user_history], axis=-1)
        
        # Pass through hidden layers
        for layer in self.hidden_layers:
            x = layer(x)
            
        return self.final_layer(x)

def prepare_data(df):
    """Prepare features and labels from the dataframe"""
    # Convert features to numpy arrays
    title_features = np.stack(df['title_features'].values)
    user_history = np.stack(df['user_history_embeddings'].values)
    labels = df['label'].values
    
    return [title_features, user_history], labels

# Model parameters
HIDDEN_DIMS = [256, 128]
DROPOUT_RATE = 0.2
LEARNING_RATE = 0.001
BATCH_SIZE = 32
EPOCHS = 10

# Create train/validation split
train_size = int(0.8 * len(train_df))
train_data = train_df.iloc[:train_size]
val_data = train_df.iloc[train_size:]

# Prepare training and validation data
X_train, y_train = prepare_data(train_data)
X_val, y_val = prepare_data(val_data)

# Initialize model
model = NewsRecommender(hidden_dims=HIDDEN_DIMS, dropout_rate=DROPOUT_RATE)

# Compile model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC()]
)

# Train model
history = model.fit(
    X_train,
    y_train,
    validation_data=(X_val, y_val),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=3,
            restore_best_weights=True
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=2
        )
    ]
)

# Print final metrics
final_train_loss, final_train_acc, final_train_auc = model.evaluate(X_train, y_train, verbose=0)
final_val_loss, final_val_acc, final_val_auc = model.evaluate(X_val, y_val, verbose=0)

print("\nFinal Results:")
print(f"Train - Loss: {final_train_loss:.4f}, Accuracy: {final_train_acc:.4f}, AUC: {final_train_auc:.4f}")
print(f"Val   - Loss: {final_val_loss:.4f}, Accuracy: {final_val_acc:.4f}, AUC: {final_val_auc:.4f}")

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10

Final Results:
Train - Loss: 0.3457, Accuracy: 0.8820, AUC: 0.7687
Val   - Loss: 0.3531, Accuracy: 0.8871, AUC: 0.5259


In [82]:
train_df

Unnamed: 0,user_id,article_id,label,title_features,user_history_embeddings
0,956287,9769432,1,"[0.027196923, -0.006216835, 0.037383795, 0.046...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
1,956287,9771235,0,"[0.036877304, -0.020704815, 0.005635035, 0.026...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
2,956287,9771113,0,"[0.021741945, -0.060298197, 0.090856925, 0.052...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
3,956287,9176912,0,"[0.04941954, -0.021159602, 0.016012726, 0.0304...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
4,956287,9120051,0,"[0.01949257, 0.019124068, 0.03619889, 0.067475...","[[0.028511511, 0.007786474, 0.009804755, 0.033..."
...,...,...,...,...,...
4287,1880227,9775388,0,"[0.019473838, 0.0071321866, 0.022165135, 0.035...","[[0.013565401, -0.027250506, 0.05643654, 0.032..."
4288,1880227,9757869,0,"[0.013565401, -0.027250506, 0.05643654, 0.0320...","[[0.013565401, -0.027250506, 0.05643654, 0.032..."
4289,1880227,9775776,0,"[0.009557325, -0.046068132, 0.04188653, -0.012...","[[0.013565401, -0.027250506, 0.05643654, 0.032..."
4290,1880227,9649171,0,"[-0.010361439, -0.025747402, 0.076410465, -0.0...","[[0.013565401, -0.027250506, 0.05643654, 0.032..."


In [95]:
train_df_with_time_pd 

Unnamed: 0,user_id,article_id,label,title_features,user_history_embeddings,published_time,published_hour
0,956287,9769432,1,"[0.027196923, -0.006216835, 0.037383795, 0.046...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2023-05-18 13:47:15,13
1,956287,9771235,0,"[0.036877304, -0.020704815, 0.005635035, 0.026...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2023-05-18 15:59:53,15
2,956287,9771113,0,"[0.021741945, -0.060298197, 0.090856925, 0.052...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2023-05-18 14:11:10,14
3,956287,9176912,0,"[0.04941954, -0.021159602, 0.016012726, 0.0304...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2022-03-22 10:14:33,10
4,956287,9120051,0,"[0.01949257, 0.019124068, 0.03619889, 0.067475...","[[0.028511511, 0.007786474, 0.009804755, 0.033...",2022-02-11 05:49:31,5
...,...,...,...,...,...,...,...
4287,1880227,9775388,0,"[0.019473838, 0.0071321866, 0.022165135, 0.035...","[[0.013565401, -0.027250506, 0.05643654, 0.032...",2023-05-22 09:13:11,9
4288,1880227,9757869,0,"[0.013565401, -0.027250506, 0.05643654, 0.0320...","[[0.013565401, -0.027250506, 0.05643654, 0.032...",2023-05-17 10:18:09,10
4289,1880227,9775776,0,"[0.009557325, -0.046068132, 0.04188653, -0.012...","[[0.013565401, -0.027250506, 0.05643654, 0.032...",2023-05-22 09:13:46,9
4290,1880227,9649171,0,"[-0.010361439, -0.025747402, 0.076410465, -0.0...","[[0.013565401, -0.027250506, 0.05643654, 0.032...",2023-02-24 10:19:38,10


In [98]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input, Concatenate, Embedding
import pandas as pd
import ast

class NewsRecommenderWithTime(Model):
    def __init__(self, hidden_dims=[256, 128], dropout_rate=0.2):
        super(NewsRecommenderWithTime, self).__init__()
        
        # Hour embedding layer
        self.hour_embedding = Embedding(24, 8)  # 24 hours -> 8 dims
        
        self.hidden_layers = []
        for dim in hidden_dims:
            self.hidden_layers.extend([
                Dense(dim, activation='relu'),
                BatchNormalization(),
                Dropout(dropout_rate)
            ])
        
        self.final_layer = Dense(1, activation='sigmoid')
        
    def call(self, inputs):
        title_features, user_history, hours = inputs
        
        # Process hour features
        hour_features = self.hour_embedding(hours)
        hour_features = tf.squeeze(hour_features, axis=1)
        
        # Average the user history embeddings if there are multiple
        if len(user_history.shape) == 3:
            user_history = tf.reduce_mean(user_history, axis=1)
            
        # Concatenate all features
        x = tf.concat([title_features, user_history, hour_features], axis=-1)
        
        # Pass through hidden layers
        for layer in self.hidden_layers:
            x = layer(x)
            
        return self.final_layer(x)

def prepare_data(df):
    """Prepare features and labels from the dataframe with proper tensor conversion"""
    
    # Process title features
    title_features = []
    for feat in df['title_features'].values:
        if isinstance(feat, str):
            feat = ast.literal_eval(feat)
        title_features.append(np.array(feat, dtype=np.float32))
    title_features = np.array(title_features)
    
    # Process user history
    user_histories = []
    for hist in df['user_history_embeddings'].values:
        if isinstance(hist, str):
            hist = ast.literal_eval(hist)
        # Convert to numpy array and take mean
        hist_array = np.array([np.array(h, dtype=np.float32) for h in hist])
        mean_hist = np.mean(hist_array, axis=0)
        user_histories.append(mean_hist)
    user_histories = np.array(user_histories, dtype=np.float32)
    
    # Process hours
    hours = df['published_hour'].values.astype(np.int32).reshape(-1, 1)
    
    # Process labels
    labels = df['label'].values.astype(np.float32)
    
    # Convert to tensors
    title_features = tf.convert_to_tensor(title_features, dtype=tf.float32)
    user_histories = tf.convert_to_tensor(user_histories, dtype=tf.float32)
    hours = tf.convert_to_tensor(hours, dtype=tf.int32)
    labels = tf.convert_to_tensor(labels, dtype=tf.float32)
    
    return [title_features, user_histories, hours], labels

# Model parameters
HIDDEN_DIMS = [256, 128]
DROPOUT_RATE = 0.2
LEARNING_RATE = 0.001
BATCH_SIZE = 32
EPOCHS = 10

print("Preparing data splits...")
# Create train/validation split
train_size = int(0.8 * len(train_df_with_time_pd))
train_data = train_df_with_time_pd.iloc[:train_size]
val_data = train_df_with_time_pd.iloc[train_size:]

print("Processing training data...")
X_train, y_train = prepare_data(train_data)
print("Processing validation data...")
X_val, y_val = prepare_data(val_data)

print("Initializing model...")
# Initialize model
model = NewsRecommenderWithTime(hidden_dims=HIDDEN_DIMS, dropout_rate=DROPOUT_RATE)

# Compile model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC()]
)

# Train model
print("\nStarting model training...")
history = model.fit(
    X_train,
    y_train,
    validation_data=(X_val, y_val),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=3,
            restore_best_weights=True
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=2
        )
    ]
)

# Print final metrics
final_train_loss, final_train_acc, final_train_auc = model.evaluate(X_train, y_train, verbose=0)
final_val_loss, final_val_acc, final_val_auc = model.evaluate(X_val, y_val, verbose=0)

print("\nFinal Results:")
print(f"Train - Loss: {final_train_loss:.4f}, Accuracy: {final_train_acc:.4f}, AUC: {final_train_auc:.4f}")
print(f"Val   - Loss: {final_val_loss:.4f}, Accuracy: {final_val_acc:.4f}, AUC: {final_val_auc:.4f}")

Preparing data splits...
Processing training data...
Processing validation data...
Initializing model...

Starting model training...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10

Final Results:
Train - Loss: 0.3296, Accuracy: 0.8820, AUC: 0.8173
Val   - Loss: 0.3476, Accuracy: 0.8871, AUC: 0.5970


In [53]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Embedding, Input, Dropout, Dense, BatchNormalization
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.regularizers import l2

class NRMSModel:
    """NRMS model(Neural News Recommendation with Multi-Head Self-Attention)"""

    def __init__(self, hparams: dict, word2vec_embedding: np.ndarray = None, seed: int = None):
        """Initialization steps for NRMS."""
        self.hparams = hparams
        self.seed = seed

        # Set seed for reproducibility
        tf.random.set_seed(seed)
        np.random.seed(seed)

        # Initialize the word embeddings
        if word2vec_embedding is None:
            initializer = GlorotUniform(seed=self.seed)
            self.word2vec_embedding = initializer(shape=(self.hparams['vocab_size'], self.hparams['word_emb_dim']))
        else:
            self.word2vec_embedding = word2vec_embedding

        # Build and compile the model
        self.model, self.scorer = self._build_graph()
        data_loss = self._get_loss(self.hparams['loss'])
        train_optimizer = self._get_opt(optimizer=self.hparams['optimizer'], lr=self.hparams['learning_rate'])
        self.model.compile(loss=data_loss, optimizer=train_optimizer)

    def _get_loss(self, loss: str):
        """Make loss function, consists of data loss and regularization loss"""
        if loss == "cross_entropy_loss":
            return "categorical_crossentropy"
        elif loss == "log_loss":
            return "binary_crossentropy"
        else:
            raise ValueError(f"this loss not defined {loss}")

    def _get_opt(self, optimizer: str, lr: float):
        """Get the optimizer"""
        if optimizer == "adam":
            return tf.keras.optimizers.Adam(learning_rate=lr)
        else:
            raise ValueError(f"this optimizer not defined {optimizer}")

    def _build_graph(self):
        """Build NRMS model and scorer"""
        model, scorer = self._build_nrms()
        return model, scorer

    def _build_userencoder(self):
        """Create user encoder"""
        his_input_title = tf.keras.Input(shape=(self.hparams['history_size'], self.hparams['title_size']), dtype="int32")
        click_title_presents = tf.keras.layers.TimeDistributed(self.newsencoder)(his_input_title)
        y = SelfAttention(self.hparams['head_num'], self.hparams['head_dim'], seed=self.seed)([click_title_presents] * 3)
        user_present = AttLayer2(self.hparams['attention_hidden_dim'], seed=self.seed)(y)
        model = tf.keras.Model(his_input_title, user_present, name="user_encoder")
        return model

    def _build_newsencoder(self):
        """Create news encoder"""
        embedding_layer = tf.keras.layers.Embedding(self.word2vec_embedding.shape[0], self.word2vec_embedding.shape[1], weights=[self.word2vec_embedding], trainable=True)
        sequences_input_title = tf.keras.Input(shape=(self.hparams['title_size'],), dtype="int32")
        embedded_sequences_title = embedding_layer(sequences_input_title)
        y = tf.keras.layers.Dropout(self.hparams['dropout'])(embedded_sequences_title)
        y = SelfAttention(self.hparams['head_num'], self.hparams['head_dim'], seed=self.seed)([y, y, y])
        for layer_size in self.hparams['newsencoder_units_per_layer']:
            y = tf.keras.layers.Dense(units=layer_size, activation="relu")(y)
            y = tf.keras.layers.BatchNormalization()(y)
            y = tf.keras.layers.Dropout(self.hparams['dropout'])(y)
        pred_title = AttLayer2(self.hparams['attention_hidden_dim'], seed=self.seed)(y)
        model = tf.keras.Model(sequences_input_title, pred_title, name="news_encoder")
        return model

    def _build_nrms(self):
        """Create the core NRMS model with user and news encoders"""
        his_input_title = tf.keras.Input(shape=(self.hparams['history_size'], self.hparams['title_size']), dtype="int32")
        pred_input_title = tf.keras.Input(shape=(None, self.hparams['title_size']), dtype="int32")
        pred_input_title_one = tf.keras.Input(shape=(1, self.hparams['title_size']), dtype="int32")
        pred_title_one_reshape = tf.keras.layers.Reshape((self.hparams['title_size'],))(pred_input_title_one)

        self.userencoder = self._build_userencoder()
        self.newsencoder = self._build_newsencoder()

        user_present = self.userencoder(his_input_title)
        news_present = tf.keras.layers.TimeDistributed(self.newsencoder)(pred_input_title)
        news_present_one = self.newsencoder(pred_title_one_reshape)

        preds = tf.keras.layers.Dot(axes=-1)([news_present, user_present])
        preds = tf.keras.layers.Activation(activation="softmax")(preds)

        pred_one = tf.keras.layers.Dot(axes=-1)([news_present_one, user_present])
        pred_one = tf.keras.layers.Activation(activation="sigmoid")(pred_one)

        model = tf.keras.Model([his_input_title, pred_input_title], preds)
        scorer = tf.keras.Model([his_input_title, pred_input_title_one], pred_one)

        return model, scorer


In [1]:
def get_transformers_word_embeddings(model):
    """
    Extracts the word embeddings from a pre-trained transformer model.
    For TensorFlow models, this uses the `model.get_input_embeddings()` method
    to retrieve the embedding layer.

    Returns:
        np.ndarray: The word embeddings as a NumPy array.
    """
    embedding_layer = model.get_input_embeddings()  # Get the embedding layer
    # Convert to NumPy
    return embedding_layer.weights[0].numpy()

  from .autonotebook import tqdm as notebook_tqdm


In [47]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import tensorflow as tf
import numpy as np

from tensorflow.keras.layers import Embedding, Input, Dropout, Dense, BatchNormalization
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.regularizers import l2


#ebrec/models/newsrec/nrms.py

class NRMSModel:
    """NRMS model(Neural News Recommendation with Multi-Head Self-Attention)

    Chuhan Wu, Fangzhao Wu, Suyu Ge, Tao Qi, Yongfeng Huang,and Xing Xie, "Neural News
    Recommendation with Multi-Head Self-Attention" in Proceedings of the 2019 Conference
    on Empirical Methods in Natural Language Processing and the 9th International Joint Conference
    on Natural Language Processing (EMNLP-IJCNLP)

    Attributes:
    """

    def __init__(
        self,
        hparams: dict,
        word2vec_embedding: np.ndarray = None,
        word_emb_dim: int = 300,
        vocab_size: int = 32000,
        seed: int = None,
    ):
        """Initialization steps for NRMS."""
        self.hparams = hparams
        self.seed = seed

        # SET SEED:
        tf.random.set_seed(seed)
        np.random.seed(seed)

        # INIT THE WORD-EMBEDDINGS:
        if word2vec_embedding is None:
            # Xavier Initialization
            initializer = GlorotUniform(seed=self.seed)
            self.word2vec_embedding = initializer(shape=(vocab_size, word_emb_dim))
            # self.word2vec_embedding = np.random.rand(vocab_size, word_emb_dim)
        else:
            self.word2vec_embedding = word2vec_embedding

        # BUILD AND COMPILE MODEL:
        self.model, self.scorer = self._build_graph()
        data_loss = self._get_loss(self.hparams.loss)
        train_optimizer = self._get_opt(
            optimizer=self.hparams.optimizer, lr=self.hparams.learning_rate
        )
        self.model.compile(loss=data_loss, optimizer=train_optimizer)

    def _get_loss(self, loss: str):
        """Make loss function, consists of data loss and regularization loss
        Returns:
            object: Loss function or loss function name
        """
        if loss == "cross_entropy_loss":
            data_loss = "categorical_crossentropy"
        elif loss == "log_loss":
            data_loss = "binary_crossentropy"
        else:
            raise ValueError(f"this loss not defined {loss}")
        return data_loss

    def _get_opt(self, optimizer: str, lr: float):
        """Get the optimizer according to configuration. Usually we will use Adam.
        Returns:
            object: An optimizer.
        """
        # TODO: shouldn't be a string input you should just set the optimizer, to avoid stuff like this:
        # => 'WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.'
        if optimizer == "adam":
            train_opt = tf.keras.optimizers.Adam(learning_rate=lr)
        else:
            raise ValueError(f"this optimizer not defined {optimizer}")
        return train_opt

    def _build_graph(self):
        """Build NRMS model and scorer.

        Returns:
            object: a model used to train.
            object: a model used to evaluate and inference.
        """
        model, scorer = self._build_nrms()
        return model, scorer

    def _build_userencoder(self, titleencoder):
        """The main function to create user encoder of NRMS.

        Args:
            titleencoder (object): the news encoder of NRMS.

        Return:
            object: the user encoder of NRMS.
        """
        his_input_title = tf.keras.Input(
            shape=(self.hparams.history_size, self.hparams.title_size), dtype="int32"
        )

        click_title_presents = tf.keras.layers.TimeDistributed(titleencoder)(
            his_input_title
        )
        y = SelfAttention(self.hparams.head_num, self.hparams.head_dim, seed=self.seed)(
            [click_title_presents] * 3
        )
        user_present = AttLayer2(self.hparams.attention_hidden_dim, seed=self.seed)(y)

        model = tf.keras.Model(his_input_title, user_present, name="user_encoder")
        return model

    def _build_newsencoder(self, units_per_layer: list[int] = None):
        """The main function to create news encoder of NRMS.

        Args:
            embedding_layer (object): a word embedding layer.

        Return:
            object: the news encoder of NRMS.
        """
        embedding_layer = tf.keras.layers.Embedding(
            self.word2vec_embedding.shape[0],
            self.word2vec_embedding.shape[1],
            weights=[self.word2vec_embedding],
            trainable=True,
        )
        sequences_input_title = tf.keras.Input(
            shape=(self.hparams.title_size,), dtype="int32"
        )
        embedded_sequences_title = embedding_layer(sequences_input_title)

        y = tf.keras.layers.Dropout(self.hparams.dropout)(embedded_sequences_title)
        y = SelfAttention(self.hparams.head_num, self.hparams.head_dim, seed=self.seed)(
            [y, y, y]
        )

        # Create configurable Dense layers (the if - else is something I've added):
        if units_per_layer:
            for layer in units_per_layer:
                y = tf.keras.layers.Dense(
                    units=layer,
                    activation="relu",
                    kernel_regularizer=tf.keras.regularizers.l2(
                        self.hparams.newsencoder_l2_regularization
                    ),
                )(y)
                y = tf.keras.layers.BatchNormalization()(y)
                y = tf.keras.layers.Dropout(self.hparams.dropout)(y)
        else:
            y = tf.keras.layers.Dropout(self.hparams.dropout)(y)

        pred_title = AttLayer2(self.hparams.attention_hidden_dim, seed=self.seed)(y)

        model = tf.keras.Model(sequences_input_title, pred_title, name="news_encoder")
        return model

    def _build_nrms(self):
        """The main function to create NRMS's logic. The core of NRMS
        is a user encoder and a news encoder.

        Returns:
            object: a model used to train.
            object: a model used to evaluate and inference.
        """

        his_input_title = tf.keras.Input( 
            shape=(self.hparams.history_size, self.hparams.title_size),
            dtype="int32",
        )
        pred_input_title = tf.keras.Input(
            # shape = (hparams.npratio + 1, hparams.title_size)
            shape=(None, self.hparams.title_size),
            dtype="int32",
        )
        pred_input_title_one = tf.keras.Input(
            shape=(
                1,
                self.hparams.title_size,
            ),
            dtype="int32",
        )
        pred_title_one_reshape = tf.keras.layers.Reshape((self.hparams.title_size,))(
            pred_input_title_one
        )
        titleencoder = self._build_newsencoder(
            units_per_layer=self.hparams.newsencoder_units_per_layer
        )
        self.userencoder = self._build_userencoder(titleencoder)
        self.newsencoder = titleencoder

        user_present = self.userencoder(his_input_title)
        news_present = tf.keras.layers.TimeDistributed(self.newsencoder)(
            pred_input_title
        )
        news_present_one = self.newsencoder(pred_title_one_reshape)

        preds = tf.keras.layers.Dot(axes=-1)([news_present, user_present])
        preds = tf.keras.layers.Activation(activation="softmax")(preds)

        pred_one = tf.keras.layers.Dot(axes=-1)([news_present_one, user_present])
        pred_one = tf.keras.layers.Activation(activation="sigmoid")(pred_one)

        model = tf.keras.Model([his_input_title, pred_input_title], preds)
        scorer = tf.keras.Model([his_input_title, pred_input_title_one], pred_one)

        return model, scorer

In [2]:
#ebrec/evaluation/metrics/_sklearn.py

try:
    from sklearn.metrics import (
        # _regression:
        mean_squared_error,
        # _ranking:
        roc_auc_score,
        # _classification:
        accuracy_score,
        f1_score,
        log_loss,
    )
except ImportError:
    print("sklearn not available")


#ebrec/evaluation/metrics/_ranking.py


def reciprocal_rank_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Computes the Mean Reciprocal Rank (MRR) score.

    Args:
        y_true (np.ndarray): A 1D array of ground-truth labels. These should be binary (0 or 1),
                                where 1 indicates the relevant item.
        y_pred (np.ndarray): A 1D array of predicted scores. These scores indicate the likelihood
                                of items being relevant.

    Returns:
        float: The mean reciprocal rank (MRR) score.

    Note:
        Both `y_true` and `y_pred` should be 1D arrays of the same length.
        The function assumes higher scores in `y_pred` indicate higher relevance.

    Examples:
        >>> y_true_1 = np.array([0, 0, 1])
        >>> y_pred_1 = np.array([0.5, 0.2, 0.1])
        >>> reciprocal_rank_score(y_true_1, y_pred_1)
            0.33

        >>> y_true_2 = np.array([0, 1, 1])
        >>> y_pred_2 = np.array([0.5, 0.2, 0.1])
        >>> reciprocal_rank_score(y_true_2, y_pred_2)
            0.5

        >>> y_true_3 = np.array([1, 1, 0])
        >>> y_pred_3 = np.array([0.5, 0.2, 0.1])
        >>> reciprocal_rank_score(y_true_3, y_pred_3)
            1.0

        >>> np.mean(
                [
                    reciprocal_rank_score(y_true, y_pred)
                    for y_true, y_pred in zip(
                        [y_true_1, y_true_2, y_true_3], [y_pred_1, y_pred_2, y_pred_3]
                    )
                ]
            )
            0.61
            mrr_score([y_true_1, y_true_2, y_true_3], [y_pred_1, y_pred_2, y_pred_3])
    """
    order = np.argsort(y_pred)[::-1]
    y_true = np.take(y_true, order)
    first_positive_rank = np.argmax(y_true) + 1
    return 1.0 / first_positive_rank


def dcg_score(y_true: np.ndarray, y_pred: np.ndarray, k: int = 10) -> float:
    """
    Compute the Discounted Cumulative Gain (DCG) score at a particular rank `k`.

    Args:
        y_true (np.ndarray): A 1D or 2D array of ground-truth relevance labels.
                            Each element should be a non-negative integer.
        y_pred (np.ndarray): A 1D or 2D array of predicted scores. Each element is
                            a score corresponding to the predicted relevance.
        k (int, optional): The rank at which the DCG score is calculated. Defaults
                            to 10. If `k` is larger than the number of elements, it
                            will be truncated to the number of elements.

    Note:
        In case of a 2D array, each row represents a different sample.

    Returns:
        float: The calculated DCG score for the top `k` elements.

    Raises:
        ValueError: If `y_true` and `y_pred` have different shapes.

    Examples:
        >>> from sklearn.metrics import dcg_score as dcg_score_sklearn
        >>> y_true = np.array([1, 0, 0, 1, 0])
        >>> y_pred = np.array([0.5, 0.2, 0.1, 0.8, 0.4])
        >>> dcg_score(y_true, y_pred)
            1.6309297535714575
        >>> dcg_score_sklearn([y_true], [y_pred])
            1.6309297535714573
    """
    k = min(np.shape(y_true)[-1], k)
    order = np.argsort(y_pred)[::-1]
    y_true = np.take(y_true, order[:k])
    gains = 2**y_true - 1
    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gains / discounts)


def ndcg_score(y_true: np.ndarray, y_pred: np.ndarray, k: int = 10) -> float:
    """
    Compute the Normalized Discounted Cumulative Gain (NDCG) score at a rank `k`.

    Args:
        y_true (np.ndarray): A 1D or 2D array of ground-truth relevance labels.
                            Each element should be a non-negative integer. In case
                            of a 2D array, each row represents a different sample.
        y_pred (np.ndarray): A 1D or 2D array of predicted scores. Each element is
                            a score corresponding to the predicted relevance. The
                            array should have the same shape as `y_true`.
        k (int, optional): The rank at which the NDCG score is calculated. Defaults
                            to 10. If `k` is larger than the number of elements, it
                            will be truncated to the number of elements.

    Returns:
        float: The calculated NDCG score for the top `k` elements. The score ranges
                from 0 to 1, with 1 representing the perfect ranking.

    Examples:
        >>> from sklearn.metrics import ndcg_score as ndcg_score_sklearn
        >>> y_true = np.array([1, 0, 0, 1, 0])
        >>> y_pred = np.array([0.1, 0.2, 0.1, 0.8, 0.4])
        >>> ndcg_score([y_true], [y_pred])
            0.863780110436402
        >>> ndcg_score_sklearn([y_true], [y_pred])
            0.863780110436402
        >>>
    """
    best = dcg_score(y_true, y_true, k)
    actual = dcg_score(y_true, y_pred, k)
    return actual / best


def mrr_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Computes the Mean Reciprocal Rank (MRR) score.

    THIS MIGHT NOT ALL PROPER, TO BE DETERMIEND:
        - https://github.com/recommenders-team/recommenders/issues/2141

    Args:
        y_true (np.ndarray): A 1D array of ground-truth labels. These should be binary (0 or 1),
                                where 1 indicates the relevant item.
        y_pred (np.ndarray): A 1D array of predicted scores. These scores indicate the likelihood
                                of items being relevant.

    Returns:
        float: The mean reciprocal rank (MRR) score.

    Note:
        Both `y_true` and `y_pred` should be 1D arrays of the same length.
        The function assumes higher scores in `y_pred` indicate higher relevance.

    Examples:
        >>> y_true = np.array([[1, 0, 0, 1, 0]])
        >>> y_pred = np.array([[0.5, 0.2, 0.1, 0.8, 0.4]])
        >>> mrr_score(y_true, y_pred)
            0.75

    """
    order = np.argsort(y_pred)[::-1]
    y_true = np.take(y_true, order)
    rr_score = y_true / (np.arange(len(y_true)) + 1)
    return np.sum(rr_score) / np.sum(y_true)





#ebrec/evaluation/metric/_classification.py


def auc_score_custom(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """
    Computes the Area Under the Curve (AUC) score for the Receiver Operating Characteristic (ROC) curve using a
    custom method. This implementation is particularly useful for understanding basic ROC curve properties and
    for educational purposes to demonstrate how AUC scores can be manually calculated.

    This function may produce slightly different results compared to standard library implementations (e.g., sklearn's roc_auc_score)
    in cases where positive and negative predictions have the same score. The function treats the problem as a binary classification task,
    comparing the prediction scores for positive instances against those for negative instances directly.

    Args:
        y_true (np.ndarray): A binary array indicating the true classification (1 for positive class and 0 for negative class).
        y_pred (np.ndarray): An array of scores as predicted by a model, indicating the likelihood of each instance being positive.

    Returns:
        float: The calculated AUC score, representing the probability that a randomly chosen positive instance is ranked
                higher than a randomly chosen negative instance based on the prediction scores.

    Raises:
        ValueError: If `y_true` and `y_pred` do not have the same length or if they contain invalid data types.

    Examples:
        >>> y_true = np.array([1, 1, 0, 0, 1, 0, 0, 0])
        >>> y_pred = np.array([0.9999, 0.9838, 0.5747, 0.8485, 0.8624, 0.4502, 0.3357, 0.8985])
        >>> auc_score_custom(y_true, y_pred)
            0.9333333333333333
        >>> from sklearn.metrics import roc_auc_score
        >>> roc_auc_score(y_true, y_pred)
            0.9333333333333333

        An error will occur when pos/neg prediction have same score:
        >>> y_true = np.array([1, 1, 0, 0, 1, 0, 0, 0])
        >>> y_pred = np.array([0.9999, 0.8, 0.8, 0.8485, 0.8624, 0.4502, 0.3357, 0.8985])
        >>> auc_score_custom(y_true, y_pred)
            0.7333
        >>> roc_auc_score(y_true, y_pred)
            0.7667
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    y_true_bool = y_true.astype(np.bool_)
    # Index:
    pos_scores = y_pred[y_true_bool]
    neg_scores = y_pred[np.logical_not(y_true_bool)]
    # Arrange:
    pos_scores = np.repeat(pos_scores, len(neg_scores))
    neg_scores = np.tile(neg_scores, sum(y_true_bool))
    assert len(neg_scores) == len(pos_scores)
    return (pos_scores > neg_scores).sum() / len(neg_scores)






#ebrec/evaluation/metrics/beyond_accuracy.py

def intralist_diversity(
    R: np.ndarray,
    pairwise_distance_function: Callable = cosine_distances,
) -> float:
    """Calculate the intra-list diversity of a recommendation list.

    This function implements the method described by Smyth and McClave (2001) to
    measure the diversity within a recommendation list. It calculates the average
    pairwise distance between all items in the list.

    Args:
        R (np.ndarray): A 2D numpy array where each row represents a recommendation.
            This array should be either array-like or a sparse matrix, with shape (n_samples_X, n_features).
        pairwise_distance_function (Callable, optional): A function to compute pairwise distance
            between samples. Defaults to `cosine_distances`.

    Returns:
        float: The calculated diversity score. If the recommendation list contains less than or
            equal to one item, NaN is returned to signify an undefined diversity score.

    Examples:
        >>> R1 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
        >>> print(intralist_diversity(R1))
            0.022588438516842262
        >>> print(intralist_diversity(np.array([[0.1, 0.2], [0.1, 0.2]])))
            1.1102230246251565e-16
    """
    R_n = R.shape[0]  # number of recommendations
    if R_n <= 1:
        # Less than or equal to 1 recommendations in recommendation list
        diversity = np.nan
    else:
        pairwise_distances = pairwise_distance_function(R, R)
        diversity = np.sum(pairwise_distances) / (R_n * (R_n - 1))
    return diversity


def serendipity(
    R: np.ndarray,
    H: np.ndarray,
    pairwise_distance_function: Callable = cosine_distances,
) -> float:
    """Calculate the serendipity score between a set of recommendations and user's reading history.

    This function implements the concept of serendipity as defined by Feng Lu, Anca Dumitrache, and David Graus (2020).
    Serendipity in this context is measured as the mean distance between the items in the recommendation list and the
    user's reading history.

    Args:
        R (np.ndarray): A 2D numpy array representing the recommendation list, where each row is a recommendation.
            It should be either array-like or a sparse matrix, with shape (n_samples_X, n_features).
        H (np.ndarray): A 2D numpy array representing the user's reading history, with the same format as R.
        pairwise_distance_function (Callable, optional): A function to compute pairwise distance between samples.
            Defaults to `cosine_distances`.

    Returns:
        float: The calculated serendipity score.

    References:
        Lu, F., Dumitrache, A., & Graus, D. (2020). Beyond Optimizing for Clicks: Incorporating Editorial Values in News Recommendation.
        Retrieved from https://arxiv.org/abs/2004.09980

    Examples:
        >>> R1 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
        >>> H1 = np.array([[0.7, 0.8, 0.9], [0.1, 0.2, 0.3]])
        >>> print(serendipity(R1, H1))
            0.016941328887631724
    """
    # Compute the pairwise distances between each vector:
    dists = pairwise_distance_function(R, H)
    # Compute serendipity:
    return np.mean(dists)


def coverage_count(R: np.ndarray) -> int:
    """Calculate the number of distinct items in a recommendation list.

    Args:
        R (np.ndarray): An array containing the items in the recommendation list.

    Returns:
        int: The count of distinct items in the recommendation list.

    Examples:
        >>> R1 = np.array([1, 2, 3, 4, 5, 5, 6])
        >>> print(coverage_count(R1))
            6
    """
    # Distinct items:
    return np.unique(R).size


def coverage_fraction(R: np.ndarray, C: np.ndarray) -> float:
    """Calculate the fraction of distinct items in the recommendation list compared to a universal set.

    Args:
        R (np.ndarray): An array containing the items in the recommendation list.
        C (np.ndarray): An array representing the universal set of items.
            It should contain all possible items that can be recommended.

    Returns:
        float: The fraction representing the coverage of the recommendation system.
            This is calculated as the size of unique elements in R divided by the size of unique elements in C.

    Examples:
        >>> R1 = np.array([1, 2, 3, 4, 5, 5, 6])
        >>> C1 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
        >>> print(coverage_fraction(R1, C1))  # Expected output: 0.6
            0.6
    """
    # Distinct items:
    return np.unique(R).size / np.unique(C).size


def novelty(R: np.ndarray) -> float:
    """Calculate the novelty score of recommendations based on their popularity.

    This function computes the novelty score for a set of recommendations by applying the self-information popularity metric.
    It uses the formula described by Zhou et al. (2010) and Vargas and Castells (2011). The novelty is calculated as the
    average negative logarithm (base 2) of the popularity scores of the items in the recommendation list.

    Formula:
        Novelty(R) = ( sum_{i∈R} -log2( p_i ) / ( |R| )

    where p_i represents the popularity score of each item in the recommendation list R, and |R| is the size of R.

    Args:
        R (np.ndarray): An array of popularity scores (p_i) for each item in the recommendation list.

    Returns:
        float: The calculated novelty score. Higher values indicate less popular (more novel) recommendations.

    References:
        Zhou et al. (2010).
        Vargas & Castells (2011).

    Examples:
        >>> print(novelty(np.array([0.1, 0.2, 0.3, 0.4, 0.5])))  # Expected: High score (low popularity scores)
            1.9405499757656586
        >>> print(novelty(np.array([0.9, 0.9, 0.9, 1.0, 0.5])))  # Expected: Low score (high popularity scores)
            0.29120185606703
    """
    return np.mean(-np.log2(R))


def index_of_dispersion(x: list[int]) -> float:
    """
    Computes the Index of Dispersion (variance-to-mean ratio) for a given dataset of nominal variables.

    The Index of Dispersion is a statistical measure used to quantify the dispersion or variability of a distribution
    relative to its mean. It's particularly useful in identifying whether a dataset follows a Poisson distribution,
    where the Index of Dispersion would be approximately 1.

    Formula:
        D = ( k * (N^2 - Σf^2) ) / ( N^2 * (k-1) )
    Where:
        k = number of categories in the data set (including categories with zero items),
        N = number of items in the set,
        f = number of frequencies or ratings,
        Σf^2 = sum of squared frequencies/ratings.

    Args:
        x (list[int]): A list of integers representing frequencies or counts of occurrences in different categories.
                        Each integer in the list corresponds to the count of occurrences in a given category.

    Returns:
        float: The Index of Dispersion for the dataset. Returns `np.nan` if the input list contains only one item,
                indicating an undefined Index of Dispersion. Returns 0 if there's only one category present in the dataset.

    References:
        Walker, 1999, Statistics in criminal
        Source: https://www.statisticshowto.com/index-of-dispersion/

    Examples:
        Given the following categories: Math(25), Economics(42), Chemistry(13), Physical Education (8), Religious Studies (13).
        >>> N = np.sum(25+42+13+8+13)
        >>> k = 5
        >>> sq_f2 = np.sum(25**2 + 42**2 + 13**2 + 8**2 + 13**2)
        >>> iod = ( k * (N**2 - sq_f2)) / ( N**2 * (k-1) )
            0.9079992157631604

        Validate method:
        >>> cat = [[1]*25, [2]*42, [3]*13, [4]*8, [5]*13]
        >>> flat_list = [item for sublist in cat for item in sublist]
        >>> index_of_dispersion(flat_list)
            0.9079992157631604
    """
    # number of items
    N = len(x)
    # compute frequencies
    count = Counter(x)
    # number of categories
    k = len(count)
    if k == 1:
        if N == 1:
            return np.nan
        else:
            return 0
    # squared frequencies
    f_squared = [count.get(f) ** 2 for f in count]
    # compute Index of Dispersion
    D = k * (N**2 - sum(f_squared)) / (N**2 * (k - 1))
    return D

#from ebrec.evaluation.utils import convert_to_binary

def convert_to_binary(y_pred: np.ndarray, threshold: float):
    y_pred = np.asarray(y_pred)
    y_pred[y_pred >= threshold] = 1
    y_pred[y_pred < threshold] = 0
    return y_pred


def is_iterable_nested_dtype(iterable: Iterable[any], dtypes) -> bool:
    """
    Check whether iterable is a nested with dtype,
    note, we assume all types in iterable are the the same.
    Check all cases: any(isinstance(i, dtypes) for i in a)

    Args:
        iterable (Iterable[Any]): iterable (list, array, tuple) of any type of data
        dtypes (Tuple): tuple of possible dtypes, e.g. dtypes = (list, np.ndarray)
    Returns:
        bool: boolean whether it is true or false

    Examples:
        >>> is_iterable_nested_dtype([1, 2, 3], list)
            False
        >>> is_iterable_nested_dtype([1, 2, 3], (list, int))
            True
        >>> is_iterable_nested_dtype([[1], [2], [3]], list)
            True
    """
    return isinstance(iterable[0], dtypes)


def compute_combinations(n: int, r: int) -> int:
    """Compute Combinations where order does not matter (without replacement)

    Source: https://www.statskingdom.com/combinations-calculator.html
    Args:
        n (int): number of items
        r (int): number of items being chosen at a time
    Returns:
        int: number of possible combinations

    Formula:
    * nCr = n! / ( (n - r)! * r! )

    Assume the following:
    * we sample without replacement of items
    * order of the outcomes does NOT matter
    """
    return int(
        (np.math.factorial(n)) / (np.math.factorial(n - r) * np.math.factorial(r))
    )


def scale_range(
    m: np.ndarray,
    r_min: float = None,
    r_max: float = None,
    t_min: float = 0,
    t_max: float = 1.0,
) -> None:
    """Scale an array between a range
    Source: https://stats.stackexchange.com/questions/281162/scale-a-number-between-a-range

    m -> ((m-r_min)/(r_max-r_min)) * (t_max-t_min) + t_min

    Args:
        m ∈ [r_min,r_max] denote your measurements to be scaled
        r_min denote the minimum of the range of your measurement
        r_max denote the maximum of the range of your measurement
        t_min denote the minimum of the range of your desired target scaling
        t_max denote the maximum of the range of your desired target scaling
    """
    if not r_min:
        r_min = np.min(m)
    if not r_max:
        r_max = np.max(m)
    return ((m - r_min) / (r_max - r_min)) * (t_max - t_min) + t_min


# utils for
def compute_item_popularity_scores(R: Iterable[np.ndarray]) -> dict[str, float]:
    """Compute popularity scores for items based on their occurrence in user interactions.

    This function calculates the popularity score of each item as the fraction of users who have interacted with that item.
    The popularity score, p_i, for an item is defined as the number of users who have interacted with the item divided by the
    total number of users.

    Formula:
        p_i = | {u ∈ U}, r_ui != Ø | / |U|

    where p_i is the popularity score of an item, U is the total number of users, and r_ui is the interaction of user u with item i (non-zero
    interaction implies the user has seen the item).

    Note:
        Each entry can only have the same item ones. TODO - ADD THE TEXT DONE HERE.

    Args:
        R (Iterable[np.ndarray]): An iterable of numpy arrays, where each array represents the items interacted with by a single user.
            Each element in the array should be a string identifier for an item.

    Returns:
        dict[str, float]: A dictionary where keys are item identifiers and values are their corresponding popularity scores (as floats).

    Examples:
    >>> R = [
            np.array(["item1", "item2", "item3"]),
            np.array(["item1", "item3"]),
            np.array(["item1", "item4"]),
        ]
    >>> print(popularity_scores(R))
        {'item1': 1.0, 'item2': 0.3333333333333333, 'item3': 0.6666666666666666, 'item4': 0.3333333333333333}
    """
    U = len(R)
    R_flatten = np.concatenate(R)
    item_counts = Counter(R_flatten)
    return {item: (r_ui / U) for item, r_ui in item_counts.items()}


def compute_normalized_distribution(
    R: np.ndarray,
    weights: np.ndarray = None,
    distribution: dict = None,
) -> dict:
    """
    Compute a normalized weighted distribution for a list of items that each can have a single representation assigned.

    Args:
        R (np.ndarray): An array of items representation.
        weights (np.ndarray, optional): Weights to assign each element in R. Defaults to None.
            * If None, equal weights are assigned to all elements.
        distribution (dict, optional): Dictionary to accumulate distribution values. Defaults to None.
            * If None, a new dictionary is created.

    Returns:
        dict: A dictionary with normalized distribution values.

    Examples:
        >>> R = np.array(["a", "b", "c", "c"])
        >>> compute_normalized_distribution(R)
            {'a': 0.25, 'b': 0.25, 'c': 0.5}
    """
    n_elements = len(R)

    # Use existing distribution or create a new one
    distr = distribution if distribution is not None else {}
    
    # Assign equal weights if weights are not provided
    weights = weights if weights is not None else np.ones(n_elements) / n_elements
    
    for item, weight in zip(R, weights):
        distr[item] = weight + distr.get(item, 0.0)
    
    return distr



def get_keys_in_dict(id_list: any, dictionary: dict) -> list[any]:
    """
    Returns a list of IDs from id_list that are keys in the dictionary.
    Args:
        id_list (List[Any]): List of IDs to check against the dictionary.
        dictionary (Dict[Any, Any]): Dictionary where keys are checked against the IDs.

    Returns:
        List[Any]: List of IDs that are also keys in the dictionary.

    Examples:
        >>> get_keys_in_dict(['a', 'b', 'c'], {'a': 1, 'c': 3, 'd': 4})
            ['a', 'c']
    """
    return [id_ for id_ in id_list if id_ in dictionary]


def check_key_in_all_nested_dicts(dictionary: dict, key: str) -> None:
    """
    Checks if the given key is present in all nested dictionaries within the main dictionary.
    Raises a ValueError if the key is not found in any of the nested dictionaries.

    Args:
        dictionary (dict): The dictionary containing nested dictionaries to check.
        key (str): The key to look for in all nested dictionaries.

    Raises:
        ValueError: If the key is not present in any of the nested dictionaries.

    Example:
        >>> nested_dict = {
                "101": {"name": "Alice", "age": 30},
                "102": {"name": "Bob", "age": 25},
            }
        >>> check_key_in_all_nested_dicts(nested_dict, "age")
        # No error is raised
        >>> check_key_in_all_nested_dicts(nested_dict, "salary")
        # Raises ValueError: 'salary is not present in all nested dictionaries.'
    """
    for dict_key, sub_dict in dictionary.items():
        if not isinstance(sub_dict, dict) or key not in sub_dict:
            raise ValueError(
                f"'{key}' is not present in '{dict_key}' nested dictionary."
            )

#ebrec/evaluation/protocols.py

class Metric(Protocol):
    name: str

    def calculate(self, y_true: np.ndarray, y_score: np.ndarray) -> float: ...

    def __str__(self) -> str:
        return f"<Callable Metric: {self.name}>: params: {self.__dict__}"

    def __repr__(self) -> str:
        return str(self)

    def __call__(self, y_true: np.ndarray, y_score: np.ndarray) -> float:
        return self.calculate(y_true, y_score)

class AccuracyScore(Metric):
    def __init__(self, threshold: float = 0.5):
        self.threshold = threshold
        self.name = "accuracy"

    def calculate(self, y_true: list[np.ndarray], y_pred: list[np.ndarray]) -> float:
        res = np.mean(
            [
                accuracy_score(
                    each_labels, convert_to_binary(each_preds, self.threshold)
                )
                for each_labels, each_preds in tqdm(
                    zip(y_true, y_pred), ncols=80, total=len(y_true), desc="AUC"
                )
            ]
        )
        return float(res)


class F1Score(Metric):
    def __init__(self, threshold: float = 0.5):
        self.threshold = threshold
        self.name = "f1"

    def calculate(self, y_true: list[np.ndarray], y_pred: list[np.ndarray]) -> float:
        res = np.mean(
            [
                f1_score(each_labels, convert_to_binary(each_preds, self.threshold))
                for each_labels, each_preds in tqdm(
                    zip(y_true, y_pred), ncols=80, total=len(y_true), desc="AUC"
                )
            ]
        )
        return float(res)


class RootMeanSquaredError(Metric):
    def __init__(self):
        self.name = "rmse"

    def calculate(self, y_true: list[np.ndarray], y_pred: list[np.ndarray]) -> float:
        res = np.mean(
            [
                np.sqrt(mean_squared_error(each_labels, each_preds))
                for each_labels, each_preds in tqdm(
                    zip(y_true, y_pred), ncols=80, total=len(y_true), desc="AUC"
                )
            ]
        )
        return float(res)


class AucScore(Metric):
    def __init__(self):
        self.name = "auc"

    def calculate(self, y_true: list[np.ndarray], y_pred: list[np.ndarray]) -> float:
        res = np.mean(
            [
                roc_auc_score(each_labels, each_preds)
                for each_labels, each_preds in tqdm(
                    zip(y_true, y_pred), ncols=80, total=len(y_true), desc="AUC"
                )
            ]
        )
        return float(res)


class LogLossScore(Metric):
    def __init__(self):
        self.name = "logloss"

    def calculate(self, y_true: list[np.ndarray], y_pred: list[np.ndarray]) -> float:
        res = np.mean(
            [
                log_loss(
                    each_labels,
                    [max(min(p, 1.0 - 10e-12), 10e-12) for p in each_preds],
                )
                for each_labels, each_preds in tqdm(
                    zip(y_true, y_pred), ncols=80, total=len(y_true), desc="AUC"
                )
            ]
        )
        return float(res)


class MrrScore(Metric):
    def __init__(self) -> Metric:
        self.name = "mrr"

    def calculate(self, y_true: list[np.ndarray], y_pred: list[np.ndarray]) -> float:
        mean_mrr = np.mean(
            [
                mrr_score(each_labels, each_preds)
                for each_labels, each_preds in tqdm(
                    zip(y_true, y_pred), ncols=80, total=len(y_true), desc="AUC"
                )
            ]
        )
        return float(mean_mrr)


class NdcgScore(Metric):
    def __init__(self, k: int):
        self.k = k
        self.name = f"ndcg@{k}"

    def calculate(self, y_true: list[np.ndarray], y_pred: list[np.ndarray]) -> float:
        res = np.mean(
            [
                ndcg_score(each_labels, each_preds, self.k)
                for each_labels, each_preds in tqdm(
                    zip(y_true, y_pred), ncols=80, total=len(y_true), desc="AUC"
                )
            ]
        )
        return float(res)


class MetricEvaluator:
    """
    >>> y_true = [[1, 0, 0], [1, 1, 0], [1, 0, 0, 0]]
    >>> y_pred = [[0.2, 0.3, 0.5], [0.18, 0.7, 0.1], [0.18, 0.2, 0.1, 0.1]]

    >>> met_eval = MetricEvaluator(
            labels=y_true,
            predictions=y_pred,
            metric_functions=[
                AucScore(),
                MrrScore(),
                NdcgScore(k=5),
                NdcgScore(k=10),
                LogLossScore(),
                RootMeanSquaredError(),
                AccuracyScore(threshold=0.5),
                F1Score(threshold=0.5),
            ],
        )
    >>> met_eval.evaluate()
    {
        "auc": 0.5555555555555556,
        "mrr": 0.5277777777777778,
        "ndcg@5": 0.7103099178571526,
        "ndcg@10": 0.7103099178571526,
        "logloss": 0.716399020295845,
        "rmse": 0.5022870658128165
        "accuracy": 0.5833333333333334,
        "f1": 0.2222222222222222
    }
    """

    def __init__(
        self,
        labels: list[np.ndarray],
        predictions: list[np.ndarray],
        metric_functions: list[Metric],
    ):
        self.labels = labels
        self.predictions = predictions
        self.metric_functions = metric_functions
        self.evaluations = dict()

    def evaluate(self) -> dict:
        self.evaluations = {
            metric_function.name: metric_function(self.labels, self.predictions)
            for metric_function in self.metric_functions
        }
        return self

    @property
    def metric_functions(self):
        return self.__metric_functions

    @metric_functions.setter
    def metric_functions(self, values):
        invalid_callables = self.__invalid_callables(values)
        if not any(invalid_callables) and invalid_callables:
            self.__metric_functions = values
        else:
            invalid_objects = list(compress(values, invalid_callables))
            invalid_types = [type(item) for item in invalid_objects]
            raise TypeError(f"Following object(s) are not callable: {invalid_types}")

    @staticmethod
    def __invalid_callables(iter: Iterable):
        return [not callable(item) for item in iter]

    def __str__(self):
        if self.evaluations:
            evaluations_json = json.dumps(self.evaluations, indent=4)
            return f"<MetricEvaluator class>: \n {evaluations_json}"
        else:
            return f"<MetricEvaluator class>: {self.evaluations}"

    def __repr__(self):
        return str(self)

In [52]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import tensorflow as tf
import numpy as np

from tensorflow.keras.layers import Embedding, Input, Dropout, Dense, BatchNormalization
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.regularizers import l2


#ebrec/models/newsrec/nrms.py

class NRMSModel:
    """NRMS model(Neural News Recommendation with Multi-Head Self-Attention)

    Chuhan Wu, Fangzhao Wu, Suyu Ge, Tao Qi, Yongfeng Huang,and Xing Xie, "Neural News
    Recommendation with Multi-Head Self-Attention" in Proceedings of the 2019 Conference
    on Empirical Methods in Natural Language Processing and the 9th International Joint Conference
    on Natural Language Processing (EMNLP-IJCNLP)

    Attributes:
    """

    def __init__(
        self,
        hparams: dict,
        word2vec_embedding: np.ndarray = None,
        word_emb_dim: int = 300,
        vocab_size: int = 32000,
        seed: int = None,
    ):
        """Initialization steps for NRMS."""
        self.hparams = hparams
        self.seed = seed

        # SET SEED:
        tf.random.set_seed(seed)
        np.random.seed(seed)

        # INIT THE WORD-EMBEDDINGS:
        if word2vec_embedding is None:
            # Xavier Initialization
            initializer = GlorotUniform(seed=self.seed)
            self.word2vec_embedding = initializer(shape=(vocab_size, word_emb_dim))
            # self.word2vec_embedding = np.random.rand(vocab_size, word_emb_dim)
        else:
            self.word2vec_embedding = word2vec_embedding

        # BUILD AND COMPILE MODEL:
        self.model, self.scorer = self._build_graph()
        data_loss = self._get_loss(self.hparams.loss)
        train_optimizer = self._get_opt(
            optimizer=self.hparams.optimizer, lr=self.hparams.learning_rate
        )
        self.model.compile(loss=data_loss, optimizer=train_optimizer)

    def _get_loss(self, loss: str):
        """Make loss function, consists of data loss and regularization loss
        Returns:
            object: Loss function or loss function name
        """
        if loss == "cross_entropy_loss":
            data_loss = "categorical_crossentropy"
        elif loss == "log_loss":
            data_loss = "binary_crossentropy"
        else:
            raise ValueError(f"this loss not defined {loss}")
        return data_loss

    def _get_opt(self, optimizer: str, lr: float):
        """Get the optimizer according to configuration. Usually we will use Adam.
        Returns:
            object: An optimizer.
        """
        # TODO: shouldn't be a string input you should just set the optimizer, to avoid stuff like this:
        # => 'WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.'
        if optimizer == "adam":
            train_opt = tf.keras.optimizers.Adam(learning_rate=lr)
        else:
            raise ValueError(f"this optimizer not defined {optimizer}")
        return train_opt

    def _build_graph(self):
        """Build NRMS model and scorer.

        Returns:
            object: a model used to train.
            object: a model used to evaluate and inference.
        """
        model, scorer = self._build_nrms()
        return model, scorer

    def _build_userencoder(self, titleencoder):
        """The main function to create user encoder of NRMS.

        Args:
            titleencoder (object): the news encoder of NRMS.

        Return:
            object: the user encoder of NRMS.
        """
        his_input_title = tf.keras.Input(
            shape=(self.hparams.history_size, self.hparams.title_size), dtype="int32"
        )

        click_title_presents = tf.keras.layers.TimeDistributed(titleencoder)(
            his_input_title
        )
        y = SelfAttention(self.hparams.head_num, self.hparams.head_dim, seed=self.seed)(
            [click_title_presents] * 3
        )
        user_present = AttLayer2(self.hparams.attention_hidden_dim, seed=self.seed)(y)

        model = tf.keras.Model(his_input_title, user_present, name="user_encoder")
        return model

    def _build_newsencoder(self, units_per_layer: list[int] = None):
        """The main function to create news encoder of NRMS.

        Args:
            embedding_layer (object): a word embedding layer.

        Return:
            object: the news encoder of NRMS.
        """
        embedding_layer = tf.keras.layers.Embedding(
            self.word2vec_embedding.shape[0],
            self.word2vec_embedding.shape[1],
            weights=[self.word2vec_embedding],
            trainable=True,
        )
        sequences_input_title = tf.keras.Input(
            shape=(self.hparams.title_size,), dtype="int32"
        )
        embedded_sequences_title = embedding_layer(sequences_input_title)

        y = tf.keras.layers.Dropout(self.hparams.dropout)(embedded_sequences_title)
        y = SelfAttention(self.hparams.head_num, self.hparams.head_dim, seed=self.seed)(
            [y, y, y]
        )

        # Create configurable Dense layers (the if - else is something I've added):
        if units_per_layer:
            for layer in units_per_layer:
                y = tf.keras.layers.Dense(
                    units=layer,
                    activation="relu",
                    kernel_regularizer=tf.keras.regularizers.l2(
                        self.hparams.newsencoder_l2_regularization
                    ),
                )(y)
                y = tf.keras.layers.BatchNormalization()(y)
                y = tf.keras.layers.Dropout(self.hparams.dropout)(y)
        else:
            y = tf.keras.layers.Dropout(self.hparams.dropout)(y)

        pred_title = AttLayer2(self.hparams.attention_hidden_dim, seed=self.seed)(y)

        model = tf.keras.Model(sequences_input_title, pred_title, name="news_encoder")
        return model

    def _build_nrms(self):
        """The main function to create NRMS's logic. The core of NRMS
        is a user encoder and a news encoder.

        Returns:
            object: a model used to train.
            object: a model used to evaluate and inference.
        """

        his_input_title = tf.keras.Input( 
            shape=(self.hparams.history_size, self.hparams.title_size),
            dtype="int32",
        )
        pred_input_title = tf.keras.Input(
            # shape = (hparams.npratio + 1, hparams.title_size)
            shape=(None, self.hparams.title_size),
            dtype="int32",
        )
        pred_input_title_one = tf.keras.Input(
            shape=(
                1,
                self.hparams.title_size,
            ),
            dtype="int32",
        )
        pred_title_one_reshape = tf.keras.layers.Reshape((self.hparams.title_size,))(
            pred_input_title_one
        )
        titleencoder = self._build_newsencoder(
            units_per_layer=self.hparams.newsencoder_units_per_layer
        )
        self.userencoder = self._build_userencoder(titleencoder)
        self.newsencoder = titleencoder

        user_present = self.userencoder(his_input_title)
        news_present = tf.keras.layers.TimeDistributed(self.newsencoder)(
            pred_input_title
        )
        news_present_one = self.newsencoder(pred_title_one_reshape)

        preds = tf.keras.layers.Dot(axes=-1)([news_present, user_present])
        preds = tf.keras.layers.Activation(activation="softmax")(preds)

        pred_one = tf.keras.layers.Dot(axes=-1)([news_present_one, user_present])
        pred_one = tf.keras.layers.Activation(activation="sigmoid")(pred_one)

        model = tf.keras.Model([his_input_title, pred_input_title], preds)
        scorer = tf.keras.Model([his_input_title, pred_input_title_one], pred_one)

        return model, scorer

In [3]:
import numpy as np
import random
import json

try:
    import polars as pl
except ImportError:
    print("polars not available")



def _check_columns_in_df(df: pl.DataFrame, columns: list[str]) -> None:
    """
    Checks whether all specified columns are present in a Polars DataFrame.
    Raises a ValueError if any of the specified columns are not present in the DataFrame.

    Args:
        df (pl.DataFrame): The input DataFrame.
        columns (list[str]): The names of the columns to check for.

    Returns:
        None.

    Examples:
    >>> df = pl.DataFrame({"user_id": [1], "first_name": ["J"]})
    >>> check_columns_in_df(df, columns=["user_id", "not_in"])
        ValueError: Invalid input provided. The dataframe does not contain columns ['not_in'].
    """
    columns_not_in_df = [col for col in columns if col not in df.columns]
    if columns_not_in_df:
        raise ValueError(
            f"Invalid input provided. The DataFrame does not contain columns {columns_not_in_df}."
        )


def _validate_equal_list_column_lengths(df: pl.DataFrame, col1: str, col2: str) -> bool:
    """
    Checks if the items in two list columns of a DataFrame have equal lengths.

    Args:
        df (pl.DataFrame): The DataFrame containing the list columns.
        col1 (str): The name of the first list column.
        col2 (str): The name of the second list column.

    Returns:
        bool: True if the items in the two list columns have equal lengths, False otherwise.

    Raises:
        None.

    >>> df = pl.DataFrame({
            'col1': [[1, 2, 3], [4, 5], [6]],
            'col2': [[10, 20], [30, 40, 50], [60, 70, 80]],
        })
    >>> _validate_equal_list_column_lengths(df, 'col1', 'col2')
        ValueError: Mismatch in the lengths of the number of items (row-based) between the columns: 'col1' and 'col2'. Please ensure equal lengths.
    >>> df = df.with_columns(pl.Series('col1', [[1, 2], [3, 4, 5], [6, 7, 8]]))
    >>> _validate_equal_list_column_lengths(df, 'col1', 'col2')
    """
    if not df.select(pl.col(col1).list.len() == pl.col(col2).list.len())[col1].all():
        raise ValueError(
            f"Mismatch in the lengths of the number of items (row-based) between the columns: '{col1}' and '{col2}'. Please ensure equal lengths."
        )


def slice_join_dataframes(
    df1: pl.DataFrame,
    df2: pl.DataFrame,
    on: str,
    how: str,
) -> pl.DataFrame:
    """
    Join two dataframes optimized for memory efficiency.
    """
    return pl.concat(
        (
            rows.join(
                df2,
                on=on,
                how=how,
            )
            for rows in df1.iter_slices()
        )
    )


def rename_columns(df: pl.DataFrame, map_dict: dict[str, str]) -> pl.DataFrame:
    """
    Examples:
        >>> import polars as pl
        >>> df = pl.DataFrame({'A': [1, 2], 'B': [3, 4]})
        >>> map_dict = {'A': 'X', 'B': 'Y'}
        >>> rename_columns(df, map_dict)
            shape: (2, 2)
            ┌─────┬─────┐
            │ X   ┆ Y   │
            │ --- ┆ --- │
            │ i64 ┆ i64 │
            ╞═════╪═════╡
            │ 1   ┆ 3   │
            │ 2   ┆ 4   │
            └─────┴─────┘
        >>> rename_columns(df, {"Z" : "P"})
            shape: (2, 2)
            ┌─────┬─────┐
            │ A   ┆ B   │
            │ --- ┆ --- │
            │ i64 ┆ i64 │
            ╞═════╪═════╡
            │ 1   ┆ 3   │
            │ 2   ┆ 4   │
            └─────┴─────┘
    """
    map_dict = {key: val for key, val in map_dict.items() if key in df.columns}
    if len(map_dict):
        df = df.rename(map_dict)
    return df


def from_dict_to_polars(dictionary: dict) -> pl.DataFrame:
    """
    When dealing with dictionary with intergers as keys
    Example:
    >>> dictionary = {1: "a", 2: "b"}
    >>> from_dict_to_polars(dictionary)
        shape: (2, 2)
        ┌──────┬────────┐
        │ keys ┆ values │
        │ ---  ┆ ---    │
        │ i64  ┆ str    │
        ╞══════╪════════╡
        │ 1    ┆ a      │
        │ 2    ┆ b      │
        └──────┴────────┘
    >>> pl.from_dict(dictionary)
        raise ValueError("Series name must be a string.")
            ValueError: Series name must be a string.
    """
    return pl.DataFrame(
        {"keys": list(dictionary.keys()), "values": list(dictionary.values())}
    )


def shuffle_rows(df: pl.DataFrame, seed: int = None) -> pl.DataFrame:
    """
    Shuffle the rows of a DataFrame. This methods allows for LazyFrame,
    whereas, 'df.sample(fraction=1)' is not compatible.

    Examples:
    >>> df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
    >>> shuffle_rows(df.lazy(), seed=123).collect()
        shape: (3, 3)
        ┌─────┬─────┬─────┐
        │ a   ┆ b   ┆ c   │
        │ --- ┆ --- ┆ --- │
        │ i64 ┆ i64 ┆ i64 │
        ╞═════╪═════╪═════╡
        │ 1   ┆ 1   ┆ 1   │
        │ 3   ┆ 3   ┆ 3   │
        │ 2   ┆ 2   ┆ 2   │
        └─────┴─────┴─────┘
    >>> shuffle_rows(df.lazy(), seed=None).collect().sort("a")
        shape: (3, 3)
        ┌─────┬─────┬─────┐
        │ a   ┆ b   ┆ c   │
        │ --- ┆ --- ┆ --- │
        │ i64 ┆ i64 ┆ i64 │
        ╞═════╪═════╪═════╡
        │ 1   ┆ 1   ┆ 1   │
        │ 2   ┆ 2   ┆ 2   │
        │ 3   ┆ 3   ┆ 3   │
        └─────┴─────┴─────┘

    Test_:
    >>> all([sum(row) == row[0]*3 for row in shuffle_rows(df, seed=None).iter_rows()])
        True

    Note:
        Be aware that 'pl.all().shuffle()' shuffles columns-wise, i.e., with if pl.all().shuffle(None)
        each column's element are shuffled independently from each other (example might change with no seed):
    >>> df_ = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}).select(pl.all().shuffle(None)).sort("a")
    >>> df_
        shape: (3, 3)
        ┌─────┬─────┬─────┐
        │ a   ┆ b   ┆ c   │
        │ --- ┆ --- ┆ --- │
        │ i64 ┆ i64 ┆ i64 │
        ╞═════╪═════╪═════╡
        │ 1   ┆ 3   ┆ 1   │
        │ 2   ┆ 2   ┆ 3   │
        │ 3   ┆ 1   ┆ 2   │
        └─────┴─────┴─────┘
    >>> all([sum(row) == row[0]*3 for row in shuffle_rows(df_, seed=None).iter_rows()])
        False
    """
    seed = seed if seed is not None else random.randint(1, 1_000_000)
    return df.select(pl.all().shuffle(seed))


def keep_unique_values_in_list(df: pl.DataFrame, column: str) -> pl.DataFrame:
    """
    Removes duplicate article IDs from the specified list column of a DataFrame.

    Args:
        df (pl.DataFrame): The input DataFrame containing the list column with article IDs.
        column (str): The name of the list column containing article IDs.

    Returns:
        pl.DataFrame: A new DataFrame with the same columns as the input DataFrame, but with duplicate
        article IDs removed from the specified list column.

    Example:
        >>> df = pl.DataFrame({
                "article_ids": [[1, 2, 3, 1, 2], [3, 4, 5, 3], [1, 2, 3, 1, 2, 3]],
                "hh": ["h", "e", "y"]
            })
        >>> keep_unique_values_in_list(df.lazy(), "article_ids").collect()
            shape: (3, 1)
            ┌─────────────┐
            │ article_ids │
            │ ---         │
            │ list[i64]   │
            ╞═════════════╡
            │ [1, 2, 3]   │
            │ [3, 4, 5]   │
            │ [1, 2, 3]   │
            └─────────────┘
    """
    return df.with_columns(pl.col(column).list.unique())


def filter_minimum_lengths_from_list(
    df: pl.DataFrame,
    n: int,
    column: str,
) -> pl.DataFrame:
    """Filters a DataFrame based on the minimum number of elements in an array column.

    Args:
        df (pl.DataFrame): The input DataFrame to filter.
        n (int): The minimum number of elements required in the array column.
        column (str): The name of the array column to filter on.

    Returns:
        pl.DataFrame: The filtered DataFrame.

    Example:
    >>> df = pl.DataFrame(
            {
                "user_id": [1, 2, 3, 4],
                "article_ids": [["a", "b", "c"], ["a", "b"], ["a"], ["a"]],
            }
        )
    >>> filter_minimum_lengths_from_list(df, n=2, column="article_ids")
        shape: (2, 2)
        ┌─────────┬─────────────────┐
        │ user_id ┆ article_ids     │
        │ ---     ┆ ---             │
        │ i64     ┆ list[str]       │
        ╞═════════╪═════════════════╡
        │ 1       ┆ ["a", "b", "c"] │
        │ 2       ┆ ["a", "b"]      │
        └─────────┴─────────────────┘
    >>> filter_minimum_lengths_from_list(df, n=None, column="article_ids")
        shape: (4, 2)
        ┌─────────┬─────────────────┐
        │ user_id ┆ article_ids     │
        │ ---     ┆ ---             │
        │ i64     ┆ list[str]       │
        ╞═════════╪═════════════════╡
        │ 1       ┆ ["a", "b", "c"] │
        │ 2       ┆ ["a", "b"]      │
        │ 3       ┆ ["a"]           │
        │ 4       ┆ ["a"]           │
        └─────────┴─────────────────┘
    """
    return (
        df.filter(pl.col(column).list.len() >= n)
        if column in df and n is not None and n > 0
        else df
    )


def filter_maximum_lengths_from_list(
    df: pl.DataFrame,
    n: int,
    column: str,
) -> pl.DataFrame:
    """Filters a DataFrame based on the maximum number of elements in an array column.

    Args:
        df (pl.DataFrame): The input DataFrame to filter.
        n (int): The maximum number of elements required in the array column.
        column (str): The name of the array column to filter on.

    Returns:
        pl.DataFrame: The filtered DataFrame.

    Example:
    >>> df = pl.DataFrame(
            {
                "user_id": [1, 2, 3, 4],
                "article_ids": [["a", "b", "c"], ["a", "b"], ["a"], ["a"]],
            }
        )
    >>> filter_maximum_lengths_from_list(df, n=2, column="article_ids")
        shape: (3, 2)
        ┌─────────┬─────────────┐
        │ user_id ┆ article_ids │
        │ ---     ┆ ---         │
        │ i64     ┆ list[str]   │
        ╞═════════╪═════════════╡
        │ 2       ┆ ["a", "b"]  │
        │ 3       ┆ ["a"]       │
        │ 4       ┆ ["a"]       │
        └─────────┴─────────────┘
    >>> filter_maximum_lengths_from_list(df, n=None, column="article_ids")
        shape: (4, 2)
        ┌─────────┬─────────────────┐
        │ user_id ┆ article_ids     │
        │ ---     ┆ ---             │
        │ i64     ┆ list[str]       │
        ╞═════════╪═════════════════╡
        │ 1       ┆ ["a", "b", "c"] │
        │ 2       ┆ ["a", "b"]      │
        │ 3       ┆ ["a"]           │
        │ 4       ┆ ["a"]           │
        └─────────┴─────────────────┘
    """
    return (
        df.filter(pl.col(column).list.len() <= n)
        if column in df and n is not None and n > 0
        else df
    )


def split_df_fraction(
    df: pl.DataFrame,
    fraction=0.8,
    seed: int = None,
    shuffle: bool = True,
):
    """
    Splits a DataFrame into two parts based on a specified fraction.
    >>> df = pl.DataFrame({'A': range(10), 'B': range(10, 20)})
    >>> df1, df2 = split_df(df, fraction=0.8, seed=42, shuffle=True)
    >>> len(df1)
        8
    >>> len(df2)
        2
    """
    if not 0 < fraction < 1:
        raise ValueError("fraction must be between 0 and 1")
    df = df.sample(fraction=1.0, shuffle=shuffle, seed=seed)
    n_split_sample = int(len(df) * fraction)
    return df[:n_split_sample], df[n_split_sample:]


def split_df_chunks(df: pl.DataFrame, n_chunks: int):
    """
    Splits a DataFrame into a specified number of chunks.

    Args:
        df (pl.DataFrame): The DataFrame to be split into chunks.
        n_chunks (int): The number of chunks to divide the DataFrame into.

    Returns:
        list: A list of DataFrame chunks. Each element in the list is a DataFrame
        representing a chunk of the original data.

    Examples
    >>> import polars as pl
    >>> df = pl.DataFrame({'A': range(3)})
    >>> chunks = split_df_chunks(df, 2)
    >>> chunks
        [shape: (1, 1)
        ┌─────┐
        │ A   │
        │ --- │
        │ i64 │
        ╞═════╡
        │ 0   │
        └─────┘, shape: (2, 1)
        ┌─────┐
        │ A   │
        │ --- │
        │ i64 │
        ╞═════╡
        │ 1   │
        │ 2   │
        └─────┘]
    """
    # Calculate the number of rows per chunk
    chunk_size = df.height // n_chunks

    # Split the DataFrame into chunks
    chunks = [df[i * chunk_size : (i + 1) * chunk_size] for i in range(n_chunks)]

    # Append the remainder rows to the last chunk
    if df.height % n_chunks != 0:
        remainder_start_idx = n_chunks * chunk_size
        chunks[-1] = pl.concat([chunks[-1], df[remainder_start_idx:]])

    return chunks


def drop_nulls_from_list(df: pl.DataFrame, column: str) -> pl.DataFrame:
    """
    Drops null values from a specified column in a Polars DataFrame.

    Args:
        df (pl.DataFrame): The input DataFrame.
        column (str): The name of the column to drop null values from.

    Returns:
        pl.DataFrame: A new DataFrame with null values dropped from the specified column.

    Examples:
    >>> df = pl.DataFrame(
            {"user_id": [101, 102, 103], "dynamic_article_id": [[1, None, 3], None, [4, 5]]}
        )
    >>> print(df)
        shape: (3, 2)
        ┌─────────┬────────────────────┐
        │ user_id ┆ dynamic_article_id │
        │ ---     ┆ ---                │
        │ i64     ┆ list[i64]          │
        ╞═════════╪════════════════════╡
        │ 101     ┆ [1, null, 3]       │
        │ 102     ┆ null               │
        │ 103     ┆ [4, 5]             │
        └─────────┴────────────────────┘
    >>> drop_nulls_from_list(df, "dynamic_article_id")
        shape: (3, 2)
        ┌─────────┬────────────────────┐
        │ user_id ┆ dynamic_article_id │
        │ ---     ┆ ---                │
        │ i64     ┆ list[i64]          │
        ╞═════════╪════════════════════╡
        │ 101     ┆ [1, 3]             │
        │ 102     ┆ null               │
        │ 103     ┆ [4, 5]             │
        └─────────┴────────────────────┘
    """
    return df.with_columns(pl.col(column).list.eval(pl.element().drop_nulls()))


def filter_list_elements(df: pl.DataFrame, column: str, ids: list[any]) -> pl.DataFrame:
    """
    Removes list elements from a specified column in a Polars DataFrame that are not found in a given list of identifiers.

    Args:
        df (pl.DataFrame): The Polars DataFrame to process.
        column (str): The name of the column from which to remove unknown elements.
        ids (list[any]): A list of identifiers to retain in the specified column. Elements not in this list will be removed.

    Returns:
        pl.DataFrame: A new Polars DataFrame with the same structure as the input DataFrame, but with elements not found in
                    the 'ids' list removed from the specified 'column'.

    Examples:
    >>> df = pl.DataFrame({"A": [1, 2, 3, 4, 5], "B": [[1, 3], [3, 4], None, [7, 8], [9, 10]]})
    >>> ids = [1, 3, 5, 7]
    >>> filter_list_elements(df.lazy(), "B", ids).collect()
        shape: (5, 2)
        ┌─────┬───────────┐
        │ A   ┆ B         │
        │ --- ┆ ---       │
        │ i64 ┆ list[i64] │
        ╞═════╪═══════════╡
        │ 1   ┆ [1, 3]    │
        │ 2   ┆ [3]       │
        │ 3   ┆ null      │
        │ 4   ┆ [7]       │
        │ 5   ┆ null      │
        └─────┴───────────┘
    """
    GROUPBY_COL = "_groupby"
    COLUMNS = df.columns
    df = df.with_row_index(GROUPBY_COL)
    df_ = (
        df.select(pl.col(GROUPBY_COL, column))
        .drop_nulls()
        .explode(column)
        .filter(pl.col(column).is_in(ids))
        .group_by(GROUPBY_COL)
        .agg(column)
    )
    return df.drop(column).join(df_, on=GROUPBY_COL, how="left").select(COLUMNS)


def filter_elements(df: pl.DataFrame, column: str, ids: list[any]) -> pl.DataFrame:
    """
    Removes elements from a specified column in a Polars DataFrame that are not found in a given list of identifiers.

    Args:
        df (pl.DataFrame): The Polars DataFrame to process.
        column (str): The name of the column from which to remove unknown elements.
        ids (list[any]): A list of identifiers to retain in the specified column. Elements not in this list will be removed.

    Returns:
        pl.DataFrame: A new Polars DataFrame with the same structure as the input DataFrame, but with elements not found in
                    the 'ids' list removed from the specified 'column'.

    Examples:
    >>> df = pl.DataFrame({"A": [1, 2, 3, 4, 5], "B": [[1, 3], [3, 4], None, [7, 8], [9, 10]]})
        shape: (5, 2)
        ┌─────┬───────────┐
        │ A   ┆ B         │
        │ --- ┆ ---       │
        │ i64 ┆ list[i64] │
        ╞═════╪═══════════╡
        │ 1   ┆ [1, 3]    │
        │ 2   ┆ [3, 4]    │
        │ 3   ┆ null      │
        │ 4   ┆ [7, 8]    │
        │ 5   ┆ [9, 10]   │
        └─────┴───────────┘
    >>> ids = [1, 3, 5, 7]
    >>> filter_elements(df.lazy(), "A", ids).collect()
        shape: (5, 2)
        ┌──────┬───────────┐
        │ A    ┆ B         │
        │ ---  ┆ ---       │
        │ i64  ┆ list[i64] │
        ╞══════╪═══════════╡
        │ 1    ┆ [1, 3]    │
        │ null ┆ [3, 4]    │
        │ 3    ┆ null      │
        │ null ┆ [7, 8]    │
        │ 5    ┆ [9, 10]   │
        └──────┴───────────┘
    """
    GROUPBY_COL = "_groupby"
    COLUMNS = df.columns
    df = df.with_row_index(GROUPBY_COL)
    df_ = (
        df.select(pl.col(GROUPBY_COL, column))
        .drop_nulls()
        .filter(pl.col(column).is_in(ids))
    )
    return df.drop(column).join(df_, on=GROUPBY_COL, how="left").select(COLUMNS)


def concat_str_columns(df: pl.DataFrame, columns: list[str]) -> pl.DataFrame:
    """
    >>> df = pl.DataFrame(
            {
                "id": [1, 2, 3],
                "first_name": ["John", "Jane", "Alice"],
                "last_name": ["Doe", "Doe", "Smith"],
            }
        )
    >>> concatenated_df, concatenated_column_name = concat_str_columns(df, columns=['first_name', 'last_name'])
    >>> concatenated_df
        shape: (3, 4)
        ┌─────┬────────────┬───────────┬──────────────────────┐
        │ id  ┆ first_name ┆ last_name ┆ first_name-last_name │
        │ --- ┆ ---        ┆ ---       ┆ ---                  │
        │ i64 ┆ str        ┆ str       ┆ str                  │
        ╞═════╪════════════╪═══════════╪══════════════════════╡
        │ 1   ┆ John       ┆ Doe       ┆ John Doe             │
        │ 2   ┆ Jane       ┆ Doe       ┆ Jane Doe             │
        │ 3   ┆ Alice      ┆ Smith     ┆ Alice Smith          │
        └─────┴────────────┴───────────┴──────────────────────┘
    """
    concat_name = "-".join(columns)
    concat_columns = df.select(pl.concat_str(columns, separator=" ").alias(concat_name))
    return df.with_columns(concat_columns), concat_name


def filter_empty_text_column(df: pl.DataFrame, column: str) -> pl.DataFrame:
    """
    Example:
    >>> df = pl.DataFrame({"Name": ["John", "Alice", "Bob", ""], "Age": [25, 28, 30, 22]})
    >>> filter_empty_text_column(df, "Name")
        shape: (3, 2)
        ┌───────┬─────┐
        │ Name  ┆ Age │
        │ ---   ┆ --- │
        │ str   ┆ i64 │
        ╞═══════╪═════╡
        │ John  ┆ 25  │
        │ Alice ┆ 28  │
        │ Bob   ┆ 30  │
        └───────┴─────┘
    """
    return df.filter(pl.col(column).str.lengths() > 0)


def shuffle_list_column(
    df: pl.DataFrame, column: str, seed: int = None
) -> pl.DataFrame:
    """Shuffles the values in a list column of a DataFrame.

    Args:
        df (pl.DataFrame): The input DataFrame.
        column (str): The name of the column to shuffle.
        seed (int, optional): An optional seed value.
            Defaults to None.

    Returns:
        pl.DataFrame: A new DataFrame with the specified column shuffled.

    Example:
    >>> df = pl.DataFrame(
            {
                "id": [1, 2, 3],
                "list_col": [["a-", "b-", "c-"], ["a#", "b#"], ["a@", "b@", "c@"]],
                "rdn": ["h", "e", "y"],
            }
        )
    >>> shuffle_list_column(df, 'list_col', seed=1)
        shape: (3, 3)
        ┌─────┬────────────────────┬─────┐
        │ id  ┆ list_col           ┆ rdn │
        │ --- ┆ ---                ┆ --- │
        │ i64 ┆ list[str]          ┆ str │
        ╞═════╪════════════════════╪═════╡
        │ 1   ┆ ["c-", "b-", "a-"] ┆ h   │
        │ 2   ┆ ["a#", "b#"]       ┆ e   │
        │ 3   ┆ ["b@", "c@", "a@"] ┆ y   │
        └─────┴────────────────────┴─────┘

    No seed:
    >>> shuffle_list_column(df, 'list_col', seed=None)
        shape: (3, 3)
        ┌─────┬────────────────────┬─────┐
        │ id  ┆ list_col           ┆ rdn │
        │ --- ┆ ---                ┆ --- │
        │ i64 ┆ list[str]          ┆ str │
        ╞═════╪════════════════════╪═════╡
        │ 1   ┆ ["b-", "a-", "c-"] ┆ h   │
        │ 2   ┆ ["a#", "b#"]       ┆ e   │
        │ 3   ┆ ["a@", "c@", "b@"] ┆ y   │
        └─────┴────────────────────┴─────┘

    Test_:
    >>> assert (
            sorted(shuffle_list_column(df, "list_col", seed=None)["list_col"].to_list()[0])
            == df["list_col"].to_list()[0]
        )

    >>> df = pl.DataFrame({
            'id': [1, 2, 3],
            'list_col': [[6, 7, 8], [-6, -7, -8], [60, 70, 80]],
            'rdn': ['h', 'e', 'y']
        })
    >>> shuffle_list_column(df.lazy(), 'list_col', seed=2).collect()
        shape: (3, 3)
        ┌─────┬──────────────┬─────┐
        │ id  ┆ list_col     ┆ rdn │
        │ --- ┆ ---          ┆ --- │
        │ i64 ┆ list[i64]    ┆ str │
        ╞═════╪══════════════╪═════╡
        │ 1   ┆ [7, 6, 8]    ┆ h   │
        │ 2   ┆ [-8, -7, -6] ┆ e   │
        │ 3   ┆ [60, 80, 70] ┆ y   │
        └─────┴──────────────┴─────┘

    Test_:
    >>> assert (
            sorted(shuffle_list_column(df, "list_col", seed=None)["list_col"].to_list()[0])
            == df["list_col"].to_list()[0]
        )
    """
    _COLUMN_ORDER = df.columns
    GROUPBY_ID = generate_unique_name(_COLUMN_ORDER, "_groupby_id")

    df = df.with_row_count(GROUPBY_ID)
    df_shuffle = (
        df.explode(column)
        .pipe(shuffle_rows, seed=seed)
        .group_by(GROUPBY_ID)
        .agg(column)
    )
    return (
        df.drop(column)
        .join(df_shuffle, on=GROUPBY_ID, how="left")
        .drop(GROUPBY_ID)
        .select(_COLUMN_ORDER)
    )


def split_df_in_n(df: pl.DataFrame, num_splits: int) -> list[pl.DataFrame]:
    """
    Split a DataFrame into n equal-sized splits.

    Args:
        df (pandas.DataFrame): The DataFrame to be split.
        num_splits (int): The number of splits to create.

    Returns:
        List[pandas.DataFrame]: A list of DataFrames, each representing a split.

    Examples:
        >>> df = pl.DataFrame({'A': [1, 2, 3, 4, 5, 6, 7], "B" : [1, 2, 3, 4, 5, 6, 7]})
        >>> splits = split_df_in_n(df, 3)
        >>> for d in splits:
                print(d)
                shape: (3, 2)
                ┌─────┬─────┐
                │ A   ┆ B   │
                │ --- ┆ --- │
                │ i64 ┆ i64 │
                ╞═════╪═════╡
                │ 1   ┆ 1   │
                │ 2   ┆ 2   │
                │ 3   ┆ 3   │
                └─────┴─────┘
                shape: (3, 2)
                ┌─────┬─────┐
                │ A   ┆ B   │
                │ --- ┆ --- │
                │ i64 ┆ i64 │
                ╞═════╪═════╡
                │ 4   ┆ 4   │
                │ 5   ┆ 5   │
                │ 6   ┆ 6   │
                └─────┴─────┘
                shape: (1, 2)
                ┌─────┬─────┐
                │ A   ┆ B   │
                │ --- ┆ --- │
                │ i64 ┆ i64 │
                ╞═════╪═════╡
                │ 7   ┆ 7   │
                └─────┴─────┘

    """
    rows_per_split = int(np.ceil(df.shape[0] / num_splits))
    return [
        df[i * rows_per_split : (1 + i) * rows_per_split] for i in range(num_splits)
    ]


def concat_list_str(df: pl.DataFrame, column: str) -> pl.DataFrame:
    """
    Concatenate strings within lists for a specified column in a DataFrame.

    Args:
        df (polars.DataFrame): The input DataFrame.
        column (str): The name of the column in `df` that contains lists of strings
                        to be concatenated.

    Returns:
        polars.DataFrame: A DataFrame with the same structure as `df` but with the
                            specified column's lists of strings concatenated and
                            converted to a string instead of list.

    Examples:
        >>> df = pl.DataFrame({
                "strings": [["ab", "cd"], ["ef", "gh"], ["ij", "kl"]]
            })
        >>> concat_list_str(df, "strings")
            shape: (3, 1)
            ┌─────────┐
            │ strings │
            │ ---     │
            │ str     │
            ╞═════════╡
            │ ab cd   │
            │ ef gh   │
            │ ij kl   │
            └─────────┘
    """
    return df.with_columns(
        pl.col(column).list.eval(pl.element().str.concat(" "))
    ).explode(column)

In [5]:

# BEHAVIORS
DEFAULT_IMPRESSION_TIMESTAMP_COL = "impression_time"
DEFAULT_IS_BEYOND_ACCURACY_COL = "is_beyond_accuracy"
DEFAULT_CLICKED_ARTICLES_COL = "article_ids_clicked"
DEFAULT_SCROLL_PERCENTAGE_COL = "scroll_percentage"
DEFAULT_INVIEW_ARTICLES_COL = "article_ids_inview"
DEFAULT_IMPRESSION_ID_COL = "impression_id"
DEFAULT_IS_SUBSCRIBER_COL = "is_subscriber"
DEFAULT_IS_SSO_USER_COL = "is_sso_user"
DEFAULT_ARTICLE_ID_COL = "article_id"
DEFAULT_SESSION_ID_COL = "session_id"
DEFAULT_READ_TIME_COL = "read_time"
DEFAULT_DEVICE_COL = "device_type"
DEFAULT_POSTCODE_COL = "postcode"
DEFAULT_GENDER_COL = "gender"
DEFAULT_USER_COL = "user_id"
DEFAULT_AGE_COL = "age"

DEFAULT_NEXT_SCROLL_PERCENTAGE_COL = f"next_{DEFAULT_SCROLL_PERCENTAGE_COL}"
DEFAULT_NEXT_READ_TIME_COL = f"next_{DEFAULT_READ_TIME_COL}"

# ARTICLES
DEFAULT_ARTICLE_MODIFIED_TIMESTAMP_COL = "last_modified_time"
DEFAULT_ARTICLE_PUBLISHED_TIMESTAMP_COL = "published_time"
DEFAULT_SENTIMENT_LABEL_COL = "sentiment_label"
DEFAULT_SENTIMENT_SCORE_COL = "sentiment_score"
DEFAULT_TOTAL_READ_TIME_COL = "total_read_time"
DEFAULT_TOTAL_PAGEVIEWS_COL = "total_pageviews"
DEFAULT_TOTAL_INVIEWS_COL = "total_inviews"
DEFAULT_ARTICLE_TYPE_COL = "article_type"
DEFAULT_CATEGORY_STR_COL = "category_str"
DEFAULT_SUBCATEGORY_COL = "subcategory"
DEFAULT_ENTITIES_COL = "entity_groups"
DEFAULT_IMAGE_IDS_COL = "image_ids"
DEFAULT_SUBTITLE_COL = "subtitle"
DEFAULT_CATEGORY_COL = "category"
DEFAULT_NER_COL = "ner_clusters"
DEFAULT_PREMIUM_COL = "premium"
DEFAULT_TOPICS_COL = "topics"
DEFAULT_TITLE_COL = "title"
DEFAULT_BODY_COL = "body"
DEFAULT_URL_COL = "url"

# HISTORY
DEFAULT_HISTORY_IMPRESSION_TIMESTAMP_COL = f"{DEFAULT_IMPRESSION_TIMESTAMP_COL}_fixed"
DEFAULT_HISTORY_SCROLL_PERCENTAGE_COL = f"{DEFAULT_SCROLL_PERCENTAGE_COL}_fixed"
DEFAULT_HISTORY_ARTICLE_ID_COL = f"{DEFAULT_ARTICLE_ID_COL}_fixed"
DEFAULT_HISTORY_READ_TIME_COL = f"{DEFAULT_READ_TIME_COL}_fixed"

# CREATE
DEFAULT_KNOWN_USER_COL = "is_known_user"
DEFAULT_LABELS_COL = "labels"




#from ebrec.utils._python import create_lookup_dict
import polars as pl
#from ebrec.utils._constants import DEFAULT_ARTICLE_ID_COL


def load_article_id_embeddings(
    df: pl.DataFrame, path: str, item_col: str = DEFAULT_ARTICLE_ID_COL
) -> pl.DataFrame:
    """Load embeddings artifacts and join to articles on 'article_id'
    Args:
        path (str): Path to document embeddings
    """
    return df.join(pl.read_parquet(path), on=item_col, how="left")


def create_article_id_to_value_mapping(
    df: pl.DataFrame,
    value_col: str,
    article_col: str = DEFAULT_ARTICLE_ID_COL,
):
    return create_lookup_dict(
        df.select(article_col, value_col), key=article_col, value=value_col
    )


def convert_text2encoding_with_transformers(
    df: pl.DataFrame,
    tokenizer: AutoTokenizer,
    column: str,
    max_length: int = None,
) -> pl.DataFrame:
    """Converts text in a specified DataFrame column to tokens using a provided tokenizer.
    Args:
        df (pl.DataFrame): The input DataFrame containing the text column.
        tokenizer (AutoTokenizer): The tokenizer to use for encoding the text. (from transformers import AutoTokenizer)
        column (str): The name of the column containing the text.
        max_length (int, optional): The maximum length of the encoded tokens. Defaults to None.
    Returns:
        pl.DataFrame: A new DataFrame with an additional column containing the encoded tokens.
    Example:
    >>> from transformers import AutoTokenizer
    >>> import polars as pl
    >>> df = pl.DataFrame({
            'text': ['This is a test.', 'Another test string.', 'Yet another one.']
        })
    >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    >>> encoded_df, new_column = convert_text2encoding_with_transformers(df, tokenizer, 'text', max_length=20)
    >>> print(encoded_df)
        shape: (3, 2)
        ┌──────────────────────┬───────────────────────────────┐
        │ text                 ┆ text_encode_bert-base-uncased │
        │ ---                  ┆ ---                           │
        │ str                  ┆ list[i64]                     │
        ╞══════════════════════╪═══════════════════════════════╡
        │ This is a test.      ┆ [2023, 2003, … 0]             │
        │ Another test string. ┆ [2178, 3231, … 0]             │
        │ Yet another one.     ┆ [2664, 2178, … 0]             │
        └──────────────────────┴───────────────────────────────┘
    >>> print(new_column)
        text_encode_bert-base-uncased
    """
    text = df[column].to_list()
    # set columns
    new_column = f"{column}_encode_{tokenizer.name_or_path}"
    # If 'max_length' is provided then set it, else encode each string its original length
    padding = "max_length" if max_length else False
    encoded_tokens = tokenizer(
        text,
        add_special_tokens=False,
        padding=padding,
        max_length=max_length,
        truncation=True,
    )["input_ids"]
    return df.with_columns(pl.Series(new_column, encoded_tokens)), new_column


def create_sort_based_prediction_score(
    df: pl.DataFrame,
    column: str,
    desc: bool,
    article_id_col: str = DEFAULT_ARTICLE_ID_COL,
    prediction_score_col: str = "prediction_score",
) -> pl.DataFrame:
    """
    Generates a prediction score for each row in a Polars DataFrame based on the sorting of a specified column.

    Args:
        df (pl.DataFrame): The input DataFrame to process.
        column (str): The name of the column to sort by and to base the prediction scores on.
        desc (bool): Determines the sorting order. If True, sort in descending order; otherwise, in ascending order.
        article_id_col (str, optional): The name article ID column. Defaults to "article_id".
        prediction_score_col (str, optional): The name to assign to the prediction score column. Defaults to "prediction_score".

    Returns:
        pl.DataFrame: A Polars DataFrame including the original data along with the new prediction score column.

    Examples:
    >>> import polars as pl
    >>> df = pl.DataFrame({
            "article_id": [1, 2, 3, 4, 5],
            "views": [100, 150, 200, 50, 300],
        })
    >>> create_sort_based_prediction_score(df, "views", True)
        shape: (5, 3)
        ┌────────────┬───────┬──────────────────┐
        │ article_id ┆ views ┆ prediction_score │
        │ ---        ┆ ---   ┆ ---              │
        │ i64        ┆ i64   ┆ f64              │
        ╞════════════╪═══════╪══════════════════╡
        │ 5          ┆ 300   ┆ 1.0              │
        │ 3          ┆ 200   ┆ 0.5              │
        │ 2          ┆ 150   ┆ 0.333333         │
        │ 1          ┆ 100   ┆ 0.25             │
        │ 4          ┆ 50    ┆ 0.2              │
        └────────────┴───────┴──────────────────┘
    """
    _TEMP_NAME = "index"
    return (
        (
            df.select(article_id_col, column)
            .sort(by=column, descending=desc)
            .with_row_index(name=_TEMP_NAME, offset=1)
        )
        .with_columns((1 / pl.col(_TEMP_NAME)).alias(prediction_score_col))
        .drop(_TEMP_NAME)
    )

In [None]:

def create_binary_labels_column(
    df: pl.DataFrame,
    shuffle: bool = False,
    seed: int = None,
    clicked_col: str = DEFAULT_CLICKED_ARTICLES_COL,
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL,
    label_col: str = DEFAULT_LABELS_COL,
) -> pl.DataFrame:
    """Creates a new column in a DataFrame containing binary labels indicating
    whether each article ID in the "article_ids" column is present in the corresponding
    "list_destination" column.

    Args:
        df (pl.DataFrame): The input DataFrame.

    Returns:
        pl.DataFrame: A new DataFrame with an additional "labels" column.

    Examples:
    >>> from ebrec.utils._constants import (
            DEFAULT_CLICKED_ARTICLES_COL,
            DEFAULT_INVIEW_ARTICLES_COL,
            DEFAULT_LABELS_COL,
        )
    >>> df = pl.DataFrame(
            {
                DEFAULT_INVIEW_ARTICLES_COL: [[1, 2, 3], [4, 5, 6], [7, 8]],
                DEFAULT_CLICKED_ARTICLES_COL: [[2, 3, 4], [3, 5], None],
            }
        )
    >>> create_binary_labels_column(df)
        shape: (3, 3)
        ┌────────────────────┬─────────────────────┬───────────┐
        │ article_ids_inview ┆ article_ids_clicked ┆ labels    │
        │ ---                ┆ ---                 ┆ ---       │
        │ list[i64]          ┆ list[i64]           ┆ list[i8]  │
        ╞════════════════════╪═════════════════════╪═══════════╡
        │ [1, 2, 3]          ┆ [2, 3, 4]           ┆ [0, 1, 1] │
        │ [4, 5, 6]          ┆ [3, 5]              ┆ [0, 1, 0] │
        │ [7, 8]             ┆ null                ┆ [0, 0]    │
        └────────────────────┴─────────────────────┴───────────┘
    >>> create_binary_labels_column(df.lazy(), shuffle=True, seed=123).collect()
        shape: (3, 3)
        ┌────────────────────┬─────────────────────┬───────────┐
        │ article_ids_inview ┆ article_ids_clicked ┆ labels    │
        │ ---                ┆ ---                 ┆ ---       │
        │ list[i64]          ┆ list[i64]           ┆ list[i8]  │
        ╞════════════════════╪═════════════════════╪═══════════╡
        │ [3, 1, 2]          ┆ [2, 3, 4]           ┆ [1, 0, 1] │
        │ [5, 6, 4]          ┆ [3, 5]              ┆ [1, 0, 0] │
        │ [7, 8]             ┆ null                ┆ [0, 0]    │
        └────────────────────┴─────────────────────┴───────────┘
    Test_:
    >>> assert create_binary_labels_column(df, shuffle=False)[DEFAULT_LABELS_COL].to_list() == [
            [0, 1, 1],
            [0, 1, 0],
            [0, 0],
        ]
    >>> assert create_binary_labels_column(df, shuffle=True)[DEFAULT_LABELS_COL].list.sum().to_list() == [
            2,
            1,
            0,
        ]
    """
    _check_columns_in_df(df, [inview_col, clicked_col])
    _COLUMNS = df.columns
    GROUPBY_ID = generate_unique_name(_COLUMNS, "_groupby_id")

    df = df.with_row_index(GROUPBY_ID)

    if shuffle:
        df = shuffle_list_column(df, column=inview_col, seed=seed)

    df_labels = (
        df.explode(inview_col)
        .with_columns(
            pl.col(inview_col).is_in(pl.col(clicked_col)).cast(pl.Int8).alias(label_col)
        )
        .group_by(GROUPBY_ID)
        .agg(label_col)
    )
    return (
        df.join(df_labels, on=GROUPBY_ID, how="left")
        .drop(GROUPBY_ID)
        .select(_COLUMNS + [label_col])
    )


def create_user_id_to_int_mapping(
    df: pl.DataFrame, user_col: str = DEFAULT_USER_COL, value_str: str = "id"
):
    return create_lookup_dict(
        df.select(pl.col(user_col).unique()).with_row_index(value_str),
        key=user_col,
        value=value_str,
    )


def filter_minimum_negative_samples(
    df,
    n: int,
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL,
    clicked_col: str = DEFAULT_CLICKED_ARTICLES_COL,
) -> pl.DataFrame:
    """
    >>> from ebrec.utils._constants import DEFAULT_CLICKED_ARTICLES_COL, DEFAULT_INVIEW_ARTICLES_COL
    >>> df = pl.DataFrame(
            {
                DEFAULT_INVIEW_ARTICLES_COL: [[1, 2, 3], [1], [1, 2, 3]],
                DEFAULT_CLICKED_ARTICLES_COL: [[1], [1], [1, 2]],
            }
        )
    >>> filter_minimum_negative_samples(df, n=1)
        shape: (2, 2)
        ┌────────────────────┬─────────────────────┐
        │ article_ids_inview ┆ article_ids_clicked │
        │ ---                ┆ ---                 │
        │ list[i64]          ┆ list[i64]           │
        ╞════════════════════╪═════════════════════╡
        │ [1, 2, 3]          ┆ [1]                 │
        │ [1, 2, 3]          ┆ [1, 2]              │
        └────────────────────┴─────────────────────┘
    >>> filter_minimum_negative_samples(df, n=2)
        shape: (3, 2)
        ┌─────────────┬──────────────────┐
        │ article_ids ┆ list_destination │
        │ ---         ┆ ---              │
        │ list[i64]   ┆ list[i64]        │
        ╞═════════════╪══════════════════╡
        │ [1, 2, 3]   ┆ [1]              │
        └─────────────┴──────────────────┘
    """
    return (
        df.filter((pl.col(inview_col).list.len() - pl.col(clicked_col).list.len()) >= n)
        if n is not None and n > 0
        else df
    )


def ebnerd_from_path(
    path: Path,
    history_size: int = 30,
    padding: int = 0,
    user_col: str = DEFAULT_USER_COL,
    history_aids_col: str = DEFAULT_HISTORY_ARTICLE_ID_COL,
) -> pl.DataFrame:
    """
    Load ebnerd - function
    """
    df_history = (
        pl.scan_parquet(r"C:\Users\bilba\Downloads\DL_Small\history.parquet")
        .select(user_col, history_aids_col)
        .pipe(
            truncate_history,
            column=history_aids_col,
            history_size=history_size,
            padding_value=padding,
            enable_warning=False,
        )
    )
    df_behaviors = (
        pl.scan_parquet(r"C:\Users\bilba\Downloads\DL_Small\behaviors.parquet")
        .collect()
        .pipe(
            slice_join_dataframes,
            df2=df_history.collect(),
            on=user_col,
            how="left",
        )
    )
    return df_behaviors


def filter_read_times(df, n: int, column: str) -> pl.DataFrame:
    """
    Use this to set the cutoff for 'read_time' and 'next_read_time'
    """
    return (
        df.filter(pl.col(column) >= n)
        if column in df and n is not None and n > 0
        else df
    )


def unique_article_ids_in_behaviors(
    df: pl.DataFrame,
    col: str = "ids",
    item_col: str = DEFAULT_ARTICLE_ID_COL,
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL,
    clicked_col: str = DEFAULT_CLICKED_ARTICLES_COL,
) -> pl.Series:
    """
    Examples:
        >>> df = pl.DataFrame({
                DEFAULT_ARTICLE_ID_COL: [1, 2, 3, 4],
                DEFAULT_INVIEW_ARTICLES_COL: [[2, 3], [1, 4], [4], [1, 2, 3]],
                DEFAULT_CLICKED_ARTICLES_COL: [[], [2], [3, 4], [1]],
            })
        >>> unique_article_ids_in_behaviors(df).sort()
            [
                1
                2
                3
                4
            ]
    """
    df = df.lazy()
    return (
        pl.concat(
            (
                df.select(pl.col(item_col).unique().alias(col)),
                df.select(pl.col(inview_col).explode().unique().alias(col)),
                df.select(pl.col(clicked_col).explode().unique().alias(col)),
            )
        )
        .drop_nulls()
        .unique()
        .collect()
    ).to_series()


def add_known_user_column(
    df: pl.DataFrame,
    known_users: Iterable[int],
    user_col: str = DEFAULT_USER_COL,
    known_user_col: str = DEFAULT_KNOWN_USER_COL,
) -> pl.DataFrame:
    """
    Adds a new column to the DataFrame indicating whether the user ID is in the list of known users.
    Args:
        df: A Polars DataFrame object.
        known_users: An iterable of integers representing the known user IDs.
    Returns:
        A new Polars DataFrame with an additional column 'is_known_user' containing a boolean value
        indicating whether the user ID is in the list of known users.
    Examples:
        >>> df = pl.DataFrame({'user_id': [1, 2, 3, 4]})
        >>> add_known_user_column(df, [2, 4])
            shape: (4, 2)
            ┌─────────┬───────────────┐
            │ user_id ┆ is_known_user │
            │ ---     ┆ ---           │
            │ i64     ┆ bool          │
            ╞═════════╪═══════════════╡
            │ 1       ┆ false         │
            │ 2       ┆ true          │
            │ 3       ┆ false         │
            │ 4       ┆ true          │
            └─────────┴───────────────┘
    """
    return df.with_columns(pl.col(user_col).is_in(known_users).alias(known_user_col))


def sample_article_ids(
    df: pl.DataFrame,
    n: int,
    with_replacement: bool = False,
    seed: int = None,
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL,
) -> pl.DataFrame:
    """
    Randomly sample article IDs from each row of a DataFrame with or without replacement

    Args:
        df: A polars DataFrame containing the column of article IDs to be sampled.
        n: The number of article IDs to sample from each list.
        with_replacement: A boolean indicating whether to sample with replacement.
            Default is False.
        seed: An optional seed to use for the random number generator.

    Returns:
        A new polars DataFrame with the same columns as `df`, but with the article
        IDs in the specified column replaced by a list of `n` sampled article IDs.

    Examples:
    >>> from ebrec.utils._constants import DEFAULT_INVIEW_ARTICLES_COL
    >>> df = pl.DataFrame(
            {
                "clicked": [
                    [1],
                    [4, 5],
                    [7, 8, 9],
                ],
                DEFAULT_INVIEW_ARTICLES_COL: [
                    ["A", "B", "C"],
                    ["D", "E", "F"],
                    ["G", "H", "I"],
                ],
                "col" : [
                    ["h"],
                    ["e"],
                    ["y"]
                ]
            }
        )
    >>> print(df)
        shape: (3, 3)
        ┌──────────────────┬─────────────────┬───────────┐
        │ list_destination ┆ article_ids     ┆ col       │
        │ ---              ┆ ---             ┆ ---       │
        │ list[i64]        ┆ list[str]       ┆ list[str] │
        ╞══════════════════╪═════════════════╪═══════════╡
        │ [1]              ┆ ["A", "B", "C"] ┆ ["h"]     │
        │ [4, 5]           ┆ ["D", "E", "F"] ┆ ["e"]     │
        │ [7, 8, 9]        ┆ ["G", "H", "I"] ┆ ["y"]     │
        └──────────────────┴─────────────────┴───────────┘
    >>> sample_article_ids(df, n=2, seed=42)
        shape: (3, 3)
        ┌──────────────────┬─────────────┬───────────┐
        │ list_destination ┆ article_ids ┆ col       │
        │ ---              ┆ ---         ┆ ---       │
        │ list[i64]        ┆ list[str]   ┆ list[str] │
        ╞══════════════════╪═════════════╪═══════════╡
        │ [1]              ┆ ["A", "C"]  ┆ ["h"]     │
        │ [4, 5]           ┆ ["D", "F"]  ┆ ["e"]     │
        │ [7, 8, 9]        ┆ ["G", "I"]  ┆ ["y"]     │
        └──────────────────┴─────────────┴───────────┘
    >>> sample_article_ids(df.lazy(), n=4, with_replacement=True, seed=42).collect()
        shape: (3, 3)
        ┌──────────────────┬───────────────────┬───────────┐
        │ list_destination ┆ article_ids       ┆ col       │
        │ ---              ┆ ---               ┆ ---       │
        │ list[i64]        ┆ list[str]         ┆ list[str] │
        ╞══════════════════╪═══════════════════╪═══════════╡
        │ [1]              ┆ ["A", "A", … "C"] ┆ ["h"]     │
        │ [4, 5]           ┆ ["D", "D", … "F"] ┆ ["e"]     │
        │ [7, 8, 9]        ┆ ["G", "G", … "I"] ┆ ["y"]     │
        └──────────────────┴───────────────────┴───────────┘
    """
    _check_columns_in_df(df, [inview_col])
    _COLUMNS = df.columns
    GROUPBY_ID = generate_unique_name(_COLUMNS, "_groupby_id")
    df = df.with_row_count(name=GROUPBY_ID)

    df_ = (
        df.explode(inview_col)
        .group_by(GROUPBY_ID)
        .agg(
            pl.col(inview_col).sample(n=n, with_replacement=with_replacement, seed=seed)
        )
    )
    return (
        df.drop(inview_col)
        .join(df_, on=GROUPBY_ID, how="left")
        .drop(GROUPBY_ID)
        .select(_COLUMNS)
    )


def remove_positives_from_inview(
    df: pl.DataFrame,
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL,
    clicked_col: str = DEFAULT_CLICKED_ARTICLES_COL,
):
    """Removes all positive article IDs from a DataFrame column containing inview articles and another column containing
    clicked articles. Only negative article IDs (i.e., those that appear in the inview articles column but not in the
    clicked articles column) are retained.

    Args:
        df (pl.DataFrame): A DataFrame with columns containing inview articles and clicked articles.

    Returns:
        pl.DataFrame: A new DataFrame with only negative article IDs retained.

    Examples:
    >>> from ebrec.utils._constants import DEFAULT_INVIEW_ARTICLES_COL, DEFAULT_CLICKED_ARTICLES_COL
    >>> df = pl.DataFrame(
            {
                "user_id": [1, 1, 2],
                DEFAULT_CLICKED_ARTICLES_COL: [
                    [1, 2],
                    [1],
                    [3],
                ],
                DEFAULT_INVIEW_ARTICLES_COL: [
                    [1, 2, 3],
                    [1, 2, 3],
                    [1, 2, 3],
                ],
            }
        )
    >>> remove_positives_from_inview(df)
        shape: (3, 3)
        ┌─────────┬─────────────────────┬────────────────────┐
        │ user_id ┆ article_ids_clicked ┆ article_ids_inview │
        │ ---     ┆ ---                 ┆ ---                │
        │ i64     ┆ list[i64]           ┆ list[i64]          │
        ╞═════════╪═════════════════════╪════════════════════╡
        │ 1       ┆ [1, 2]              ┆ [3]                │
        │ 1       ┆ [1]                 ┆ [2, 3]             │
        │ 2       ┆ [3]                 ┆ [1, 2]             │
        └─────────┴─────────────────────┴────────────────────┘
    """
    _check_columns_in_df(df, [inview_col, clicked_col])
    negative_article_ids = (
        list(filter(lambda x: x not in clicked, inview))
        for inview, clicked in zip(df[inview_col].to_list(), df[clicked_col].to_list())
    )
    return df.with_columns(pl.Series(inview_col, list(negative_article_ids)))


def sampling_strategy_wu2019(
    df: pl.DataFrame,
    npratio: int,
    shuffle: bool = False,
    with_replacement: bool = True,
    seed: int = None,
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL,
    clicked_col: str = DEFAULT_CLICKED_ARTICLES_COL,
) -> pl.DataFrame:
    """
    Samples negative articles from the inview article pool for a given negative-position-ratio (npratio).
    The npratio (negative article per positive article) is defined as the number of negative article samples
    to draw for each positive article sample.

    This function follows the sampling strategy introduced in the paper "NPA: Neural News Recommendation with
    Personalized Attention" by Wu et al. (KDD '19).

    This is done according to the following steps:
    1. Remove the positive click-article id pairs from the DataFrame.
    2. Explode the DataFrame based on the clicked articles column.
    3. Downsample the inview negative article ids for each exploded row using the specified npratio, either
        with or without replacement.
    4. Concatenate the clicked articles back to the inview articles as lists.
    5. Convert clicked articles column to type List(Int)

    References:
        Chuhan Wu, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang, and Xing Xie. 2019.
        Npa: Neural news recommendation with personalized attention. In KDD, pages 2576-2584. ACM.

    Args:
        df (pl.DataFrame): The input DataFrame containing click-article id pairs.
        npratio (int): The ratio of negative in-view article ids to positive click-article ids.
        shuffle (bool, optional): Whether to shuffle the order of the in-view article ids in each list. Default is True.
        with_replacement (bool, optional): Whether to sample the inview article ids with or without replacement.
            Default is True.
        seed (int, optional): Random seed for reproducibility. Default is None.
        inview_col (int, optional): inview column name. Default is DEFAULT_INVIEW_ARTICLES_COL,
        clicked_col (int, optional): clicked column name. Default is DEFAULT_CLICKED_ARTICLES_COL,

    Returns:
        pl.DataFrame: A new DataFrame with downsampled in-view article ids for each click according to the specified npratio.
        The DataFrame has the same columns as the input DataFrame.

    Raises:
        ValueError: If npratio is less than 0.
        ValueError: If the input DataFrame does not contain the necessary columns.

    Examples:
    >>> from ebrec.utils._constants import DEFAULT_CLICKED_ARTICLES_COL, DEFAULT_INVIEW_ARTICLES_COL
    >>> import polars as pl
    >>> df = pl.DataFrame(
            {
                "impression_id": [0, 1, 2, 3],
                "user_id": [1, 1, 2, 3],
                DEFAULT_INVIEW_ARTICLES_COL: [[1, 2, 3], [1, 2, 3, 4], [1, 2, 3], [1]],
                DEFAULT_CLICKED_ARTICLES_COL: [[1, 2], [1, 3], [1], [1]],
            }
        )
    >>> df
        shape: (4, 4)
        ┌───────────────┬─────────┬────────────────────┬─────────────────────┐
        │ impression_id ┆ user_id ┆ article_ids_inview ┆ article_ids_clicked │
        │ ---           ┆ ---     ┆ ---                ┆ ---                 │
        │ i64           ┆ i64     ┆ list[i64]          ┆ list[i64]           │
        ╞═══════════════╪═════════╪════════════════════╪═════════════════════╡
        │ 0             ┆ 1       ┆ [1, 2, 3]          ┆ [1, 2]              │
        │ 1             ┆ 1       ┆ [1, 2, … 4]        ┆ [1, 3]              │
        │ 2             ┆ 2       ┆ [1, 2, 3]          ┆ [1]                 │
        │ 3             ┆ 3       ┆ [1]                ┆ [1]                 │
        └───────────────┴─────────┴────────────────────┴─────────────────────┘
    >>> sampling_strategy_wu2019(df, npratio=1, shuffle=False, with_replacement=True, seed=123)
        shape: (6, 4)
        ┌───────────────┬─────────┬────────────────────┬─────────────────────┐
        │ impression_id ┆ user_id ┆ article_ids_inview ┆ article_ids_clicked │
        │ ---           ┆ ---     ┆ ---                ┆ ---                 │
        │ i64           ┆ i64     ┆ list[i64]          ┆ list[i64]           │
        ╞═══════════════╪═════════╪════════════════════╪═════════════════════╡
        │ 0             ┆ 1       ┆ [3, 1]             ┆ [1]                 │
        │ 0             ┆ 1       ┆ [3, 2]             ┆ [2]                 │
        │ 1             ┆ 1       ┆ [4, 1]             ┆ [1]                 │
        │ 1             ┆ 1       ┆ [4, 3]             ┆ [3]                 │
        │ 2             ┆ 2       ┆ [3, 1]             ┆ [1]                 │
        │ 3             ┆ 3       ┆ [null, 1]          ┆ [1]                 │
        └───────────────┴─────────┴────────────────────┴─────────────────────┘
    >>> sampling_strategy_wu2019(df, npratio=1, shuffle=True, with_replacement=True, seed=123)
        shape: (6, 4)
        ┌───────────────┬─────────┬────────────────────┬─────────────────────┐
        │ impression_id ┆ user_id ┆ article_ids_inview ┆ article_ids_clicked │
        │ ---           ┆ ---     ┆ ---                ┆ ---                 │
        │ i64           ┆ i64     ┆ list[i64]          ┆ list[i64]           │
        ╞═══════════════╪═════════╪════════════════════╪═════════════════════╡
        │ 0             ┆ 1       ┆ [3, 1]             ┆ [1]                 │
        │ 0             ┆ 1       ┆ [2, 3]             ┆ [2]                 │
        │ 1             ┆ 1       ┆ [4, 1]             ┆ [1]                 │
        │ 1             ┆ 1       ┆ [4, 3]             ┆ [3]                 │
        │ 2             ┆ 2       ┆ [3, 1]             ┆ [1]                 │
        │ 3             ┆ 3       ┆ [null, 1]          ┆ [1]                 │
        └───────────────┴─────────┴────────────────────┴─────────────────────┘
    >>> sampling_strategy_wu2019(df, npratio=2, shuffle=False, with_replacement=True, seed=123)
        shape: (6, 4)
        ┌───────────────┬─────────┬────────────────────┬─────────────────────┐
        │ impression_id ┆ user_id ┆ article_ids_inview ┆ article_ids_clicked │
        │ ---           ┆ ---     ┆ ---                ┆ ---                 │
        │ i64           ┆ i64     ┆ list[i64]          ┆ list[i64]           │
        ╞═══════════════╪═════════╪════════════════════╪═════════════════════╡
        │ 0             ┆ 1       ┆ [3, 3, 1]          ┆ [1]                 │
        │ 0             ┆ 1       ┆ [3, 3, 2]          ┆ [2]                 │
        │ 1             ┆ 1       ┆ [4, 2, 1]          ┆ [1]                 │
        │ 1             ┆ 1       ┆ [4, 2, 3]          ┆ [3]                 │
        │ 2             ┆ 2       ┆ [3, 2, 1]          ┆ [1]                 │
        │ 3             ┆ 3       ┆ [null, null, 1]    ┆ [1]                 │
        └───────────────┴─────────┴────────────────────┴─────────────────────┘
    # If we use without replacement, we need to ensure there are enough negative samples:
    >>> sampling_strategy_wu2019(df, npratio=2, shuffle=False, with_replacement=False, seed=123)
        polars.exceptions.ShapeError: cannot take a larger sample than the total population when `with_replacement=false`
    ## Either you'll have to remove the samples or split the dataframe yourself and only upsample the samples that doesn't have enough
    >>> min_neg = 2
    >>> sampling_strategy_wu2019(
            df.filter(pl.col(DEFAULT_INVIEW_ARTICLES_COL).list.len() > (min_neg + 1)),
            npratio=min_neg,
            shuffle=False,
            with_replacement=False,
            seed=123,
        )
        shape: (2, 4)
        ┌───────────────┬─────────┬────────────────────┬─────────────────────┐
        │ impression_id ┆ user_id ┆ article_ids_inview ┆ article_ids_clicked │
        │ ---           ┆ ---     ┆ ---                ┆ ---                 │
        │ i64           ┆ i64     ┆ list[i64]          ┆ i64                 │
        ╞═══════════════╪═════════╪════════════════════╪═════════════════════╡
        │ 1             ┆ 1       ┆ [2, 4, 1]          ┆ 1                   │
        │ 1             ┆ 1       ┆ [2, 4, 3]          ┆ 3                   │
        └───────────────┴─────────┴────────────────────┴─────────────────────┘
    """
    df = (
        # Step 1: Remove the positive 'article_id' from inview articles
        df.pipe(
            remove_positives_from_inview, inview_col=inview_col, clicked_col=clicked_col
        )
        # Step 2: Explode the DataFrame based on the clicked articles column
        .explode(clicked_col)
        # Step 3: Downsample the inview negative 'article_id' according to npratio (negative 'article_id' per positive 'article_id')
        .pipe(
            sample_article_ids,
            n=npratio,
            with_replacement=with_replacement,
            seed=seed,
            inview_col=inview_col,
        )
        # Step 4: Concatenate the clicked articles back to the inview articles as lists
        .with_columns(pl.concat_list([inview_col, clicked_col]))
        # Step 5: Convert clicked articles column to type List(Int):
        .with_columns(pl.col(inview_col).list.tail(1).alias(clicked_col))
    )
    if shuffle:
        df = shuffle_list_column(df, inview_col, seed)
    return df


def truncate_history(
    df: pl.DataFrame,
    column: str,
    history_size: int,
    padding_value: Any = None,
    enable_warning: bool = True,
) -> pl.DataFrame:
    """Truncates the history of a column containing a list of items.

    It is the tail of the values, i.e. the history ids should ascending order
    because each subsequent element (original timestamp) is greater than the previous element

    Args:
        df (pl.DataFrame): The input DataFrame.
        column (str): The name of the column to truncate.
        history_size (int): The maximum size of the history to retain.
        padding_value (Any): Pad each list with specified value, ensuring
            equal length to each element. Default is None (no padding).
        enable_warning (bool): warn the user that history is expected in ascedings order.
            Default is True

    Returns:
        pl.DataFrame: A new DataFrame with the specified column truncated.

    Examples:
    >>> df = pl.DataFrame(
            {"id": [1, 2, 3], "history": [["a", "b", "c"], ["d", "e", "f", "g"], ["h", "i"]]}
        )
    >>> df
        shape: (3, 2)
        ┌─────┬───────────────────┐
        │ id  ┆ history           │
        │ --- ┆ ---               │
        │ i64 ┆ list[str]         │
        ╞═════╪═══════════════════╡
        │ 1   ┆ ["a", "b", "c"]   │
        │ 2   ┆ ["d", "e", … "g"] │
        │ 3   ┆ ["h", "i"]        │
        └─────┴───────────────────┘
    >>> truncate_history(df, 'history', 3)
        shape: (3, 2)
        ┌─────┬─────────────────┐
        │ id  ┆ history         │
        │ --- ┆ ---             │
        │ i64 ┆ list[str]       │
        ╞═════╪═════════════════╡
        │ 1   ┆ ["a", "b", "c"] │
        │ 2   ┆ ["e", "f", "g"] │
        │ 3   ┆ ["h", "i"]      │
        └─────┴─────────────────┘
    >>> truncate_history(df.lazy(), 'history', 3, '-').collect()
        shape: (3, 2)
        ┌─────┬─────────────────┐
        │ id  ┆ history         │
        │ --- ┆ ---             │
        │ i64 ┆ list[str]       │
        ╞═════╪═════════════════╡
        │ 1   ┆ ["a", "b", "c"] │
        │ 2   ┆ ["e", "f", "g"] │
        │ 3   ┆ ["-", "h", "i"] │
        └─────┴─────────────────┘
    """
    if enable_warning:
        function_name = inspect.currentframe().f_code.co_name
        warnings.warn(f"{function_name}: The history IDs expeced in ascending order")
    if padding_value is not None:
        df = df.with_columns(
            pl.col(column)
            .list.reverse()
            .list.eval(pl.element().extend_constant(padding_value, n=history_size))
            .list.reverse()
        )
    return df.with_columns(pl.col(column).list.tail(history_size))


def create_dynamic_history(
    df: pl.DataFrame,
    history_size: int,
    history_col: str = "history_dynamic",
    user_col: str = DEFAULT_USER_COL,
    item_col: str = DEFAULT_ARTICLE_ID_COL,
    timestamp_col: str = DEFAULT_IMPRESSION_TIMESTAMP_COL,
) -> pl.DataFrame:
    """Generates a dynamic history of user interactions with articles based on a given DataFrame.

    Beaware, the groupby_rolling will add all the Null values, which can only be removed afterwards.
    Unlike the 'create_fixed_history' where we first remove all the Nulls, we can only do this afterwards.
    As a results, the 'history_size' might be set to N but after removal of Nulls it is (N-n_nulls) long.

    Args:
        df (pl.DataFrame): A Polars DataFrame with columns 'user_id', 'article_id', and 'first_page_time'.
        history_size (int): The maximum number of previous interactions to include in the dynamic history for each user.

    Returns:
        pl.DataFrame: A new Polars DataFrame with the same columns as the input DataFrame, plus two new columns per user:
        - 'dynamic_article_id': a list of up to 'history_size' article IDs representing the user's previous interactions,
            ordered from most to least recent. If there are fewer than 'history_size' previous interactions, the list
            is padded with 'None' values.
    Raises:
        ValueError: If the input DataFrame does not contain columns 'user_id', 'article_id', and 'first_page_time'.

    Examples:
    >>> from ebrec.utils._constants import (
            DEFAULT_IMPRESSION_TIMESTAMP_COL,
            DEFAULT_ARTICLE_ID_COL,
            DEFAULT_USER_COL,
        )
    >>> df = pl.DataFrame(
            {
                DEFAULT_USER_COL: [0, 0, 0, 1, 1, 1, 0, 2],
                DEFAULT_ARTICLE_ID_COL: [
                    9604210,
                    9634540,
                    9640420,
                    9647983,
                    9647984,
                    9647981,
                    None,
                    None,
                ],
                DEFAULT_IMPRESSION_TIMESTAMP_COL: [
                    datetime.datetime(2023, 2, 18),
                    datetime.datetime(2023, 2, 18),
                    datetime.datetime(2023, 2, 25),
                    datetime.datetime(2023, 2, 22),
                    datetime.datetime(2023, 2, 21),
                    datetime.datetime(2023, 2, 23),
                    datetime.datetime(2023, 2, 19),
                    datetime.datetime(2023, 2, 26),
                ],
            }
        )
    >>> create_dynamic_history(df, 3)
        shape: (8, 4)
        ┌─────────┬────────────┬─────────────────────┬────────────────────┐
        │ user_id ┆ article_id ┆ impression_time     ┆ history_dynamic    │
        │ ---     ┆ ---        ┆ ---                 ┆ ---                │
        │ i64     ┆ i64        ┆ datetime[μs]        ┆ list[i64]          │
        ╞═════════╪════════════╪═════════════════════╪════════════════════╡
        │ 0       ┆ 9604210    ┆ 2023-02-18 00:00:00 ┆ []                 │
        │ 0       ┆ 9634540    ┆ 2023-02-18 00:00:00 ┆ [9604210]          │
        │ 0       ┆ null       ┆ 2023-02-19 00:00:00 ┆ [9604210, 9634540] │
        │ 0       ┆ 9640420    ┆ 2023-02-25 00:00:00 ┆ [9604210, 9634540] │
        │ 1       ┆ 9647984    ┆ 2023-02-21 00:00:00 ┆ []                 │
        │ 1       ┆ 9647983    ┆ 2023-02-22 00:00:00 ┆ [9647984]          │
        │ 1       ┆ 9647981    ┆ 2023-02-23 00:00:00 ┆ [9647984, 9647983] │
        │ 2       ┆ null       ┆ 2023-02-26 00:00:00 ┆ []                 │
        └─────────┴────────────┴─────────────────────┴────────────────────┘
    """
    _check_columns_in_df(df, [user_col, timestamp_col, item_col])
    GROUPBY_ID = generate_unique_name(df.columns, "_groupby_id")
    df = df.sort([user_col, timestamp_col])
    return (
        df.with_columns(
            # DYNAMIC HISTORY START
            df.with_row_index(name=GROUPBY_ID)
            .with_columns(pl.col([GROUPBY_ID]).cast(pl.Int64))
            .rolling(
                index_column=GROUPBY_ID,
                period=f"{history_size}i",
                closed="left",
                by=[user_col],
            )
            .agg(pl.col(item_col).alias(history_col))
            # DYNAMIC HISTORY END
        )
        .pipe(drop_nulls_from_list, column=history_col)
        .drop(GROUPBY_ID)
    )


def create_fixed_history(
    df: pl.DataFrame,
    dt_cutoff: datetime,
    history_size: int = None,
    history_col: str = "history_fixed",
    user_col: str = DEFAULT_USER_COL,
    item_col: str = DEFAULT_ARTICLE_ID_COL,
    timestamp_col: str = DEFAULT_IMPRESSION_TIMESTAMP_COL,
) -> pl.DataFrame:
    """
    Create fixed histories for each user in a dataframe of user browsing behavior.

    Args:
        df (pl.DataFrame): A dataframe with columns "user_id", "first_page_time", and "article_id", representing user browsing behavior.
        dt_cutoff (datetime): A datetime object representing the cutoff time. Only browsing behavior before this time will be considered.
        history_size (int, optional): The maximum number of previous interactions to include in the fixed history for each user (using tail). Default is None.
            If None, all interactions are included.

    Returns:
        pl.DataFrame: A modified dataframe with columns "user_id" and "fixed_article_id". Each row represents a user and their fixed browsing history,
        which is a list of article IDs. The "fixed_" prefix is added to distinguish the fixed history from the original "article_id" column.

    Raises:
        ValueError: If the input dataframe does not contain the required columns.

    Examples:
        >>> from ebrec.utils._constants import (
                DEFAULT_IMPRESSION_TIMESTAMP_COL,
                DEFAULT_ARTICLE_ID_COL,
                DEFAULT_USER_COL,
            )
        >>> df = pl.DataFrame(
                {
                    DEFAULT_USER_COL: [0, 0, 0, 1, 1, 1, 0, 2],
                    DEFAULT_ARTICLE_ID_COL: [
                        9604210,
                        9634540,
                        9640420,
                        9647983,
                        9647984,
                        9647981,
                        None,
                        None,
                    ],
                    DEFAULT_IMPRESSION_TIMESTAMP_COL: [
                        datetime.datetime(2023, 2, 18),
                        datetime.datetime(2023, 2, 18),
                        datetime.datetime(2023, 2, 25),
                        datetime.datetime(2023, 2, 22),
                        datetime.datetime(2023, 2, 21),
                        datetime.datetime(2023, 2, 23),
                        datetime.datetime(2023, 2, 19),
                        datetime.datetime(2023, 2, 26),
                    ],
                }
            )
        >>> dt_cutoff = datetime.datetime(2023, 2, 24)
        >>> create_fixed_history(df.lazy(), dt_cutoff).collect()
            shape: (8, 4)
            ┌─────────┬────────────┬─────────────────────┬─────────────────────────────┐
            │ user_id ┆ article_id ┆ impression_time     ┆ history_fixed               │
            │ ---     ┆ ---        ┆ ---                 ┆ ---                         │
            │ i64     ┆ i64        ┆ datetime[μs]        ┆ list[i64]                   │
            ╞═════════╪════════════╪═════════════════════╪═════════════════════════════╡
            │ 0       ┆ 9604210    ┆ 2023-02-18 00:00:00 ┆ [9604210, 9634540]          │
            │ 0       ┆ 9634540    ┆ 2023-02-18 00:00:00 ┆ [9604210, 9634540]          │
            │ 0       ┆ null       ┆ 2023-02-19 00:00:00 ┆ [9604210, 9634540]          │
            │ 0       ┆ 9640420    ┆ 2023-02-25 00:00:00 ┆ [9604210, 9634540]          │
            │ 1       ┆ 9647984    ┆ 2023-02-21 00:00:00 ┆ [9647984, 9647983, 9647981] │
            │ 1       ┆ 9647983    ┆ 2023-02-22 00:00:00 ┆ [9647984, 9647983, 9647981] │
            │ 1       ┆ 9647981    ┆ 2023-02-23 00:00:00 ┆ [9647984, 9647983, 9647981] │
            │ 2       ┆ null       ┆ 2023-02-26 00:00:00 ┆ null                        │
            └─────────┴────────────┴─────────────────────┴─────────────────────────────┘
        >>> create_fixed_history(df.lazy(), dt_cutoff, 1).collect()
            shape: (8, 4)
            ┌─────────┬────────────┬─────────────────────┬───────────────┐
            │ user_id ┆ article_id ┆ impression_time     ┆ history_fixed │
            │ ---     ┆ ---        ┆ ---                 ┆ ---           │
            │ i64     ┆ i64        ┆ datetime[μs]        ┆ list[i64]     │
            ╞═════════╪════════════╪═════════════════════╪═══════════════╡
            │ 0       ┆ 9604210    ┆ 2023-02-18 00:00:00 ┆ [9634540]     │
            │ 0       ┆ 9634540    ┆ 2023-02-18 00:00:00 ┆ [9634540]     │
            │ 0       ┆ null       ┆ 2023-02-19 00:00:00 ┆ [9634540]     │
            │ 0       ┆ 9640420    ┆ 2023-02-25 00:00:00 ┆ [9634540]     │
            │ 1       ┆ 9647984    ┆ 2023-02-21 00:00:00 ┆ [9647981]     │
            │ 1       ┆ 9647983    ┆ 2023-02-22 00:00:00 ┆ [9647981]     │
            │ 1       ┆ 9647981    ┆ 2023-02-23 00:00:00 ┆ [9647981]     │
            │ 2       ┆ null       ┆ 2023-02-26 00:00:00 ┆ null          │
            └─────────┴────────────┴─────────────────────┴───────────────┘
    """
    _check_columns_in_df(df, [user_col, timestamp_col, item_col])

    df = df.sort(user_col, timestamp_col)
    df_history = (
        df.select(user_col, timestamp_col, item_col)
        .filter(pl.col(item_col).is_not_null())
        .filter(pl.col(timestamp_col) < dt_cutoff)
        .group_by(user_col)
        .agg(
            pl.col(item_col).alias(history_col),
        )
    )
    if history_size is not None:
        df_history = df_history.with_columns(
            pl.col(history_col).list.tail(history_size)
        )
    return df.join(df_history, on=user_col, how="left")


def create_fixed_history_aggr_columns(
    df: pl.DataFrame,
    dt_cutoff: datetime,
    history_size: int = None,
    columns: list[str] = [],
    suffix: str = "_fixed",
    user_col: str = DEFAULT_USER_COL,
    item_col: str = DEFAULT_ARTICLE_ID_COL,
    timestamp_col: str = DEFAULT_IMPRESSION_TIMESTAMP_COL,
) -> pl.DataFrame:
    """
    This function aggregates historical data in a Polars DataFrame based on a specified cutoff datetime and user-defined columns.
    The historical data is fixed to a given number of most recent records per user.

    Parameters:
        df (pl.DataFrame): The input Polars DataFrame OR LazyFrame.
        dt_cutoff (datetime): The cutoff datetime for filtering the history.
        history_size (int, optional): The number of most recent records to keep for each user.
            If None, all history before the cutoff is kept.
        columns (list[str], optional): List of column names to be included in the aggregation.
            These columns are in addition to the mandatory 'user_id', 'article_id', and 'impression_timestamp'.
        lazy_output (bool, optional): whether to output df as LazyFrame.

    Returns:
        pl.DataFrame: A new DataFrame with the original columns and added columns for each specified column in the history.
        Each new column contains a list of historical values.

    Raises:
        ValueError: If the input dataframe does not contain the required columns.

    Examples:
        >>> from ebrec.utils._constants import (
                DEFAULT_IMPRESSION_TIMESTAMP_COL,
                DEFAULT_ARTICLE_ID_COL,
                DEFAULT_READ_TIME_COL,
                DEFAULT_USER_COL,
            )
        >>> df = pl.DataFrame(
                {
                    DEFAULT_USER_COL: [0, 0, 0, 1, 1, 1, 0, 2],
                    DEFAULT_ARTICLE_ID_COL: [
                        9604210,
                        9634540,
                        9640420,
                        9647983,
                        9647984,
                        9647981,
                        None,
                        None,
                    ],
                    DEFAULT_IMPRESSION_TIMESTAMP_COL: [
                        datetime.datetime(2023, 2, 18),
                        datetime.datetime(2023, 2, 18),
                        datetime.datetime(2023, 2, 25),
                        datetime.datetime(2023, 2, 22),
                        datetime.datetime(2023, 2, 21),
                        datetime.datetime(2023, 2, 23),
                        datetime.datetime(2023, 2, 19),
                        datetime.datetime(2023, 2, 26),
                    ],
                    DEFAULT_READ_TIME_COL: [
                        0,
                        2,
                        8,
                        13,
                        1,
                        1,
                        6,
                        1
                    ],
                    "nothing": [
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                    ],
                }
            )
        >>> dt_cutoff = datetime.datetime(2023, 2, 24)
        >>> columns = [DEFAULT_IMPRESSION_TIMESTAMP_COL, DEFAULT_READ_TIME_COL]
        >>> create_fixed_history_aggr_columns(df.lazy(), dt_cutoff, columns=columns).collect()
            shape: (8, 8)
            ┌─────────┬────────────┬─────────────────────┬───────────┬─────────┬─────────────────┬─────────────────────────────┬───────────────────────────────────┐
            │ user_id ┆ article_id ┆ impression_time     ┆ read_time ┆ nothing ┆ read_time_fixed ┆ article_id_fixed            ┆ impression_time_fixed             │
            │ ---     ┆ ---        ┆ ---                 ┆ ---       ┆ ---     ┆ ---             ┆ ---                         ┆ ---                               │
            │ i64     ┆ i64        ┆ datetime[μs]        ┆ i64       ┆ null    ┆ list[i64]       ┆ list[i64]                   ┆ list[datetime[μs]]                │
            ╞═════════╪════════════╪═════════════════════╪═══════════╪═════════╪═════════════════╪═════════════════════════════╪═══════════════════════════════════╡
            │ 0       ┆ 9604210    ┆ 2023-02-18 00:00:00 ┆ 0         ┆ null    ┆ [0, 2]          ┆ [9604210, 9634540]          ┆ [2023-02-18 00:00:00, 2023-02-18… │
            │ 0       ┆ 9634540    ┆ 2023-02-18 00:00:00 ┆ 2         ┆ null    ┆ [0, 2]          ┆ [9604210, 9634540]          ┆ [2023-02-18 00:00:00, 2023-02-18… │
            │ 0       ┆ null       ┆ 2023-02-19 00:00:00 ┆ 6         ┆ null    ┆ [0, 2]          ┆ [9604210, 9634540]          ┆ [2023-02-18 00:00:00, 2023-02-18… │
            │ 0       ┆ 9640420    ┆ 2023-02-25 00:00:00 ┆ 8         ┆ null    ┆ [0, 2]          ┆ [9604210, 9634540]          ┆ [2023-02-18 00:00:00, 2023-02-18… │
            │ 1       ┆ 9647984    ┆ 2023-02-21 00:00:00 ┆ 1         ┆ null    ┆ [1, 13, 1]      ┆ [9647984, 9647983, 9647981] ┆ [2023-02-21 00:00:00, 2023-02-22… │
            │ 1       ┆ 9647983    ┆ 2023-02-22 00:00:00 ┆ 13        ┆ null    ┆ [1, 13, 1]      ┆ [9647984, 9647983, 9647981] ┆ [2023-02-21 00:00:00, 2023-02-22… │
            │ 1       ┆ 9647981    ┆ 2023-02-23 00:00:00 ┆ 1         ┆ null    ┆ [1, 13, 1]      ┆ [9647984, 9647983, 9647981] ┆ [2023-02-21 00:00:00, 2023-02-22… │
            │ 2       ┆ null       ┆ 2023-02-26 00:00:00 ┆ 1         ┆ null    ┆ null            ┆ null                        ┆ null                              │
            └─────────┴────────────┴─────────────────────┴───────────┴─────────┴─────────────────┴─────────────────────────────┴───────────────────────────────────┘
        >>> create_fixed_history_aggr_columns(df.lazy(), dt_cutoff, 1, columns=columns).collect()
            shape: (8, 8)
            ┌─────────┬────────────┬─────────────────────┬───────────┬─────────┬─────────────────┬──────────────────┬───────────────────────┐
            │ user_id ┆ article_id ┆ impression_time     ┆ read_time ┆ nothing ┆ read_time_fixed ┆ article_id_fixed ┆ impression_time_fixed │
            │ ---     ┆ ---        ┆ ---                 ┆ ---       ┆ ---     ┆ ---             ┆ ---              ┆ ---                   │
            │ i64     ┆ i64        ┆ datetime[μs]        ┆ i64       ┆ null    ┆ list[i64]       ┆ list[i64]        ┆ list[datetime[μs]]    │
            ╞═════════╪════════════╪═════════════════════╪═══════════╪═════════╪═════════════════╪══════════════════╪═══════════════════════╡
            │ 0       ┆ 9604210    ┆ 2023-02-18 00:00:00 ┆ 0         ┆ null    ┆ [2]             ┆ [9634540]        ┆ [2023-02-18 00:00:00] │
            │ 0       ┆ 9634540    ┆ 2023-02-18 00:00:00 ┆ 2         ┆ null    ┆ [2]             ┆ [9634540]        ┆ [2023-02-18 00:00:00] │
            │ 0       ┆ null       ┆ 2023-02-19 00:00:00 ┆ 6         ┆ null    ┆ [2]             ┆ [9634540]        ┆ [2023-02-18 00:00:00] │
            │ 0       ┆ 9640420    ┆ 2023-02-25 00:00:00 ┆ 8         ┆ null    ┆ [2]             ┆ [9634540]        ┆ [2023-02-18 00:00:00] │
            │ 1       ┆ 9647984    ┆ 2023-02-21 00:00:00 ┆ 1         ┆ null    ┆ [1]             ┆ [9647981]        ┆ [2023-02-23 00:00:00] │
            │ 1       ┆ 9647983    ┆ 2023-02-22 00:00:00 ┆ 13        ┆ null    ┆ [1]             ┆ [9647981]        ┆ [2023-02-23 00:00:00] │
            │ 1       ┆ 9647981    ┆ 2023-02-23 00:00:00 ┆ 1         ┆ null    ┆ [1]             ┆ [9647981]        ┆ [2023-02-23 00:00:00] │
            │ 2       ┆ null       ┆ 2023-02-26 00:00:00 ┆ 1         ┆ null    ┆ null            ┆ null             ┆ null                  │
            └─────────┴────────────┴─────────────────────┴───────────┴─────────┴─────────────────┴──────────────────┴───────────────────────┘
        >>> create_fixed_history_aggr_columns(df.lazy(), dt_cutoff, 1).collect()
            shape: (8, 6)
            ┌─────────┬────────────┬─────────────────────┬───────────┬─────────┬──────────────────┐
            │ user_id ┆ article_id ┆ impression_time     ┆ read_time ┆ nothing ┆ article_id_fixed │
            │ ---     ┆ ---        ┆ ---                 ┆ ---       ┆ ---     ┆ ---              │
            │ i64     ┆ i64        ┆ datetime[μs]        ┆ i64       ┆ null    ┆ list[i64]        │
            ╞═════════╪════════════╪═════════════════════╪═══════════╪═════════╪══════════════════╡
            │ 0       ┆ 9604210    ┆ 2023-02-18 00:00:00 ┆ 0         ┆ null    ┆ [9634540]        │
            │ 0       ┆ 9634540    ┆ 2023-02-18 00:00:00 ┆ 2         ┆ null    ┆ [9634540]        │
            │ 0       ┆ null       ┆ 2023-02-19 00:00:00 ┆ 6         ┆ null    ┆ [9634540]        │
            │ 0       ┆ 9640420    ┆ 2023-02-25 00:00:00 ┆ 8         ┆ null    ┆ [9634540]        │
            │ 1       ┆ 9647984    ┆ 2023-02-21 00:00:00 ┆ 1         ┆ null    ┆ [9647981]        │
            │ 1       ┆ 9647983    ┆ 2023-02-22 00:00:00 ┆ 13        ┆ null    ┆ [9647981]        │
            │ 1       ┆ 9647981    ┆ 2023-02-23 00:00:00 ┆ 1         ┆ null    ┆ [9647981]        │
            │ 2       ┆ null       ┆ 2023-02-26 00:00:00 ┆ 1         ┆ null    ┆ null             │
            └─────────┴────────────┴─────────────────────┴───────────┴─────────┴──────────────────┘
        >>> create_fixed_history_aggr_columns(df.lazy(), dt_cutoff, 1).head(1).collect()
            shape: (1, 6)
            ┌─────────┬────────────┬─────────────────────┬───────────┬─────────┬──────────────────┐
            │ user_id ┆ article_id ┆ impression_time     ┆ read_time ┆ nothing ┆ article_id_fixed │
            │ ---     ┆ ---        ┆ ---                 ┆ ---       ┆ ---     ┆ ---              │
            │ i64     ┆ i64        ┆ datetime[μs]        ┆ i64       ┆ null    ┆ list[i64]        │
            ╞═════════╪════════════╪═════════════════════╪═══════════╪═════════╪══════════════════╡
            │ 0       ┆ 9604210    ┆ 2023-02-18 00:00:00 ┆ 0         ┆ null    ┆ [9634540]        │
            └─────────┴────────────┴─────────────────────┴───────────┴─────────┴──────────────────┘
    """
    _check_columns_in_df(df, [user_col, item_col, timestamp_col] + columns)
    aggr_columns = list(set([item_col] + columns))
    df = df.sort(user_col, timestamp_col)
    df_history = (
        df.select(pl.all())
        .filter(pl.col(item_col).is_not_null())
        .filter(pl.col(timestamp_col) < dt_cutoff)
        .group_by(user_col)
        .agg(
            pl.col(aggr_columns).suffix(suffix),
        )
    )
    if history_size is not None:
        for col in aggr_columns:
            df_history = df_history.with_columns(
                pl.col(col + suffix).list.tail(history_size)
            )
    return df.join(df_history, on="user_id", how="left")


def add_prediction_scores(
    df: pl.DataFrame,
    scores: Iterable[float],
    prediction_scores_col: str = "scores",
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL,
) -> pl.DataFrame:
    """
    Adds prediction scores to a DataFrame for the corresponding test predictions.

    Args:
        df (pl.DataFrame): The DataFrame to which the prediction scores will be added.
        test_prediction (Iterable[float]): A list, array or simialr of prediction scores for the test data.

    Returns:
        pl.DataFrame: The DataFrame with the prediction scores added.

    Raises:
        ValueError: If there is a mismatch in the lengths of the list columns.

    >>> from ebrec.utils._constants import DEFAULT_INVIEW_ARTICLES_COL
    >>> df = pl.DataFrame(
            {
                "id": [1,2],
                DEFAULT_INVIEW_ARTICLES_COL: [
                    [1, 2, 3],
                    [4, 5],
                ],
            }
        )
    >>> test_prediction = [[0.3], [0.4], [0.5], [0.6], [0.7]]
    >>> add_prediction_scores(df.lazy(), test_prediction).collect()
        shape: (2, 3)
        ┌─────┬─────────────┬────────────────────────┐
        │ id  ┆ article_ids ┆ prediction_scores_test │
        │ --- ┆ ---         ┆ ---                    │
        │ i64 ┆ list[i64]   ┆ list[f32]              │
        ╞═════╪═════════════╪════════════════════════╡
        │ 1   ┆ [1, 2, 3]   ┆ [0.3, 0.4, 0.5]        │
        │ 2   ┆ [4, 5]      ┆ [0.6, 0.7]             │
        └─────┴─────────────┴────────────────────────┘
    ## The input can can also be an np.array
    >>> add_prediction_scores(df.lazy(), np.array(test_prediction)).collect()
        shape: (2, 3)
        ┌─────┬─────────────┬────────────────────────┐
        │ id  ┆ article_ids ┆ prediction_scores_test │
        │ --- ┆ ---         ┆ ---                    │
        │ i64 ┆ list[i64]   ┆ list[f32]              │
        ╞═════╪═════════════╪════════════════════════╡
        │ 1   ┆ [1, 2, 3]   ┆ [0.3, 0.4, 0.5]        │
        │ 2   ┆ [4, 5]      ┆ [0.6, 0.7]             │
        └─────┴─────────────┴────────────────────────┘
    """
    GROUPBY_ID = generate_unique_name(df.columns, "_groupby_id")
    # df_preds = pl.DataFrame()
    scores = (
        df.lazy()
        .select(pl.col(inview_col))
        .with_row_index(GROUPBY_ID)
        .explode(inview_col)
        .with_columns(pl.Series(prediction_scores_col, scores).explode())
        .group_by(GROUPBY_ID)
        .agg(inview_col, prediction_scores_col)
        .sort(GROUPBY_ID)
        .collect()
    )
    return df.with_columns(scores.select(prediction_scores_col)).drop(GROUPBY_ID)


def down_sample_on_users(
    df: pl.DataFrame,
    n: int,
    user_col: str = DEFAULT_USER_COL,
    seed: int = None,
) -> pl.DataFrame:
    """
    Down-samples a DataFrame by randomly selecting up to 'n' rows per unique user.

    Args:
        df (pl.DataFrame): The input DataFrame to be down-sampled.
        n (int): The maximum number of rows to retain per user.
        user_col (str): The column representing user identifiers. Defaults to DEFAULT_USER_COL.
        seed (int, optional): The random seed for reproducibility. Defaults to None.

    Returns:
        pl.DataFrame: A down-sampled DataFrame with at most 'n' rows per user.
    >>> import polars as pl
    >>> df = pl.DataFrame(
            {
                "user_id": [1, 1, 1, 2, 2, 3],
                "value": [10, 20, 30, 40, 50, 60],
            }
        )
    >>> down_sample_on_users(df, n=2, user_col="user_id", seed=42)
        shape: (5, 2)
        ┌─────────┬───────┐
        │ user_id ┆ value │
        │ ---     ┆ ---   │
        │ i64     ┆ i64   │
        ╞═════════╪═══════╡
        │ 1       ┆ 10    │
        │ 1       ┆ 20    │
        │ 2       ┆ 40    │
        │ 2       ┆ 50    │
        │ 3       ┆ 60    │
        └─────────┴───────┘
    """

    GROUPBY_ID = generate_unique_name(df.columns, "_groupby_id")
    df = df.with_row_index(GROUPBY_ID)

    filter_index = (
        df.sample(fraction=1.0, shuffle=True, seed=seed)
        .group_by(pl.col(user_col))
        .agg(GROUPBY_ID)
        .with_columns(pl.col(GROUPBY_ID).list.tail(n))
    ).select(pl.col(GROUPBY_ID).explode())

    return df.filter(pl.col(GROUPBY_ID).is_in(filter_index)).drop(GROUPBY_ID)

In [7]:


def generate_embeddings_with_transformers(
    model: TFAutoModel,
    tokenizer: AutoTokenizer,
    text_list: list[str],
    batch_size: int = 8,
    disable_tqdm: bool = False,
) -> tf.Tensor:
    """
    Generates embeddings for a list of texts using a pre-trained transformer model.

    Args:
        model (TFAutoModel): The pre-trained transformer model to use.
        tokenizer (AutoTokenizer): Tokenizer for the transformer model.
        text_list (list of str): A list of texts to generate embeddings for.
        batch_size (int): The batch size to use for generating embeddings.

    Returns:
        tf.Tensor: A tensor containing the embeddings for the input texts.
    """
    # Tokenize input texts
    tokenized_text = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="tf"
    )

    # Prepare TensorFlow dataset
    dataset = tf.data.Dataset.from_tensor_slices(
        (tokenized_text["input_ids"], tokenized_text["attention_mask"])
    )
    dataset = dataset.batch(batch_size)

    # Collect embeddings
    embeddings = []
    for input_ids, attention_mask in tqdm(dataset, desc="Encoding", disable=disable_tqdm):
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }
        outputs = model(**inputs, training=False)
        embeddings.append(outputs.last_hidden_state[:, 0, :])  # [CLS] token embedding

    return tf.concat(embeddings, axis=0)


if __name__ == "__main__":
    model_name = "xlm-roberta-base"
    batch_size = 8
    text_list = [
        "hej med dig. Jeg er en tekst.",
        "Jeg er en anden tekst, skal du spille smart?",
        "oh nej..",
    ]

    # Load the model and tokenizer
    model = TFAutoModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Generate embeddings
    embeddings = generate_embeddings_with_transformers(
        model, tokenizer, text_list, batch_size
    )
    print(embeddings.numpy())  # Convert TensorFlow tensor to NumPy array for inspection


from typing import Iterable
from pathlib import Path
from tqdm import tqdm
import polars as pl
import numpy as np
import datetime
import zipfile
import torch
import time
import json
import yaml
import time


def read_json_file(path: str, verbose: bool = False) -> dict:
    if verbose:
        print(f"Writing JSON: '{path}'")
    with open(path) as file:
        return json.load(file)


def write_json_file(dictionary: dict, path: str, verbose: bool = False) -> None:
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as file:
        json.dump(dictionary, file)
    if verbose:
        print(f"Writing JSON: '{path}'")


def read_yaml_file(path: str) -> dict:
    with open(path, "r") as file:
        return yaml.safe_load(file)


def write_yaml_file(dictionary: dict, path: str) -> None:
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as file:
        yaml.dump(dictionary, file, default_flow_style=False)


def rank_predictions_by_score(
    arr: Iterable[float],
) -> list[np.ndarray]:
    """
    Converts the prediction scores based on their ranking (1 for highest score,
    2 for second highest, etc.), effectively ranking prediction scores for each row.

    Reference:
        https://github.com/recommenders-team/recommenders/blob/main/examples/00_quick_start/nrms_MIND.ipynb

    >>> prediction_scores = [[0.2, 0.1, 0.3], [0.1, 0.2], [0.4, 0.2, 0.1, 0.3]]
    >>> [rank_predictions_by_score(row) for row in prediction_scores]
        [
            array([2, 3, 1]),
            array([2, 1]),
            array([1, 3, 4, 2])
        ]
    """
    return np.argsort(np.argsort(arr)[::-1]) + 1


def write_submission_file(
    impression_ids: Iterable[int],
    prediction_scores: Iterable[any],
    path: Path = Path("predictions.txt"),
    rm_file: bool = True,
    filename_zip: str = None,
) -> None:
    """
    We align the submission file similar to MIND-format for users who are familar.

    Reference:
        https://github.com/recommenders-team/recommenders/blob/main/examples/00_quick_start/nrms_MIND.ipynb

    Example:
    >>> impression_ids = [237, 291, 320]
    >>> prediction_scores = [[0.2, 0.1, 0.3], [0.1, 0.2], [0.4, 0.2, 0.1, 0.3]]
    >>> write_submission_file(impression_ids, prediction_scores, path="predictions.txt", rm_file=False)
    ## Output file:
        237 [0.2,0.1,0.3]
        291 [0.1,0.2]
        320 [0.4,0.2,0.1,0.3]
    """
    path = Path(path)
    with open(path, "w") as f:
        for impr_index, preds in tqdm(zip(impression_ids, prediction_scores)):
            preds = "[" + ",".join([str(i) for i in preds]) + "]"
            f.write(" ".join([str(impr_index), preds]) + "\n")
    # =>
    zip_submission_file(path=path, rm_file=rm_file, filename_zip=filename_zip)


def read_submission_file(path: Path) -> tuple[int, any]:
    """
    >>> impression_ids = [237, 291, 320]
    >>> prediction_scores = [[0.2, 0.1, 0.3], [0.1, 0.2], [0.4, 0.2, 0.1, 0.3]]
    >>> write_submission_file(impression_ids, prediction_scores, path="predictions.txt", rm_file=False)
    >>> read_submission_file("predictions.txt")
        (
            [237, 291, 320],
            [[0.2, 0.1, 0.3], [0.1, 0.2], [0.4, 0.2, 0.1, 0.3]]
        )
    """
    # Read and parse the file
    impression_ids = []
    prediction_scores = []
    with open(path, "r") as file:
        for line in file:
            impression_id_str, scores_str = parse_line(line)
            impression_ids.append(int(impression_id_str))
            prediction_scores.append(scores_str)
    return impression_ids, prediction_scores


def zip_submission_file(
    path: Path,
    filename_zip: str = None,
    verbose: bool = True,
    rm_file: bool = True,
) -> None:
    """
    Compresses a specified file into a ZIP archive within the same directory.

    Args:
        path (Path): The directory path where the file to be zipped and the resulting zip file will be located.
        filename_input (str, optional): The name of the file to be compressed. Defaults to the path.name.
        filename_zip (str, optional): The name of the output ZIP file. Defaults to "prediction.zip".
        verbose (bool, optional): If set to True, the function will print the process details. Defaults to True.
        rm_file (bool, optional): If set to True, the original file will be removed after compression. Defaults to True.

    Returns:
        None: This function does not return any value.
    """
    path = Path(path)
    if filename_zip:
        path_zip = path.parent.joinpath(filename_zip)
    else:
        path_zip = path.with_suffix(".zip")

    if path_zip.suffix != ".zip":
        raise ValueError(f"suffix for {path_zip.name} has to be '.zip'")
    if verbose:
        print(f"Zipping {path} to {path_zip}")
    f = zipfile.ZipFile(path_zip, "w", zipfile.ZIP_DEFLATED)
    f.write(path, arcname=path.name)
    f.close()
    if rm_file:
        path.unlink()


def parse_line(l) -> tuple[str, list[float]]:
    """
    Parses a single line of text into an identifier and a list of ranks.
    """
    impid, ranks = l.strip("\n").split()
    ranks = json.loads(ranks)
    return impid, ranks


def time_it(enable=True):
    def decorator(func):
        def wrapper(*args, **kwargs):
            if enable:
                start_time = time.time()
            result = func(*args, **kwargs)
            if enable:
                end_time = time.time()
                elapsed_time = end_time - start_time
                print(f"... {func.__name__} completed in {elapsed_time:.2f} seconds")
            return result

        return wrapper

    return decorator


def df_shape_time_it(enable=True):
    def decorator(func):
        def wrapper(*args, **kwargs):
            #
            if enable:
                try:
                    # Incase of LazyFrame, this is not possible:
                    start_shape = args[0].shape
                except:
                    pass
                start_time = time.time()

            # Run function:
            result = func(*args, **kwargs)

            #
            if enable:
                end_time = time.time()
                time_taken = round(end_time - start_time, 6)
                try:
                    # Incase of LazyFrame, this is not possible:
                    end_shape = result.shape
                    row_dropped_frac = round(
                        (start_shape[0] - end_shape[0]) / start_shape[0] * 100, 2
                    )
                    shape_ba = f"=> Before/After: {start_shape}/{end_shape} ({row_dropped_frac}% rows dropped)"
                except:
                    shape_ba = f"=> Before/After: NA/NA (NA% rows dropped)"
                print(
                    f"""Time taken by '{func.__name__}': {time_taken} seconds\n{shape_ba}"""
                )
            return result

        return wrapper

    return decorator


def generate_unique_name(existing_names: list[str], base_name: str = "new_name"):
    """
    Generate a unique name based on a list of existing names.

    Args:
        existing_names (list of str): The list of existing names.
        base_name (str): The base name to start with. Default is 'newName'.

    Returns:
        str: A unique name.
    Example
    >>> existing_names = ['name1', 'name2', 'newName', 'newName_1']
    >>> generate_unique_name(existing_names, 'newName')
        'newName_2'
    """
    if base_name not in existing_names:
        return base_name

    suffix = 1
    new_name = f"{base_name}_{suffix}"

    while new_name in existing_names:
        suffix += 1
        new_name = f"{base_name}_{suffix}"

    return new_name


def compute_npratio(n_pos: int, n_neg: int) -> float:
    """
    Similar approach as:
        "Neural News Recommendation with Long- and Short-term User Representations (An et al., ACL 2019)"

    Example:
    >>> pos = 492_185
    >>> neg = 9_224_537
    >>> round(compute_npratio(pos, neg), 2)
        18.74
    """
    return 1 / (n_pos / n_neg)


def strfdelta(tdelta: datetime.timedelta):
    """
    Example:
    >>> tdelta = datetime.timedelta(days=1, hours=3, minutes=42, seconds=54)
    >>> strfdelta(tdelta)
        '1 days 3:42:54'
    """
    days = tdelta.days
    hours, rem = divmod(tdelta.seconds, 3600)
    minutes, seconds = divmod(rem, 60)
    return f"{days} days {hours}:{minutes}:{seconds}"


def str_datetime_now():
    return datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")


def get_object_variables(object_: object) -> dict:
    """
    Example:
    >>> class example:
            a = 2
            b = 3
    >>> get_object_variables(example)
        {'a': 2, 'b': 3}
    """
    return {
        name: value
        for name, value in vars(object_).items()
        if not name.startswith("__") and not callable(value)
    }


def batch_items_generator(items: Iterable[any], batch_size: int):
    """
    Generator function that chunks a list of items into batches of a specified size.

    Args:
        items (list): The list of items to be chunked.
        batch_size (int): The number of items to include in each batch.

    Yields:
        list: A batch of items from the input list.

    Examples:
        >>> items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        >>> batch_size = 3
        >>> for batch in chunk_list(items, batch_size):
        ...     print(batch)
        [1, 2, 3]
        [4, 5, 6]
        [7, 8, 9]
        [10]
    """
    for i in range(0, len(items), batch_size):
        yield items[i : i + batch_size]


def unnest_dictionary(dictionary, parent_key="") -> dict:
    """
    Unnests a dictionary by adding the key to the nested names.

    Args:
        dictionary (dict): The nested dictionary to be unnested.
        parent_key (str, optional): The parent key to be prepended to the nested keys. Defaults to "".

    Returns:
        dict: The unnested dictionary where each nested key is prefixed with the parent keys, separated by dots.

    Example:
    >>> nested_dict = {
            "key1": "value1",
            "key2": {"nested_key1": "nested_value1", "nested_key2": "nested_value2"},
            "key3": {"nested_key3": {"deeply_nested_key": "deeply_nested_value"}},
        }
    >>> unnest_dictionary(nested_dict)
        {
            "key1": "value1",
            "nested_key1-key2": "nested_value1",
            "nested_key2-key2": "nested_value2",
            "deeply_nested_key-nested_key3-key3": "deeply_nested_value",
        }
    """
    unnested_dict = {}
    for key, value in dictionary.items():
        new_key = f"{key}-{parent_key}" if parent_key else key
        if isinstance(value, dict):
            unnested_dict.update(unnest_dictionary(value, parent_key=new_key))
        else:
            unnested_dict[new_key] = value
    return unnested_dict


def get_torch_device(use_gpu: bool = True):
    if use_gpu and torch.cuda.is_available():
        return "cuda:0"
    elif use_gpu and torch.backends.mps.is_available():
        return "cpu"  # "mps" is not working for me..
    else:
        return "cpu"


def convert_to_nested_list(lst, sublist_size: int):
    """
    Example:
    >>> list_ = [0, 0, 1, 1, 0, 0]
    >>> convert_to_nested_list(list_,3)
        [[0, 0, 1], [1, 0, 0]]
    """
    nested_list = [lst[i : i + sublist_size] for i in range(0, len(lst), sublist_size)]
    return nested_list


def repeat_by_list_values_from_matrix(
    input_array: np.array,
    matrix: np.array,
    repeats: np.array,
) -> np.array:
    """
    Example:
        >>> input = np.array([[1, 0], [0, 0]])
        >>> matrix = np.array([[7,8,9], [10,11,12]])
        >>> repeats = np.array([1, 2])
        >>> repeat_by_list_values_from_matrix(input, matrix, repeats)
            array([[[10, 11, 12],
                    [ 7,  8,  9]],
                    [[ 7,  8,  9],
                    [ 7,  8,  9]],
                    [[ 7,  8,  9],
                    [ 7,  8,  9]]])
    """
    return np.repeat(matrix[input_array], repeats=repeats, axis=0)


def create_lookup_dict(df: pl.DataFrame, key: str, value: str) -> dict:
    """
    Creates a dictionary lookup table from a Pandas-like DataFrame.

    Args:
        df (pl.DataFrame): The DataFrame from which to create the lookup table.
        key (str): The name of the column containing the keys for the lookup table.
        value (str): The name of the column containing the values for the lookup table.

    Returns:
        dict: A dictionary where the keys are the values from the `key` column of the DataFrame
            and the values are the values from the `value` column of the DataFrame.

    Example:
        >>> df = pl.DataFrame({'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']})
        >>> create_lookup_dict(df, 'id', 'name')
            {1: 'Alice', 2: 'Bob', 3: 'Charlie'}
    """
    return dict(zip(df[key], df[value]))


def create_lookup_objects(
    lookup_dictionary: dict[int, np.array], unknown_representation: str
) -> tuple[dict[int, pl.Series], np.array]:
    """Creates lookup objects for efficient data retrieval.

    This function generates a dictionary of indexes and a matrix from the given lookup dictionary.
    The generated lookup matrix has an additional row based on the specified unknown representation
    which could be either zeros or the mean of the values in the lookup dictionary.

    Args:
        lookup_dictionary (dict[int, np.array]): A dictionary where keys are unique identifiers (int)
            and values are some representations which can be any data type, commonly used for lookup operations.
        unknown_representation (str): Specifies the method to represent unknown entries.
            It can be either 'zeros' to represent unknowns with a row of zeros, or 'mean' to represent
            unknowns with a row of mean values computed from the lookup dictionary.

    Raises:
        ValueError: If the unknown_representation is not either 'zeros' or 'mean',
            a ValueError will be raised.

    Returns:
        tuple[dict[int, pl.Series], np.array]: A tuple containing two items:
            - A dictionary with the same keys as the lookup_dictionary where values are polars Series
                objects containing a single value, which is the index of the key in the lookup dictionary.
            - A numpy array where the rows correspond to the values in the lookup_dictionary and an
                additional row representing unknown entries as specified by the unknown_representation argument.

    Example:
    >>> data = {
            10: np.array([0.1, 0.2, 0.3]),
            20: np.array([0.4, 0.5, 0.6]),
            30: np.array([0.7, 0.8, 0.9]),
        }
    >>> lookup_dict, lookup_matrix = create_lookup_objects(data, "zeros")

    >>> lookup_dict
        {10: shape: (1,)
            Series: '' [i64]
            [
                    1
            ], 20: shape: (1,)
            Series: '' [i64]
            [
                    2
            ], 30: shape: (1,)
            Series: '' [i64]
            [
                    3
        ]}
    >>> lookup_matrix
        array([[0. , 0. , 0. ],
            [0.1, 0.2, 0.3],
            [0.4, 0.5, 0.6],
            [0.7, 0.8, 0.9]])
    """
    # MAKE LOOKUP DICTIONARY
    lookup_indexes = {
        id: pl.Series("", [i]) for i, id in enumerate(lookup_dictionary, start=1)
    }
    # MAKE LOOKUP MATRIX
    lookup_matrix = np.array(list(lookup_dictionary.values()))

    if unknown_representation == "zeros":
        UNKNOWN_ARRAY = np.zeros(lookup_matrix.shape[1], dtype=lookup_matrix.dtype)
    elif unknown_representation == "mean":
        UNKNOWN_ARRAY = np.mean(lookup_matrix, axis=0, dtype=lookup_matrix.dtype)
    else:
        raise ValueError(
            f"'{unknown_representation}' is not a specified method. Can be either 'zeros' or 'mean'."
        )

    lookup_matrix = np.vstack([UNKNOWN_ARRAY, lookup_matrix])
    return lookup_indexes, lookup_matrix


def batch_items_generator(items: Iterable[any], batch_size: int):
    """
    Generator function that chunks a list of items into batches of a specified size.

    Args:
        items (list): The list of items to be chunked.
        batch_size (int): The number of items to include in each batch.

    Yields:
        list: A batch of items from the input list.

    Examples:
        >>> items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        >>> batch_size = 3
        >>> for batch in chunk_list(items, batch_size):
        ...     print(batch)
        [1, 2, 3]
        [4, 5, 6]
        [7, 8, 9]
        [10]
    """
    for i in range(0, len(items), batch_size):
        yield items[i : i + batch_size]


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFXLMRobertaModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing TFXLMRobertaModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFXLMRobertaModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFXLMRobertaModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFXLMRobertaModel for predictions without further training.
Encoding: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it]

[[ 0.44902223  0.4104863   0.18612826 ... -0.21816656  0.26711282
  -0.05442879]
 [ 0.26226935  0.19896637  0.09871999 ... -0.12434417  0.16267559
  -0.04929945]
 [ 0.32007506  0.29021174  0.06231691 ... -0.54429066  0.21992481
   0.17783664]]





In [8]:


#ebrec/models/newsrec/dataloader.py

@dataclass
class NewsrecDataLoader(tf.keras.utils.Sequence):
    """
    A DataLoader for news recommendation.
    """

    behaviors: pl.DataFrame
    history_column: str
    article_dict: dict[int, any]
    unknown_representation: str
    eval_mode: bool = False
    batch_size: int = 32
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL
    labels_col: str = DEFAULT_LABELS_COL
    user_col: str = DEFAULT_USER_COL
    kwargs: field(default_factory=dict) = None

    def __post_init__(self):
        """
        Post-initialization method. Loads the data and sets additional attributes.
        """
        self.lookup_article_index, self.lookup_article_matrix = create_lookup_objects(
            self.article_dict, unknown_representation=self.unknown_representation
        )
        self.unknown_index = [0]
        self.X, self.y = self.load_data()
        if self.kwargs is not None:
            self.set_kwargs(self.kwargs)

    def __len__(self) -> int:
        return int(np.ceil(len(self.X) / float(self.batch_size)))

    def __getitem__(self):
        raise ValueError("Function '__getitem__' needs to be implemented.")

    def load_data(self) -> tuple[pl.DataFrame, pl.DataFrame]:
        X = self.behaviors.drop(self.labels_col).with_columns(
            pl.col(self.inview_col).list.len().alias("n_samples")
        )
        y = self.behaviors[self.labels_col]
        return X, y

    def set_kwargs(self, kwargs: dict):
        for key, value in kwargs.items():
            setattr(self, key, value)


@dataclass
class NRMSDataLoader(NewsrecDataLoader):
    def transform(self, df: pl.DataFrame) -> pl.DataFrame:
        return df.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.lookup_article_index,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.lookup_article_index,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        )

    def __getitem__(self, idx) -> tuple[tuple[np.ndarray], np.ndarray]:
        """
        his_input_title:    (samples, history_size, document_dimension)
        pred_input_title:   (samples, npratio, document_dimension)
        batch_y:            (samples, npratio)
        """
        batch_X = self.X[idx * self.batch_size : (idx + 1) * self.batch_size].pipe(
            self.transform
        )
        batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]
        # =>
        if self.eval_mode:
            repeats = np.array(batch_X["n_samples"])
            # =>
            batch_y = np.array(batch_y.explode().to_list()).reshape(-1, 1)
            # =>
            his_input_title = repeat_by_list_values_from_matrix(
                batch_X[self.history_column].to_list(),
                matrix=self.lookup_article_matrix,
                repeats=repeats,
            )
            # =>
            pred_input_title = self.lookup_article_matrix[
                batch_X[self.inview_col].explode().to_list()
            ]
        else:
            batch_y = np.array(batch_y.to_list())
            his_input_title = self.lookup_article_matrix[
                batch_X[self.history_column].to_list()
            ]
            pred_input_title = self.lookup_article_matrix[
                batch_X[self.inview_col].to_list()
            ]
            pred_input_title = np.squeeze(pred_input_title, axis=2)

        his_input_title = np.squeeze(his_input_title, axis=2)
        return (his_input_title, pred_input_title), batch_y


@dataclass
class NRMSDataLoaderPretransform(NewsrecDataLoader):
    """
    In the __post_init__ pre-transform the entire DataFrame. This is useful for
    when data can fit in memory, as it will be much faster ones training.
    Note, it might not be as scaleable.
    """

    def __post_init__(self):
        super().__post_init__()
        self.X = self.X.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.lookup_article_index,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.lookup_article_index,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        )

    def __getitem__(self, idx) -> tuple[tuple[np.ndarray], np.ndarray]:
        """
        his_input_title:    (samples, history_size, document_dimension)
        pred_input_title:   (samples, npratio, document_dimension)
        batch_y:            (samples, npratio)
        """
        batch_X = self.X[idx * self.batch_size : (idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]
        # =>
        if self.eval_mode:
            repeats = np.array(batch_X["n_samples"])
            # =>
            batch_y = np.array(batch_y.explode().to_list()).reshape(-1, 1)
            # =>
            his_input_title = repeat_by_list_values_from_matrix(
                batch_X[self.history_column].to_list(),
                matrix=self.lookup_article_matrix,
                repeats=repeats,
            )
            # =>
            pred_input_title = self.lookup_article_matrix[
                batch_X[self.inview_col].explode().to_list()
            ]
        else:
            batch_y = np.array(batch_y.to_list())
            his_input_title = self.lookup_article_matrix[
                batch_X[self.history_column].to_list()
            ]
            pred_input_title = self.lookup_article_matrix[
                batch_X[self.inview_col].to_list()
            ]
            pred_input_title = np.squeeze(pred_input_title, axis=2)

        his_input_title = np.squeeze(his_input_title, axis=2)
        return (his_input_title, pred_input_title), batch_y


@dataclass(kw_only=True)
class LSTURDataLoader(NewsrecDataLoader):
    """
    NPA and LSTUR shares the same DataLoader
    """

    user_id_mapping: dict[int, int] = None
    unknown_user_value: int = 0

    def transform(self, df: pl.DataFrame) -> pl.DataFrame:
        return (
            df.pipe(
                map_list_article_id_to_value,
                behaviors_column=self.history_column,
                mapping=self.lookup_article_index,
                fill_nulls=self.unknown_index,
                drop_nulls=False,
            )
            .pipe(
                map_list_article_id_to_value,
                behaviors_column=self.inview_col,
                mapping=self.lookup_article_index,
                fill_nulls=self.unknown_index,
                drop_nulls=False,
            )
            .with_columns(
                pl.col(self.user_col).replace(
                    self.user_id_mapping, default=self.unknown_user_value
                )
            )
        )

    def __getitem__(self, idx) -> tuple[tuple[np.ndarray], np.ndarray]:
        """
        user_indexes:       ()
        his_input_title:    (samples, history_size, document_dimension)
        pred_input_title:   (samples, npratio, document_dimension)
        batch_y:            (samples, npratio)
        """
        batch_X = self.X[idx * self.batch_size : (idx + 1) * self.batch_size].pipe(
            self.transform
        )
        batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]
        # =>
        if self.eval_mode:
            repeats = np.array(batch_X["n_samples"])
            # =>
            batch_y = np.array(batch_y.explode().to_list()).reshape(-1, 1)
            # =>
            user_indexes = np.array(
                batch_X.select(
                    pl.col(self.user_col).repeat_by(pl.col("n_samples")).explode()
                )[self.user_col].to_list()
            ).reshape(-1, 1)
            # =>
            his_input_title = repeat_by_list_values_from_matrix(
                batch_X[self.history_column].to_list(),
                matrix=self.lookup_article_matrix,
                repeats=repeats,
            )
            # =>
            pred_input_title = self.lookup_article_matrix[
                batch_X[self.inview_col].explode().to_list()
            ]
        else:
            # =>
            batch_y = np.array(batch_y.to_list())
            # =>
            user_indexes = np.array(batch_X[self.user_col].to_list()).reshape(-1, 1)
            # =>
            his_input_title = self.lookup_article_matrix[
                batch_X[self.history_column].to_list()
            ]
            # =>
            pred_input_title = self.lookup_article_matrix[
                batch_X[self.inview_col].to_list()
            ]
            pred_input_title = np.squeeze(pred_input_title, axis=2)
        # =>
        his_input_title = np.squeeze(his_input_title, axis=2)
        return (user_indexes, his_input_title, pred_input_title), batch_y


@dataclass(kw_only=True)
class NAMLDataLoader(NewsrecDataLoader):
    """
    Eval mode not implemented
    """

    unknown_category_value: int = 0
    unknown_subcategory_value: int = 0
    body_mapping: dict[int, list[int]] = None
    category_mapping: dict[int, int] = None
    subcategory_mapping: dict[int, int] = None

    def __post_init__(self):
        self.title_prefix = "title_"
        self.body_prefix = "body_"
        self.category_prefix = "category_"
        self.subcategory_prefix = "subcategory_"
        (
            self.lookup_article_index_body,
            self.lookup_article_matrix_body,
        ) = create_lookup_objects(
            self.body_mapping, unknown_representation=self.unknown_representation
        )
        if self.eval_mode:
            raise ValueError("'eval_mode = True' is not implemented for NAML")

        return super().__post_init__()

    def transform(self, df: pl.DataFrame) -> tuple[pl.DataFrame]:
        """
        Special case for NAML as it requires body-encoding, verticals, & subvertivals
        """
        # =>
        title = df.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.lookup_article_index,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.lookup_article_index,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        )
        # =>
        body = df.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.lookup_article_index_body,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.lookup_article_index_body,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        )
        # =>
        category = df.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.category_mapping,
            fill_nulls=self.unknown_category_value,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.category_mapping,
            fill_nulls=self.unknown_category_value,
            drop_nulls=False,
        )
        # =>
        subcategory = df.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.subcategory_mapping,
            fill_nulls=self.unknown_subcategory_value,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.subcategory_mapping,
            fill_nulls=self.unknown_subcategory_value,
            drop_nulls=False,
        )
        return (
            pl.DataFrame()
            .with_columns(title.select(pl.all().name.prefix(self.title_prefix)))
            .with_columns(body.select(pl.all().name.prefix(self.body_prefix)))
            .with_columns(category.select(pl.all().name.prefix(self.category_prefix)))
            .with_columns(
                subcategory.select(pl.all().name.prefix(self.subcategory_prefix))
            )
        )

    def __getitem__(self, idx) -> tuple[tuple[np.ndarray], np.ndarray]:
        batch_X = self.X[idx * self.batch_size : (idx + 1) * self.batch_size].pipe(
            self.transform
        )
        batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]
        # =>
        batch_y = np.array(batch_y.to_list())
        his_input_title = np.array(
            batch_X[self.title_prefix + self.history_column].to_list()
        )
        his_input_body = np.array(
            batch_X[self.body_prefix + self.history_column].to_list()
        )
        his_input_vert = np.array(
            batch_X[self.category_prefix + self.history_column].to_list()
        )[:, :, np.newaxis]
        his_input_subvert = np.array(
            batch_X[self.subcategory_prefix + self.history_column].to_list()
        )[:, :, np.newaxis]
        # =>
        pred_input_title = np.array(
            batch_X[self.title_prefix + self.inview_col].to_list()
        )
        pred_input_body = np.array(
            batch_X[self.body_prefix + self.inview_col].to_list()
        )
        pred_input_vert = np.array(
            batch_X[self.category_prefix + self.inview_col].to_list()
        )[:, :, np.newaxis]
        pred_input_subvert = np.array(
            batch_X[self.subcategory_prefix + self.inview_col].to_list()
        )[:, :, np.newaxis]
        # =>
        his_input_title = np.squeeze(
            self.lookup_article_matrix[his_input_title], axis=2
        )
        pred_input_title = np.squeeze(
            self.lookup_article_matrix[pred_input_title], axis=2
        )
        his_input_body = np.squeeze(
            self.lookup_article_matrix_body[his_input_body], axis=2
        )
        pred_input_body = np.squeeze(
            self.lookup_article_matrix_body[pred_input_body], axis=2
        )
        # =>
        return (
            his_input_title,
            his_input_body,
            his_input_vert,
            his_input_subvert,
            pred_input_title,
            pred_input_body,
            pred_input_vert,
            pred_input_subvert,
        ), batch_y

In [9]:
#ebrec/models/newsrec/model_config.py)


#
DEFAULT_TITLE_SIZE = 30
DEFAULT_BODY_SIZE = 40
UNKNOWN_TITLE_VALUE = [0] * DEFAULT_TITLE_SIZE
UNKNOWN_BODY_VALUE = [0] * DEFAULT_BODY_SIZE

DEFAULT_DOCUMENT_SIZE = 768


def print_hparams(hparams_class):
    for attr, value in hparams_class.__annotations__.items():
        # Print attribute names and values
        print(f"{attr}: {getattr(hparams_class, attr)}")


def hparams_to_dict(hparams_class) -> dict:
    params = {}
    for attr, value in hparams_class.__annotations__.items():
        params[attr] = getattr(hparams_class, attr)
    return params


class hparams_naml:
    # INPUT DIMENTIONS:
    title_size: int = DEFAULT_TITLE_SIZE
    history_size: int = 20
    body_size: int = DEFAULT_BODY_SIZE
    vert_num: int = 100
    vert_emb_dim: int = 10
    subvert_num: int = 100
    subvert_emb_dim: int = 10
    # MODEL ARCHITECTURE
    dense_activation: str = "relu"
    cnn_activation: str = "relu"
    attention_hidden_dim: int = 200
    filter_num: int = 400
    window_size: int = 3
    # MODEL OPTIMIZER:
    optimizer: str = "adam"
    loss: str = "cross_entropy_loss"
    dropout: float = 0.2
    learning_rate: float = 1e-4


class hparams_lstur:
    # INPUT DIMENTIONS:
    title_size: int = DEFAULT_TITLE_SIZE
    history_size: int = 20
    n_users: int = 50000
    # MODEL ARCHITECTURE
    cnn_activation: str = "relu"
    type: str = "ini"
    attention_hidden_dim: int = 200
    gru_unit: int = 400
    filter_num: int = 400
    window_size: int = 3
    # MODEL OPTIMIZER:
    optimizer: str = "adam"
    loss: str = "cross_entropy_loss"
    dropout: float = 0.2
    learning_rate: float = 1e-4


class hparams_npa:
    # INPUT DIMENTIONS:
    title_size: int = DEFAULT_TITLE_SIZE
    history_size: int = 20
    n_users: int = 50000
    # MODEL ARCHITECTURE
    cnn_activation: str = "relu"
    attention_hidden_dim: int = 200
    user_emb_dim: int = 400
    filter_num: int = 400
    window_size: int = 3
    # MODEL OPTIMIZER:
    optimizer: str = "adam"
    loss: str = "cross_entropy_loss"
    dropout: float = 0.2
    learning_rate: float = 1e-4


class hparams_nrms:
    # INPUT DIMENTIONS:
    title_size: int = DEFAULT_TITLE_SIZE
    history_size: int = 20
    # MODEL ARCHITECTURE
    head_num: int = 20
    head_dim: int = 20
    attention_hidden_dim: int = 200
    # MODEL OPTIMIZER:
    optimizer: str = "adam"
    loss: str = "cross_entropy_loss"
    dropout: float = 0.2
    learning_rate: float = 1e-4
    # MY OWN LITTLE TWIST:
    newsencoder_units_per_layer: list[int] = None
    newsencoder_l2_regularization: float = 1e-4


class hparams_nrms_docvec:
    # INPUT DIMENTIONS:
    title_size: int = DEFAULT_DOCUMENT_SIZE
    history_size: int = 20
    # MODEL ARCHITECTURE
    head_num: int = 16
    head_dim: int = 16
    attention_hidden_dim: int = 200
    # MODEL OPTIMIZER:
    optimizer: str = "adam"
    loss: str = "cross_entropy_loss"
    dropout: float = 0.2
    learning_rate: float = 1e-4
    newsencoder_units_per_layer: list[int] = [512, 512, 512]
    newsencoder_l2_regularization: float = 1e-4

In [10]:



#ebrec/models/newsrec/layers.py


class AttLayer2(layers.Layer):
    """Soft alignment attention implement.

    Attributes:
        dim (int): attention hidden dim
    """

    def __init__(self, dim=200, seed=0, **kwargs):
        """Initialization steps for AttLayer2.

        Args:
            dim (int): attention hidden dim
        """

        self.dim = dim
        self.seed = seed
        super(AttLayer2, self).__init__(**kwargs)

    def build(self, input_shape):
        """Initialization for variables in AttLayer2
        There are there variables in AttLayer2, i.e. W, b and q.

        Args:
            input_shape (object): shape of input tensor.
        """

        assert len(input_shape) == 3
        dim = self.dim
        self.W = self.add_weight(
            name="W",
            shape=(int(input_shape[-1]), dim),
            initializer=keras.initializers.glorot_uniform(seed=self.seed),
            trainable=True,
        )
        self.b = self.add_weight(
            name="b",
            shape=(dim,),
            initializer=keras.initializers.Zeros(),
            trainable=True,
        )
        self.q = self.add_weight(
            name="q",
            shape=(dim, 1),
            initializer=keras.initializers.glorot_uniform(seed=self.seed),
            trainable=True,
        )
        super(AttLayer2, self).build(input_shape)  # be sure you call this somewhere!

    def call(self, inputs, mask=None, **kwargs):
        """Core implemention of soft attention

        Args:
            inputs (object): input tensor.

        Returns:
            object: weighted sum of input tensors.
        """

        attention = K.tanh(K.dot(inputs, self.W) + self.b)
        attention = K.dot(attention, self.q)

        attention = K.squeeze(attention, axis=2)

        if mask == None:
            attention = K.exp(attention)
        else:
            attention = K.exp(attention) * K.cast(mask, dtype="float32")

        attention_weight = attention / (
            K.sum(attention, axis=-1, keepdims=True) + K.epsilon()
        )

        attention_weight = K.expand_dims(attention_weight)
        weighted_input = inputs * attention_weight
        return K.sum(weighted_input, axis=1)

    def compute_mask(self, input, input_mask=None):
        """Compte output mask value

        Args:
            input (object): input tensor.
            input_mask: input mask

        Returns:
            object: output mask.
        """
        return None

    def compute_output_shape(self, input_shape):
        """Compute shape of output tensor

        Args:
            input_shape (tuple): shape of input tensor.

        Returns:
            tuple: shape of output tensor.
        """
        return input_shape[0], input_shape[-1]


class SelfAttention(layers.Layer):
    """Multi-head self attention implement.

    Args:
        multiheads (int): The number of heads.
        head_dim (object): Dimention of each head.
        mask_right (boolean): whether to mask right words.

    Returns:
        object: Weighted sum after attention.
    """

    def __init__(self, multiheads, head_dim, seed=0, mask_right=False, **kwargs):
        """Initialization steps for AttLayer2.

        Args:
            multiheads (int): The number of heads.
            head_dim (object): Dimention of each head.
            mask_right (boolean): whether to mask right words.
        """

        self.multiheads = multiheads
        self.head_dim = head_dim
        self.output_dim = multiheads * head_dim
        self.mask_right = mask_right
        self.seed = seed
        super(SelfAttention, self).__init__(**kwargs)

    def compute_output_shape(self, input_shape):
        """Compute shape of output tensor.

        Returns:
            tuple: output shape tuple.
        """

        return (input_shape[0][0], input_shape[0][1], self.output_dim)

    def build(self, input_shape):
        """Initialization for variables in SelfAttention.
        There are three variables in SelfAttention, i.e. WQ, WK ans WV.
        WQ is used for linear transformation of query.
        WK is used for linear transformation of key.
        WV is used for linear transformation of value.

        Args:
            input_shape (object): shape of input tensor.
        """

        self.WQ = self.add_weight(
            name="WQ",
            shape=(int(input_shape[0][-1]), self.output_dim),
            initializer=keras.initializers.glorot_uniform(seed=self.seed),
            trainable=True,
        )
        self.WK = self.add_weight(
            name="WK",
            shape=(int(input_shape[1][-1]), self.output_dim),
            initializer=keras.initializers.glorot_uniform(seed=self.seed),
            trainable=True,
        )
        self.WV = self.add_weight(
            name="WV",
            shape=(int(input_shape[2][-1]), self.output_dim),
            initializer=keras.initializers.glorot_uniform(seed=self.seed),
            trainable=True,
        )
        super(SelfAttention, self).build(input_shape)

    def Mask(self, inputs, seq_len, mode="add"):
        """Mask operation used in multi-head self attention

        Args:
            seq_len (object): sequence length of inputs.
            mode (str): mode of mask.

        Returns:
            object: tensors after masking.
        """

        if seq_len is None:
            return inputs
        else:
            mask = K.one_hot(indices=seq_len[:, 0], num_classes=K.shape(inputs)[1])
            mask = 1 - K.cumsum(mask, axis=1)

            for _ in range(len(inputs.shape) - 2):
                mask = K.expand_dims(mask, 2)

            if mode == "mul":
                return inputs * mask
            elif mode == "add":
                return inputs - (1 - mask) * 1e12

    def call(self, QKVs):
        """Core logic of multi-head self attention.

        Args:
            QKVs (list): inputs of multi-head self attention i.e. qeury, key and value.

        Returns:
            object: ouput tensors.
        """
        if len(QKVs) == 3:
            Q_seq, K_seq, V_seq = QKVs
            Q_len, V_len = None, None
        elif len(QKVs) == 5:
            Q_seq, K_seq, V_seq, Q_len, V_len = QKVs
        Q_seq = K.dot(Q_seq, self.WQ)
        Q_seq = K.reshape(
            Q_seq, shape=(-1, K.shape(Q_seq)[1], self.multiheads, self.head_dim)
        )
        Q_seq = K.permute_dimensions(Q_seq, pattern=(0, 2, 1, 3))

        K_seq = K.dot(K_seq, self.WK)
        K_seq = K.reshape(
            K_seq, shape=(-1, K.shape(K_seq)[1], self.multiheads, self.head_dim)
        )
        K_seq = K.permute_dimensions(K_seq, pattern=(0, 2, 1, 3))

        V_seq = K.dot(V_seq, self.WV)
        V_seq = K.reshape(
            V_seq, shape=(-1, K.shape(V_seq)[1], self.multiheads, self.head_dim)
        )
        V_seq = K.permute_dimensions(V_seq, pattern=(0, 2, 1, 3))
        A = tf.matmul(Q_seq, K_seq, adjoint_a=False, adjoint_b=True) / K.sqrt(
            K.cast(self.head_dim, dtype="float32")
        )

        A = K.permute_dimensions(
            A, pattern=(0, 3, 2, 1)
        )  # A.shape=[batch_size,K_sequence_length,Q_sequence_length,self.multiheads]

        A = self.Mask(A, V_len, "add")
        A = K.permute_dimensions(A, pattern=(0, 3, 2, 1))

        if self.mask_right:
            ones = K.ones_like(A[:1, :1])
            lower_triangular = K.tf.matrix_band_part(ones, num_lower=-1, num_upper=0)
            mask = (ones - lower_triangular) * 1e12
            A = A - mask
        A = K.softmax(A)

        O_seq = tf.matmul(A, V_seq, adjoint_a=True, adjoint_b=False)
        O_seq = K.permute_dimensions(O_seq, pattern=(0, 2, 1, 3))

        O_seq = K.reshape(O_seq, shape=(-1, K.shape(O_seq)[1], self.output_dim))
        O_seq = self.Mask(O_seq, Q_len, "mul")
        return O_seq

    def get_config(self):
        """add multiheads, multiheads and mask_right into layer config.

        Returns:
            dict: config of SelfAttention layer.
        """
        config = super(SelfAttention, self).get_config()
        config.update(
            {
                "multiheads": self.multiheads,
                "head_dim": self.head_dim,
                "mask_right": self.mask_right,
            }
        )
        return config


class ComputeMasking(layers.Layer):
    """Compute if inputs contains zero value.

    Returns:
        bool tensor: True for values not equal to zero.
    """

    def __init__(self, **kwargs):
        super(ComputeMasking, self).__init__(**kwargs)

    def call(self, inputs, **kwargs):
        mask = K.not_equal(inputs, 0)
        return K.cast(mask, K.floatx())

    def compute_output_shape(self, input_shape):
        return input_shape


class OverwriteMasking(layers.Layer):
    """Set values at spasific positions to zero.

    Args:
        inputs (list): value tensor and mask tensor.

    Returns:
        object: tensor after setting values to zero.
    """

    def __init__(self, **kwargs):
        super(OverwriteMasking, self).__init__(**kwargs)

    def build(self, input_shape):
        super(OverwriteMasking, self).build(input_shape)

    def call(self, inputs, **kwargs):
        return inputs[0] * K.expand_dims(inputs[1])

    def compute_output_shape(self, input_shape):
        return input_shape[0]


def PersonalizedAttentivePooling(dim1, dim2, dim3, seed=0):
    """Soft alignment attention implement.
    Attributes:
        dim1 (int): first dimention of value shape.
        dim2 (int): second dimention of value shape.
        dim3 (int): shape of query

    Returns:
        object: weighted summary of inputs value.
    """
    vecs_input = keras.Input(shape=(dim1, dim2), dtype="float32")
    query_input = keras.Input(shape=(dim3,), dtype="float32")

    user_vecs = layers.Dropout(0.2)(vecs_input)
    user_att = layers.Dense(
        dim3,
        activation="tanh",
        kernel_initializer=keras.initializers.glorot_uniform(seed=seed),
        bias_initializer=keras.initializers.Zeros(),
    )(user_vecs)
    user_att2 = layers.Dot(axes=-1)([query_input, user_att])
    user_att2 = layers.Activation("softmax")(user_att2)
    user_vec = layers.Dot((1, 1))([user_vecs, user_att2])

    model = keras.Model([vecs_input, query_input], user_vec)
    return model

In [11]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

#ebsrec/models/newsrec/nrms_docsvec.py

class NRMSDocVec:
    """
    Modified NRMS model (Neural News Recommendation with Multi-Head Self-Attention)
    - Initiated with article-embeddings.

    Chuhan Wu, Fangzhao Wu, Suyu Ge, Tao Qi, Yongfeng Huang,and Xing Xie, "Neural News
    Recommendation with Multi-Head Self-Attention" in Proceedings of the 2019 Conference
    on Empirical Methods in Natural Language Processing and the 9th International Joint Conference
    on Natural Language Processing (EMNLP-IJCNLP)

    Attributes:
    """

    def __init__(
        self,
        hparams: dict,
        seed: int = None,
    ):
        """Initialization steps for NRMS."""
        self.hparams = hparams
        self.seed = seed

        # SET SEED:
        tf.random.set_seed(seed)
        np.random.seed(seed)
        # BUILD AND COMPILE MODEL:
        self.model, self.scorer = self._build_graph()
        data_loss = self._get_loss(self.hparams.loss)
        train_optimizer = self._get_opt(
            optimizer=self.hparams.optimizer, lr=self.hparams.learning_rate
        )
        self.model.compile(loss=data_loss, optimizer=train_optimizer)

    def _get_loss(self, loss: str):
        """Make loss function, consists of data loss and regularization loss
        Returns:
            object: Loss function or loss function name
        """
        if loss == "cross_entropy_loss":
            data_loss = "categorical_crossentropy"
        elif loss == "log_loss":
            data_loss = "binary_crossentropy"
        else:
            raise ValueError(f"this loss not defined {loss}")
        return data_loss

    def _get_opt(self, optimizer: str, lr: float):
        """Get the optimizer according to configuration. Usually we will use Adam.
        Returns:
            object: An optimizer.
        """
        if optimizer == "adam":
            train_opt = tf.keras.optimizers.Adam(learning_rate=lr)
        else:
            raise ValueError(f"this optimizer not defined {optimizer}")
        return train_opt

    def _build_graph(self):
        """Build NRMS model and scorer.

        Returns:
            object: a model used to train.
            object: a model used to evaluate and inference.
        """
        model, scorer = self._build_nrms()
        return model, scorer

    def _build_userencoder(self, titleencoder):
        """The main function to create user encoder of NRMS.

        Args:
            titleencoder (object): the news encoder of NRMS.

        Return:
            object: the user encoder of NRMS.
        """
        his_input_title = tf.keras.Input(
            shape=(self.hparams.history_size, self.hparams.title_size), dtype="float32"
        )

        click_title_presents = tf.keras.layers.TimeDistributed(titleencoder)(
            his_input_title
        )
        y = SelfAttention(self.hparams.head_num, self.hparams.head_dim, seed=self.seed)(
            [click_title_presents] * 3
        )
        user_present = AttLayer2(self.hparams.attention_hidden_dim, seed=self.seed)(y)

        model = tf.keras.Model(his_input_title, user_present, name="user_encoder")
        return model

    def _build_newsencoder(self, units_per_layer: list[int] = list[512, 512, 512]):
        """THIS IS OUR IMPLEMENTATION.
        The main function to create a news encoder.

        Parameters:
            units_per_layer (int): The number of neurons in each Dense layer.

        Return:
            object: the news encoder.
        """
        DOCUMENT_VECTOR_DIM = self.hparams.title_size
        OUTPUT_DIM = self.hparams.head_num * self.hparams.head_dim

        # DENSE LAYERS (FINE-TUNED):
        sequences_input_title = tf.keras.Input(
            shape=(DOCUMENT_VECTOR_DIM), dtype="float32"
        )
        x = sequences_input_title
        # Create configurable Dense layers:
        for layer in units_per_layer:
            x = tf.keras.layers.Dense(
                units=layer,
                activation="relu",
                kernel_regularizer=tf.keras.regularizers.l2(
                    self.hparams.newsencoder_l2_regularization
                ),
            )(x)
            x = tf.keras.layers.BatchNormalization()(x)
            x = tf.keras.layers.Dropout(self.hparams.dropout)(x)

        # OUTPUT:
        pred_title = tf.keras.layers.Dense(units=OUTPUT_DIM, activation="relu")(x)

        # Construct the final model
        model = tf.keras.Model(
            inputs=sequences_input_title, outputs=pred_title, name="news_encoder"
        )

        return model

    def _build_nrms(self):
        """The main function to create NRMS's logic. The core of NRMS
        is a user encoder and a news encoder.

        Returns:
            object: a model used to train.
            object: a model used to evaluate and inference.
        """

        his_input_title = tf.keras.Input(
            shape=(self.hparams.history_size, self.hparams.title_size),
            dtype="float32",
        )
        pred_input_title = tf.keras.Input(
            # shape = (hparams.npratio + 1, hparams.title_size)
            shape=(None, self.hparams.title_size),
            dtype="float32",
        )
        pred_input_title_one = tf.keras.Input(
            shape=(
                1,
                self.hparams.title_size,
            ),
            dtype="float32",
        )
        pred_title_one_reshape = tf.keras.layers.Reshape((self.hparams.title_size,))(
            pred_input_title_one
        )
        titleencoder = self._build_newsencoder(
            units_per_layer=self.hparams.newsencoder_units_per_layer
        )
        self.userencoder = self._build_userencoder(titleencoder)
        self.newsencoder = titleencoder

        user_present = self.userencoder(his_input_title)
        news_present = tf.keras.layers.TimeDistributed(self.newsencoder)(
            pred_input_title
        )
        news_present_one = self.newsencoder(pred_title_one_reshape)

        preds = tf.keras.layers.Dot(axes=-1)([news_present, user_present])
        preds = tf.keras.layers.Activation(activation="softmax")(preds)

        pred_one = tf.keras.layers.Dot(axes=-1)([news_present_one, user_present])
        pred_one = tf.keras.layers.Activation(activation="sigmoid")(pred_one)

        model = tf.keras.Model([his_input_title, pred_input_title], preds)
        scorer = tf.keras.Model([his_input_title, pred_input_title_one], pred_one)

        return model, scorer

In [12]:
#src/ebrec/utils/_articles_behaviors.py


def map_list_article_id_to_value(
    behaviors: pl.DataFrame,
    behaviors_column: str,
    mapping: dict[int, pl.Series],
    drop_nulls: bool = False,
    fill_nulls: any = None,
) -> pl.DataFrame:
    """

    Maps the values of a column in a DataFrame `behaviors` containing article IDs to their corresponding values
    in a column in another DataFrame `articles`. The mapping is performed using a dictionary constructed from
    the two DataFrames. The resulting DataFrame has the same columns as `behaviors`, but with the article IDs
    replaced by their corresponding values.

    Args:
        behaviors (pl.DataFrame): The DataFrame containing the column to be mapped.
        behaviors_column (str): The name of the column to be mapped in `behaviors`.
        mapping (dict[int, pl.Series]): A dictionary with article IDs as keys and corresponding values as values.
            Note, 'replace' works a lot faster when values are of type pl.Series!
        drop_nulls (bool): If `True`, any rows in the resulting DataFrame with null values will be dropped.
            If `False` and `fill_nulls` is specified, null values in `behaviors_column` will be replaced with `fill_null`.
        fill_nulls (Optional[any]): If specified, any null values in `behaviors_column` will be replaced with this value.

    Returns:
        pl.DataFrame: A new DataFrame with the same columns as `behaviors`, but with the article IDs in
            `behaviors_column` replaced by their corresponding values in `mapping`.

    Example:
    >>> behaviors = pl.DataFrame(
            {"user_id": [1, 2, 3, 4, 5], "article_ids": [["A1", "A2"], ["A2", "A3"], ["A1", "A4"], ["A4", "A4"], None]}
        )
    >>> articles = pl.DataFrame(
            {
                "article_id": ["A1", "A2", "A3"],
                "article_type": ["News", "Sports", "Entertainment"],
            }
        )
    >>> articles_dict = dict(zip(articles["article_id"], articles["article_type"]))
    >>> map_list_article_id_to_value(
            behaviors=behaviors,
            behaviors_column="article_ids",
            mapping=articles_dict,
            fill_nulls="Unknown",
        )
        shape: (4, 2)
        ┌─────────┬─────────────────────────────┐
        │ user_id ┆ article_ids                 │
        │ ---     ┆ ---                         │
        │ i64     ┆ list[str]                   │
        ╞═════════╪═════════════════════════════╡
        │ 1       ┆ ["News", "Sports"]          │
        │ 2       ┆ ["Sports", "Entertainment"] │
        │ 3       ┆ ["News", "Unknown"]         │
        │ 4       ┆ ["Unknown", "Unknown"]      │
        │ 5       ┆ ["Unknown"]                 │
        └─────────┴─────────────────────────────┘
    >>> map_list_article_id_to_value(
            behaviors=behaviors,
            behaviors_column="article_ids",
            mapping=articles_dict,
            drop_nulls=True,
        )
        shape: (4, 2)
        ┌─────────┬─────────────────────────────┐
        │ user_id ┆ article_ids                 │
        │ ---     ┆ ---                         │
        │ i64     ┆ list[str]                   │
        ╞═════════╪═════════════════════════════╡
        │ 1       ┆ ["News", "Sports"]          │
        │ 2       ┆ ["Sports", "Entertainment"] │
        │ 3       ┆ ["News"]                    │
        │ 4       ┆ null                        │
        │ 5       ┆ null                        │
        └─────────┴─────────────────────────────┘
    >>> map_list_article_id_to_value(
            behaviors=behaviors,
            behaviors_column="article_ids",
            mapping=articles_dict,
            drop_nulls=False,
        )
        shape: (4, 2)
        ┌─────────┬─────────────────────────────┐
        │ user_id ┆ article_ids                 │
        │ ---     ┆ ---                         │
        │ i64     ┆ list[str]                   │
        ╞═════════╪═════════════════════════════╡
        │ 1       ┆ ["News", "Sports"]          │
        │ 2       ┆ ["Sports", "Entertainment"] │
        │ 3       ┆ ["News", null]              │
        │ 4       ┆ [null, null]                │
        │ 5       ┆ [null]                      │
        └─────────┴─────────────────────────────┘
    """
    GROUPBY_ID = generate_unique_name(behaviors.columns, "_groupby_id")
    behaviors = behaviors.lazy().with_row_index(GROUPBY_ID)
    # =>
    select_column = (
        behaviors.select(pl.col(GROUPBY_ID), pl.col(behaviors_column))
        .explode(behaviors_column)
        .with_columns(pl.col(behaviors_column).replace(mapping, default=None))
        .collect()
    )
    # =>
    if drop_nulls:
        select_column = select_column.drop_nulls()
    elif fill_nulls is not None:
        select_column = select_column.with_columns(
            pl.col(behaviors_column).fill_null(fill_nulls)
        )
    # =>
    select_column = (
        select_column.lazy().group_by(GROUPBY_ID).agg(behaviors_column).collect()
    )
    return (
        behaviors.drop(behaviors_column)
        .collect()
        .join(select_column, on=GROUPBY_ID, how="left")
        .drop(GROUPBY_ID)
    )

In [13]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.


#ebrec/models/newsrec/nrms.py

class NRMSModel:
    """NRMS model(Neural News Recommendation with Multi-Head Self-Attention)

    Chuhan Wu, Fangzhao Wu, Suyu Ge, Tao Qi, Yongfeng Huang,and Xing Xie, "Neural News
    Recommendation with Multi-Head Self-Attention" in Proceedings of the 2019 Conference
    on Empirical Methods in Natural Language Processing and the 9th International Joint Conference
    on Natural Language Processing (EMNLP-IJCNLP)

    Attributes:
    """

    def __init__(
        self,
        hparams: dict,
        word2vec_embedding: np.ndarray = None,
        word_emb_dim: int = 300,
        vocab_size: int = 32000,
        seed: int = None,
    ):
        """Initialization steps for NRMS."""
        self.hparams = hparams
        self.seed = seed

        # SET SEED:
        tf.random.set_seed(seed)
        np.random.seed(seed)

        # INIT THE WORD-EMBEDDINGS:
        if word2vec_embedding is None:
            # Xavier Initialization
            initializer = GlorotUniform(seed=self.seed)
            self.word2vec_embedding = initializer(shape=(vocab_size, word_emb_dim))
            # self.word2vec_embedding = np.random.rand(vocab_size, word_emb_dim)
        else:
            self.word2vec_embedding = word2vec_embedding

        # BUILD AND COMPILE MODEL:
        self.model, self.scorer = self._build_graph()
        data_loss = self._get_loss(self.hparams.loss)
        train_optimizer = self._get_opt(
            optimizer=self.hparams.optimizer, lr=self.hparams.learning_rate
        )
        self.model.compile(loss=data_loss, optimizer=train_optimizer)

    def _get_loss(self, loss: str):
        """Make loss function, consists of data loss and regularization loss
        Returns:
            object: Loss function or loss function name
        """
        if loss == "cross_entropy_loss":
            data_loss = "categorical_crossentropy"
        elif loss == "log_loss":
            data_loss = "binary_crossentropy"
        else:
            raise ValueError(f"this loss not defined {loss}")
        return data_loss

    def _get_opt(self, optimizer: str, lr: float):
        """Get the optimizer according to configuration. Usually we will use Adam.
        Returns:
            object: An optimizer.
        """
        # TODO: shouldn't be a string input you should just set the optimizer, to avoid stuff like this:
        # => 'WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.'
        if optimizer == "adam":
            train_opt = tf.keras.optimizers.Adam(learning_rate=lr)
        else:
            raise ValueError(f"this optimizer not defined {optimizer}")
        return train_opt

    def _build_graph(self):
        """Build NRMS model and scorer.

        Returns:
            object: a model used to train.
            object: a model used to evaluate and inference.
        """
        model, scorer = self._build_nrms()
        return model, scorer

    def _build_userencoder(self, titleencoder):
        """The main function to create user encoder of NRMS.

        Args:
            titleencoder (object): the news encoder of NRMS.

        Return:
            object: the user encoder of NRMS.
        """
        his_input_title = tf.keras.Input(
            shape=(self.hparams.history_size, self.hparams.title_size), dtype="int32"
        )

        click_title_presents = tf.keras.layers.TimeDistributed(titleencoder)(
            his_input_title
        )
        y = SelfAttention(self.hparams.head_num, self.hparams.head_dim, seed=self.seed)(
            [click_title_presents] * 3
        )
        user_present = AttLayer2(self.hparams.attention_hidden_dim, seed=self.seed)(y)

        model = tf.keras.Model(his_input_title, user_present, name="user_encoder")
        return model

    def _build_newsencoder(self, units_per_layer: list[int] = None):
        """The main function to create news encoder of NRMS.

        Args:
            embedding_layer (object): a word embedding layer.

        Return:
            object: the news encoder of NRMS.
        """
        embedding_layer = tf.keras.layers.Embedding(
            self.word2vec_embedding.shape[0],
            self.word2vec_embedding.shape[1],
            weights=[self.word2vec_embedding],
            trainable=True,
        )
        sequences_input_title = tf.keras.Input(
            shape=(self.hparams.title_size,), dtype="int32"
        )
        embedded_sequences_title = embedding_layer(sequences_input_title)

        y = tf.keras.layers.Dropout(self.hparams.dropout)(embedded_sequences_title)
        y = SelfAttention(self.hparams.head_num, self.hparams.head_dim, seed=self.seed)(
            [y, y, y]
        )

        # Create configurable Dense layers (the if - else is something I've added):
        if units_per_layer:
            for layer in units_per_layer:
                y = tf.keras.layers.Dense(
                    units=layer,
                    activation="relu",
                    kernel_regularizer=tf.keras.regularizers.l2(
                        self.hparams.newsencoder_l2_regularization
                    ),
                )(y)
                y = tf.keras.layers.BatchNormalization()(y)
                y = tf.keras.layers.Dropout(self.hparams.dropout)(y)
        else:
            y = tf.keras.layers.Dropout(self.hparams.dropout)(y)

        pred_title = AttLayer2(self.hparams.attention_hidden_dim, seed=self.seed)(y)

        model = tf.keras.Model(sequences_input_title, pred_title, name="news_encoder")
        return model

    def _build_nrms(self):
        """The main function to create NRMS's logic. The core of NRMS
        is a user encoder and a news encoder.

        Returns:
            object: a model used to train.
            object: a model used to evaluate and inference.
        """

        his_input_title = tf.keras.Input( 
            shape=(self.hparams.history_size, self.hparams.title_size),
            dtype="int32",
        )
        pred_input_title = tf.keras.Input(
            # shape = (hparams.npratio + 1, hparams.title_size)
            shape=(None, self.hparams.title_size),
            dtype="int32",
        )
        pred_input_title_one = tf.keras.Input(
            shape=(
                1,
                self.hparams.title_size,
            ),
            dtype="int32",
        )
        pred_title_one_reshape = tf.keras.layers.Reshape((self.hparams.title_size,))(
            pred_input_title_one
        )
        titleencoder = self._build_newsencoder(
            units_per_layer=self.hparams.newsencoder_units_per_layer
        )
        self.userencoder = self._build_userencoder(titleencoder)
        self.newsencoder = titleencoder

        user_present = self.userencoder(his_input_title)
        news_present = tf.keras.layers.TimeDistributed(self.newsencoder)(
            pred_input_title
        )
        news_present_one = self.newsencoder(pred_title_one_reshape)

        preds = tf.keras.layers.Dot(axes=-1)([news_present, user_present])
        preds = tf.keras.layers.Activation(activation="softmax")(preds)

        pred_one = tf.keras.layers.Dot(axes=-1)([news_present_one, user_present])
        pred_one = tf.keras.layers.Activation(activation="sigmoid")(pred_one)

        model = tf.keras.Model([his_input_title, pred_input_title], preds)
        scorer = tf.keras.Model([his_input_title, pred_input_title_one], pred_one)

        return model, scorer

In [14]:
#examples/reproducibility_scripts/args_nrms.py

sys.argv = [
    "--data_path", "path/to/data",
    "--datasplit", "train_test_split",
    "--transformer_model_name", "bert-base-uncased",
    "--seed", "42",
    "--bs_train", "32",
    "--bs_test", "32",
    "--epochs", "10",
    "--fraction_test", "0.2",
    "--train_fraction", "0.8",
    # Add any other arguments you need
]


def get_args():
    parser = argparse.ArgumentParser(
        description="Argument parser for NRMSModel training"
    )

    parser.add_argument(
        "--data_path",
        type=str,
        default=str("~/ebnerd_data"),
        help="Path to the data directory",
    )

    # General settings
    parser.add_argument("--seed", type=int, default=123, help="Random seed")
    parser.add_argument(
        "--datasplit", type=str, default="ebnerd_small", help="Dataset split to use"
    )
    parser.add_argument("--debug", action="store_true", help="Enable debug mode")

    # Batch sizes
    parser.add_argument(
        "--bs_train", type=int, default=32, help="Batch size for training"
    )
    parser.add_argument(
        "--bs_test", type=int, default=32, help="Batch size for testing"
    )
    parser.add_argument(
        "--batch_size_test_wo_b",
        type=int,
        default=32,
        help="Batch size for testing without balancing",
    )
    parser.add_argument(
        "--batch_size_test_w_b",
        type=int,
        default=4,
        help="Batch size for testing with balancing",
    )

    # History and ratios
    parser.add_argument(
        "--history_size", type=int, default=20, help="History size for the model"
    )
    parser.add_argument(
        "--npratio", type=int, default=4, help="Negative-positive ratio"
    )

    # Training settings
    parser.add_argument("--epochs", type=int, default=5, help="Number of epochs")
    parser.add_argument(
        "--train_fraction",
        type=float,
        default=1.0,
        help="Fraction of training data to use",
    )
    parser.add_argument(
        "--fraction_test",
        type=float,
        default=1.0,
        help="Fraction of testing data to use",
    )

    # Model and loader settings
    parser.add_argument(
        "--nrms_loader",
        type=str,
        default="NRMSDataLoaderPretransform",
        choices=["NRMSDataLoaderPretransform", "NRMSDataLoader"],
        help="Data loader type (speed or memory efficient)",
    )

    # Chunk processing
    parser.add_argument(
        "--n_chunks_test", type=int, default=10, help="Number of test chunks to process"
    )
    parser.add_argument(
        "--chunks_done", type=int, default=0, help="Number of chunks already processed"
    )

    # =====================================================================================
    #  ############################# UNIQUE FOR NRMSDocVec ###############################
    # =====================================================================================
    # Transformer settings
    parser.add_argument(
        "--transformer_model_name",
        type=str,
        default="FacebookAI/xlm-roberta-large",
        help="Transformer model name",
    )
    parser.add_argument(
        "--max_title_length",
        type=int,
        default=30,
        help="Maximum length of title encoding",
    )

    # Hyperparameters
    parser.add_argument(
        "--head_num", type=int, default=20, help="Number of attention heads"
    )
    parser.add_argument(
        "--head_dim", type=int, default=20, help="Dimension of each attention head"
    )
    parser.add_argument(
        "--attention_hidden_dim",
        type=int,
        default=200,
        help="Dimension of attention hidden layers",
    )

    # Optimizer settings
    parser.add_argument(
        "--optimizer", type=str, default="adam", help="Optimizer to use"
    )
    parser.add_argument(
        "--loss", type=str, default="cross_entropy_loss", help="Loss function"
    )
    parser.add_argument("--dropout", type=float, default=0.20, help="Dropout rate")
    parser.add_argument(
        "--learning_rate", type=float, default=1e-4, help="Learning rate"
    )

    # Parse known args to avoid Jupyter kernel args
    args, unknown = parser.parse_known_args()
    return args
