In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor, Normalize, Compose, Lambda
import torch.nn.functional as F
import time

#Debug parameters:
PRINT_EPOCH_PROGRESS_MESSAGES = False
  #whether to print "Processing: 10000/50000" messages while processing epochs
PRINT_EPOCH_PROCESSING_TIME = True
  #whether to print "Processed in: 6.17 seconds" messages after processing an epoch
EPOCH_PRINT_STRIDE = 1
  #number of epochs to skip printing (eg 10 will only print every 10th epoch)
  #last epoch is always printed
ALWAYS_PRINT_CLASS_ACCURACY = True
  #whether to print the accuracy per class ("Accuracy for dog: 20%") for every epoch
  #last epoch is always printed with class accuracy

#Hyper-parameters:
EPOCHS = 100
  #number of iterations - a simple network will be around ~5 seconds per epoch, a bigger one can be ~20
LEARNING_RATE = 1e-3
BATCH_SIZE = 4
LOSS_FUNCTION = nn.CrossEntropyLoss()
TRANSFORM = Compose(
    [ToTensor(),
     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #normalizing the data from [0,1] to [-1,1], which makes the net happier for reasons i don't understand

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device: ' + device)

train_data = torchvision.datasets.CIFAR10(
    root='datasets',
    train=True,
    download=True,
    transform=TRANSFORM
)
test_data = torchvision.datasets.CIFAR10(
    root='datasets',
    train=False,
    download=True,
    transform=TRANSFORM
)
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') #used for printing results by class

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.convolutional = nn.Sequential(
            nn.Conv2d(3, 8, 5),
            #take 5x5 convolutions to turn the 3 input colour channels into 8 output channels
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            #squish a channel to half size by taking the highest value in every 2x2 block
            #this is done to keep the processing small - most of these values are very similar
            nn.Conv2d(8, 32, 5),
            #turn those 8 channels into 32
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            #squish again!
        )
        self.linear = nn.Sequential(
            nn.Linear(32 * 5 * 5, 200),
            #take the 32 channels (now 5x5 in size due to squishing and edges being lost from no padding)
            #and make a fullly connected neural net to make it into a single flat matrix
            nn.ReLU(),
            nn.Linear(200, 100),
            #do another linear step so the network can make some clever deductions
            nn.ReLU(),
            nn.Linear(100, 10),
            #finally reduce to 10 outputs - these are our output classes
        )

    def forward(self, x):
        x = self.convolutional(x)
        x = x.view(-1, 32 * 5 * 5)
          #take the [4, 32, 5, 5] tensor and resize it to be a [4, 32*5*5] tensor
          #so we can do linear stuffs with it
        x = self.linear(x)
        return x

def train(dataloader, model, loss_fn, optimizer, do_print=False):
    size = len(dataloader.dataset)
    for batch_num, (inputs, labels) in enumerate(dataloader):
        #move tensors to correct device
        inputs = inputs.to(device)
        labels = labels.to(device)
        # Compute prediction and loss
        prediction = model(inputs)
        loss = loss_fn(prediction, labels)
        # Backpropagate and optimize model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #print something every 100 batches to let us know it's not dead
        if PRINT_EPOCH_PROGRESS_MESSAGES and do_print and batch_num % 100 == 0:
            current = batch_num * len(inputs)
            print("Processed: " + str(current) + "/" + str(size))

def test(dataloader, model, loss_fn, do_classes=False):
    total_loss = 0
    total_correct = 0
    #init dictionary of classes
    total_correct_class = {}
    for classname in classes:
        total_correct_class[classname] = 0

    with torch.no_grad(): #disable gradients when not training (makes it faster)
        for inputs, labels in dataloader:
            #move tensors to correct device
            inputs = inputs.to(device)
            labels = labels.to(device)
            #calculate test error
            prediction = model(inputs)
            total_loss += loss_fn(prediction, labels).item()
            total_correct += (prediction.argmax(1) == labels).type(torch.float).sum().item()
            if do_classes:
                #isolate label predictions:
                _, label_predictions = torch.max(prediction, 1)
                #need to process these individually, can't be handled as a batch
                for label, prediction in zip(labels, label_predictions):
                    if label == prediction:
                        total_correct_class[classes[label]] += 1

    size = len(dataloader.dataset)
    class_size = size // 10
    if do_classes:
        for classname, correct in total_correct_class.items():
            percentage_correct = round(100 * correct / class_size, 4)
            print("Accuracy for " + classname + ": " + str(percentage_correct) + "%")

    average_loss = round(total_loss / size, 5)
    percentage_correct = round(100 * total_correct / size, 4)
    print("Accuracy: " + str(percentage_correct) + "%")
    print("Average Loss: " + str(average_loss))

model = NeuralNetwork().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
previous_time = time.time()
for epoch in range(EPOCHS):
    #only print if this matches an epoch print stride, or is the last epoch
    do_print = epoch % EPOCH_PRINT_STRIDE == 0 or epoch == EPOCHS-1
    #likewise for printing class accuracies
    do_classes = ALWAYS_PRINT_CLASS_ACCURACY or epoch == EPOCHS-1
    if do_print:
      print("-----------------------------")
      print("Epoch " + str(epoch+1))
    train(train_dataloader, model, LOSS_FUNCTION, optimizer, do_print)
    #only calculate results if printing them:
    if do_print:
      test(test_dataloader, model, LOSS_FUNCTION, do_classes)
      if PRINT_EPOCH_PROCESSING_TIME:
        print("Processed in: " + str(round(time.time() - previous_time,2)) + " seconds")
        previous_time = time.time()
      
print("finished :D")