# Classifying Fashion-MNIST

Now it's your turn to build and train a neural network. You'll be using the [Fashion-MNIST dataset](https://github.com/zalandoresearch/fashion-mnist), a drop-in replacement for the MNIST dataset. MNIST is actually quite trivial with neural networks where you can easily achieve better than 97% accuracy. Fashion-MNIST is a set of 28x28 greyscale images of clothes. It's more complex than MNIST, so it's a better representation of the actual performance of your network, and a better representation of datasets you'll use in the real world.

In this notebook, you'll build your own neural network. For the most part, you could just copy and paste the code from Part 3, but you wouldn't be learning. It's important for you to write the code yourself and get it to work. Feel free to consult the previous notebooks though as you work through this.

First off, let's load the dataset through torchvision.

In [8]:
import torch
from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True)

# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, shuffle=True)

Here we can see one of the images.

In [2]:
from matplotlib import pyplot as plt

image, label = next(iter(trainloader))
plt.imshow(image[0].view([28,28]))
plt.show()

<Figure size 640x480 with 1 Axes>

## Building the network

Here you should define your network. As with MNIST, each image is 28x28 which is a total of 784 pixels, and there are 10 classes. You should include at least one hidden layer. We suggest you use ReLU activations for the layers and to return the logits or log-softmax from the forward pass. It's up to you how many layers you add and the size of those layers.

In [3]:
from torch import nn, optim
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.h1 = nn.Linear(784, 256)
        self.h2 = nn.Linear(256, 128)
        self.h3 = nn.Linear(128, 64)
        self.out = nn.Linear(64, 10)
        self.relu = nn.ReLU()
        self.logsoftmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.h1(x)
        x = self.relu(x)
        x = self.h2(x)
        x = self.relu(x)

        x = self.h3(x)
        x = self.relu(x)

        x = self.out(x)
        x = self.logsoftmax(x)
        return x

## Sending Data to Virtual Workers

In [4]:
# Importing syft, torch and hooking torch
import syft as sy, torch as th
hook = sy.TorchHook(th)

In [5]:
workers_dict = {}
i = 0

for images, labels in trainloader:
    worker_name = 'W{}'.format(i)
    workers_dict[worker_name] = {}
    workers_dict[worker_name]['worker'] = sy.VirtualWorker(hook, id=worker_name)
    workers_dict[worker_name]['images'] = images.send(workers_dict[worker_name]['worker'])
    workers_dict[worker_name]['labels'] = labels.send(workers_dict[worker_name]['worker'])
    i+=1

In [6]:
# # not needed because virtual workers already have all other virtual workers added

# for worker in workers_dict.values():
#     worker['worker'].add_workers([
#         w['worker'] for w in workers_dict.values() if w['worker'].id != worker['worker'].id
#     ])

In [7]:
model = Network()
criterion = nn.NLLLoss()
optmizer = optim.SGD(model.parameters(), lr=0.01)
secure_worker = sy.VirtualWorker(hook, id='secure_worker')

In [9]:
# copying model to each worker
for worker in workers_dict.values():
    worker['model'] = model.copy().send(worker['worker'])
    worker['optmizer'] = optim.SGD(model.parameters(), lr=0.01)

In [None]:
epoch = 100

for i in range(epoch):
    
    # training each workers' models
    for worker in workers_dict.values():
        
        w_model = worker['model']
        w_images = worker['images']
        w_opt = worker['optmizer']
        w_labels = worker['labels']
        
        output = w_model(w_images)
        
        w_opt.zero_grad()

        loss = criterion(output, w_labels)

        loss.backward()
        w_opt.step()
        
    else:
        
        # sending models to secure worker
        for worker in workers_dict.values():
            worker['model'].move(secure_worker)
        
        # averaging models
        h1_weights = th.stack([w['model'].h1.weight.data for w in workers_dict.values()]).mean(0)
        h1_bias = th.stack([w['model'].h1.bias.data for w in workers_dict.values()]).mean(0)

        h2_weights = th.stack([w['model'].h2.weight.data for w in workers_dict.values()]).mean(0)
        h2_bias = th.stack([w['model'].h2.bias.data for w in workers_dict.values()]).mean(0)

        h3_weights = th.stack([w['model'].h3.weight.data for w in workers_dict.values()]).mean(0)
        h3_bias = th.stack([w['model'].h3.bias.data for w in workers_dict.values()]).mean(0)

        out_weights = th.stack([w['model'].out.weight.data for w in workers_dict.values()]).mean(0)
        out_bias = th.stack([w['model'].out.bias.data for w in workers_dict.values()]).mean(0)
        
        # updating local model 
        with th.no_grad():
            model.h1.weight.set_(h1_weights.get())
            model.h1.bias.set_(h1_bias.get())
            model.h2.weight.set_(h2_weights.get())
            model.h2.bias.set_(h2_bias.get())
            model.h3.weight.set_(h3_weights.get())
            model.h3.bias.set_(h3_bias.get())
            model.out.weight.set_(out_weights.get())
            model.out.bias.set_(out_bias.get())
        
        # copying updated model to virtual workers
        for worker in workers_dict.values():
            worker['model'] = model.copy().send(worker['worker'])    

In [11]:
# epoch = 10

# for i in range(epoch):
#     for worker in workers_dict.values():
        
#         w_model = worker['model']
#         w_images = worker['images']
#         w_opt = worker['optmizer']
#         w_labels = worker['labels']
        
#         output = w_model(w_images)
        
#         w_opt.zero_grad()

#         loss = criterion(output, w_labels)

#         loss.backward()
#         w_opt.step()
# else:
#     for worker in workers_dict.values():
#         worker['model'].move(secure_worker)

In [12]:
# h1_weights = th.stack([w['model'].h1.weight.data for w in workers_dict.values()]).mean(0)
# h1_bias = th.stack([w['model'].h1.bias.data for w in workers_dict.values()]).mean(0)

# h2_weights = th.stack([w['model'].h2.weight.data for w in workers_dict.values()]).mean(0)
# h2_bias = th.stack([w['model'].h2.bias.data for w in workers_dict.values()]).mean(0)

# h3_weights = th.stack([w['model'].h3.weight.data for w in workers_dict.values()]).mean(0)
# h3_bias = th.stack([w['model'].h3.bias.data for w in workers_dict.values()]).mean(0)

# out_weights = th.stack([w['model'].out.weight.data for w in workers_dict.values()]).mean(0)
# out_bias = th.stack([w['model'].out.bias.data for w in workers_dict.values()]).mean(0)

In [13]:
# with th.no_grad():
#     model.h1.weight.set_(h1_weights.get())
#     model.h1.bias.set_(h1_bias.get())
#     model.h2.weight.set_(h2_weights.get())
#     model.h2.bias.set_(h2_bias.get())
#     model.h3.weight.set_(h3_weights.get())
#     model.h3.bias.set_(h3_bias.get())
#     model.out.weight.set_(out_weights.get())
#     model.out.bias.set_(out_bias.get())

In [14]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

dataiter = iter(testloader)
images, labels = dataiter.next()
img = images[0]
label = labels[0]

# plt.imshow(img.view([28,28]))

# Convert 2D image to 1D vector
img = img.resize_(1, 784)

ps = torch.exp(model(img))
predicted = ps.argmax().item()

print('Label: {}'.format(label))
print('Predicted: {}'.format(predicted))
# plt.show()

Label: 2
Predicted: 3


In [22]:
images, labels = next(iter(testloader))
(model(images).argmax(1) == labels).double().mean().item()

0.072265625