In [1]:
import os 
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data._utils.collate import default_collate

from sklearn.metrics import roc_auc_score, f1_score

import syft as sy
from uuid import UUID
from uuid import uuid4

hook = sy.TorchHook(torch)

In [2]:
from src.psi.util import Client, Server
from src.utils import add_ids
from src.utils.data_utils import id_collate_fn

In [3]:
class VerticalDataset(Dataset):
    """Dataset for Vertical Federated Learning"""

    def __init__(self, ids, data, labels=None):
        """
        Args:
            ids (Numpy Array) : Numpy Array with UUIDS
            data (Numpy Array) : Numpy Array with Features
            targets (Numpy Array) : Numpy Array with Labels. None if not available. 
        """
        self.ids = ids
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        """Return record single record"""
        feature = self.data[index].astype(np.float32)

        if self.labels is None:
            label  = None
        else:
            label = int(self.labels[index]) if self.labels is not None else None

        id = self.ids[index]

        # Return a tuple of non-None elements
        return (*filter(lambda x: x is not None, (feature, label, id)),)
    
    def get_ids(self):
        """Return a list of the ids of this dataset."""
        return [str(id_) for id_ in self.ids]
    
    def sort_by_ids(self):
        """
        Sort the dataset by IDs in ascending order
        """
        ids = self.get_ids()
        sorted_idxs = np.argsort(ids)


        self.data = self.data[sorted_idxs]

        if self.labels is not None:
            self.labels = self.labels[sorted_idxs]

        self.ids = self.ids[sorted_idxs]

# Load Data

In [4]:
# Load Intact Data
data_dir = "/ssd003/projects/pets/datasets"
INTACT_DATA_PATH = f"{data_dir}/prdct_insurance_stats_info.csv"
intact_df_full = pd.read_csv(INTACT_DATA_PATH)
intact_df_full.head()

Unnamed: 0,PWAPART,PWABEDR,PWALAND,PPERSAUT,PBESAUT,PMOTSCO,PVRAAUT,PAANHANG,PTRACTOR,PWERKT,...,AWAOREG,ABRAND,AZEILPL,APLEZIER,AFIETS,AINBOED,ABYSTAND,UUID,ORIGIN,CARAVAN
0,0.0,0.0,0.0,0.666667,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.142857,0.0,0.0,0.0,0.0,0.0,cb8db3d9-e0c2-4f2c-a866-d579a1d61ff4,train,0
1,0.666667,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.142857,0.0,0.0,0.0,0.0,0.0,eed33082-7621-473d-b255-c5d35198627a,train,0
2,0.666667,0.0,0.0,0.666667,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.142857,0.0,0.0,0.0,0.0,0.0,7950c7c4-4003-44d0-88dc-b05e6437afd4,train,0
3,0.0,0.0,0.0,0.666667,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.142857,0.0,0.0,0.0,0.0,0.0,6b807aa1-88d8-4ab0-ad1b-a8ef051f91e1,train,0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.142857,0.0,0.0,0.0,0.0,0.0,c1100358-d07d-4d70-aad0-49ace40eadda,train,0


In [5]:
#Load Vendor Data
VENDOR_DATA_PATH = f"{data_dir}/demograhic_info.csv"
vendor_df_full = pd.read_csv(VENDOR_DATA_PATH)
vendor_df_full.head()

Unnamed: 0,MAANTHUI,MGEMOMV,MGEMLEEF,MGODRK,MGODPR,MGODOV,MGODGE,MRELGE,MRELSA,MRELOV,...,MOSHOOFD_3,MOSHOOFD_4,MOSHOOFD_5,MOSHOOFD_6,MOSHOOFD_7,MOSHOOFD_8,MOSHOOFD_9,MOSHOOFD_10,UUID,ORIGIN
0,0.0,0.4,0.2,0.0,0.555556,0.2,0.333333,0.777778,0.0,0.222222,...,0,0,0,0,0,1,0,0,cb8db3d9-e0c2-4f2c-a866-d579a1d61ff4,train
1,0.0,0.2,0.2,0.111111,0.444444,0.2,0.444444,0.666667,0.285714,0.222222,...,0,0,0,0,0,1,0,0,eed33082-7621-473d-b255-c5d35198627a,train
2,0.0,0.2,0.2,0.0,0.444444,0.4,0.444444,0.333333,0.285714,0.444444,...,0,0,0,0,0,1,0,0,7950c7c4-4003-44d0-88dc-b05e6437afd4,train
3,0.0,0.4,0.4,0.222222,0.333333,0.4,0.444444,0.555556,0.285714,0.222222,...,1,0,0,0,0,0,0,0,6b807aa1-88d8-4ab0-ad1b-a8ef051f91e1,train
4,0.0,0.6,0.2,0.111111,0.444444,0.2,0.444444,0.777778,0.142857,0.222222,...,0,0,0,0,0,0,0,1,c1100358-d07d-4d70-aad0-49ace40eadda,train


