In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor, Normalize, Compose, Lambda
import time

#Debug parameters:
PRINT_EPOCH_PROGRESS_MESSAGES = False
  #whether to print "Processing: 10000/50000" messages while processing epochs
PRINT_EPOCH_PROCESSING_TIME = False
  #whether to print "Processed in: 6.17 seconds" messages after processing an epoch
EPOCH_PRINT_STRIDE = 5
  #number of epochs to skip printing (eg 10 will only print every 10th epoch)
  #last epoch is always printed

#Hyper-parameters:
EPOCHS = 100
  #number of iterations - a simple network will be around ~5 seconds per epoch, a bigger one can be ~20
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
LOSS_FUNCTION = nn.CrossEntropyLoss()
TRANSFORM = Compose(
    [ToTensor(),
     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device: ' + device)

train_data = torchvision.datasets.CIFAR10(
    root='datasets',
    train=True,
    download=True,
    transform=TRANSFORM
)
test_data = torchvision.datasets.CIFAR10(
    root='datasets',
    train=False,
    download=True,
    transform=TRANSFORM
)
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.network_stack = nn.Sequential(
            nn.Linear(3*32*32, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        output = self.network_stack(x)
        return output

def train(dataloader, model, loss_fn, optimizer, do_print=False):
    size = len(dataloader.dataset)
    for batch_num, (inputs, labels) in enumerate(dataloader):
        #move tensors to correct device
        inputs = inputs.to(device)
        labels = labels.to(device)
        # Compute prediction and loss
        prediction = model(inputs)
        loss = loss_fn(prediction, labels)
        # Backpropagate and optimize model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #print something every 100 batches to let us know it's not dead
        if PRINT_EPOCH_PROGRESS_MESSAGES and do_print and batch_num % 100 == 0:
            current = batch_num * len(inputs)
            print("Processed: " + str(current) + "/" + str(size))

def test(dataloader, model, loss_fn):
    total_loss = 0
    total_correct = 0
    with torch.no_grad(): #disable gradients when not training (makes it faster)
        for inputs, labels in dataloader:
            #move tensors to correct device
            inputs = inputs.to(device)
            labels = labels.to(device)
            #calculate test error
            prediction = model(inputs)
            total_loss += loss_fn(prediction, labels).item()
            total_correct += (prediction.argmax(1) == labels).type(torch.float).sum().item()

    size = len(dataloader.dataset)
    average_loss = round(total_loss / size, 5)
    percentage_correct = 100 * round(total_correct / size, 4)
    print("Accuracy: " + str(percentage_correct) + "%")
    print("Average Loss: " + str(average_loss))

model = NeuralNetwork().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
previous_time = time.time()
for epoch in range(EPOCHS):
    #only print if this matches an epoch print stride, or is the last epoch
    do_print = epoch % EPOCH_PRINT_STRIDE == 0 or epoch == EPOCHS-1
    if do_print:
      print("-----------------------------")
      print("Epoch " + str(epoch+1))
    train(train_dataloader, model, LOSS_FUNCTION, optimizer, do_print)
    #only calculate results if printing them:
    if do_print:
      test(test_dataloader, model, LOSS_FUNCTION)
      if PRINT_EPOCH_PROCESSING_TIME:
        print("Processed in: " + str(round(time.time() - previous_time,2)) + " seconds")
        previous_time = time.time()
print("finished :D")