In [None]:
import sys
from pathlib import Path
parent_dir = str(Path.cwd().parent)
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

import syft as sy
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
np.random.seed(666)
from Distributed_CreditCard_Data import CreditCardDataLoader, Distributed_CreditCard, binary_acc

dataDir = Path.cwd().parent.parent/'CreditCard/'

In [None]:
import sys
import logging

# preserve training log
so = open("config1.log", 'w', 10)
sys.stdout.echo = so
sys.stderr.echo = so

get_ipython().log.handlers[0].stream = so
get_ipython().log.setLevel(logging.INFO)

In [4]:
class DigitalNN(nn.Module):
    """ Partial model for digital transaction domain
    Args:
        input_size (int): number of features in digital transaction domain
        digital_transaction_intput (tensor): input size of digital transaction domain
    
    """
    def __init__(
            self, 
            input_size: int,
            hidden_size: int = 32,
            output_size: int = 16,
        ):
        super().__init__()

        self.layers_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, output_size),
            nn.ReLU(),
        )
        
    def forward(self, digital_input):
        
        for layer in self.layers_stack:
            digital_input = layer(digital_input)
        
        return digital_input
    
    # save weights of partial model on remote worker
    def get_weights(self):
        return self.state_dict()

class RetailNN(nn.Module):
    """ Partial model for retail transaction domain

    Args:
        input_size (int): number of features in retail transaction domain
        retail_transaction_intput (tensor): input size of retail transaction domain
    
    """
    def __init__(
            self, 
            input_size: int,
            hidden_size: int = 32,
            output_size: int = 16,
        ):
        super().__init__()

        self.layers_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, output_size),
            nn.ReLU(),
        )
        
    def forward(self, retail_input):
        
        for layer in self.layers_stack:
            retail_input = layer(retail_input)
        
        return retail_input
    
    # save weights of partial model on remote worker
    def get_weights(self):
        return self.state_dict()
    
class FraudPrevNN(nn.Module):
    """ Partial model for fraud prevention domain

    Args:
        input_size (int): number of features in fraud prevention domain
        fraud_prev_intput (tensor): input size of fraud prevention domain
    
    """
    def __init__(
            self, 
            input_size: int,
            hidden_size: int = 32,
            output_size: int = 16,
        ):
        super().__init__()

        self.layers_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, output_size),
            nn.ReLU(),
        )
        
    def forward(self, fraud_prev_input):
        
        for layer in self.layers_stack:
            fraud_prev_input = layer(fraud_prev_input)
        
        return fraud_prev_input
    
    # save weights of partial model on remote worker
    def get_weights(self):
        return self.state_dict()

class GovernanceNN(nn.Module):
    """Partial model for governance side
    
    Args:
        input_size (int): number of features in governance side
        governance_input (tensor): input size of governance side
    """
    def __init__(
            self,
            input_size: int = 48,
            hidden_size: int = 64,
            output_size: int = 2,
        ):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.layers_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, output_size),
            nn.ReLU(),
        )

    def forward(self, governance_input):

        for layer in self.layers_stack:
            governance_input = layer(governance_input)
        return governance_input
    
    # save weights of partial model on remote worker
    def get_weights(self):
        return self.state_dict()
    
class SplitFraudNN(nn.Module):
    """ Split fraud neural network 
    
    Args:
        data_pointer (dict): syft data pointer
    """
    def __init__(self, models, optimizers, data_owners, server):
        super().__init__()
        
        self.models = models
        self.optimizers = optimizers
        self.data_owners = data_owners
        self.server = server
    
    def forward(self, data_pointer):

        #individual client's output upto their respective cut layer
        client_output = {}

        #outputs that is moved to server and subjected to concatenate for server input
        remote_output = []

        for owner in self.data_owners:
            client_output[owner.id] = self.models[owner.id](data_pointer[owner.id])
            remote_output.append(client_output[owner.id].move(self.server, requires_grad=True))
        
        # concatenate the output of individual client's output
        server_input = torch.cat(remote_output, dim=1)
        # server side make prediction
        pred = self.models["server"](server_input)

        return pred

    def zero_grad(self):
        for opt in self.optimizers:
            opt.zero_grad()
    
    def step(self):
        for opt in self.optimizers:
            opt.step()

    def train(self):
        for loc in self.models.keys():
            self.models[loc].train()
    
    def eval(self):
        for loc in self.models.keys():
            self.models[loc].eval()

def train(x, target, splitNN):
    
    splitNN.zero_grads()
    pred = splitNN.forward(x)
  
    criterion = nn.CrossEntropyLoss()
    loss = criterion(pred, target)
    
    #Backprop the loss on the end layer
    loss.backward()
    splitNN.step()
    
    return loss.detach().get()
        

In [None]:
digital_transaction_data = np.load(dataDir/'digital_transaction_train.npy')
retail_transaction_data = np.load(dataDir/'retail_transaction_train.npy')
fraud_prevention_data = np.load(dataDir/'fraud_prevention_train.npy')
labels = np.load(dataDir/'labels_train.npy')

train_data = CreditCardDataLoader(digital_transaction_data, retail_transaction_data, fraud_prevention_data, labels)
train_loader = DataLoader(train_data, batch_size=256, shuffle=True)

# set up virtual workers for SplitFraudNN
hook = sy.TorchHook(torch)
digital_domain = sy.VirtualWorker(hook, id="digital_domain")
retail_domain = sy.VirtualWorker(hook, id="retail_domain")
fraud_prevention_domain = sy.VirtualWorker(hook, id="fraud_prevention_domain")
server = sy.VirtualWorker(hook, id="server")

data_owners = (digital_domain, retail_domain, fraud_prevention_domain)
model_locations = [digital_domain, retail_domain, fraud_prevention_domain, server]

# set up distributed data loader for SplitFraudNN
distributed_trainloader = Distributed_CreditCard(data_owners=data_owners, data_loader=train_loader)

# set up models for SplitFraudNN

models = {
    "digital_domain": DigitalNN(input_size=10),
    "retail_domain": RetailNN(input_size=10),
    "fraud_prevention_domain": FraudPrevNN(input_size=10),
    "server": GovernanceNN(input_size=48),
}

# set up optimizers for SplitFraudNN
optimizers = [
    optim.Adam(models[owner.id].parameters(), lr=1e-4) for owner in model_locations
]

for location in model_locations:
    models[location.id].send(location)

In [None]:
print(models)

epochs = 100
torch.autograd.set_detect_anomaly(True)
splitFraudNN = SplitFraudNN(models, optimizers, data_owners, server)

for i in tqdm(range(epochs)):
    running_loss = 0.0
    splitFraudNN.train()
    for data_ptr, labels in distributed_trainloader:
        labels = labels.send(server)
        loss = train(data_ptr, labels, splitFraudNN)
        running_loss += loss
    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(distributed_trainloader)))