In [8]:
# Carve out validation set
assert intact_df_full.shape[0] == vendor_df_full.shape[0]

intact_df = intact_df_full[intact_df_full['ORIGIN'] == 'train']
intact_df_val = intact_df_full[intact_df_full['ORIGIN'] == 'test']

_ = intact_df.pop('ORIGIN')
_ = intact_df_val.pop('ORIGIN')

vendor_df = vendor_df_full[vendor_df_full['ORIGIN'] =='train']
vendor_df_val = vendor_df_full[vendor_df_full['ORIGIN']=='test']

_ = vendor_df.pop('ORIGIN')
_ = vendor_df_val.pop('ORIGIN')

In [9]:
# Get UID Column
uuids = intact_df.pop('UUID').values
uuids_val = intact_df_val.pop('UUID').values

_ = vendor_df.pop('UUID')
_ = vendor_df_val.pop('UUID')

In [10]:
## Sanity Check
print(f"Training \tIntact Data: {str(intact_df.shape)} Vendor: {str(vendor_df.shape)}")
print(f"Validation: \tIntact Data: {str(intact_df_val.shape)} Vendor: {str(vendor_df_val.shape)}")

Training 	Intact Data: (5822, 43) Vendor: (5822, 91)
Validation: 	Intact Data: (4000, 43) Vendor: (4000, 91)


### Define Dataloader Classes 

