# Federated learning recommendation system based on the SpFedRec Framework

Recommendation systems play a key role in success of modern e-commerce webshop softwares. To create a solution with enchanced data protection, the system's design is based on [federated learning](https://research.ibm.com/blog/what-is-federated-learning). Using federated learning [introduces new challenges](https://arxiv.org/abs/2301.00767) to security, energy and hardware requirements on the client side, system complexity, and communication costs. To tackle these issues, the implementation is based on the [SpFedRec Framework](https://journalofcloudcomputing.springeropen.com/articles/10.1186/s13677-023-00435-5) combined with other well-known techniques.

TODOS:
- concretize this document's main goal
- make it clear that this document doesn't distinguish the SAS and REC server from the article but handles it as a single server side component
- add explanation that choosing negative items using some kind of strategy is probably a good idea

## 0. Import dependencies

Firstly let's import the dependencies that this notebook requires!

In [1]:
import math
import numpy as np
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn

## 1. Load data

To inspect and try out the implementation, fake data is required. Luckily, a huge and high quality tabular dataset is available for free in this subject area, the [H&M Personalized Fashion Recommendations dataset](https://www.kaggle.com/competitions/h-and-m-personalized-fashion-recommendations). Because of this notebook's goal, I won't use the this whole big dataset, but instead just a little part of it. The following cell contains a small sample from this dataset, and this small sample will be used later to demonstrate the SpFedRec framework. I'm using pandas DataFrame to store the data.

The sample dataframe contains transactions; one column contains the buyer client's ID and one another column contains the bought item's ID.

In [2]:
sample_data = [
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 663713001],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 541518023],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 578020002],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 723529001],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 351484002],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 727808001],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 727808007],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 858883002],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 851400006],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 750424014],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 870304002],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 852643001],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 852643003],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 794321007],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 505221004],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 685687003],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 685687004],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 685687001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 505221001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 508184022],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 522992001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 605106001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 567618001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 528931002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 349301001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 590414001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 590414002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 570309005],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 577992001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 552570004],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 649018001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 633150009],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 581162008],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 616808001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 567618002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 622964004],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 464454004],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 550718001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 583533001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 272591001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 686406001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 413707001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 665851003],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 665851001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 656213001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 351933001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 478549001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 634591002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 665654001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 724244001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 681569001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 703843001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 609598006],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 682899001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 644073002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 678079005],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 678079003],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 562637001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 682334001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 609598012],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 629801001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 644763001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 444325004],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 678339001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 628794001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 634591003],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 693387003],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 636392001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 644073001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 664075001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 638939001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 628816002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 425217006],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 672800001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 703656001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 566618004],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 605939001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 671502001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 694671001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 610665002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 657291004],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 663613001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 628816005],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 692778001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 655267003],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 692778002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 660108001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 664368004],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 708352002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 681376001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 572187001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 752945001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 651697001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 578478001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 745843001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 531526002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 619580007],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 619580001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 713692001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 779136002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 708379003],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 719260001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 554784003],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 659983002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 782643001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 515815001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 619580008],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 784278001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 312878001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 312878010],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 730683001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 787147002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 614622018],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 745745001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 666444002],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 349301041],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 721257001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 160442010],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 849942001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 372860001],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 160442007],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 304786008],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 554757003],
    ["00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2", 808651003],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 663713001],
    ["000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318", 541518023],
    ["00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4c73235dccbbc132280", 688873012],
    ["00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4c73235dccbbc132280", 501323011],
    ["00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4c73235dccbbc132280", 598859003],
    ["00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4c73235dccbbc132280", 688873020],
    ["00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4c73235dccbbc132280", 688873011],
    ["0008968c0d451dbc5a9968da03196fe20051965edde7413775c4eb3be9abe9c2", 531310002],
    ["0008968c0d451dbc5a9968da03196fe20051965edde7413775c4eb3be9abe9c2", 529841001],
    ["000aa7f0dc06cd7174389e76c9e132a67860c5f65f970699daccc14425ac31a8", 501820043],
    ["000aa7f0dc06cd7174389e76c9e132a67860c5f65f970699daccc14425ac31a8", 501820043],
    ["000aa7f0dc06cd7174389e76c9e132a67860c5f65f970699daccc14425ac31a8", 674681001],
    ["000aa7f0dc06cd7174389e76c9e132a67860c5f65f970699daccc14425ac31a8", 671505001],
    ["000aa7f0dc06cd7174389e76c9e132a67860c5f65f970699daccc14425ac31a8", 671505001],
]

