In [None]:
# Adapted from https://neptune.ai/blog/neural-network-guide
import torch
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
import torch.optim as optim

from collections import Counter
import matplotlib.pyplot as plt
import numpy as np

In [None]:
data_dir = "/content/gdrive"
folder_name = "kaggle_dataaset"
image_folders = os.path.join(data_dir, folder_name)

transform = transforms.Compose([transforms.Resize((50, 50)), transforms.ToTensor()])
images = []
for file in os.listdir(image_folders):
    try:
      images.append(ImageFolder(os.path.join(image_folders, file), transform=transform))
    except:
      print(file)
datasets = torch.utils.data.ConcatDataset(images)

In [None]:
i = 0
for dataset in datasets.datasets:
    if not i:
        result = Counter(dataset.targets)
        i += 1
    else:
        result += Counter(dataset.targets)

result = dict(result)
print("""Total Number of Images for each Class:
    Class 0 (No Breast Cancer): {}
    Class 1 (Breast Cancer present): {}""".format(result[0], result[1]))

In [None]:
random_seed = 189 # the CS version
torch.manual_seed(random_seed)

test_size = int(0.25 * (result[0] + result[1]))
print(test_size)
train_size = len(datasets) - test_size
train_dataset, test_dataset = random_split(datasets, [train_size, test_size])

In [None]:
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=128,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=64,
                                         shuffle=False, num_workers=2)

In [None]:
# functions to show an image

def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images[:6], nrow=3))
# show labels
labels[:6]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

### CalHacks 8.0 FHE
Using Taylor series expansion of smooth ReLU approximation (Softplus) of:

$$
f(x) = \ln{(1+e^x)} \\
\approxeq \ln{2} + \frac{x}{2} + \frac{x^2}{8} + O(x^4)
$$.

And smooth maximum approximation:

$$
\max{\textbf{x}} = \frac{\sum_{x_i \in \textbf{x}} x_i^{\left(p+k\right)}}{\sum_{x_i \in \textbf{x}} x^p}
$$
for large constant $p, k \in \mathbb{Z}^+$.

Above approximations chosen to ensure that all operations are within the subset allowed for Fully-Homomorphic Encryption (FHE). We inject them into `pytorch` here.

In [None]:
from fhe_approx import patch_relu, patch_maxpools
patch_relu(F, nn)
patch_maxpools(F, nn)

### FHE for log-softmax
To handle log-softmax with the subset of operations permitted in FHE context, we approximate $\log{x}$ using Taylor series expansion with radius 2 about origin as $$\ln{x} \approxeq (x - 1) - \frac{1}{2} \left(x-1\right)^2 +\frac{1}{3}\left(x-1\right)^3 - \frac{1}{4} \left(x-1\right)^4 + o(x^5)$$.

We similarly handle the exponential function using the Maclaurin Series as $$e^x \triangleq \exp{x} = \sum_{i=0}^{\infty} \frac{x^i}{i!} \\ \approxeq 1 + \frac{x!}{1} + \frac{x^2}{2!} + \frac{x^3}{3!} + \frac{x^4}{4!} + o(x^5)$$. 

Putting these approximations together, we can handle softmax, which is defined as $$
s(\textbf{x}) _i \triangleq \frac{e^{z_i}}{\sum_{x_i \in \textbf{x}} e^{z_i}}
$$
for $\textbf{z} \in \mathbb{R}^k$. Using the exponential function approximation above, it is tractable in the FHE context.

Finally, we simply take the natural log of the output of the above using the log approximation we defined earlier.

In [None]:
from fhe_approx import patch_logsoftmax
patch_logsoftmax(F, nn)

In [None]:
class BreastCancerClassifyNet(nn.Module):
  def __init__(self):
    super(BreastCancerClassifyNet, self).__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
    self.conv2 = nn.Conv2d(64, 128, kernel_size=3)
    self.conv3 = nn.Conv2d(128, 256, kernel_size=3)
    self.pool = nn.MaxPool2d(2, 2)
    self.fc1 = nn.Linear(4096, 1024)
    self.fc2 = nn.Linear(1024, 512)
    self.fc3 = nn.Linear(512, 1)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = self.pool(F.relu(self.conv3(x)))
    x = x.view(-1, self.flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    x = F.log_softmax(x)
    return x

  def flat_features(self, x):
    size = x.size()[1:]
    num_features = 1
    for s in size:
      num_features *= s
    return num_features

net = BreastCancerClassifyNet()
net = net.to(device)

In [None]:
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr = 0.001)

In [None]:
test_data_iter = iter(testloader)
test_images, test_labels = test_data_iter.next()
for epoch in range(20):
  running_loss = 0
  for i, data in enumerate(trainloader, 0):
    input_imgs, labels = data
    input_imgs = input_imgs.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = net(input_imgs)
    labels = labels.unsqueeze(1).float()
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    #printing stats and checking prediction as we train
    running_loss += loss.item()
    if i % 10000 == 0:
      print('epoch', epoch+1, 'loss', running_loss/10000)
      imshow(torchvision.utils.make_grid(test_images[0].detach()))
      test_out = net(test_images.to(device))
      _, predicted_out = torch.max(test_out, 1)
      print('Predicted : ', ' '.join('%5s' % predicted_out[0]))


print('Training finished')

In [None]:
correct = 0
total = 0
with torch.no_grad():
  for data in testloader:
    test_images, test_labels = data
    test_out = net(test_images.to(device))
    _, predicted = torch.max(test_out.data, 1)
    total += test_labels.size(0)
    for _id, out_pred in enumerate(predicted):
      if int(out_pred) == int(test_labels[_id]):
        correct += 1

print('Accuracy of the network on the 44252 test images: %d %%' % (
        100 * correct / total))