#### TODO:
- add .backward to PhiTensor
- fix data subjects for CrossEntropyLoss
- implement optimizer similar using nn.Module overriding forward to accept DP Tensors

In [1]:
import syft
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from syft import nn

In [3]:
import syft.core.tensor.nn.functional as F
from syft import PhiTensor

In [4]:
class ConvNet(torch.nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=2)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=2)
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=2)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm2d(512)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.avg = nn.AvgPool2d(3)
        self.fc = nn.Linear(512 * 1 * 1, 2)
        
    def forward(self, x: PhiTensor):
        # First layer of CNN - running 1 at a time to debug and see if any individual componenet is failing
#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = F.leaky_relu(x)
#         x = self.pool(x)
        
        # Subsequent layers
        x = self.pool(F.leaky_relu(self.bn1(self.conv1(x))))
        x = self.pool(F.leaky_relu(self.bn2(self.conv2(x))))
        x = self.pool(F.leaky_relu(self.bn3(self.conv3(x))))
        x = self.pool(F.leaky_relu(self.bn4(self.conv4(x))))
        x = self.pool(F.leaky_relu(self.bn5(self.conv5(x))))
        x = self.avg(x)
        x = x.reshape((-1, 512 * 1 * 1))
        x = self.fc(x)
        return x

In [5]:
cnn_model = ConvNet()

In [6]:
from syft import PhiTensor
import numpy as np

N = 10
C_in = 3
H_in = 50
W_in = 50


input_shape = (N, C_in, H_in, W_in)
x = PhiTensor(child=np.random.randint(low=0, high=255, size=input_shape),
              data_subjects=np.zeros(input_shape),
              min_vals=0,
              max_vals=255
             )

In [7]:
def create_phi_tensor():
    return PhiTensor(
        child=np.random.randint(0, 255, (50, 50, 3)),
        data_subjects=np.ones((50, 50, 3)) * np.random.choice([0, 1]),
        min_vals=0,
        max_vals=255
    )

def create_target_phi_tensor(input_shape):
    y = PhiTensor(child=np.random.randint(low=0, high=2, size=input_shape),
                  data_subjects=np.zeros(input_shape),
                  min_vals=0,
                  max_vals=1
             )
    return y

In [8]:
loss_fn = nn.CrossEntropyLoss()
prediction = cnn_model(x)


target = create_target_phi_tensor(10)
loss_fn(prediction, target)

tensor(0.9165)


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
loader_train = [(create_phi_tensor(), create_target_phi_tensor(1)) for i in range(10)]

In [None]:
epochs = 1
classes = 2
batch_size = 128
alpha = 0.002
device = 'cpu'

model = ConvNet().to(device)
pb_loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adamax(model.parameters(), lr=alpha)

In [None]:
from tqdm import tqdm

total_step = len(loader_train)
published_output_list = []
for epoch in range(epochs):
    for i, (images, labels) in tqdm(enumerate(loader_train)):        
        # Forward pass
        outputs = model(images)
        
        loss = pb_loss_func(outputs, labels)
        print("Loss: ",loss.child)
        
        
#         published_output = outputs.publish(
#             get_budget_for_user=get_budget_for_user, 
#             deduct_epsilon_for_user=deduct_epsilon_for_user, 
#             ledger=ledger, 
#             sigma=1000
#         ).decode()
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, epochs, i+1, total_step, loss.item()))