## 0. Init

In [None]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

sys.path.append("../../")

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from choice_learn.basket_models.self_attention_model import SelfAttentionModel


In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print("Num GPUs:", len(physical_devices))

## I.  Synthetic Dataset

Dataset We build a synthetic dataset for which the catalog has 8 items, I=
{1,...,8}, with the following interactions:
- Cannibalization: {0,1, 2} on the one hand and {3,4,5} on the other hand
form groups of items cannibalizing each other.
- Complementarity: each of the items in {0,1,2} are complementary to each
of the items in {3,4,5};
- Neutral: 6, 7 and 8 are neutral in the sense that they don’t have specific
interaction with other items.
- When choosing among the first nest, user 0 prefers item 0, user 1 the item 1 and user 2 the item 2.

In [None]:
from choice_learn.basket_models.datasets.synthetic_dataset import SyntheticDataGenerator

"""items_nest : dict
                Dictionary defining item sets and their relations.
                Key should be next index and values list of items indexes, e.g."""

items_nest = { 0:[0, 1,2],
                1: [3,4,5],
                2: [6,7,8]}

"""nests_interactions: list
                List of interactions between nests for each nest. Symmetry should
                be ensure by users, e.g."""
                
nests_interactions = [["", "compl", "neutral", "neutral"],
                    ["compl", "", "neutral", "neutral"],
                    ["neutral", "neutral", "", "neutral"]]
    
""" proba_complementary_items : float
        Probability of adding complementary items to the basket.
    proba_neutral_items : float
        Probability of adding neutral items to the basket.
    noise_proba : float
        Probability of adding noise items to the basket."""

user_profile = {0:{ "nest" : 0, "item" : 0}, 1: {"nest" : 0, "item" : 1}, 2: {"nest" : 0, "item" : 2}}

"""user_profile : dict
                Dictionary defining user profiles.
                Key should be user index and values a dict with 'nest' and 'item' keys"""

data = SyntheticDataGenerator(items_nest=items_nest,
                       nests_interactions=nests_interactions,
                       proba_complementary_items=1,
                       proba_neutral_items=0.0,
                       noise_proba=0.0,
                       user_profile=user_profile
                        )

data = data.generate_trip_dataset(n_baskets=1000, assortments_matrix=np.ones((1, 9)))

data.available_items

## II. Self Attention Model

In [None]:
from tensorflow.keras.callbacks import EarlyStopping

lr = 0.005
n_epochs = 20
batch_size = 32
latent_sizes = {"short_term": 2, "long_term": 2}
L = 7
hinge_margin = 0.7
short_term_ratio = 0.3
n_negative_samples = 1
optimizer = "adam"
λ = 0.0
dropout_rate = 0.0



In [None]:
model = SelfAttentionModel(
    optimizer=optimizer,
    n_negative_samples=n_negative_samples,
    lr=lr,
    epochs=n_epochs,
    batch_size=batch_size,
    latent_sizes=latent_sizes,
    hinge_margin=hinge_margin,
    short_term_ratio=short_term_ratio,
    l2_regularization=λ,
    dropout_rate=dropout_rate,
   
)

model.instantiate(n_items=data.n_items, n_users=data.n_users)

In [None]:
history = model.fit(trip_dataset=data, verbose=2)

In [None]:
plt.plot(history["train_loss"])
plt.plot(history["val_loss"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training Loss")



plt.show()

## III. Embedding Visualisation

In [None]:
X = model.X
U = model.U
V = model.V
d = model.d_long

In [None]:
from sklearn.decomposition import PCA
import seaborn as sns

embedding = model.V
basket_batch = [[0,1,3,7]]
m_batch, affinity_matrix = model.embed_context(basket_batch, is_training=False)


if d ==1:
    embedding = np.hstack([embedding, np.zeros((embedding.shape[0],1))])
    U = tf.concat([U, tf.zeros((U.shape[0],1))], axis=1)

if d > 2:
    pca = PCA(n_components=2)
    V_pca = pca.fit_transform(embedding)
    U_pca = pca.transform(U.numpy())
else:
    V_pca = embedding
    U_pca = U.numpy()
plt.figure(figsize=(10, 8))
plt.scatter(V_pca[:, 0], V_pca[:, 1])#, c=color_group)

plt.scatter(U_pca[:,0], U_pca[:,1], color='red', marker='x', s=100, label='User 0')

for i in range(V_pca.shape[0]):
    plt.annotate(str(i), (V_pca[i, 0], V_pca[i, 1]+0.05), 
                     fontsize=8, ha='center', va='center')
for i in range(U_pca.shape[0]):
    plt.annotate(f'U{i}', (U_pca[i, 0], U_pca[i, 1]+0.05), 
                     fontsize=8, ha='center', va='center')

plt.title("PCA visualization of item embeddings")
plt.axis()
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.show()


sns.heatmap(affinity_matrix[0],
            annot=True,     
            fmt='.2f',
            ) 

