# Clothing recommendation system based on two tower model

This notebook contains the implementation of a recommendation system based on [two tower model architecture](https://research.google/pubs/pub48840/) and inspections related to it. The two tower model is implemented, trained and tested with [PyTorch framework](https://pytorch.org/). To simplify operations on the structured dataset, [pandas](https://pandas.pydata.org/) library is used. The used dataset for the inspections of the model is the [H&M Personalized Fashion Recommendations dataset](https://www.kaggle.com/competitions/h-and-m-personalized-fashion-recommendations/). To track the experiments, run hyperparameter search and save model artifacts [Weights & Biases (wandb)](https://wandb.ai/) is used. The model is exported and saved in [ONNX](https://onnx.ai/) format. The notebook is written in Python.

## 1. Two tower model implementation and experimentation

Let's import the required dependencies first! If a dependency can't be imported, because it is missing from the system, then [pip](https://pypi.org/project/pip/) can be used for instance to download them.

In [None]:
import bisect
import json
import math
import numpy as np
import pandas as pd
import torch
import wandb
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn
import traceback

To make steps reproducible, that contain random behaviour, let's set the random generator's seed to a constant number.

In [None]:
np.random.seed(2)

To achieve different goals, different datasets are can be more suited. But datasets that are use the data from the H&M fashion dataset might require similar functionalities, like attribute transformations. The **WebshopDataContainer** class contains these common implementations, that the more specific datasets might reuse.

In [None]:
class WebshopDataContainer:
    _ARTICLE_GARMENT_GROUP_NAME_NUMERIC_ENCODER = {
        "Accessories": 0,
        "Blouses": 1,
        "Dressed": 2,
        "Dresses Ladies": 3,
        "Dresses/Skirts girls": 4,
        "Jersey Basic": 5,
        "Jersey Fancy": 6,
        "Knitwear": 7,
        "Outdoor": 8,
        "Shirts": 9,
        "Shoes": 10,
        "Shorts": 11,
        "Skirts": 12,
        "Socks and Tights": 13,
        "Special Offers": 14,
        "Swimwear": 15,
        "Trousers": 16,
        "Trousers Denim": 17,
        "Under-, Nightwear": 18,
        "Unknown": 19,
        "Woven/Jersey/Knitted mix Baby": 20,
    }

    _ARTICLE_GRAPHICAL_APPEARANCE_NUMERIC_ENCODER = {
        "All over pattern": 0,
        "Application/3D": 1,
        "Argyle": 2,
        "Chambray": 3,
        "Check": 4,
        "Colour blocking": 5,
        "Contrast": 6,
        "Denim": 7,
        "Dot": 8,
        "Embroidery": 9,
        "Front print": 10,
        "Glittering/Metallic": 11,
        "Hologram": 12,
        "Jacquard": 13,
        "Lace": 14,
        "Melange": 15,
        "Mesh": 16,
        "Metallic": 17,
        "Mixed solid/pattern": 18,
        "Neps": 19,
        "Other pattern": 20,
        "Other structure": 21,
        "Placement print": 22,
        "Sequin": 23,
        "Slub": 24,
        "Solid": 25,
        "Stripe": 26,
        "Transparent": 27,
        "Treatment": 28,
        "Unknown": 29,
    }

    _ARTICLE_INDEX_NAME_NUMERIC_ENCODER = {
        "Baby Sizes 50-98": 0,
        "Children Accessories, Swimwear": 1,
        "Children Sizes 134-170": 2,
        "Children Sizes 92-140": 3,
        "Divided": 4,
        "Ladies Accessories": 5,
        "Ladieswear": 6,
        "Lingeries/Tights": 7,
        "Menswear": 8,
        "Sport": 9,
    }

    _ARTICLE_PERCEIVED_COLOR_MASTER_NAME_NUMERIC_ENCODER = {
        "Beige": 0,
        "Black": 1,
        "Blue": 2,
        "Bluish Green": 3,
        "Brown": 4,
        "Green": 5,
        "Grey": 6,
        "Khaki green": 7,
        "Lilac Purple": 8,
        "Metal": 9,
        "Mole": 10,
        "Orange": 11,
        "Pink": 12,
        "Red": 13,
        "Turquoise": 14,
        "undefined": 15,
        "Unknown": 16,
        "White": 17,
        "Yellow": 18,
        "Yellowish Green": 19,
    }

    _ARTICLE_PRODUCT_GROUP_NAME_NUMERIC_ENCODER = {
        "Accessories": 0,
        "Bags": 1,
        "Cosmetic": 2,
        "Fun": 3,
        "Furniture": 4,
        "Garment Full body": 5,
        "Garment Lower body": 6,
        "Garment Upper body": 7,
        "Garment and Shoe care": 8,
        "Interior textile": 9,
        "Items": 10,
        "Nightwear": 11,
        "Shoes": 12,
        "Socks & Tights": 13,
        "Stationery": 14,
        "Swimwear": 15,
        "Underwear": 16,
        "Underwear/nightwear": 17,
        "Unknown": 18,
    }

    def __init__(
            self,
            articles_df,
            customers_df,
            transactions_df,
            customer_age_group_size,
            customer_minimum_age,
            article_id_column_name,
            article_garment_group_column_name,
            article_graphical_appearance_column_name,
            article_index_name_column_name,
            article_perceived_color_master_column_name,
            article_product_group_column_name,
            customer_id_column_name,
            customer_age_column_name,
            transaction_date_column_name,
    ):

        self._articles_df = articles_df
        self._customers_df = customers_df
        self._transactions_df = transactions_df

        self._article_id_column_name = article_id_column_name
        self._article_garment_group_column_name = article_garment_group_column_name
        self._article_graphical_appearance_column_name = article_graphical_appearance_column_name
        self._article_index_name_column_name = article_index_name_column_name
        self._article_perceived_color_master_column_name = article_perceived_color_master_column_name
        self._article_product_group_column_name = article_product_group_column_name
        self._customer_id_column_name = customer_id_column_name
        self._customer_age_column_name = customer_age_column_name
        self._transaction_date_column_name = transaction_date_column_name

        self._customer_age_group_size = customer_age_group_size
        self._customer_minimum_age = customer_minimum_age

        if transaction_date_column_name in transactions_df.columns:
            self._is_transaction_date_column_used = True
            self._transaction_dates = transactions_df[transaction_date_column_name].unique()
        else:
            self._is_transaction_date_column_used = False

    def _calculate_customer_age_group_index(self, age):
        return math.ceil((age - self._customer_minimum_age) / self._customer_age_group_size)

    def _create_result_attributes(self, article, customer, transaction):
        article_id = article[self._article_id_column_name]
        customer_id = customer[self._customer_id_column_name]
        result_attributes = {
            self._article_id_column_name: self._articles_df.index.get_loc(
                self._articles_df[self._articles_df[self._article_id_column_name] == article_id].index[0]),
            self._customer_id_column_name: self._customers_df.index.get_loc(
                self._customers_df[self._customers_df[self._customer_id_column_name] == customer_id].index[0]),
        }

        if self._article_garment_group_column_name in self._articles_df.columns:
            result_attributes[self._article_garment_group_column_name] = \
                self._ARTICLE_GARMENT_GROUP_NAME_NUMERIC_ENCODER[
                    article[self._article_garment_group_column_name]
                ]

        if self._article_graphical_appearance_column_name in self._articles_df.columns:
            result_attributes[self._article_graphical_appearance_column_name] = (
                self._ARTICLE_GRAPHICAL_APPEARANCE_NUMERIC_ENCODER[
                    article[self._article_graphical_appearance_column_name]
                ]
            )

        if self._article_index_name_column_name in self._articles_df.columns:
            result_attributes[self._article_index_name_column_name] = self._ARTICLE_INDEX_NAME_NUMERIC_ENCODER[
                article[self._article_index_name_column_name]
            ]

        if self._article_perceived_color_master_column_name in self._articles_df.columns:
            result_attributes[self._article_perceived_color_master_column_name] = (
                self._ARTICLE_PERCEIVED_COLOR_MASTER_NAME_NUMERIC_ENCODER[
                    article[self._article_perceived_color_master_column_name]
                ])

        if self._article_product_group_column_name in self._articles_df.columns:
            result_attributes[self._article_product_group_column_name] = \
                self._ARTICLE_PRODUCT_GROUP_NAME_NUMERIC_ENCODER[
                    article[self._article_product_group_column_name]
                ]

        if self._customer_age_column_name in self._customers_df.columns:
            result_attributes[self._customer_age_column_name] = self._calculate_customer_age_group_index(
                customer[self._customer_age_column_name]
            )

        if self._is_transaction_date_column_used:
            result_attributes[self._transaction_date_column_name] = np.where(
                self._transaction_dates == transaction[self._transaction_date_column_name]
            )[0][0]

        return result_attributes

    @classmethod
    def get_article_garment_group_name_numeric_encoder(cls):
        return cls._ARTICLE_GARMENT_GROUP_NAME_NUMERIC_ENCODER.copy()

    @classmethod
    def get_article_graphical_appearance_numeric_encoder(cls):
        return cls._ARTICLE_GRAPHICAL_APPEARANCE_NUMERIC_ENCODER.copy()

    @classmethod
    def get_article_index_name_numeric_encoder(cls):
        return cls._ARTICLE_INDEX_NAME_NUMERIC_ENCODER.copy()

    @classmethod
    def get_article_perceived_color_master_name_numeric_encoder(cls):
        return cls._ARTICLE_PERCEIVED_COLOR_MASTER_NAME_NUMERIC_ENCODER.copy()

    @classmethod
    def get_article_product_group_name_numeric_encoder(cls):
        return cls._ARTICLE_PRODUCT_GROUP_NAME_NUMERIC_ENCODER.copy()

The **NegativeBatchSampledTransactionDataset** stores the data of the transactions. The speciality of it (as its name suggests) that its also generates negative transaction samples. For every real transaction a negative transaction with the same date and customer generated and its article is chosen in such a way, that the customer has no real transaction with that. The information, that the transaction is real or a negative example, is stored as the label value.

In [None]:
class NegativeBatchSampledTransactionDataset(Dataset, WebshopDataContainer):

    def __init__(
            self,
            articles_df,
            customers_df,
            transactions_df,
            customer_age_group_size,
            customer_minimum_age,
            article_id_column_name,
            article_garment_group_column_name,
            article_graphical_appearance_column_name,
            article_index_name_column_name,
            article_perceived_color_master_column_name,
            article_product_group_column_name,
            customer_id_column_name,
            customer_age_column_name,
            transaction_date_column_name,
            get_negative_transactions_from_cache: bool = False,
    ):
        WebshopDataContainer.__init__(
            self=self,
            articles_df=articles_df,
            customers_df=customers_df,
            transactions_df=transactions_df,
            customer_age_group_size=customer_age_group_size,
            customer_minimum_age=customer_minimum_age,
            article_id_column_name=article_id_column_name,
            article_garment_group_column_name=article_garment_group_column_name,
            article_graphical_appearance_column_name=article_graphical_appearance_column_name,
            article_index_name_column_name=article_index_name_column_name,
            article_perceived_color_master_column_name=article_perceived_color_master_column_name,
            article_product_group_column_name=article_product_group_column_name,
            customer_id_column_name=customer_id_column_name,
            customer_age_column_name=customer_age_column_name,
            transaction_date_column_name=transaction_date_column_name,
        )

        self._cached_negative_transactions_path \
            = "two_tower_model_clothing_recommendation_system_cache/cached_negative_transactions.json"

        if get_negative_transactions_from_cache:
            try:
                with open(self._cached_negative_transactions_path, "r") as json_file:
                    self._negative_article_ids_of_customers = json.load(json_file)
            except FileNotFoundError:
                self._negative_article_ids_of_customers = None
        else:
            self._negative_article_ids_of_customers = None

        if self._negative_article_ids_of_customers is None:
            self._negative_article_ids_of_customers = self._create_negative_article_ids_of_customers(
                articles_df=self._articles_df,
                transactions_df=self._transactions_df,
            )
            try:
                with open(self._cached_negative_transactions_path, "w") as json_file:
                    json.dump(self._negative_article_ids_of_customers, json_file)
            except (PermissionError, FileNotFoundError, OSError) as e:
                print(f"Failed to save negative transactions! {e}")

        self.transaction_counter_for_customer = {}

    def _create_negative_article_ids_of_customers(self, articles_df, transactions_df):
        negative_article_ids_of_customers = {}
        customer_count = transactions_df[self._customer_id_column_name].nunique()
        print("Starting to generate negative articles for customers...")

        for index, customer_id in enumerate(transactions_df[self._customer_id_column_name].drop_duplicates()):
            transactions_of_customer = transactions_df[transactions_df[self._customer_id_column_name] == customer_id]
            ids_of_articles_purchased_by_customer = transactions_of_customer[self._article_id_column_name].drop_duplicates()

            ids_of_negative_articles_of_customer = (
                articles_df[~articles_df[self._article_id_column_name].isin(ids_of_articles_purchased_by_customer)]
                [self._article_id_column_name]
            )

            negative_article_ids_of_customers[customer_id] = np.random.choice(
                ids_of_negative_articles_of_customer,
                size=len(transactions_of_customer),
                replace=False,
            ).tolist()

            if index % 100 == 0:
                print(f"Negative article generation progress: {index}/{customer_count}")

        print("Finished generating negative articles for customers!")
        return negative_article_ids_of_customers

    def __len__(self):
        return len(self._transactions_df)

    def __getitem__(self, idx):
        transaction = self._transactions_df.iloc[idx]
        customer = self._customers_df[
            self._customers_df[self._customer_id_column_name] == transaction[self._customer_id_column_name]
            ].iloc[0]
        positive_article = self._articles_df[
            self._articles_df[self._article_id_column_name] == transaction[self._article_id_column_name]
            ].iloc[0]

        customer_id = transaction[self._customer_id_column_name]
        transaction_counter_of_customer = self.transaction_counter_for_customer.get(customer_id, 0)
        negative_article_id = self._negative_article_ids_of_customers[customer_id][transaction_counter_of_customer]
        negative_article = (
            self._articles_df[self._articles_df[self._article_id_column_name] == negative_article_id].iloc[0]
        )

        negative_transaction = pd.Series({
            self._article_id_column_name: negative_article_id,
            self._customer_id_column_name: customer_id,
        })

        if self._transaction_date_column_name in self._transactions_df.columns:
            negative_transaction[self._transaction_date_column_name] = transaction[self._transaction_date_column_name]

        self.transaction_counter_for_customer[customer_id] = transaction_counter_of_customer + 1

        positive_transaction_attributes = self._create_result_attributes(
            article=positive_article,
            customer=customer,
            transaction=transaction,
        )
        negative_transactions_attributes = self._create_result_attributes(
            article=negative_article,
            customer=customer,
            transaction=transaction,
        )

        return (
            [positive_transaction_attributes, negative_transactions_attributes],
            [1.0, 0.0],
        )

The **TransactionDataset** is a simple dataset. It just maps the attributes of transactions and related entities (article, customer), and returns them.

In [None]:
class TransactionDataset(Dataset, WebshopDataContainer):

    def __init__(
            self,
            articles_df,
            customers_df,
            transactions_df,
            customer_age_group_size,
            customer_minimum_age,
            article_id_column_name,
            article_garment_group_column_name,
            article_graphical_appearance_column_name,
            article_index_name_column_name,
            article_perceived_color_master_column_name,
            article_product_group_column_name,
            customer_id_column_name,
            customer_age_column_name,
            transaction_date_column_name,
    ):
        WebshopDataContainer.__init__(
            self=self,
            articles_df=articles_df,
            customers_df=customers_df,
            transactions_df=transactions_df,
            customer_age_group_size=customer_age_group_size,
            customer_minimum_age=customer_minimum_age,
            article_id_column_name=article_id_column_name,
            article_garment_group_column_name=article_garment_group_column_name,
            article_graphical_appearance_column_name=article_graphical_appearance_column_name,
            article_index_name_column_name=article_index_name_column_name,
            article_perceived_color_master_column_name=article_perceived_color_master_column_name,
            article_product_group_column_name=article_product_group_column_name,
            customer_id_column_name=customer_id_column_name,
            customer_age_column_name=customer_age_column_name,
            transaction_date_column_name=transaction_date_column_name,
        )

    def __len__(self):
        return len(self._transactions_df)

    def __getitem__(self, idx):
        transaction = self._transactions_df.iloc[idx]
        customer = self._customers_df[
            self._customers_df[self._customer_id_column_name] == transaction[self._customer_id_column_name]
            ].iloc[0]
        positive_article = self._articles_df[
            self._articles_df[self._article_id_column_name] == transaction[self._article_id_column_name]
            ].iloc[0]

        transaction_attributes = self._create_result_attributes(
            article=positive_article,
            customer=customer,
            transaction=transaction,
        )

        return transaction_attributes, 1.0

The **TwoTowerModel** class contains the implementation of the recommender model. The tower model contains two separate models, usually both model is an identical neural network. Each tower model receives the input data of one of the related entity type. In this case the item tower calculates the output using the data of the article entities, the query tower uses the data of the customer entities. The tower models calculate an embedding vector, and the goal of the model is to create similar vectors of such entities that are connected. In the current example the connection between articles and customers is whether they appear together in a transaction. Dense categorical attributes (typically enum values) are encoded with one hot encoding, sparse categorical attributes are encoded using embedding layers. The embedding layers' dictionaries' size are incremented with one, because if the recommendation system is deployed, it is possible, that it will get a value, that was not contained in its test dataset. These unknown values should be converted into this highest possible input.

In [None]:
class TwoTowerModel(nn.Module):

    def __init__(
            self,
            number_of_neurons_in_the_first_hidden_dense_layer,
            number_of_neurons_in_the_second_hidden_dense_layer,
            number_of_neurons_in_the_third_hidden_dense_layer,
            article_id_embedding_dimension_and_unique_value_count,
            article_garment_group_name_unique_value_count,
            article_graphical_appearance_name_unique_value_count,
            article_index_name_unique_value_count,
            article_perceived_color_master_name_unique_value_count,
            article_product_group_name_unique_value_count,
            customer_id_embedding_dimension_and_unique_value_count,
            customer_age_group_count,
            transaction_date_embedding_dimension_and_unique_value_count,
            similarity_function,
    ):
        super().__init__()

        self._initialize_input_layer(
            article_id_embedding_dimension_and_unique_value_count=article_id_embedding_dimension_and_unique_value_count,
            article_index_name_unique_value_count=article_index_name_unique_value_count,
            article_garment_group_name_unique_value_count=article_garment_group_name_unique_value_count,
            article_product_group_name_unique_value_count=article_product_group_name_unique_value_count,
            article_perceived_color_master_name_unique_value_count=article_perceived_color_master_name_unique_value_count,
            article_graphical_appearance_name_unique_value_count=article_graphical_appearance_name_unique_value_count,
            customer_id_embedding_dimension_and_unique_value_count=customer_id_embedding_dimension_and_unique_value_count,
            customer_age_group_count=customer_age_group_count,
            transaction_date_embedding_dimension_and_unique_value_count=transaction_date_embedding_dimension_and_unique_value_count,
        )

        self._create_item_tower_model(
            dimension_of_input_for_first_dense_layer=self.calculate_item_tower_input_dimension(),
            number_of_neurons_in_the_first_hidden_dense_layer=number_of_neurons_in_the_first_hidden_dense_layer,
            number_of_neurons_in_the_second_hidden_dense_layer=number_of_neurons_in_the_second_hidden_dense_layer,
            number_of_neurons_in_the_third_hidden_dense_layer=number_of_neurons_in_the_third_hidden_dense_layer,
        )
        self._create_query_tower_model(
            dimension_of_input_for_first_dense_layer=self.calculate_query_tower_input_dimension(),
            number_of_neurons_in_the_first_hidden_dense_layer=number_of_neurons_in_the_first_hidden_dense_layer,
            number_of_neurons_in_the_second_hidden_dense_layer=number_of_neurons_in_the_second_hidden_dense_layer,
            number_of_neurons_in_the_third_hidden_dense_layer=number_of_neurons_in_the_third_hidden_dense_layer,
        )

        self.similarity_function = similarity_function

    def _initialize_input_layer(
            self,
            article_id_embedding_dimension_and_unique_value_count,
            article_index_name_unique_value_count,
            article_garment_group_name_unique_value_count,
            article_product_group_name_unique_value_count,
            article_perceived_color_master_name_unique_value_count,
            article_graphical_appearance_name_unique_value_count,
            customer_id_embedding_dimension_and_unique_value_count,
            customer_age_group_count,
            transaction_date_embedding_dimension_and_unique_value_count,
    ):
        # TODO prepare one-hot encoded attribute processing for unknown values
        if article_id_embedding_dimension_and_unique_value_count is not None:
            self._article_id_embedding_dimension = article_id_embedding_dimension_and_unique_value_count[0]
            self.article_id_encoder = nn.Embedding(
                num_embeddings=article_id_embedding_dimension_and_unique_value_count[1] + 1,
                embedding_dim=article_id_embedding_dimension_and_unique_value_count[0],
            )
        else:
            self._article_id_embedding_dimension = 0

        if customer_id_embedding_dimension_and_unique_value_count is not None:
            self._customer_id_embedding_dimension = customer_id_embedding_dimension_and_unique_value_count[0]
            self.customer_id_encoder = nn.Embedding(
                num_embeddings=customer_id_embedding_dimension_and_unique_value_count[1] + 1,
                embedding_dim=customer_id_embedding_dimension_and_unique_value_count[0],
            )
        else:
            self._customer_id_embedding_dimension = 0

        if transaction_date_embedding_dimension_and_unique_value_count is not None \
                and transaction_date_embedding_dimension_and_unique_value_count != 0:
            self._customer_transaction_date_dimension = transaction_date_embedding_dimension_and_unique_value_count[0]
            self.transaction_date_encoder = nn.Embedding(
                num_embeddings=transaction_date_embedding_dimension_and_unique_value_count[1] + 1,
                embedding_dim=transaction_date_embedding_dimension_and_unique_value_count[0],
            )
        else:
            self._customer_transaction_date_dimension = 0
            self.transaction_date_encoder = None

        self._customer_age_group_unique_value_count = customer_age_group_count

        self._article_index_name_unique_value_count = article_index_name_unique_value_count
        self._article_garment_group_name_unique_value_count = article_garment_group_name_unique_value_count
        self._article_product_group_name_unique_value_count = article_product_group_name_unique_value_count
        self._article_perceived_color_master_name_unique_value_count = article_perceived_color_master_name_unique_value_count
        self._article_graphical_appearance_name_unique_value_count = article_graphical_appearance_name_unique_value_count

    def _create_item_tower_model(
            self,
            dimension_of_input_for_first_dense_layer,
            number_of_neurons_in_the_first_hidden_dense_layer,
            number_of_neurons_in_the_second_hidden_dense_layer,
            number_of_neurons_in_the_third_hidden_dense_layer,
    ):
        self.item_tower_model = self._create_dense_sequential_stack(
            list_of_layer_input_and_output_dimensions=self._assemble_list_of_sequential_dense_linear_stack_dimensions(
                dimension_of_input_for_first_dense_layer=dimension_of_input_for_first_dense_layer,
                number_of_neurons_in_the_first_hidden_dense_layer=number_of_neurons_in_the_first_hidden_dense_layer,
                number_of_neurons_in_the_second_hidden_dense_layer=number_of_neurons_in_the_second_hidden_dense_layer,
                number_of_neurons_in_the_third_hidden_dense_layer=number_of_neurons_in_the_third_hidden_dense_layer,
            ),
        )

    @classmethod
    def _assemble_list_of_sequential_dense_linear_stack_dimensions(
            cls,
            dimension_of_input_for_first_dense_layer,
            number_of_neurons_in_the_first_hidden_dense_layer,
            number_of_neurons_in_the_second_hidden_dense_layer,
            number_of_neurons_in_the_third_hidden_dense_layer,
    ):
        list_of_tower_layer_input_and_output_dimensions = [
            dimension_of_input_for_first_dense_layer,
            number_of_neurons_in_the_first_hidden_dense_layer,
            number_of_neurons_in_the_second_hidden_dense_layer,
        ]

        if number_of_neurons_in_the_third_hidden_dense_layer is not None or number_of_neurons_in_the_third_hidden_dense_layer != 0:
            list_of_tower_layer_input_and_output_dimensions.append(number_of_neurons_in_the_third_hidden_dense_layer)

        return list_of_tower_layer_input_and_output_dimensions

    @classmethod
    def _create_dense_sequential_stack(cls, list_of_layer_input_and_output_dimensions):
        modules = []
        number_of_layers = len(list_of_layer_input_and_output_dimensions) - 1
        for layer_idx in range(0, number_of_layers):
            modules.append(
                nn.Linear(
                    in_features=list_of_layer_input_and_output_dimensions[layer_idx],
                    out_features=list_of_layer_input_and_output_dimensions[layer_idx + 1],
                )
            )
            if layer_idx != (number_of_layers - 1):
                modules.append(nn.ReLU())

        return nn.Sequential(*modules)

    def calculate_item_tower_input_dimension(self):
        return (
                self._article_id_embedding_dimension
                + self._article_garment_group_name_unique_value_count
                + self._article_index_name_unique_value_count
                + self._article_product_group_name_unique_value_count
                + self._article_graphical_appearance_name_unique_value_count
                + self._article_perceived_color_master_name_unique_value_count
        )

    def _create_query_tower_model(
            self,
            dimension_of_input_for_first_dense_layer,
            number_of_neurons_in_the_first_hidden_dense_layer,
            number_of_neurons_in_the_second_hidden_dense_layer,
            number_of_neurons_in_the_third_hidden_dense_layer,
    ):
        self.query_tower_model = self._create_dense_sequential_stack(
            list_of_layer_input_and_output_dimensions=self._assemble_list_of_sequential_dense_linear_stack_dimensions(
                dimension_of_input_for_first_dense_layer=dimension_of_input_for_first_dense_layer,
                number_of_neurons_in_the_first_hidden_dense_layer=number_of_neurons_in_the_first_hidden_dense_layer,
                number_of_neurons_in_the_second_hidden_dense_layer=number_of_neurons_in_the_second_hidden_dense_layer,
                number_of_neurons_in_the_third_hidden_dense_layer=number_of_neurons_in_the_third_hidden_dense_layer,
            ),
        )

    def calculate_query_tower_input_dimension(self):
        return (self._customer_id_embedding_dimension + self._customer_transaction_date_dimension
                + self._customer_age_group_unique_value_count)

    def forward(
            self,
            article_id_index=None,
            article_index_name_index=None,
            article_garment_group_name_index=None,
            article_graphical_appearance_name_index=None,
            article_perceived_color_master_name_index=None,
            article_product_group_name_index=None,
            customer_id_index=None,
            customer_age_group_index=None,
            transaction_date_index=None,
    ):
        item_tower_input = self._preprocess_item_tower_input_using_input_layer(
            article_id_index=article_id_index,
            article_index_name_index=article_index_name_index,
            article_garment_group_name_index=article_garment_group_name_index,
            article_product_group_name_index=article_product_group_name_index,
            perceived_color_master_name_index=article_perceived_color_master_name_index,
            graphical_appearance_name_index=article_graphical_appearance_name_index,
        )
        query_tower_input = self._preprocess_query_tower_input_using_input_layer(
            customer_id_index=customer_id_index,
            customer_age_group_index=customer_age_group_index,
            transaction_date_index=transaction_date_index,
        )

        item_embedding = self.item_tower_model(item_tower_input)
        query_embedding = self.query_tower_model(query_tower_input)

        return self.similarity_function(item_embedding, query_embedding)

    def _preprocess_item_tower_input_using_input_layer(
            self,
            article_id_index,
            article_index_name_index,
            article_garment_group_name_index,
            article_product_group_name_index,
            perceived_color_master_name_index,
            graphical_appearance_name_index,
    ):
        feature_tensors = []

        if article_id_index is not None and self.article_id_encoder is not None:
            feature_tensors.append(self.article_id_encoder(article_id_index))

        if article_index_name_index is not None and self._article_index_name_unique_value_count != 0:
            feature_tensors.append(
                nn.functional.one_hot(article_index_name_index, self._article_index_name_unique_value_count)
            )

        if article_garment_group_name_index is not None and self._article_garment_group_name_unique_value_count != 0:
            feature_tensors.append(
                nn.functional.one_hot(
                    article_garment_group_name_index,
                    self._article_garment_group_name_unique_value_count,
                )
            )

        if article_product_group_name_index is not None and self._article_product_group_name_unique_value_count != 0:
            feature_tensors.append(
                nn.functional.one_hot(
                    article_product_group_name_index,
                    self._article_product_group_name_unique_value_count,
                )
            )

        if perceived_color_master_name_index is not None and self._article_perceived_color_master_name_unique_value_count != 0:
            feature_tensors.append(
                nn.functional.one_hot(
                    perceived_color_master_name_index,
                    self._article_perceived_color_master_name_unique_value_count,
                )
            )

        if graphical_appearance_name_index is not None and self._article_graphical_appearance_name_unique_value_count != 0:
            feature_tensors.append(
                nn.functional.one_hot(
                    graphical_appearance_name_index,
                    self._article_graphical_appearance_name_unique_value_count,
                )
            )

        return torch.cat(feature_tensors, dim=1)

    def _preprocess_query_tower_input_using_input_layer(
            self,
            customer_id_index,
            customer_age_group_index,
            transaction_date_index,
    ):
        feature_tensors = []

        if customer_id_index is not None and self.customer_id_encoder is not None:
            feature_tensors.append(self.customer_id_encoder(customer_id_index))

        if customer_age_group_index is not None and self._customer_age_group_unique_value_count != 0:
            feature_tensors.append(
                nn.functional.one_hot(customer_age_group_index, self._customer_age_group_unique_value_count))

        if transaction_date_index is not None and self.transaction_date_encoder is not None:
            feature_tensors.append(self.transaction_date_encoder(transaction_date_index))

        return torch.cat(feature_tensors, dim=1)

The function in the cell below can be used to train the model.

In [None]:
def train_model(
        negative_batch_dataloader,
        model,
        loss_fn,
        optimizer,
        article_id_input_key,
        article_index_name_input_key,
        article_garment_group_input_key,
        article_product_group_input_key,
        article_perceived_color_master_input_key,
        article_graphical_appearance_input_key,
        customer_id_input_key,
        customer_age_group_input_key,
        transaction_date_input_key,
):
    model.train()
    size = len(negative_batch_dataloader.dataset)
    for batch_index, (X, y) in enumerate(negative_batch_dataloader):
        concatenated_positive_and_negative_items_input = {}
        for input_key in X[0].keys():
            concatenated_positive_and_negative_items_input[input_key] = torch.cat(
                (X[0][input_key], X[1][input_key]),
                dim=0,
            )
        concatenated_positive_and_negative_items_label = torch.cat((y[0], y[1]), dim=0)
        logits = model(
            article_id_index=concatenated_positive_and_negative_items_input.get(
                article_id_input_key
            ),
            article_index_name_index=concatenated_positive_and_negative_items_input.get(
                article_index_name_input_key
            ),
            article_garment_group_name_index=concatenated_positive_and_negative_items_input.get(
                article_garment_group_input_key
            ),
            article_graphical_appearance_name_index=concatenated_positive_and_negative_items_input.get(
                article_graphical_appearance_input_key
            ),
            article_perceived_color_master_name_index=concatenated_positive_and_negative_items_input.get(
                article_perceived_color_master_input_key
            ),
            article_product_group_name_index=concatenated_positive_and_negative_items_input.get(
                article_product_group_input_key
            ),
            customer_id_index=concatenated_positive_and_negative_items_input.get(
                customer_id_input_key
            ),
            customer_age_group_index=concatenated_positive_and_negative_items_input.get(
                customer_age_group_input_key
            ),
            transaction_date_index=concatenated_positive_and_negative_items_input.get(
                transaction_date_input_key
            ),
        )
        loss = loss_fn(logits, concatenated_positive_and_negative_items_label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch_index % 100 == 0:
            loss_value, current = loss.item(), (batch_index + 1) * len(concatenated_positive_and_negative_items_input)
            print(f"loss: {loss_value:>7f}  [{current:>5d}/{size:>5d}]")

The **perform_loss_test** function can be used to test the model with a loss function. The function calculates average loss.

In [None]:
def perform_loss_test(
        negative_batch_dataloader,
        model,
        loss_fn,
        article_id_input_key,
        article_index_name_input_key,
        article_garment_group_input_key,
        article_product_group_input_key,
        article_perceived_color_master_input_key,
        article_graphical_appearance_input_key,
        customer_id_input_key,
        customer_age_group_input_key,
        transaction_date_input_key,
):
    loss_test_average_test_loss = 0
    number_of_samples_in_dataset = len(negative_batch_dataloader.dataset)
    for batch_index, (X, y) in enumerate(negative_batch_dataloader):
        concatenated_positive_and_negative_items_input = {}
        for input_key in X[0].keys():
            concatenated_positive_and_negative_items_input[input_key] = torch.cat(
                (X[0][input_key], X[1][input_key]),
                dim=0,
            )
        concatenated_positive_and_negative_items_label = torch.cat((y[0], y[1]), dim=0)
        logits = model(
            article_id_index=concatenated_positive_and_negative_items_input.get(
                article_id_input_key
            ),
            article_index_name_index=concatenated_positive_and_negative_items_input.get(
                article_index_name_input_key
            ),
            article_garment_group_name_index=concatenated_positive_and_negative_items_input.get(
                article_garment_group_input_key
            ),
            article_graphical_appearance_name_index=concatenated_positive_and_negative_items_input.get(
                article_graphical_appearance_input_key
            ),
            article_perceived_color_master_name_index=concatenated_positive_and_negative_items_input.get(
                article_perceived_color_master_input_key
            ),
            article_product_group_name_index=concatenated_positive_and_negative_items_input.get(
                article_product_group_input_key
            ),
            customer_id_index=concatenated_positive_and_negative_items_input.get(
                customer_id_input_key
            ),
            customer_age_group_index=concatenated_positive_and_negative_items_input.get(
                customer_age_group_input_key
            ),
            transaction_date_index=concatenated_positive_and_negative_items_input.get(
                transaction_date_input_key
            ),
        )
        loss = loss_fn(logits, concatenated_positive_and_negative_items_label)
        loss_test_average_test_loss += loss.item() * len(concatenated_positive_and_negative_items_input)

        if batch_index % 100 == 0:
            current_sample_idx = (batch_index + 1) * len(concatenated_positive_and_negative_items_input)
            print(f"Current average test loss: {loss_test_average_test_loss:>7f}" +
                  f"[{current_sample_idx:>5d}/{number_of_samples_in_dataset:>5d}]")

    loss_test_average_test_loss /= number_of_samples_in_dataset

    return loss_test_average_test_loss

The **perform_recall_test** tests the model's recall score on a test dataset. The recall scores are calculated for each customer, then calculates the average of these scores and returns it. If a customer has more than k transaction, then the k transactions with the highest score will be used in the calculations.

In [None]:
def perform_recall_test(
        inspected_top_k_per_customer,
        positive_interactions_dataloader,
        model,
        article_id_input_key,
        article_index_name_input_key,
        article_garment_group_input_key,
        article_product_group_input_key,
        article_perceived_color_master_input_key,
        article_graphical_appearance_input_key,
        customer_id_input_key,
        customer_age_group_input_key,
        transaction_date_input_key,
):
    highest_logits_of_customers = {}
    recall_scores_of_customers = []
    number_of_samples_in_dataset = len(positive_interactions_dataloader.dataset)
    for batch_idx, (X, y) in enumerate(positive_interactions_dataloader):
        logits = model(
            article_id_index=X.get(
                article_id_input_key
            ),
            article_index_name_index=X.get(
                article_index_name_input_key
            ),
            article_garment_group_name_index=X.get(
                article_garment_group_input_key
            ),
            article_graphical_appearance_name_index=X.get(
                article_graphical_appearance_input_key
            ),
            article_perceived_color_master_name_index=X.get(
                article_perceived_color_master_input_key
            ),
            article_product_group_name_index=X.get(
                article_product_group_input_key
            ),
            customer_id_index=X.get(
                customer_id_input_key
            ),
            customer_age_group_index=X.get(
                customer_age_group_input_key
            ),
            transaction_date_index=X.get(
                transaction_date_input_key
            ),
        )
        for logit_idx, logit in enumerate(logits):
            customer_id = X[customer_id_input_key][logit_idx]

            if customer_id not in highest_logits_of_customers.keys():
                highest_logits_of_customers[customer_id] = []

            bisect.insort(highest_logits_of_customers[customer_id], logit, key=lambda x: -x)

            if len(highest_logits_of_customers) > inspected_top_k_per_customer:
                highest_logits_of_customers[customer_id] \
                    = highest_logits_of_customers[customer_id][:inspected_top_k_per_customer]

        if batch_idx % 100 == 0:
            current_sample_idx = (batch_idx + 1) * len(X)
            print(f"[{current_sample_idx:>5d}/{number_of_samples_in_dataset:>5d}]")

    for customer_id in highest_logits_of_customers.keys():
        true_positive_counter = 0
        for logit in highest_logits_of_customers[customer_id]:
            if logit >= 0.5:
                true_positive_counter += 1
        recall_scores_of_customers.append(true_positive_counter / len(highest_logits_of_customers[customer_id]))

    return sum(recall_scores_of_customers) / len(recall_scores_of_customers)

The **test_model** functions runs the tests (loss, recall) on the model, and returns their scores.

In [None]:
def test_model(
        loss_test_negative_batch_dataloader,
        recall_test_dataloader,
        model,
        loss_fn,
        article_id_input_key,
        article_index_name_input_key,
        article_garment_group_input_key,
        article_product_group_input_key,
        article_perceived_color_master_input_key,
        article_graphical_appearance_input_key,
        customer_id_input_key,
        customer_age_group_input_key,
        transaction_date_input_key,
):
    model.eval()
    with torch.no_grad():
        loss_test_result_average_loss = perform_loss_test(
            negative_batch_dataloader=loss_test_negative_batch_dataloader,
            model=model,
            loss_fn=loss_fn,
            article_id_input_key=article_id_input_key,
            article_index_name_input_key=article_index_name_input_key,
            article_garment_group_input_key=article_garment_group_input_key,
            article_product_group_input_key=article_product_group_input_key,
            article_perceived_color_master_input_key=article_perceived_color_master_input_key,
            article_graphical_appearance_input_key=article_graphical_appearance_input_key,
            customer_id_input_key=customer_id_input_key,
            customer_age_group_input_key=customer_age_group_input_key,
            transaction_date_input_key=transaction_date_input_key,
        )
        recall_score = perform_recall_test(
            inspected_top_k_per_customer=500,
            positive_interactions_dataloader=recall_test_dataloader,
            model=model,
            article_id_input_key=article_id_input_key,
            article_index_name_input_key=article_index_name_input_key,
            article_garment_group_input_key=article_garment_group_input_key,
            article_product_group_input_key=article_product_group_input_key,
            article_perceived_color_master_input_key=article_perceived_color_master_input_key,
            article_graphical_appearance_input_key=article_graphical_appearance_input_key,
            customer_id_input_key=customer_id_input_key,
            customer_age_group_input_key=customer_age_group_input_key,
            transaction_date_input_key=transaction_date_input_key,
        )

        return loss_test_result_average_loss, recall_score

**load_training_and_test_data** function loads all the data that is required for the model inspection. The training dataset contains the data of the first week of 2019, while the test data contains the second week. Only such articles' and customers' data is kept, that has a transaction in the inspected period.

In [None]:
def load_training_and_test_data(
        article_id_column_name,
        article_garment_group_column_name,
        article_graphical_appearance_column_name,
        article_index_name_column_name,
        article_perceived_color_master_column_name,
        article_product_group_column_name,
        customer_id_column_name,
        customer_age_column_name,
        transaction_date_column_name,
        customer_age_group_size,
        customer_age_minimum,
        is_article_garment_group_name_used,
        is_article_graphical_appearance_name_used,
        is_article_index_name_used,
        is_article_perceived_color_master_name_used,
        is_article_product_group_name_used,
        transaction_date_embedding_dimension,
):
    path_to_articles = "C:\\Sajat\\Egyetem\\MSc\\Onallo\\HM_dataset\\articles.csv"
    path_to_customers = "C:\\Sajat\\Egyetem\\MSc\\Onallo\\HM_dataset\\customers.csv"
    path_to_transactions = "C:\\Sajat\\Egyetem\\MSc\\Onallo\\HM_dataset\\transactions_train.csv"

    transactions_df_columns = [article_id_column_name, customer_id_column_name, transaction_date_column_name]

    transactions_df_iter = pd.read_csv(
        filepath_or_buffer=path_to_transactions,
        usecols=transactions_df_columns,
        iterator=True,
        chunksize=100_000,
    )

    transactions_start_date = "2019-01-07"
    transactions_end_date = "2019-01-20"
    test_transactions_delimiter_date = "2019-01-14"

    transactions_df = pd.concat(
        [
            transactions_df_chunk[
                (transactions_df_chunk[transaction_date_column_name] >= transactions_start_date) &
                (transactions_df_chunk[transaction_date_column_name] <= transactions_end_date)
                ]
            for transactions_df_chunk in transactions_df_iter]
    )

    training_transactions_df = transactions_df[
        transactions_df[transaction_date_column_name] < test_transactions_delimiter_date]
    loss_test_transactions_df = transactions_df[
        transactions_df[transaction_date_column_name] >= test_transactions_delimiter_date]
    recall_test_transactions_df = loss_test_transactions_df

    if transaction_date_embedding_dimension != 0:
        transactions_df.drop(transaction_date_column_name, axis=1)
        loss_test_transactions_df.drop(transaction_date_column_name, axis=1)
        recall_test_transactions_df.drop(transaction_date_column_name, axis=1)

    articles_df_columns = [article_id_column_name]

    if is_article_garment_group_name_used:
        articles_df_columns.append(article_garment_group_column_name)

    if is_article_graphical_appearance_name_used:
        articles_df_columns.append(article_graphical_appearance_column_name)

    if is_article_index_name_used:
        articles_df_columns.append(article_index_name_column_name)

    if is_article_perceived_color_master_name_used:
        articles_df_columns.append(article_perceived_color_master_column_name)

    if is_article_product_group_name_used:
        articles_df_columns.append(article_product_group_column_name)

    articles_df = pd.read_csv(
        filepath_or_buffer=path_to_articles,
        usecols=articles_df_columns,
    )

    articles_df = articles_df[articles_df[article_id_column_name].isin(transactions_df[article_id_column_name])]

    customers_df_columns = [customer_id_column_name]

    if customer_age_group_size != 0:
        customers_df_columns.append(customer_age_column_name)

    customers_df = pd.read_csv(
        filepath_or_buffer=path_to_customers,
        usecols=customers_df_columns,
    )

    customers_df = customers_df[customers_df[customer_id_column_name].isin(transactions_df[customer_id_column_name])]

    customers_df[customer_age_column_name].fillna(
        value=customers_df[customer_age_column_name].max() + customer_age_minimum / 2,
        inplace=True,
    )

    print(f"Number of articles: {len(articles_df)}")
    print(f"Number of customers:  {len(customers_df)}")
    print(f"Number of training transactions: {len(training_transactions_df)}")
    print(f"Number of loss test and recall test transactions: {len(loss_test_transactions_df)}")

    return articles_df, customers_df, training_transactions_df, loss_test_transactions_df, recall_test_transactions_df

**get_pytorch_device** function determines what kind of hardware will be used for the model's training and testing.

In [None]:
def get_pytorch_device():
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )

    print(f"Using {device} device")

    return device

This **train_and_test_model** runs the training and tests of the model in epoch iterations. *NegativeBatchSampledTransactionDataset* doubles the number of items in a single batch because of the negative transaction generation, so only half batch size value is passed to actually use the required batch size.

In [None]:
def train_and_test_model(
        model,
        loss_fn,
        optimizer,
        articles_df,
        customers_df,
        training_transactions_df,
        loss_test_transactions_df,
        recall_test_transactions_df,
        customer_age_group_size,
        customer_minimum_age,
        epochs,
        training_batch_size,
        test_batch_size,
        article_id_column_name,
        article_garment_group_column_name,
        article_graphical_appearance_column_name,
        article_index_name_column_name,
        article_perceived_color_master_column_name,
        article_product_group_column_name,
        customer_id_column_name,
        customer_age_column_name,
        transaction_date_column_name,
):
    print("Starting training and testing...")
    loss_test_loss = 0
    recall_score = 0
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")

        train_model(
            negative_batch_dataloader=DataLoader(
                dataset=NegativeBatchSampledTransactionDataset(
                    articles_df=articles_df,
                    customers_df=customers_df,
                    transactions_df=training_transactions_df,
                    customer_age_group_size=customer_age_group_size,
                    customer_minimum_age=customer_minimum_age,
                    article_id_column_name=article_id_column_name,
                    article_garment_group_column_name=article_garment_group_column_name,
                    article_graphical_appearance_column_name=article_graphical_appearance_column_name,
                    article_index_name_column_name=article_index_name_column_name,
                    article_perceived_color_master_column_name=article_perceived_color_master_column_name,
                    article_product_group_column_name=article_product_group_column_name,
                    customer_id_column_name=customer_id_column_name,
                    customer_age_column_name=customer_age_column_name,
                    transaction_date_column_name=transaction_date_column_name,
                ),
                batch_size=round(training_batch_size / 2),
            ),
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            article_id_input_key=article_id_column_name,
            article_index_name_input_key=article_index_name_column_name,
            article_garment_group_input_key=article_garment_group_column_name,
            article_product_group_input_key=article_product_group_column_name,
            article_perceived_color_master_input_key=article_perceived_color_master_column_name,
            article_graphical_appearance_input_key=article_graphical_appearance_column_name,
            customer_id_input_key=customer_id_column_name,
            customer_age_group_input_key=customer_age_column_name,
            transaction_date_input_key=transaction_date_column_name,
        )

        loss_test_loss, recall_score = test_model(
            loss_test_negative_batch_dataloader=DataLoader(
                dataset=NegativeBatchSampledTransactionDataset(
                    articles_df=articles_df,
                    customers_df=customers_df,
                    transactions_df=loss_test_transactions_df,
                    customer_age_group_size=customer_age_group_size,
                    customer_minimum_age=customer_minimum_age,
                    article_id_column_name=article_id_column_name,
                    article_garment_group_column_name=article_garment_group_column_name,
                    article_graphical_appearance_column_name=article_graphical_appearance_column_name,
                    article_index_name_column_name=article_index_name_column_name,
                    article_perceived_color_master_column_name=article_perceived_color_master_column_name,
                    article_product_group_column_name=article_product_group_column_name,
                    customer_id_column_name=customer_id_column_name,
                    customer_age_column_name=customer_age_column_name,
                    transaction_date_column_name=transaction_date_column_name,
                ),
                batch_size=round(test_batch_size / 2),
            ),
            recall_test_dataloader=DataLoader(
                dataset=TransactionDataset(
                    articles_df=articles_df,
                    customers_df=customers_df,
                    transactions_df=recall_test_transactions_df,
                    customer_age_group_size=customer_age_group_size,
                    customer_minimum_age=customer_minimum_age,
                    article_id_column_name=article_id_column_name,
                    article_garment_group_column_name=article_garment_group_column_name,
                    article_graphical_appearance_column_name=article_graphical_appearance_column_name,
                    article_index_name_column_name=article_index_name_column_name,
                    article_perceived_color_master_column_name=article_perceived_color_master_column_name,
                    article_product_group_column_name=article_product_group_column_name,
                    customer_id_column_name=customer_id_column_name,
                    customer_age_column_name=customer_age_column_name,
                    transaction_date_column_name=transaction_date_column_name,
                ),
                batch_size=test_batch_size,
            ),
            model=model,
            loss_fn=loss_fn,
            article_id_input_key=article_id_column_name,
            article_index_name_input_key=article_index_name_column_name,
            article_garment_group_input_key=article_garment_group_column_name,
            article_product_group_input_key=article_product_group_column_name,
            article_perceived_color_master_input_key=article_perceived_color_master_column_name,
            article_graphical_appearance_input_key=article_graphical_appearance_column_name,
            customer_id_input_key=customer_id_column_name,
            customer_age_group_input_key=customer_age_column_name,
            transaction_date_input_key=transaction_date_column_name,
        )

    print(f"Loss test loss: {loss_test_loss}")
    print(f"Recall test recall: {recall_score}")
    return loss_test_loss, recall_score

**perform_configured_model_inspection** function uses its *config* parameter to initialize the model inspections parameters, and then runs it.

In [None]:
def perform_configured_model_inspection(config):
    article_id_column_name = "article_id"
    article_garment_group_column_name = "garment_group_name"
    article_graphical_appearance_column_name = "graphical_appearance_name"
    article_index_name_column_name = "index_name"
    article_perceived_color_master_column_name = "perceived_colour_master_name"
    article_product_group_column_name = "product_group_name"
    customer_id_column_name = "customer_id"
    customer_age_column_name = "age"
    transaction_date_column_name = "t_dat"

    customer_maximum_age = 99
    customer_minimum_age = 16

    articles_df, customers_df, training_transactions_df, loss_test_transactions_df, recall_test_transactions_df \
        = load_training_and_test_data(
        article_id_column_name=article_id_column_name,
        article_garment_group_column_name=article_garment_group_column_name,
        article_graphical_appearance_column_name=article_graphical_appearance_column_name,
        article_index_name_column_name=article_index_name_column_name,
        article_perceived_color_master_column_name=article_perceived_color_master_column_name,
        article_product_group_column_name=article_product_group_column_name,
        customer_id_column_name=customer_id_column_name,
        customer_age_column_name=customer_age_column_name,
        transaction_date_column_name=transaction_date_column_name,
        customer_age_group_size=config.customer_age_group_size,
        customer_age_minimum=customer_minimum_age,
        is_article_garment_group_name_used=config.is_article_garment_group_name_used,
        is_article_graphical_appearance_name_used=config.is_article_graphical_appearance_name_used,
        is_article_index_name_used=config.is_article_index_name_used,
        is_article_perceived_color_master_name_used=config.is_article_perceived_color_master_name_used,
        is_article_product_group_name_used=config.is_article_product_group_name_used,
        transaction_date_embedding_dimension=config.transaction_date_embedding_dimension,
    )

    device = get_pytorch_device()

    two_tower_model = TwoTowerModel(
        number_of_neurons_in_the_first_hidden_dense_layer=config.number_of_neurons_in_the_first_hidden_dense_layer,
        number_of_neurons_in_the_second_hidden_dense_layer=config.number_of_neurons_in_the_second_hidden_dense_layer,
        number_of_neurons_in_the_third_hidden_dense_layer=config.number_of_neurons_in_the_third_hidden_dense_layer,
        article_id_embedding_dimension_and_unique_value_count=(
            config.article_id_embedding_dimension,
            len(articles_df)
        ) if config.article_id_embedding_dimension != 0 else None,
        article_index_name_unique_value_count=(
            len(WebshopDataContainer.get_article_index_name_numeric_encoder()) if config.is_article_index_name_used else 0
        ),
        article_garment_group_name_unique_value_count=(
            len(WebshopDataContainer.get_article_garment_group_name_numeric_encoder())
            if config.is_article_garment_group_name_used else 0
        ),
        article_graphical_appearance_name_unique_value_count=(
            len(WebshopDataContainer.get_article_graphical_appearance_numeric_encoder())
            if config.is_article_graphical_appearance_name_used else 0
        ),
        article_product_group_name_unique_value_count=(
            len(WebshopDataContainer.get_article_product_group_name_numeric_encoder())
            if config.is_article_product_group_name_used else 0
        ),
        article_perceived_color_master_name_unique_value_count=(
            len(WebshopDataContainer.get_article_perceived_color_master_name_numeric_encoder())
            if config.is_article_perceived_color_master_name_used else 0
        ),
        customer_id_embedding_dimension_and_unique_value_count=(
            config.customer_id_embedding_dimension,
            len(customers_df),
        ) if config.customer_id_embedding_dimension != 0 else None,
        customer_age_group_count=math.ceil(
            (customer_maximum_age - customer_minimum_age) / config.customer_age_group_size
        ) if config.customer_age_group_size != 0 else 0,
        transaction_date_embedding_dimension_and_unique_value_count=(
            config.transaction_date_embedding_dimension,
            training_transactions_df["t_dat"].nunique(),
        ) if config.transaction_date_embedding_dimension != 0 else None,
        similarity_function=nn.CosineSimilarity(),
    ).to(device)
    cross_entropy_loss_fn = nn.CrossEntropyLoss()
    two_tower_model_adam_optimizer = torch.optim.Adam(
        params=two_tower_model.parameters(),
        lr=config.learning_rate,
        betas=(config.beta1, config.beta2),
        weight_decay=config.weight_decay,
        amsgrad=config.is_amsgrad_used,
    )

    test_loss, recall_score = train_and_test_model(
        model=two_tower_model,
        loss_fn=cross_entropy_loss_fn,
        optimizer=two_tower_model_adam_optimizer,
        articles_df=articles_df,
        customers_df=customers_df,
        training_transactions_df=training_transactions_df,
        loss_test_transactions_df=loss_test_transactions_df,
        recall_test_transactions_df=recall_test_transactions_df,
        customer_age_group_size=config.customer_age_group_size,
        customer_minimum_age=customer_minimum_age,
        epochs=config.epochs,
        training_batch_size=config.batch_size,
        test_batch_size=config.batch_size,
        article_id_column_name=article_id_column_name,
        article_garment_group_column_name=article_garment_group_column_name,
        article_graphical_appearance_column_name=article_graphical_appearance_column_name,
        article_index_name_column_name=article_index_name_column_name,
        article_perceived_color_master_column_name=article_perceived_color_master_column_name,
        article_product_group_column_name=article_product_group_column_name,
        customer_id_column_name=customer_id_column_name,
        customer_age_column_name=customer_age_column_name,
        transaction_date_column_name=transaction_date_column_name,
    )

    return two_tower_model, test_loss, recall_score

**perform_configured_model_inspection_with_exception_logging** wraps *perform_configured_model_inspection* into a try-catch block, and logs the exception's trace. The main reason of this function's existence to detect errors in the model inspection's implementation.

In [None]:
def perform_configured_model_inspection_with_exception_logging(config):
    try:
        return perform_configured_model_inspection(config)
    except Exception as exception:
        print("Model inspection was interrupted due to an exception being thrown:")
        traceback.print_exc()
        raise exception

**save_model_in_onnx_format_and_get_path** saves the received PyTorch model to ONNX format, and returns the saved model's local path. The function uses the relatively old *torch.onnx.export* function instead the newer *torch.onnx.dynamo_export* function, because unfortunately the newer solution didn't work when I tried to use it. The ONNX model's input shape saved with dynamic dimension value, because it allows to use the model with parameterizable batch size.

In [None]:
def save_model_in_onnx_format_and_get_path(model, filename, input_dimension, sample_input=None):
    local_destination_path = f"../ml_artifacts/{filename}.onnx"
    if sample_input is None:
        sample_input = torch.randn(1, input_dimension)
    torch.onnx.export(
        model,
        sample_input,
        local_destination_path,
        export_params=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size"},
            "output": {0: "batch_size"}
        },
    )
    return local_destination_path

**save_model_as_wandb_artifact** saves the model as ONNX model, then uploads to wandb.

In [None]:
def save_model_as_wandb_artifact(model, artifact_model_name, filename, input_dimension, sample_input=None):
    model_artifact_path = save_model_in_onnx_format_and_get_path(
        model=model,
        filename=filename,
        input_dimension=input_dimension,
        sample_input=sample_input,
    )
    model_artifact = wandb.Artifact(name=artifact_model_name, type="model")
    model_artifact.add_file(model_artifact_path)
    wandb.log_artifact(model_artifact)

The model inspections and hyperparameter optimizations are run using [wand sweep](https://docs.wandb.ai/guides/sweeps). This function runs the model examinations, then uploads the score of the inspected model and saves and uploads the tower models and embedding layers to wandb.

In [None]:
def sweep_main():
    wandb.init()
    model, loss, recall_score = perform_configured_model_inspection_with_exception_logging(wandb.config)
    wandb.log({
        "loss": loss,
        "recall": recall_score,
    })
    save_model_as_wandb_artifact(
        model=model.item_tower_model,
        artifact_model_name="Retrival_two_tower_model_item_tower",
        filename="retrieval_item_tower_model",
        input_dimension=model.calculate_item_tower_input_dimension(),
    )
    save_model_as_wandb_artifact(
        model=model.query_tower_model,
        artifact_model_name="Retrival_two_tower_model_query_tower",
        filename="retrieval_query_tower_model",
        input_dimension=model.calculate_query_tower_input_dimension(),
    )
    save_model_as_wandb_artifact(
        model=model.article_id_encoder,
        artifact_model_name="Retrival_two_tower_model_item_id_encoder",
        filename="retrieval_item_id_encoder",
        input_dimension=1,
        sample_input=torch.zeros((1, 1), dtype=torch.int64),
    )
    save_model_as_wandb_artifact(
        model=model.customer_id_encoder,
        artifact_model_name="Retrival_two_tower_model_query_id_encoder",
        filename="retrieval_query_id_encoder",
        input_dimension=1,
        sample_input=torch.zeros((1, 1), dtype=torch.int64),
    )
    if model.transaction_date_encoder is not None:
        save_model_as_wandb_artifact(
            model=model.transaction_date_encoder,
            artifact_model_name="Retrival_two_tower_query_transaction_date_encoder",
            filename="retrieval_query_transaction_date_encoder",
            input_dimension=1,
            sample_input=torch.zeros((1, 1), dtype=torch.int64),
        )

This cell defines the hyperparameter attributes' value range. The syntax is defined by the requirements of wandb sweep.

In [None]:
sweep_parameters = {
    "batch_size": {"values": [256, 512, 1024], },
    "epochs": {"values": [1, 2, 4], },

    "learning_rate": {"max": 0.1, "min": 0.0001, },
    "beta1": {"max": 0.99, "min": 0.9, },
    "beta2": {"max": 0.9999, "min": 0.999, },
    "weight_decay": {"max": 0.01, "min": 0.0001, },
    "is_amsgrad_used": {"values": [0, 1], },

    "number_of_neurons_in_the_first_hidden_dense_layer": {"values": [128, 256], },
    "number_of_neurons_in_the_second_hidden_dense_layer": {"values": [64, 128, 256], },
    "number_of_neurons_in_the_third_hidden_dense_layer": {"values": [0, 32, 64, 128, 256], },

    "article_id_embedding_dimension": {"values": [16, 32, 64, 128, ], },
    "is_article_garment_group_name_used": {"values": [0, 1], },
    "is_article_graphical_appearance_name_used": {"values": [0, 1, ], },
    "is_article_index_name_used": {"values": [0, 1], },
    "is_article_perceived_color_master_name_used": {"values": [0, 1], },
    "is_article_product_group_name_used": {"values": [0, 1], },

    "customer_id_embedding_dimension": {"values": [16, 32, 64, 128, ], },
    "customer_age_group_size": {"values": [0, 5, 10], },
    "transaction_date_embedding_dimension": {"values": [0, 8, 16, 32, 64, 128, 256, ], },
}

sweep_configuration = {
    "method": "bayes",
    "name": "retrieval-model-sweep",
    "metric": {
        "goal": "maximize",
        "name": "recall",
    },
    "parameters": sweep_parameters,
}

The code below authenticates in wandb.

In [None]:
wandb.login()

This code cell creates a sweep - a parametrized model inspection - and then executes it.

In [None]:
sweep_id = wandb.sweep(sweep=sweep_configuration, project="wandb-test")

wandb.agent(sweep_id=sweep_id, function=sweep_main, count=1)

## 2. Planned further development tasks

The current inspections uses a relatively small dataset because of the limited hardware that is currently available to me to train and test the model on. I'm actively searching for opportunities to get better hardware because of this reason. After I get better hardware, it could be a good idea to modify the possible value range of some hyperparameters.

The current solution could be improved using a [multi-stage recommendation system architecture](https://resources.nvidia.com/en-us-merlin/bad-a-multi-stage-recommender). The currently implemented two tower model could be used as its retrieval phase, and a ranker model could be added to produce better predictions using the output of the retrieval model. Implementing a filtering phase could also make sense, it seems possible that the customers usually don't buy articles bought previously, but first this idea should be checked and confirmed in the [Exploratory Data Analysis Notebook](hm_dataset_inspection_eda.ipynb).