In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

from skimage.util import random_noise

In [None]:


# Get cpu or gpu device for training.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(DEVICE))



In [None]:


class NeuralNetwork(nn.Module):
    """
    Define a demo neural network
    """

    def __init__(self):
        super().__init__()

        # If using a Linear network
        # self.flatten = nn.Flatten()
        # self.linear_relu_stack = nn.Sequential(
        #     nn.Linear(28 * 28, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 10)
        # )

        # If using a CNN
        self.conv_relu_stack = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout2d(p=0.5),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout2d(p=0.5)
        )
        self.linear = nn.Linear(7 * 7 * 64, 10, bias=True)

    def forward(self, sample):
        """
        Computes the outputs of the network from the input sample

        Parameters
        ----------
        sample : tensor
            The input sample

        Returns
        -------
        out : tensor
            The output of the network
        """
        # For a Linear network
        # sample = self.flatten(sample)
        # out = self.linear_relu_stack(sample)

        # For a CNN
        out = self.conv_relu_stack(sample)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out



In [None]:


class SaltPepperTransform:
    """
    Define a custom PyTorch transform to implement 
    Salt and Pepper Data Augmentation
    """

    def __init__(self, amount):
        """
        Pass custom parameters to the transform in init

        Parameters
        ----------
        amount : float
            The amount of salt and pepper noise to add to the image sample
        """
        super().__init__()
        self.amount = amount

        # conversion transforms we will use
        self.to_tensor = transforms.ToTensor()
        self.to_pil = transforms.ToPILImage()

    def __call__(self, sample):
        """
        Transform the sample when called

        Parameters
        ----------
        sample : PIL.Image
            The image to augment with noise

        Returns
        -------
        noise_img : PIL.Image
            The image with noise added
        """
        salt_img = torch.tensor(random_noise(self.to_tensor(sample),
                                             mode='salt', amount=self.amount))

        return self.to_pil(salt_img)



In [None]:


def get_dataloaders():
    """
    Gets the DataLoaders that will be used in the demo

    Returns
    -------
    train_dataloader : torch.utils.data.DataLoader
        The DataLoader containing the training data
    test_dataloader : torch.utils.data.DataLoader
        The DataLoader containing the test data
    """

    # Download training data from open datasets.
    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        # transform=ToTensor(),
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(90),
            transforms.RandomAffine(90, (0.3, 0.3), (1.0, 2.0)),
            SaltPepperTransform(0.05),
            transforms.ToTensor()
        ]),
    )

    # Download test data from open datasets.
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor(),
    )

    train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE)
    test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE)

    test_data = test_dataloader.dataset[0]
    test_sample = test_data[0]
    test_score = test_data[1]
    print('Shape of Test Samples: [N, C, H, W]: {}'.format(test_sample.shape))
    print('Test Score y: {} type: {}'.format(test_score, type(test_score)))

    return train_dataloader, test_dataloader



In [None]:


def train_network(dataloader, training_model, loss_function):
    """
    Trains the demo network

    Parameters
    ----------
    dataloader : torch.utils.data.DataLoader
        The DataLoader with the training data
    training_model : nn.Module
        The network being trained
    loss_function : CrossEntropyLoss
        The loss function used in training

    Returns
    -------
    None
    """
    # Define the optimizer that will train the network model
    optimizer = torch.optim.SGD(training_model.parameters(), lr=1e-3)

    # Set the network in training mode that enables some training-only 
    # layers like dropout
    training_model.train()

    # Perform the training
    total_training_size = len(dataloader.dataset)
    for batch, (sample, score) in enumerate(dataloader):

        # Copy the sample to the device doing the calculations
        sample, score = sample.to(DEVICE), score.to(DEVICE)

        # Compute prediction error between the model output and the 
        # training score
        prediction = training_model(sample)
        loss = loss_function(prediction, score)

        # Perform backpropagation and update the network weights to minimize 
        # the error
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Output the score every 100th training iteration so output 
        # window is not spammed
        if batch % 100 == 0:
            loss, current_batch = loss.item(), batch * len(sample)
            print('Current Loss: {} Training Progress: [{} / {}]'.format(
                loss, current_batch, total_training_size))



In [None]:


def test_network(dataloader, test_model, loss_function):
    """
    Runs the demo network in validation mode

    Parameters
    ----------
    dataloader : torch.utils.data.DataLoader
        The DataLoader with the validation testing data
    test_model : nn.Module
        The network being validated
    loss_function : CrossEntropyLoss
        The loss function used in validation

    Returns
    -------
    None
    """
    # Set the network into evaluation mode
    test_model.eval()

    # Initializes the counts to zero
    num_correct = 0.0
    total_loss = 0.0

    # Disable the gradient calculation since not training
    with torch.no_grad():
        for sample, score in dataloader:

            # Copy the sample to the device doing the calculations
            sample, score = sample.to(DEVICE), score.to(DEVICE)

            # Compute prediction error between the model output and the 
            # training score
            prediction = test_model(sample)
            total_loss = loss_function(prediction, score).item()

            num_correct += (prediction.argmax(1) == score).type(
                               torch.float).sum().item()

    # Calculate the average loss by dividing by the number of batches
    average_loss = total_loss / len(dataloader)

    # Calculate the fraction of correct outputs
    percent_correct = 100.0 * (num_correct / len(dataloader.dataset))
    print('Accuracy: {}%, Average Loss: {}'.format(
        percent_correct, average_loss))




In [None]:


if __name__ == '__main__':

    # Set the training batch size and number of training epochs
    BATCH_SIZE = 64
    EPOCHS = 5

    # Get the network model
    model = NeuralNetwork().to(DEVICE)
    print(model)

    # Define the loss function to use and
    loss_func = nn.CrossEntropyLoss()

    # Get the training and testing data. If the data does not exist on 
    # the local disk, it will download and save it. Local copy will 
    # be used on future runs
    train_dl, test_dl = get_dataloaders()

    for i in range(EPOCHS):
        print('Training Epoch: {}'.format(i + 1))
        print('---------------------')
        train_network(train_dl, model, loss_func)
        test_network(test_dl, model, loss_func)
    print('Complete')