transactions_df = pd.DataFrame(sample_data, columns =["customer_id", "article_id"])

transactions_df["customer_id"] = transactions_df["customer_id"].astype("string")
transactions_df["article_id"] = pd.to_numeric(transactions_df["article_id"], downcast="unsigned")

print(f"Number of transactions: {len(transactions_df)}")
print(transactions_df.head())

Number of transactions: 138
                                         customer_id  article_id
0  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   663713001
1  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   541518023
2  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   578020002
3  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   723529001
4  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   351484002


This notebook demonstrates that particular case, when two clients perform training. Let's choose the IDs of these training clients!

In [3]:
number_of_customer_ids = 2

customer_ids = transactions_df["customer_id"].unique()[:number_of_customer_ids]

print("IDs of customers that participate in the learning:")
for customer_id in customer_ids:
    print(customer_id)

IDs of customers that participate in the learning:
000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318
00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2


Let's create dataframes, that contain transactions specifically of the training clients!

In [4]:
def filter_dataframe_by_customer_id(dataframe, customer_id):
    return dataframe[dataframe["customer_id"] == customer_id]

In [5]:
client_transactions_dfs = {}

for customer_id in customer_ids:
    client_transactions_dfs[customer_id] = filter_dataframe_by_customer_id(transactions_df, customer_id) \
                                            .drop_duplicates()

for idx, (client_id, client_transactions_df) in enumerate(client_transactions_dfs.items()):
    print(f"Count of unique transactions of the {idx + 1}th client: {len(client_transactions_df)}")
    print(f"First transactions of the {idx + 1}. client:")
    print(f"{client_transactions_df.head()}\n")

Count of unique transactions of the 1th client: 14
First transactions of the 1. client:
                                         customer_id  article_id
0  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   663713001
1  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   541518023
2  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   578020002
3  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   723529001
4  000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...   351484002

Count of unique transactions of the 2th client: 110
First transactions of the 2. client:
                                          customer_id  article_id
14  00007d2de826758b65a93dd24ce629ed66842531df6699...   505221004
15  00007d2de826758b65a93dd24ce629ed66842531df6699...   685687003
16  00007d2de826758b65a93dd24ce629ed66842531df6699...   685687004
17  00007d2de826758b65a93dd24ce629ed66842531df6699...   685687001
18  00007d2de826758b65a93dd24ce629ed66842531df6699...   505221001



## 2. Traning

The SpFedRec framework is based on the [Two Tower model](https://cloud.google.com/blog/products/ai-machine-learning/scaling-deep-retrieval-tensorflow-two-towers-architecture). To simply put, this model uses two neural networks to create embedding vectors for query and article items, then nearest neighbour search algorithm is used to find the matching articles to the query. When the model is trained, the neural networks are modified to generate the matching embeddings more similar to each other. The main difference between SpFedRec framework and the basic Two Tower model, that in SpFedRec framework the article tower is stored on a central server, while the query towers are stored on the client side, and each client has their own query tower.

To use PyTorch, firstly we need to define that on what kind of hardware we wish to run pytorch.

In [6]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


Let's define hyperparameters, on which the model architecture depends on!

In [7]:
number_of_hidden_linear_layers = 3
number_of_neurons_in_layer = 64
embedding_vector_dimension = 16

Let's define the model!

In [8]:
class TowerModel(nn.Module):

    def __init__(
        self, 
        unique_input_count, 
        number_of_hidden_linear_layers, 
        number_of_neurons_in_layer, 
        output_vector_dimension,
    ):
        super().__init__()

        modules = []
        modules.append(
            nn.Embedding(
                num_embeddings=unique_input_count,
                embedding_dim=number_of_neurons_in_layer,
            )
        )
        for hidden_layer_idx in range(0, number_of_hidden_linear_layers):
            if hidden_layer_idx != (number_of_hidden_linear_layers - 1):
                modules.append(nn.Linear(number_of_neurons_in_layer, number_of_neurons_in_layer))
                modules.append(nn.ReLU())
            else:
                modules.append(nn.Linear(number_of_neurons_in_layer, output_vector_dimension))
        self.linear_relu_stack = nn.Sequential(*modules)

    def forward(self, x):
        return self.linear_relu_stack(x)

Because an Embedding layer will be used in the tower models, it is required to calculate the count of and items!

In [9]:
number_of_unique_article_ids = transactions_df["article_id"].nunique()

print(f"Number of unique articles: {number_of_unique_article_ids}")

Number of unique articles: 134


After this, it is possible to instantiate the tower models!

In [10]:
item_model = TowerModel(
    unique_input_count=number_of_unique_article_ids,
    number_of_hidden_linear_layers=number_of_hidden_linear_layers,
    number_of_neurons_in_layer=number_of_neurons_in_layer,
    output_vector_dimension=embedding_vector_dimension,
).to(device)

query_models = {}
for customer_id in customer_ids:
    query_models[customer_id] = TowerModel(
        unique_input_count=number_of_customer_ids,
        number_of_hidden_linear_layers=number_of_hidden_linear_layers,
        number_of_neurons_in_layer=number_of_neurons_in_layer,
        output_vector_dimension=embedding_vector_dimension,
    ).to(device)

print(f"Item model:\n{item_model}\n")
for customer_id in customer_ids:
    print(f"Query model for {customer_id}:\n{query_models[customer_id]}\n")

Item model:
TowerModel(
  (linear_relu_stack): Sequential(
    (0): Embedding(134, 64)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=16, bias=True)
  )
)

