In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

import syft as sy
import numpy as np
from pathlib import Path
import pandas as pd
from tqdm.notebook import tqdm
np.random.seed(666)
from Distributed_HM_Data import Distributed_HM, binary_acc

dataDir = Path.cwd().parent/'Data/'

# model will train on CPU since PySyft 0.2.9 exist bugs with CUDA
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)

In [None]:
import sys
import logging

# preserve training log
so = open("config2.log", 'w', 10)
sys.stdout.echo = so
sys.stderr.echo = so

get_ipython().log.handlers[0].stream = so
get_ipython().log.setLevel(logging.INFO)

In [None]:
class HMSaleTrainDataLoader(Dataset):
    """HMSaleTrainDataLoader Training set of HM sales data

    Args:
        transactions (pd.DataFrame): Dataframe of transaction records
        all_products_id (list): A list contains all product ids
    """
    def __init__(self, transactions, all_products_id):
        self.customers, self.products, self.prices, self.sales_channels, \
        self.club_status, self.age_groups, self.product_groups, self.color_groups, \
        self.index_name, self.labels = self.get_dataset(transactions, all_products_id)

    def __len__(self):
        return len(self.customers)
    
    def __getitem__(self, idx):
        return self.customers[idx], self.products[idx], self.prices[idx], self.sales_channels[idx], self.club_status[idx], \
               self.age_groups[idx], self.product_groups[idx], self.color_groups[idx], self.index_name[idx], self.labels[idx]
    
    def get_dataset(self, transactions, all_products_id):
        customers, products, prices, sales_channels, club_status, age_groups, product_groups, color_groups, index_name, labels  = [], [], [], [], [], [], [], [], [], []
        customer_product_set = set(zip(transactions["customer_id"], transactions["article_id"], 
                                       transactions["price"], transactions["sales_channel_id"], 
                                       transactions["club_member_status"], transactions["age"], 
                                       transactions["product_group_name"], transactions["colour_group_name"], transactions["index_name"]))
        
        """negative sampling"""
        # set up negative:positive ratio as 4:1
        negative_samples = 4

        for u, i, price, sale, club, age, product, color, index in tqdm(customer_product_set):
            customers.append(u)
            products.append(i)
            prices.append(price)
            sales_channels.append(sale)
            club_status.append(club)
            age_groups.append(age)
            product_groups.append(product)
            color_groups.append(color)
            index_name.append(index)
            labels.append(1)
            for _ in range(negative_samples):
                negative_product = np.random.choice(all_products_id)
                while (u, negative_product, price, sale, club, age, product, color, index) in customer_product_set:
                    negative_product = np.random.choice(all_products_id)
                customers.append(u)
                products.append(negative_product)
                prices.append(price)
                sales_channels.append(sale)
                club_status.append(club)
                age_groups.append(age)
                product_groups.append(product)
                color_groups.append(color)
                index_name.append(index)
                labels.append(0)
        
        customers = torch.tensor(customers)
        products = torch.tensor(products)
        prices = torch.tensor(prices)
        sales_channels = torch.tensor(sales_channels)
        club_status = torch.tensor(club_status)
        age_groups = torch.tensor(age_groups)
        product_groups = torch.tensor(product_groups)
        color_groups = torch.tensor(color_groups)
        index_name = torch.tensor(index_name)
        labels = torch.tensor(labels)
        
        return customers, products, prices, sales_channels, club_status, age_groups, product_groups, color_groups, index_name, labels


In [None]:
class SalesNN(nn.Module):
    """ Partial model for sales domain
    Args:
        num_users (int): Number of users
        num_items (int): Number of products
        prices (float): price of transactions
        sales_channels (float): sales channels
    
    """
    def __init__(
            self, 
            num_users: int, 
            num_items: int,
            input_size: int = 50,
            user_embedding_dim: int = 16,
            item_embedding_dim: int = 32,
            hidden_size_1: int = 128,
            output_size: int = 32,
        ):
        super().__init__()
        self.user_embedding_layer = nn.Embedding(num_embeddings=num_users, embedding_dim=user_embedding_dim)
        self.item_embedding_layer = nn.Embedding(num_embeddings=num_items, embedding_dim=item_embedding_dim)
        self.relu = nn.LeakyReLU()
        if num_users and num_items is not None:
                in_channels = (
                    [input_size] 
                    + [hidden_size_1]
                    + [output_size]
                )
        else:
            raise ValueError
        self.encoder = nn.Sequential(
            *[nn.Linear(in_features=in_channels[i], out_features=in_channels[i+1]) for i in range(len(in_channels)-1) if i != len(in_channels)-1]
        )
        
    def forward(self, user_input, item_input, prices, sales_channels):
        user_embedding = self.user_embedding_layer(user_input)
        item_embedding = self.item_embedding_layer(item_input)
