### fask API Experiment (SERVER PART)
This code must be Run on AWS server but can be executed in local. It is the server part of the simulation of federated leanring.

In [6]:
import torch
import pandas as pd
from sklearn.model_selection import train_test_split

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MODEL_PARAMETERS = {
    'default_model': {
        'mf_dim': 8,
        'layers': [64, 32, 16, 8],
        'reg_layers': [0, 0, 0, 0],
        'reg_mf': 0
    },
    'FedNCF': {
        'mf_dim': 16,
        'layers': [64, 32, 16, 8],
        'reg_layers': [0, 0, 0, 0],
        'reg_mf': 0
    },
}

BATCH_SIZE = 64
NUM_NEGATIVES = 50


In [7]:
from flask import Flask, request
from flask import Flask, send_file
import torch
import pickle
import collections
import copy
import io


from torch.nn import functional as F
from typing import Optional, Tuple
from torch import Tensor
import torch

from typing import List, Tuple
import pandas as pd
import scipy as sp
import numpy as np
import sys
import os

THRESHOLD = 3



class Dataset:

    def __init__(self, dataset: str):
        
        self.train_df =  pd.read_csv(os.path.join('data', 'train_exp_200.csv'))
        self.test_df = pd.read_csv(os.path.join('data', 'test_exp_200.csv'))
        self.neg_path = os.path.join('data', dataset + '-neg.csv')
        self.num_users, self.num_items = self.get_matrix_dim()
        print(f'Loaded `{dataset}` dataset: \nNumber of users - {self.num_users}, Number of items - {self.num_items}')

    def get_matrix_dim(self) -> Tuple[int, int]:
        
        num_users = max(self.train_df['user_id']) + 1
        num_items = max(self.train_df['item_id']) + 1
        return num_users, num_items

    def load_client_train_data(self) -> List[List]:
        
        mat = sp.sparse.dok_matrix((self.num_users+1, self.num_items+1), dtype=np.float32)

        for user, item, rating in self.train_df.values:
            if rating >= THRESHOLD:
                mat[user, item] = 1.0

        client_datas = [[[], [], []] for _ in range(self.num_users)]

        for (usr, item) in mat.keys():
            client_datas[usr][0].append(usr)
            client_datas[usr][1].append(item)
            client_datas[usr][2].append(1)
            for t in range(NUM_NEGATIVES):
                neg = np.random.randint(self.num_items)
                while (usr, neg) in mat.keys():
                    neg = np.random.randint(self.num_items)
                client_datas[usr][0].append(usr)
                client_datas[usr][1].append(neg)
                client_datas[usr][2].append(0)

        return client_datas

    def load_test_file(self) -> List[List[int]]:
       
        return [[user, item] for user, item, _ in self.test_df.values]

    def load_negative_file(self) -> List[List[int]]:
        
        negative_list = []
        with open(self.neg_path, "r") as f:
            line = f.readline()
            while line is not None and line != "":
                arr = line.split("\t")
                negatives = []
                for x in arr[1:]:
                    negatives.append(int(x))
                negative_list.append(negatives)
                line = f.readline()
        return negative_list
    
