## Classification of Fashion MNIST with a Vanilla FFN Achitecture

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms

## Set up the class for our Vanilla FFN

In [6]:
# Build up the class for Vanilla FFN
class VanillaFFN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(VanillaFFN, self).__init__()
        self.fc1 = nn.Linear(in_features=input_size, out_features=input_size // 2)
        self.fc2 = nn.Linear(in_features=input_size // 2, out_features=num_classes)
    
    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x

#### *Sanity check for the forward pass on the network*

In [7]:
# model sanity_check
sample_net = VanillaFFN(784, 10)
print(sample_net(torch.randn(64, 784)).shape)

torch.Size([64, 10])


## Let's set up the training data loader that'll feed in batches of data

In [8]:
# Download Dataset
train_data = datasets.FashionMNIST(root='./data', 
                                   train=True, 
                                   transform=transforms.Compose([
                                       transforms.ToTensor()
                                   ]),
                                   download=True)
# Set up the data loader for our training data.
train_data_loader = DataLoader(dataset=train_data, 
                               shuffle=True,
                               batch_size=512)

In [9]:
# sanity check for train data loader.
for idx, data_tuple in enumerate(train_data_loader):
    print(data_tuple[0].shape, data_tuple[1].shape)
    break

torch.Size([512, 1, 28, 28]) torch.Size([512])


## Device Assignment and Training Loop

In [10]:
# check and setup device tensor computations should be assigned to
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [11]:
INPUT_SIZE = 28 * 28
NUM_CLASSES = 10

vanilla_fnn_object = VanillaFFN(INPUT_SIZE, NUM_CLASSES)
vanilla_fnn_object.to(device=device)
loss_criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vanilla_fnn_object.parameters(), lr=0.001)

for epoch in range(10):
    loss = 0
    for iteration, (data, targets) in enumerate(train_data_loader):
        # assign data and targets to device.
        data = data.to(device=device)
        targets = targets.to(device=device)
        
        # forward pass, compute losses, backpropagate.
        outputs = vanilla_fnn_object(data)
        loss = loss_criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch: {}, Loss: {}".format(epoch + 1, loss))
print("Learning Done!")

Epoch: 1, Loss: 0.5267444252967834
Epoch: 2, Loss: 0.5788680911064148
Epoch: 3, Loss: 0.656249463558197
Epoch: 4, Loss: 0.35731497406959534
Epoch: 5, Loss: 0.38181138038635254
Epoch: 6, Loss: 0.4821494519710541
Epoch: 7, Loss: 0.28118789196014404
Epoch: 8, Loss: 0.24160528182983398
Epoch: 9, Loss: 0.39852407574653625
Epoch: 10, Loss: 0.3136313855648041
Learning Done!


## Testing data loader set up followed by "Accuracy Test"

In [13]:
# set up testing data and data loader
test_data = datasets.FashionMNIST(root='./data',
                                  train=False,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor()
                                  ]))

test_data_loader = DataLoader(dataset=test_data,
                              shuffle=True,
                              batch_size=512)

In [14]:
# Lets check the accuracy on the testing set.
num_correct = 0
num_samples = 0
vanilla_fnn_object.eval()    # to set in eval mode

with torch.no_grad():
    for iteration, (data, targets) in enumerate(test_data_loader):
        # assign data and targets to device.
        data = data.to(device=device)
        targets = targets.to(device=device)
        outputs = vanilla_fnn_object(data)
        
        values, index_of_max_value = outputs.max(1)
        num_correct += (index_of_max_value == targets).sum()
        num_samples += targets.size(0)
    
    print(num_correct, num_samples)
    print("Testing Accuracy of Vanilla FNN model on the FMNIST dataset: {}".format((num_correct.item() / num_samples) * 100))

tensor(8718, device='cuda:0') 10000
Testing Accuracy of Vanilla FNN model on the FMNIST dataset: 87.18
