In [None]:
"""
Implementation of an attention-based model for item recommendation.

Cf. "Attention-Based Transactional Context Embedding for Next-Item Recommendation".
Wang et al. (2018).
"""


In [None]:
import json
from pathlib import Path
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' 
import sys

sys.path.append("./../../")
print(os.getcwd())

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tqdm

from choice_learn.basket_models import TripDataset
from choice_learn.basket_models.attn_model import AttentionBasedContextEmbedding
from choice_learn.basket_models.synthetic_dataset import SyntheticDataGenerator

In [None]:
# Parameters

n_baskets = 2000
epochs = 400
lr = 0.02
embedding_dim = 6
n_negative_samples = 3

In [None]:
# Generate synthetic dataset

data_gen = SyntheticDataGenerator(
    proba_complementary_items=0.7,
    proba_neutral_items=0.3,
    noise_proba=0.15,)

trip_dataset = data_gen.generate_trip_dataset(n_baskets)

In [None]:
# Instantiate and train the model

model1 = AttentionBasedContextEmbedding(
    epochs=epochs,
    lr=lr,
    embedding_dim=embedding_dim,
    n_negative_samples=n_negative_samples
)
model1.instantiate(
    n_items=data_gen.assortment_matrix.shape[1])
history = model1.fit(trip_dataset)


In [None]:
# Visualize empirical distribution
import matplotlib.pyplot as plt

contexts = tf.constant([[i] for i in range(data_gen.assortment_matrix.shape[1])], dtype=tf.int32)
context_prediction = model1.predict(contexts)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

im1 = axes[0].imshow(
    np.stack(context_prediction),
    vmin=0.0,
    vmax=np.max(np.stack(context_prediction)),
    cmap="Spectral",
)

axes[0].set_title("Model P(i|j) on elementary baskets")
plt.colorbar(im1, ax=axes[0])
axes[1].plot(history["train_loss"], label="Training Loss")
axes[1].set_xlabel("Training Steps")
axes[1].set_ylabel("Loss")
axes[1].set_title("Training Loss History")

plt.tight_layout()
plt.show()

In [None]:
# Create evaluation dataset
eval_dataset = data_gen.generate_trip_dataset(100)

# Evaluate model
loss_eval_dataset_1 = model1.evaluate(eval_dataset)
print(f"Loss of model1 on the evaluation dataset {loss_eval_dataset_1}")

# Save model
model1.save_model("attn_model.json")

In [None]:
# Create a second model without instantiating
model2 = AttentionBasedContextEmbedding(
    epochs=epochs,
    lr=lr,
    embedding_dim=embedding_dim,
    n_negative_samples=n_negative_samples
)

# Load first model and compare results on evaluation dataset
model2.load_model("attn_model.json")
loss_eval_dataset_2 = model2.evaluate(eval_dataset)
print(f"Loss of model2 on the evaluation dataset {loss_eval_dataset_2}")
os.remove("attn_model.json")