#         user_embedding = torch.squeeze(user_embedding, dim=1)
#         item_embedding = torch.squeeze(item_embedding, dim=1)
        
        latent_vec = torch.cat([user_embedding, item_embedding, prices, sales_channels], dim=-1)
        for layer in self.encoder:
            latent_vec = layer(latent_vec)
            latent_vec = self.relu(latent_vec)
        
        return latent_vec
    
    # save weights of partial model on remote worker
    def get_weights(self):
        return self.state_dict()
    
class CustomersNN(nn.Module):
    """ Partial model for customer domain
    Args:
        club_status (int): active or inactive customers' status
        age_groups (int): age of customers
    
    """
    def __init__(
            self,
            input_size: int = 2,
            output_size: int = 5,
        ):
        super().__init__()
        self.relu = nn.LeakyReLU()
        in_channels = (
            [input_size] 
            + [output_size]
        )
        
        self.encoder = nn.Sequential(
            *[nn.Linear(in_features=in_channels[i], out_features=in_channels[i+1]) for i in range(len(in_channels)-1) if i != len(in_channels)-1]
        )
        
    def forward(self, club_status, age_groups):
        
        latent_vec = torch.cat([club_status, age_groups], dim=-1)
        
        for layer in self.encoder:
            latent_vec = layer(latent_vec)
            latent_vec = self.relu(latent_vec)
        
        return latent_vec
    
    # save weights of partial model on remote worker
    def get_weights(self):
        return self.state_dict()

class ProductsNN(nn.Module):
    """ Partial model for product domain
    Args:
        num_product_groups (int): Number of product groups
        num_color_groups: (int): Number of color groups
        num_index_name: (int): Number of index name
    
    """
    def __init__(
            self,
            num_product_groups: int,
            num_color_groups: int,
            num_index_name: int,
            product_group_embedding_dim: int = 8,
            color_group_embedding_dim: int = 16,
            index_name_embedding_dim: int = 6,
            input_size: int = 30,
            hidden_size_1: int = 64,
            output_size: int = 16,
        ):
        super().__init__()
        self.product_group_embedding_layer = nn.Embedding(num_embeddings=num_product_groups, embedding_dim=product_group_embedding_dim)
        self.color_group_embedding_layer = nn.Embedding(num_embeddings=num_color_groups, embedding_dim=color_group_embedding_dim)
        self.index_name_embedding_layer = nn.Embedding(num_embeddings=num_index_name, embedding_dim=index_name_embedding_dim)
        self.relu = nn.LeakyReLU()
        
        if num_product_groups and num_color_groups and num_index_name is not None:
                in_channels = (
                    [input_size] 
                    + [hidden_size_1]
                    + [output_size]
                )
        else:
            raise ValueError
        self.encoder = nn.Sequential(
            *[nn.Linear(in_features=in_channels[i], out_features=in_channels[i+1]) for i in range(len(in_channels)-1) if i != len(in_channels)-1]
        )
        
    def forward(self, product_groups, color_groups, index_name):
        product_group_embedding = self.product_group_embedding_layer(product_groups)
        color_group_embedding = self.color_group_embedding_layer(color_groups)
        index_name_embedding = self.index_name_embedding_layer(index_name)
#         product_group_embedding = torch.squeeze(product_group_embedding, dim=1)
#         color_group_embedding = torch.squeeze(color_group_embedding, dim=1)
#         index_name_embedding = torch.squeeze(index_name_embedding, dim=1)
        
        latent_vec = torch.cat([product_group_embedding, color_group_embedding, index_name_embedding], dim=-1)
        
        for layer in self.encoder:
            latent_vec = layer(latent_vec)
            latent_vec = self.relu(latent_vec)
        
        return latent_vec
    
    # save weights of partial model on remote worker
    def get_weights(self):
        return self.state_dict()

class GovernanceNN(nn.Module):
    """ Partial model for goverance side
    Args:
        agg_latent_input (int): aggregated input of latent vectors from client models
    
    """
    def __init__(
            self,
            input_size: int = 53,
            hidden_size_1: int = 128,
            hidden_size_2: int = 64,
            output_size: int = 2,
        ):
        super().__init__()
        self.relu = nn.LeakyReLU()
        in_channels = (
            [input_size] 
            + [hidden_size_1]
            + [hidden_size_2]
            + [output_size]
        )
        self.decoder = nn.Sequential(
            *[nn.Linear(in_features=in_channels[i], out_features=in_channels[i+1]) for i in range(len(in_channels)-1) if i != len(in_channels)-1]
        )
    def forward(self, agg_latent_input):
        
        for layer in self.decoder:
            agg_latent_input = layer(agg_latent_input)
            agg_latent_input = self.relu(agg_latent_input)
        
        out = agg_latent_input
        
        return out
    
    # save weights of partial model on remote worker
    def get_weights(self):
        return self.state_dict()


class SplitNN(nn.Module):
    def __init__(self, models, optimizers, data_owner, server):
        self.models = models
        self.optimizers = optimizers
        self.data_owners = data_owner
        self.server = server