class NeuralCollaborativeFiltering(torch.nn.Module):

    def __init__(self, num_users: int, num_items: int):
        super().__init__()
        params = MODEL_PARAMETERS['FedNCF']
        layers = params['layers']
        mf_dim = params['mf_dim']
        mlp_dim = int(layers[0] / 2)

        self.mf_embedding_user = torch.nn.Embedding(num_embeddings=num_users, embedding_dim=mf_dim, device=DEVICE)
        self.mf_embedding_item = torch.nn.Embedding(num_embeddings=num_items, embedding_dim=mf_dim, device=DEVICE)

        self.mlp_embedding_user = torch.nn.Embedding(num_embeddings=num_users, embedding_dim=mlp_dim, device=DEVICE)
        self.mlp_embedding_item = torch.nn.Embedding(num_embeddings=num_items, embedding_dim=mlp_dim, device=DEVICE)

        self.mlp = torch.nn.ModuleList()
        current_dim = 64
        for idx in range(1, len(layers)):
            self.mlp.append(torch.nn.Linear(current_dim, layers[idx]))
            current_dim = layers[idx]
            self.mlp.append(torch.nn.ReLU())
        self.output_layer = torch.nn.Linear(in_features=24, out_features=1, device=DEVICE)

    def forward(self, user_input: Tensor,
                item_input: Tensor,
                target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[float]]:
        # matrix factorization
        mf_user_latent = torch.nn.Flatten()(self.mf_embedding_user(user_input))
        mf_item_latent = torch.nn.Flatten()(self.mf_embedding_item(item_input))
        mf_vector = torch.mul(mf_user_latent, mf_item_latent)
        # mlp
        mlp_user_latent = torch.nn.Flatten()(self.mlp_embedding_user(user_input))
        mlp_item_latent = torch.nn.Flatten()(self.mlp_embedding_item(item_input))
        mlp_vector = torch.cat([mlp_user_latent, mlp_item_latent], dim=1)

        for layer in self.mlp:
            mlp_vector = layer(mlp_vector)

        predict_vector = torch.cat([mf_vector, mlp_vector], dim=1)
        logits = self.output_layer(predict_vector)

        loss = None
        if target is not None:
            target = target.view(target.shape[0], 1).to(torch.float32)
            loss = F.binary_cross_entropy_with_logits(logits, target)

        logits = torch.nn.Sigmoid()(logits)

        return logits, loss


app = Flask(__name__)


dataset = Dataset("yelp")
server_model = NeuralCollaborativeFiltering(dataset.num_users, dataset.num_items)
server_model.to(DEVICE)

# List to store weights received from clients
client_weights_list = []

def federated_averaging(client_weights: List[collections.OrderedDict]) -> collections.OrderedDict:
    """
    calculates the average of client weights
    """
    print("----------Performe fed avg----------")
    keys = client_weights[0].keys()
    averages = copy.deepcopy(client_weights[0])

    for w in client_weights[1:]:
        for key in keys:
            averages[key] += w[key]

    for key in keys:
        averages[key] /= len(client_weights)
    return averages


@app.route('/train', methods=['POST'])
def receive_weights_and_train():
    if 'file' not in request.files:
        return 'No file part', 400

    file = request.files['file']
    if file.filename == '':
        return 'No selected file', 400

    if file:
        # Deserialize the received weights
        received_weights = pickle.loads(file.read())

        # Add to client weights list
        client_weights_list.append(received_weights)

        if len(client_weights_list) == 200: 
            # Perform federated averaging
            averaged_weights = federated_averaging(client_weights_list)

            # Update server model with averaged weights
            server_model.load_state_dict(averaged_weights)

            # Clear the list for next round
            client_weights_list.clear()

            return 'Weights averaged and model updated', 200
        else:
            return 'Weights received and waiting for more clients', 200

    return 'Error in processing request', 500

@app.route('/get_weights', methods=['GET'])
def get_weights():
    """
    Endpoint to get the latest model weights.
    """
    # Create an in-memory buffer
    buffer = io.BytesIO()

    # Save your model state to this buffer
    torch.save(server_model.state_dict(), buffer)
    buffer.seek(0)

    # Send this buffer as a file
    return send_file(buffer, as_attachment=True, download_name="model_weights.pth", mimetype='application/octet-stream')

if __name__ == '__main__':
    app.run("0.0.0.0", port=8080)


Loaded `yelp` dataset: 
Number of users - 200, Number of items - 3082
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:8080
 * Running on http://192.168.0.36:8080
[33mPress CTRL+C to quit[0m
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:39] "POST /train HTTP/1.1" 200 -

----------Performe fed avg----------


127.0.0.1 - - [28/Jan/2024 14:23:54] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:54] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:54] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:54] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:54] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:23:55] "POST /train HTTP/1.1" 200 -
127.0.0.1 

----------Performe fed avg----------


127.0.0.1 - - [28/Jan/2024 14:24:09] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:09] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:09] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:09] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:09] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:09] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:10] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:10] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:10] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:10] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:10] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:10] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:10] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:10] "POST /train HTTP/1.1" 200 -
127.0.0.1 - - [28/Jan/2024 14:24:10] "POST /train HTTP/1.1" 200 -
127.0.0.1 