In [11]:
class SinglePartitionDataLoader(DataLoader):
    """DataLoader for a single vertically-partitioned dataset"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.collate_fn = id_collate_fn

In [12]:
class VerticalDataLoader:
    """Dataloader which batches data from a complete
    set of vertically-partitioned datasets
    i.e. the images dataset AND the labels dataset
    """

    def __init__(self, data1, data2, *args, **kwargs):

        self.dataloader1 = SinglePartitionDataLoader(
            data1, *args, **kwargs
        )
        self.dataloader2 = SinglePartitionDataLoader(
            data2, *args, **kwargs
        )

    def __iter__(self):
        """
        Zip Dataloaders 
        """
        return zip(self.dataloader1, self.dataloader2)

    def __len__(self):
        """
        Return length of dataset
        """
        return (len(self.dataloader1) + len(self.dataloader2)) // 2

    def drop_non_intersecting(self, intersection):
        """Remove elements and ids in the datasets that are not in the intersection."""
        self.dataloader1.dataset.data = self.dataloader1.dataset.data[intersection]
        self.dataloader1.dataset.ids = self.dataloader1.dataset.ids[intersection]

        self.dataloader1.dataset.labels = self.dataloader1.dataset.labels[intersection]
        self.dataloader2.dataset.ids = self.dataloader2.dataset.ids[intersection]

    def sort_by_ids(self) -> None:
        """
        Sort each dataset by ids
        """
        self.dataloader1.dataset.sort_by_ids()
        self.dataloader2.dataset.sort_by_ids()

## Initialize Datasets 

In [14]:
# Intact Dataset

TARGET_COLUMN = "CARAVAN"

#Training
intact_labels = np.array(intact_df.pop(TARGET_COLUMN))
intact_data = np.array(intact_df)
print("train", uuids.shape, intact_data.shape, intact_labels.shape)
intact_dim = intact_data.shape[1]
intact_dataset = VerticalDataset(ids=uuids, data=intact_data, labels=intact_labels)

#Validation
intact_labels_val = np.array(intact_df_val.pop(TARGET_COLUMN))
intact_data_val = np.array(intact_df_val)
print("validation", uuids_val.shape, intact_data_val.shape, intact_labels_val.shape)
intact_dataset_val = VerticalDataset(ids=uuids_val, data=intact_data_val, labels=intact_labels_val)

train (5822,) (5822, 42) (5822,)
validation (4000,) (4000, 42) (4000,)


In [15]:
# Vendor Dataset

#Training
vendor_data = np.array(vendor_df)
print(vendor_data.shape)
vendor_dim = vendor_data.shape[1]
vendor_feat_dim = 4
vendor_dataset = VerticalDataset(ids=uuids, data=vendor_data, labels=None)

#Validation
vendor_dataset_val = np.array(vendor_df_val)
vendor_dataset_val = VerticalDataset(ids=uuids_val, data=vendor_dataset_val, labels=None)

(5822, 91)


## Initialize Dataloader

In [16]:
## Initialize Train Dataloader 
dataloader = VerticalDataLoader(intact_dataset, vendor_dataset, batch_size=512)

# Compute private set intersection
client_items = dataloader.dataloader1.dataset.get_ids()
server_items = dataloader.dataloader2.dataset.get_ids()
 
client = Client(client_items)
server = Server(server_items)

setup, response = server.process_request(client.request, len(client_items))
intersection = client.compute_intersection(setup, response)

# Order data
dataloader.drop_non_intersecting(intersection)
dataloader.sort_by_ids()

In [17]:
## Initialize Validation Dataloader 
val_dataloader = VerticalDataLoader(intact_dataset_val, vendor_dataset_val, batch_size=512)

# Compute private set intersection
val_client_items = val_dataloader.dataloader1.dataset.get_ids()
val_server_items = val_dataloader.dataloader2.dataset.get_ids()

val_client = Client(val_client_items)
val_server = Server(val_server_items)

val_setup, val_response = val_server.process_request(val_client.request, len(val_client_items))
val_intersection = val_client.compute_intersection(val_setup, val_response)

# Order data
val_dataloader.drop_non_intersecting(val_intersection)
val_dataloader.sort_by_ids()

## **Model Preparation**

In [None]:
class IntactModel(torch.nn.Module):
    """ 
    Model for the Intact dataset
    
    Attributes
    ----------
    dim: 
        Dimensionality of Intact Data
    Methods
    -------
    forward(x):
        Performs a forward pass through the Intact Model
    """
    def __init__(self, intact_dim, vendor_dim): 
        super(IntactModel, self).__init__()
        self.fused_input_dim = intact_dim + vendor_dim
        self.layers = nn.Sequential(
            nn.Linear(self.fused_input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
    
    def forward(self, intact_feat, vendor_feat):
        feat = torch.cat([intact_feat, vendor_feat], dim=1)
        pred = self.layers(feat)
        return pred

In [None]:
class VendorModel(torch.nn.Module):
    """ 
    Model for Vendor variables
    
    Attributes
    ----------
    dim: 
        Dimensionality of the vendor data
    Methods
    -------
    forward(x):
        Performs a forward pass through the Credit Bureau Model
    """
    
    def __init__(self, vendor_dim): 
        super(VendorModel, self).__init__()
        self.vendor_dim = vendor_dim
        self.layers = torch.nn.Sequential(
            nn.Linear(self.vendor_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 4),
            nn.Sigmoid()
        )
    
    def forward(self, vendor_feat):
        pred = self.layers(vendor_feat)
        return pred

In [None]:
class SplitNN:
    """
    A class representing SplitNN

    Attributes
    ----------
    intact_model:  
        Home Credit Neural Network Module

    vendor_model:   
        Credit Bureau Neural Network Module

    intact_opt:  
        Optimizer for the Home Credit Neural Network Module

    vendor_model:   
        Optimizer for the Credit Bureau Neural Network Module

    data: 
        A list storing intermediate computations at each index

    remote_tensors: 
        A list storing intermediate computations at each index (Computation from each model detached from global computation graph)

    Methods
    -------
    forward(x):
        Performs a forward pass through the SplitNN

    backward(): 
        Performs a backward pass through the SplitNN

    zero_grads():
        Zeros the gradients of all networks in SplitNN

    step():
        Updates the parameters of all networks in SplitNN
    """


    def __init__(self, intact_model, vendor_model, intact_opt, vendor_opt):
        self.intact_model = intact_model
        self.vendor_model = vendor_model
        self.intact_opt = intact_opt
        self.vendor_opt = vendor_opt
        self.data = []
        self.remote_tensors = []

    def forward(self, intact_x, vendor_x):
        """
        Parameters
        ----------
        x:  
            Input Sample 
        """

        data = []
        remote_tensors = []

        # Forward pass through first model
        data.append(self.vendor_model(vendor_x))

        # if location of data is the same as location of the subsequent model
        if data[-1].location == self.intact_model.location:
            # store computation in remote tensor array 
            # Gradients will be only computed backward upto the point of detachment
            remote_tensors.append(data[-1].detach().requires_grad_())
        else:
            # else move data to location of subsequent model and store computation in remote tensor array 
            # Gradients will be only computed backward upto the point of detachment
            remote_tensors.append(
                data[-1].detach().move(self.intact_model.location).requires_grad_()
            )

        # Get and return final output of model
        data.append(self.intact_model(intact_x, remote_tensors[-1]))

        self.data = data 
        self.remote_tensors = remote_tensors
        return data[-1]

    def backward(self):
        # if location of data is the same as detatched data 
        if self.remote_tensors[0].location == self.data[0].location:
            # Store gradients from remote_tensor 
            grads = self.remote_tensors[0].grad.copy()
        else:
            # Move gradients to lovation of Store grad
            grads = self.remote_tensors[0].grad.copy().move(self.data[0].location)

        self.data[0].backward(grads)

    def zero_grads(self):
        """
        Parameters
        ----------
        """
        self.vendor_opt.zero_grad()
        self.intact_opt.zero_grad()


    def step(self):
        """
        Parameters
        ----------
        """
        self.vendor_opt.step()
        self.intact_opt.step()

### Initialize and Configure Models

In [None]:
# Training globals 
epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Iniatialize Intact Model and Optimizer
intact_model = IntactModel(intact_dim, vendor_feat_dim)
intact_opt = torch.optim.Adam(intact_model.parameters(), lr=.001,  betas=(0.9, 0.999))

# Iniatialize Credit Bureau Model and Optmizer
vendor_model = VendorModel(vendor_dim)
vendor_opt = torch.optim.Adam(vendor_model.parameters(), lr=.001,  betas=(0.9, 0.999))

# Define Split Neural Network
splitNN = SplitNN(intact_model, vendor_model, intact_opt, vendor_opt)
criterion = torch.nn.BCELoss()

### Configure (Virtual) Remote Workers

In [None]:
# create some workers
intact_worker = sy.VirtualWorker(hook, id="intact")
vendor_worker = sy.VirtualWorker(hook, id="vendor")

# Send Model Segments to model locations
model_locations = [intact_worker, vendor_worker]
models = [intact_model, vendor_model]
for model, location in zip(models, model_locations):
    model.send(location)

## Training

In [None]:
def train_step(dataloader, splitNN):
    running_loss = 0
    for (intact_data, labels, id1), (vendor_data, id2) in dataloader:
        # Send data and labels to machine model is on
        labels = labels.float()
        intact_data = intact_data.send(intact_model.location)
        labels = labels.send(intact_model.location)
        vendor_data = vendor_data.send(vendor_model.location)

        # Zero our grads
        splitNN.zero_grads()
    
        # Make a prediction
        pred = splitNN.forward(intact_data, vendor_data).squeeze()

        # Figure out how much we missed by
        loss = criterion(pred, labels)
    
        # Backprop the loss on the end layer
        loss.backward()
        splitNN.backward()
    
        # Change the weights
        splitNN.step()
        
        # Accumulate Loss
        running_loss += loss.get()

    
    return running_loss

In [None]:
def val_step(val_dataloader, splitNN):
    running_loss = 0
    exs = 0 
    correct = 0
    aucs = []
    f1s = []
    for (intact_data_val, labels_val, id1), (vendor_data_val, id2) in val_dataloader:
        # Send data and labels to machine model is on
        labels_val  = labels_val.float()
        intact_data_val = intact_data_val.send(intact_model.location)
        labels_val  = labels_val.send(intact_model.location)
        vendor_data_val = vendor_data_val.send(vendor_model.location)
    
        # Make a prediction
        with torch.no_grad():
            pred = splitNN.forward(intact_data_val, vendor_data_val).squeeze()
        
        #Calcualte Loss
        criterion = torch.nn.BCELoss()
        loss = criterion(pred, labels_val)

        #Calculate AUC
        thresh_pred = (pred > .5).float()
        thresh_pred = thresh_pred.get().int()
        labels_val = labels_val.get().int()

        # Fix Me: Undefined for batches with all-same labels...
        auc = roc_auc_score(labels_val, pred.get().numpy())
        f1 = f1_score(labels_val, thresh_pred)

        #Calculate Accuracy Components
        num_exs = intact_data_val.shape[0]
        num_correct = torch.sum(thresh_pred == labels_val).item()

        # Accumulate loss, accuracy and auc
        exs += num_exs
        correct += num_correct
        running_loss += loss.get()
        aucs.append(auc)
        f1s.append(f1)

    auc = np.mean(np.array(aucs))
    f1 = np.mean(np.array(f1s))
    accuracy = correct / exs

    return f1, accuracy, running_loss

In [None]:
metric_names = ["Train Loss", "Validation Loss", "Accuracy", "F1"]
metrics = {metric:[] for metric in metric_names}

# Train Loop
for i in range(epochs):

    # Train Step 
    train_loss = train_step(dataloader, splitNN)

    # Train Step
    f1, accuracy, val_loss = val_step(val_dataloader, splitNN)
    
    # Log metrics
    print(f"Epoch: {i} \t F1: {f1}")
    metrics["Train Loss"].append(train_loss.item())
    metrics["Validation Loss"].append(val_loss.item())
    metrics["Accuracy"].append(accuracy)
    metrics["F1"].append(f1)