#         self.outputs = [None]*len(self.models)
        super().__init__()
        
    def forward(self, data_pointer):
        
        #individual client's output upto their respective cut layer
        client_output = {}

        #outputs that is moved to server and subjected to concatenate for server input
        remote_output = []
        
        for owner in data_owners:
            if owner.id == "sales_domain":
                client_output[owner.id] = models[owner.id](data_pointer[owner.id][0], data_pointer[owner.id][1], data_pointer[owner.id][2], data_pointer[owner.id][3])
                remote_output.append(
                    client_output[owner.id].move(server, requires_grad=True)
                )
            elif owner.id == "customer_domain":
                client_output[owner.id] = models[owner.id](data_pointer[owner.id][0], data_pointer[owner.id][1])
                remote_output.append(
                    client_output[owner.id].move(server, requires_grad=True)
                )
            elif owner.id == "product_domain":
                client_output[owner.id] = models[owner.id](data_pointer[owner.id][0], data_pointer[owner.id][1], data_pointer[owner.id][2])
                remote_output.append(
                    client_output[owner.id].move(server, requires_grad=True)
                )
        # concat outputs from clients and send to server side
        server_input = torch.cat(remote_output, dim=-1)
        # make prediction on server model
        pred = models["server"](server_input)

        return pred

    def zero_grads(self):
        for opt in self.optimizers:
            opt.zero_grad()
        
    def step(self):
        for opt in self.optimizers:
            opt.step()
    
    def train(self):
        for loc in self.models.keys():
            self.models[loc].train()
#             if loc == "server":
#                 for i in range(len(self.models[loc])):
#                     self.models[loc][i].train()
#             else:
#                 self.models[loc].train()
    
    def eval(self):
        for loc in self.models.keys():
            self.models[loc].eval()        
    
    def load_weights(self, file_prefix):
        for loc in self.models.keys():
            self.models[loc].load_state_dict(torch.load(f"{file_prefix}_{loc}_weights.pth"))
            
    @property
    def location(self):
        return self.models[0].location if self.models and len(self.models) else None
    
def train(x, label, splitNN):
    
    #1) Zero our grads
    splitNN.zero_grads()
    
    #2) Make a prediction
    pred = splitNN.forward(x)
  
    #3) Figure out how much we missed by
    criterion = nn.CrossEntropyLoss()
    loss = criterion(pred, label)
    
    #4) Backprop the loss on the end layer
    loss.backward()
    
    #5) Feed Gradients backward through the network
    #splitNN.backward()
    
    #6) Change the weights
    splitNN.step()
    
    return loss.detach().get()
 

In [None]:
hm_data = pd.read_csv(dataDir/'medium_train.csv')
all_products_id = hm_data["article_id"].unique()
train_data = HMSaleTrainDataLoader(hm_data, all_products_id)
train_loader = DataLoader(train_data, batch_size=1024, shuffle=True)

# set up virtual worker
hook = sy.TorchHook(torch)
sales_domain = sy.VirtualWorker(hook, id="sales_domain")
customer_domain = sy.VirtualWorker(hook, id="customer_domain")
product_domain = sy.VirtualWorker(hook, id="product_domain")
server = sy.VirtualWorker(hook, id="server")
label_owner = sy.VirtualWorker(hook, id="label_owner")

data_owners = (sales_domain, customer_domain, product_domain)
model_locations = [sales_domain, customer_domain, product_domain, server]

distributed_trainloader = Distributed_HM(data_owners=data_owners, data_loader=train_loader)

# set up parameters for model
num_users = len(hm_data.customer_id.unique())
print("num_users:", num_users)
num_items = len(all_products_id)
print("num_items:", num_items)
num_product_groups = len(hm_data.product_group_name.unique())
print("num_product_groups:", num_product_groups)
num_color_groups = len(hm_data.colour_group_name.unique())
print("num_color_groups:", num_color_groups)
num_index_name = len(hm_data.index_name.unique())

models = {
    "sales_domain": SalesNN(num_users=num_users, num_items=num_items),
    "customer_domain": CustomersNN(),
    "product_domain": ProductsNN(num_product_groups=num_product_groups, num_color_groups=num_color_groups, num_index_name=num_index_name),
    "server": GovernanceNN(),
}

# set up optimizer for clients' model
optimizers = [
    optim.Adam(models[location.id].parameters(), lr=0.003)
    for location in model_locations
]

for location in model_locations:
    models[location.id].send(location)

In [None]:
print(models)

epochs = 150
torch.autograd.set_detect_anomaly(True)
splitnn = SplitNN(models, optimizers, data_owners, server)

for i in range(epochs):
    running_loss = 0.0
    splitnn.train()
    for data_ptr, labels in distributed_trainloader:  
        labels = labels.send(label_owner)
        loss = train(data_ptr, labels, splitnn)
        running_loss += loss
    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(distributed_trainloader)))
        