Query model for 000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318:
TowerModel(
  (linear_relu_stack): Sequential(
    (0): Embedding(2, 64)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=16, bias=True)
  )
)

Query model for 00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2:
TowerModel(
  (linear_relu_stack): Sequential(
    (0): Embedding(2, 64)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Line

Let's define the loss function and the optimizers for the models!

In [11]:
cosine_similarity_fn = nn.CosineSimilarity()
log_softmax_fn = nn.LogSoftmax()
nll_loss_fn = nn.NLLLoss()

item_model_optimizer = torch.optim.Adam(item_model.parameters())

query_model_optimizers = {}
for customer_id in customer_ids:
    query_model_optimizers[customer_id] = torch.optim.Adam(query_models[customer_id].parameters())

### 2.1. Client side training

The first step of the client side training is that the clients request the embeddings of the articles that they previously interacted with (in this sample the interaction means that they bought them). Because this request would leak that with what kind of items the clients interacted with, it would leak private data. So the clients also request embeddings for items, that they didn't interact with previously. Using this technique can be a trade-off for bigger communication overhead, if adding negative items to the training set doesn't really improve the model's behaviour.

In this implementation, each client determines the requested negative items' number by fixed rate to the positive items. Let's define this rate!

In [12]:
minimum_negative_to_positive_item_per_client_rate = 0.25

Then let's construct the item queries!

In [13]:
def create_client_item_query(client_transactions_df, transactions_df, customer_id, negative_to_positive_rate):
    return pd.concat([
        client_transactions_df["article_id"], 
        transactions_df[transactions_df["customer_id"] != customer_id]["article_id"].head(math.ceil(len(client_transactions_df) * negative_to_positive_rate)),
    ]).to_numpy()

In [14]:
item_queries_of_clients = {}

for customer_id in customer_ids:
    item_queries_of_clients[customer_id] = create_client_item_query(
    client_transactions_df=client_transactions_dfs[customer_id],
    transactions_df=transactions_df,
    customer_id=customer_id,
    negative_to_positive_rate=minimum_negative_to_positive_item_per_client_rate,
)

print(f"Number of queried items of the first client: {len(item_queries_of_clients[customer_ids[0]])}")
print(f"Quried items of the first client:\n{item_queries_of_clients[customer_ids[0]]}")

Number of queried items of the first client: 18
Quried items of the first client:
[663713001 541518023 578020002 723529001 351484002 727808001 727808007
 858883002 851400006 750424014 870304002 852643001 852643003 794321007
 505221004 685687003 685687004 685687001]


To train the models labels are needed. Let's label an item with 1 if it's a positive item for the client, 0 otherwise!

In [15]:
def create_labels(every_item, positive_items):
    labels = {}
    for item in every_item:
        labels[item] = 1 if item in positive_items else 0
    return labels

In [16]:
labels_of_clients = {}

for customer_id in customer_ids:
    labels_of_clients[customer_id] = create_labels(
        every_item=item_queries_of_clients[customer_id], 
        positive_items=client_transactions_dfs[customer_id]["article_id"].to_numpy(),
    )

print(f"First client's labelled item IDs and values: {list(labels_of_clients[customer_ids[0]].items())}")

First client's labelled item IDs and values: [(663713001, 1), (541518023, 1), (578020002, 1), (723529001, 1), (351484002, 1), (727808001, 1), (727808007, 1), (858883002, 1), (851400006, 1), (750424014, 1), (870304002, 1), (852643001, 1), (852643003, 1), (794321007, 1), (505221004, 0), (685687003, 0), (685687004, 0), (685687001, 0)]


In this simplified case the model only uses the IDs of the clients and items. The ID fields are a [sparse categorical variables](https://www.kaggle.com/code/colinmorris/embedding-layers), so it's a good idea to use embedding layers to transform these input fields. To pass the IDs to the PyTorch implementation of embedding layer, first we need to encode the IDs into a set of integers of limited size. A simple implementation and solution to this problem, to convert the IDs to their index value of their categorical container set, and then the limit of this set is obviously the size of the categorical container set.

In [17]:
unique_article_ids = transactions_df["article_id"].unique()

def get_index_of_customer_id(customer_id):
    return np.where(customer_ids == customer_id)[0][0]

def get_index_of_article_id(article_id):
    return  np.where(unique_article_ids == article_id)[0][0]

In the following step the item embedding vectors are calculated for items, that are contained in any customer item query.

In [18]:
item_embedding_lists = {}

for customer_id in customer_ids:
    for article_id in item_queries_of_clients[customer_id]:
        item_embedding_lists[article_id] = item_model(
            torch.from_numpy(np.asarray(get_index_of_article_id(article_id))).to(device).int()
        ).detach().numpy()

print(f"Article ID: {list(item_embedding_lists.keys())[0]}, embedding vector: {item_embedding_lists[list(item_embedding_lists.keys())[0]]}")

Article ID: 663713001, embedding vector: [ 0.06250614  0.00994737 -0.04070795 -0.00223187 -0.02485958  0.04084307
  0.00524519 -0.12380616  0.07251057  0.12679553 -0.09844545 -0.02348811
  0.23133865 -0.04011521 -0.09609254 -0.06597807]


TODO add explanation to DataSet creation and move this to logical place

In [19]:
class ClientTrainingDataset(Dataset):

    def __init__(self, item_query_of_client, item_embeddings, labels_of_client):
        data = []
        labels = []

        for item_id in item_query_of_client:
            data.append(item_embeddings[item_id])
            labels.append(labels_of_client[item_id])
        
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self,idx):
        return torch.tensor(self.data[idx]), torch.tensor(self.labels[idx])

Now we can define the training loop on the client side.

In [20]:
def query_tower_train_loop(
    client_id,
    dataloader,
    model,
    cosine_similarity_fn, 
    loss_fn,
    log_softmax_fn, 
    optimizer,
):
    model.train()
    for X, y in dataloader:
        query_embedding = model(torch.tensor(get_index_of_customer_id(customer_id)).to(device))
        expanded_query_embedding = query_embedding.unsqueeze(0)
        query_embedding_in_batch = expanded_query_embedding.repeat(len(X), 1)
        similarity = cosine_similarity_fn(X, query_embedding_in_batch)
        loss = loss_fn(log_softmax_fn(similarity), y) 
        loss.backward()
        # TODO retrieve gradients
        optimizer.step()
        optimizer.zero_grad()
        
    print(f"Loss of client {client_id}: {loss.item()}")

Let's define the batch size hyperparameter of the training!

In [21]:
batch_size = 8

After these steps, the training loop can be finally executed for both clients.

In [22]:
def execute_client_training_loop():
    for customer_id in customer_ids:
        query_tower_train_loop(
            client_id=customer_id,
            dataloader=DataLoader(
                dataset=ClientTrainingDataset(
                    item_query_of_client=item_queries_of_clients[customer_id],
                    item_embeddings=item_embedding_lists,
                    labels_of_client=labels_of_clients[customer_id],
                ),
                batch_size=batch_size,
            ),
            model=query_models[customer_id],
            cosine_similarity_fn=cosine_similarity_fn,
            log_softmax_fn=log_softmax_fn,
            loss_fn=nll_loss_fn,
            optimizer=query_model_optimizers[customer_id],
        )

In [23]:
execute_client_training_loop()

Loss of client 000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318: 0.7234421968460083
Loss of client 00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2: 0.6931471824645996
Loss of client 000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318: 0.6965091228485107
Loss of client 00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2: 0.6931471824645996
Loss of client 000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318: 0.6720598936080933
Loss of client 00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2: 0.6931471824645996


  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similarity), y)
  loss = loss_fn(log_softmax_fn(similari

### 2.2 Server side training

After the client calculated the gradients, they need to upload to the server to update the item model. Because the gradients contain information about the user, the clients can't just simply their gradients. To solve this problem, SpFedRec framework uses circular secret-sharing technique.

#### 2.2.1 Central client model training

#### 2.2.2 Centrail item model training

## 3. Inferring

TODO

## 4. Final thoughts

TODO