In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor, RandomCrop, RandomHorizontalFlip, Normalize, Compose, Lambda
import time

#Debug parameters:
PRINT_EPOCH_PROGRESS_MESSAGES = False
  #whether to print "Processing: 10000/50000" messages while processing epochs
PRINT_EPOCH_PROCESSING_TIME = True
  #whether to print "Processed in: 6.17 seconds" messages after processing an epoch
EPOCH_PRINT_STRIDE = 1
  #number of epochs to skip printing (eg 10 will only print every 10th epoch)
  #last epoch is always printed
ALWAYS_PRINT_CLASS_ACCURACY = True
  #whether to print the accuracy per class ("Accuracy for dog: 20%") for every epoch
  #last epoch is always printed with class accuracy
PRINT_TRAINING_ACCURACY = True
  #whether to print the accuracy on the training set
  #useful for telling if a bad model is overfitting or just stupid
SAVE_NETWORK = True
  #whether to save the produced network
  #saves are performed after every epoch - each model has its own save file (the most recent file is saved)
SAVE_PATH = 'models'
  #networks are all saved to "models/<time the model was started>"
LOAD_NETWORK = False
LOAD_PATH = 'models/'
  #whether to load from a saved network
PROGRAM_START = time.time()

#Hyper-parameters:
EPOCHS = 100
  #number of iterations - a simple linear network can be as fast as ~5 seconds per epoch, a bigger one can be over 100
  #cpu tends to double the processing time
LEARNING_RATE = 1e-3
BATCH_SIZE = 4
LOSS_FUNCTION = nn.CrossEntropyLoss()
TRAIN_TRANSFORM = Compose(
    [RandomCrop(32, padding=4),
     RandomHorizontalFlip(), #alter the images in ways that don't change the subject, to give us "more" images to learn from
     ToTensor(),
     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #normalizing the data from [0,1] to [-1,1]
TEST_TRANSFORM = Compose(
    [ToTensor(),
     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #same normalization as training set

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device: ' + device)

train_data = torchvision.datasets.CIFAR10(
    root='datasets',
    train=True,
    download=True,
    transform=TRAIN_TRANSFORM
)
test_data = torchvision.datasets.CIFAR10(
    root='datasets',
    train=False,
    download=True,
    transform=TEST_TRANSFORM
)
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') #used for printing results by class

#Base network with very few modifications from source.
#(Unsurprisingly) hard to optimize, since the source has already done their best to optimize it. Dropout in conv block 2 is increased from 0.05 to 0.1.
class BiggerNeuralNetwork(nn.Module):
    def __init__(self):
        super(BiggerNeuralNetwork, self).__init__()
        self.convolutional = nn.Sequential(
            # Convolutional block 1
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            # Convolutional block 2
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(p=0.1),

            # Convolutional block 3
            nn.Conv2d(128, 256, 5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.linear = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.convolutional(x)
        x = x.view(-1, 4096)
        x = self.linear(x)
        return x

#Significantly increases processing time (over double!). Learning rate is noticably improved.
#Not run for long enough to see peak, but tentatively appears to be higher? Exciting!
class FiveKernelFinalBlock(nn.Module):
    def __init__(self):
        super(FiveKernelFinalBlock, self).__init__()
        self.convolutional = nn.Sequential(
            # Convolutional block 1
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            # Convolutional block 2
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(p=0.1),

            # Convolutional block 3
            nn.Conv2d(128, 256, 5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.linear = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.convolutional(x)
        x = x.view(-1, 4096)
        x = self.linear(x)
        return x

def train(dataloader, model, loss_fn, optimizer, do_print=False):
    size = len(dataloader.dataset)
    total_correct = 0
    for batch_num, (inputs, labels) in enumerate(dataloader):
        #move tensors to correct device
        inputs = inputs.to(device)
        labels = labels.to(device)
        # Compute prediction and loss
        prediction = model(inputs)
        loss = loss_fn(prediction, labels)
        # Backpropagate and optimize model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if PRINT_TRAINING_ACCURACY:
          total_correct += (prediction.argmax(1) == labels).type(torch.float).sum().item()
        #print something every 100 batches to let us know it's not dead
        if PRINT_EPOCH_PROGRESS_MESSAGES and do_print and batch_num % 100 == 0:
            current = batch_num * len(inputs)
            print("Processed: " + str(current) + "/" + str(size))
    if PRINT_TRAINING_ACCURACY:
      percentage_correct = round(100 * total_correct / size, 4)
      print("Training Set Accuracy: " + str(percentage_correct) + "%")

def test(dataloader, model, loss_fn, do_classes=False):
    total_loss = 0
    total_correct = 0
    #init dictionary of classes
    total_correct_class = {}
    for classname in classes:
        total_correct_class[classname] = 0

    with torch.no_grad(): #disable gradients when not training (makes it faster)
        for inputs, labels in dataloader:
            #move tensors to correct device
            inputs = inputs.to(device)
            labels = labels.to(device)
            #calculate test error
            prediction = model(inputs)
            total_loss += loss_fn(prediction, labels).item()
            total_correct += (prediction.argmax(1) == labels).type(torch.float).sum().item()
            if do_classes:
                #isolate label predictions:
                _, label_predictions = torch.max(prediction, 1)
                #need to process these individually, can't be handled as a batch
                for label, prediction in zip(labels, label_predictions):
                    if label == prediction:
                        total_correct_class[classes[label]] += 1

    size = len(dataloader.dataset)
    class_size = size // 10
    if do_classes:
        for classname, correct in total_correct_class.items():
            percentage_correct = round(100 * correct / class_size, 4)
            print("Accuracy for " + classname + ": " + str(percentage_correct) + "%")

    average_loss = round(total_loss / size, 5)
    percentage_correct = round(100 * total_correct / size, 4)
    print("Accuracy: " + str(percentage_correct) + "%")
    print("Average Loss: " + str(average_loss))

model = FiveKernelFinalBlock().to(device)
if LOAD_NETWORK:
    model = net.load_state_dict(torch.load(LOAD_PATH))

#optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
previous_time = time.time()
for epoch in range(EPOCHS):
    #only print if this matches an epoch print stride, or is the last epoch
    do_print = epoch % EPOCH_PRINT_STRIDE == 0 or epoch == EPOCHS-1
    #likewise for printing class accuracies
    do_classes = ALWAYS_PRINT_CLASS_ACCURACY or epoch == EPOCHS-1
    if do_print:
      print("-----------------------------")
      print("Epoch " + str(epoch+1))
    train(train_dataloader, model, LOSS_FUNCTION, optimizer, do_print)
    #only calculate results if printing them:
    if do_print:
      test(test_dataloader, model, LOSS_FUNCTION, do_classes)
      if PRINT_EPOCH_PROCESSING_TIME:
        print("Processed in: " + str(round(time.time() - previous_time,2)) + " seconds")
        previous_time = time.time()
    if SAVE_NETWORK:
      path = ''
      if LOAD_NETWORK:
        path = LOAD_PATH
      else:
        path = SAVE_PATH + '/' + str(program_start)
      torch.save(model.state_dict(), path)
      
print("finished :D")