In [1]:
# Get all the variables, classes and functions we defined in the previous lessons
from vars.week_3 import *

# Import new modules
import torch.nn as nn 
from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


# 4. Defining simple models 
## 4.1 nn.Sequential
The `Sequential` class is very simple: it accepts a sequence of neural network `Modules` as arguments and arranges them such that the output of one is automatically sent to the input of the next in line. This saves us a bit of time writing some code, but has some drawbacks, as we shall see shortly. The following is the simplest possible neural network. It consists only of an input layer, 1 hidden layer and an output layer. It is good practice to print out the summary of your network using `torchsummary.summary`. This lets you inspect your networks parameters and the input/output sizes of each layer. Interestingly, it also acts as a sort of sanity checker for your model, because it will complain if the input/output sizes of your layers aren't compatible with each other.

In [2]:
def get_simple_linear_net():
    return nn.Sequential(
        nn.Flatten(),                # Input is a 2d array of pixel values, so we flatten it out to 1d
        nn.Linear(28*28, 128),       # Input layer connects each input node to each hidden node. MNIST images are 28*28 pixels, hidden size can be anything we want
        nn.ReLU(),                   # ReLU activation only lets a signal through if it is > 0
        nn.Linear(128, 10)  # Output connects each node in the hidden layer to 10 output classes - the number of digits we want to classify!
        
    )

summary(get_simple_linear_net(), input_size=(1, 28, 28), device="cpu")


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1                  [-1, 784]               0
            Linear-2                  [-1, 128]         100,480
              ReLU-3                  [-1, 128]               0
            Linear-4                   [-1, 10]           1,290
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.39
Estimated Total Size (MB): 0.40
----------------------------------------------------------------


### 4.1.1  Simple training loop
Now that we've defined a network, we can start training it! Let's define the simplest possible training function, for which we only require the model, number of training epochs, the dataloader and an optimisation function

In [12]:
def train_model(model, epochs, train_dl, optimiser):
    msg = ""
    for epoch in range(epochs):
        total_steps = len(train_dl)
        correct = 0
        total = 0

        model.train()  # set model to training mode
        for batch_num, (image_batch, label_batch) in enumerate(train_dl):
            # Prepare data and label batches 
            batch_sz = len(image_batch)
            output = model(image_batch)            
            # explain categorical cross-entropy loss
            losses = nn.CrossEntropyLoss()(output, label_batch)
            # Zero gradients and backpropagate the losses
            optimiser.zero_grad()
            losses.backward()
            optimiser.step()  # update model weights based on loss gradients

            # Update the total number of correct predictions and calculate accuracy
            # this needs some explanation, draw diagram
            preds = torch.argmax(output, dim=1) 
            correct += int(torch.eq(preds, label_batch).sum())
            total += batch_sz
            minibatch_accuracy = 100 * correct / total

            #### Fancy printing stuff, you can ignore this! ######
            if (batch_num + 1) % 5 == 0:
                print(" " * len(msg), end='\r')
                msg = f'Train epoch[{epoch+1}/{epochs}], MiniBatch[{batch_num + 1}/{total_steps}], Loss: {losses.item():.5f}, Acc: {minibatch_accuracy:.5f}'
                print (msg, end='\r' if epoch < epochs else "\n",flush=True)
            #### Fancy printing stuff, you can ignore this! ######

In [11]:
from torch.optim import SGD

# Defining network hyperparameters
epochs = 5
batch_sz = 32
learning_rate = 0.005
train_dl, val_dl, test_dl = load_data(DATA_PATH, batch_sz=batch_sz)   # Creating a data split

network = get_simple_linear_net()                    # Creating an instance of our network
optim = SGD(network.parameters(), lr=learning_rate)  # Stochastic gradient descent optimiser
train_model(network, epochs, train_dl, optim)        # Calling our training function

tensor([4, 8, 7, 8, 3, 8, 2, 5, 9, 0, 9, 6, 2, 4, 6, 1, 2, 4, 6, 7, 6, 4, 0, 4,
        2, 5, 0, 2, 5, 2, 6, 6])
tensor([6, 9, 1, 2, 4, 3, 1, 6, 3, 2, 4, 1, 7, 7, 1, 0, 8, 4, 4, 3, 6, 7, 9, 7,
        3, 8, 8, 7, 5, 4, 2, 3])
tensor([4, 9, 6, 5, 4, 3, 0, 6, 9, 7, 7, 8, 6, 3, 0, 3, 1, 9, 3, 1, 9, 3, 5, 5,
        7, 3, 5, 2, 0, 1, 3, 5])
tensor([0, 7, 5, 6, 5, 8, 7, 9, 2, 6, 9, 7, 2, 7, 5, 9, 7, 3, 9, 3, 0, 5, 3, 1,
        2, 2, 4, 6, 9, 8, 8, 7])
tensor([2, 7, 9, 8, 2, 2, 0, 8, 4, 9, 2, 3, 4, 6, 7, 3, 8, 0, 0, 6, 7, 7, 7, 2,
        0, 4, 2, 7, 8, 2, 3, 1])
tensor([1, 6, 8, 5, 0, 4, 1, 1, 6, 9, 6, 8, 4, 2, 0, 4, 5, 1, 1, 0, 9, 7, 2, 9,
        1, 3, 4, 5, 3, 2, 2, 1])
tensor([0, 7, 3, 0, 9, 6, 8, 3, 0, 9, 6, 6, 4, 0, 9, 1, 3, 9, 7, 4, 1, 6, 7, 3,
        5, 4, 6, 4, 5, 1, 1, 2])
tensor([5, 6, 9, 8, 2, 5, 1, 4, 8, 6, 7, 3, 2, 7, 8, 0, 5, 1, 2, 2, 0, 6, 5, 6,
        5, 6, 8, 8, 2, 0, 3, 1])
tensor([0, 9, 3, 5, 6, 2, 2, 1, 4, 7, 3, 4, 7, 1, 3, 8, 6, 3, 2, 8, 9, 5, 9, 7,
        6, 6, 9,

tensor([5, 3, 8, 7, 3, 7, 2, 0, 9, 1, 8, 8, 1, 7, 7, 3, 9, 6, 1, 0, 9, 5, 0, 5,
        3, 8, 5, 2, 8, 3, 5, 6])
tensor([1, 7, 2, 7, 7, 0, 2, 1, 2, 0, 8, 4, 1, 9, 8, 8, 4, 8, 6, 2, 8, 0, 4, 6,
        9, 8, 1, 0, 8, 2, 7, 0])
tensor([1, 1, 3, 3, 2, 2, 7, 6, 0, 3, 6, 7, 7, 6, 3, 6, 2, 2, 3, 6, 9, 6, 7, 8,
        5, 4, 0, 8, 9, 0, 7, 5])
tensor([0, 3, 0, 0, 0, 7, 6, 2, 9, 7, 6, 5, 6, 4, 6, 8, 9, 7, 9, 1, 3, 9, 9, 3,
        7, 4, 6, 0, 3, 0, 2, 4])
tensor([5, 7, 8, 0, 7, 3, 5, 8, 1, 3, 9, 3, 6, 7, 6, 9, 4, 6, 8, 1, 7, 1, 0, 5,
        2, 6, 6, 3, 9, 6, 8, 7])
tensor([6, 4, 0, 8, 0, 2, 8, 0, 6, 9, 8, 1, 6, 7, 3, 2, 7, 7, 8, 3, 0, 2, 0, 5,
        7, 9, 1, 6, 6, 3, 8, 8])
tensor([1, 6, 7, 9, 5, 0, 1, 5, 7, 7, 7, 5, 7, 8, 0, 4, 5, 7, 5, 8, 0, 2, 7, 1,
        8, 5, 2, 6, 3, 0, 9, 3])
tensor([0, 1, 2, 9, 0, 6, 9, 8, 7, 4, 7, 3, 8, 1, 2, 6, 4, 5, 9, 1, 9, 9, 9, 6,
        4, 2, 1, 9, 6, 8, 1, 1])
tensor([8, 8, 3, 7, 2, 2, 6, 4, 1, 5, 0, 4, 9, 4, 5, 6, 4, 7, 2, 7, 5, 3, 5, 6,
        8, 5, 0,

tensor([1, 1, 3, 1, 3, 6, 8, 3, 9, 6, 6, 6, 6, 7, 4, 2, 5, 2, 7, 8, 3, 3, 0, 7,
        6, 2, 9, 6, 4, 4, 9, 7])
tensor([7, 5, 9, 5, 2, 2, 1, 8, 4, 9, 2, 7, 1, 0, 9, 8, 3, 3, 5, 1, 1, 7, 2, 2,
        3, 6, 9, 1, 2, 6, 6, 2])
tensor([7, 3, 7, 5, 2, 0, 3, 7, 0, 4, 5, 8, 8, 6, 9, 2, 6, 1, 4, 1, 8, 3, 1, 2,
        2, 0, 7, 3, 0, 1, 6, 9])
tensor([4, 6, 6, 3, 0, 0, 7, 6, 8, 9, 4, 5, 9, 0, 8, 1, 6, 9, 0, 8, 2, 4, 1, 5,
        2, 3, 3, 6, 9, 9, 5, 9])
tensor([3, 8, 9, 7, 1, 8, 2, 4, 2, 3, 7, 9, 9, 2, 6, 7, 0, 5, 1, 1, 0, 5, 9, 5,
        2, 9, 6, 4, 0, 0, 2, 7])
tensor([1, 2, 8, 2, 8, 0, 5, 3, 4, 7, 6, 1, 6, 8, 5, 7, 3, 1, 1, 5, 8, 5, 9, 8,
        3, 0, 4, 9, 3, 9, 1, 3])
tensor([8, 6, 8, 8, 5, 5, 6, 9, 9, 7, 2, 4, 9, 4, 0, 5, 6, 3, 5, 3, 8, 6, 5, 9,
        6, 8, 6, 4, 9, 6, 7, 7])
tensor([4, 9, 1, 4, 6, 3, 4, 4, 0, 3, 6, 8, 1, 1, 8, 2, 0, 2, 8, 8, 4, 6, 9, 1,
        7, 2, 6, 9, 2, 7, 1, 6])
tensor([4, 9, 1, 0, 4, 4, 8, 3, 9, 8, 6, 2, 8, 5, 8, 6, 7, 7, 2, 2, 9, 9, 1, 0,
        2, 1, 9,

tensor([1, 7, 8, 8, 4, 7, 4, 9, 8, 2, 6, 4, 8, 2, 9, 8, 1, 3, 3, 1, 5, 4, 7, 8,
        0, 9, 8, 5, 2, 8, 0, 8])
tensor([3, 8, 4, 3, 0, 6, 7, 3, 9, 2, 1, 8, 9, 6, 6, 3, 4, 9, 1, 9, 5, 8, 9, 1,
        5, 1, 0, 4, 9, 0, 3, 0])
tensor([4, 0, 1, 2, 3, 1, 9, 5, 4, 7, 6, 7, 4, 0, 4, 5, 1, 0, 1, 4, 5, 0, 4, 0,
        4, 7, 6, 0, 9, 1, 9, 7])
tensor([9, 5, 1, 8, 6, 7, 5, 6, 7, 1, 8, 6, 5, 3, 0, 9, 4, 2, 7, 7, 3, 4, 3, 2,
        4, 1, 4, 1, 0, 7, 5, 1])
tensor([3, 8, 0, 0, 8, 4, 6, 0, 7, 0, 9, 8, 8, 4, 7, 8, 6, 4, 3, 3, 2, 7, 3, 4,
        9, 1, 3, 1, 6, 8, 9, 6])
tensor([8, 9, 7, 3, 0, 7, 4, 3, 9, 3, 4, 6, 6, 5, 8, 7, 3, 5, 6, 6, 8, 7, 3, 2,
        0, 6, 6, 2, 4, 7, 6, 2])
tensor([3, 7, 0, 5, 7, 3, 6, 1, 1, 8, 7, 6, 3, 2, 4, 0, 4, 1, 4, 0, 3, 0, 1, 3,
        7, 6, 8, 7, 4, 5, 4, 9])
tensor([3, 2, 1, 7, 2, 3, 7, 4, 1, 5, 1, 6, 8, 5, 3, 5, 5, 3, 8, 1, 5, 1, 4, 7,
        9, 9, 5, 5, 2, 1, 6, 5])
tensor([4, 4, 4, 1, 6, 3, 8, 9, 0, 9, 6, 1, 0, 7, 8, 3, 7, 9, 5, 0, 3, 4, 4, 6,
        0, 8, 7,

tensor([2, 3, 0, 9, 9, 2, 4, 4, 7, 6, 9, 7, 7, 4, 8, 5, 5, 7, 0, 7, 2, 8, 4, 2,
        9, 6, 2, 3, 2, 1, 0, 3])
tensor([2, 2, 1, 2, 8, 3, 7, 5, 1, 8, 8, 6, 7, 5, 9, 8, 2, 5, 7, 0, 3, 4, 7, 3,
        5, 9, 1, 8, 6, 6, 5, 2])
tensor([7, 4, 0, 2, 0, 5, 4, 6, 8, 4, 1, 2, 1, 8, 1, 9, 6, 8, 0, 3, 9, 5, 1, 9,
        9, 2, 0, 3, 5, 9, 1, 8])
tensor([1, 3, 7, 6, 8, 6, 9, 3, 0, 7, 1, 4, 1, 3, 3, 1, 6, 9, 3, 7, 7, 4, 8, 8,
        4, 5, 0, 7, 8, 4, 2, 3])
tensor([9, 0, 0, 4, 9, 3, 2, 1, 5, 8, 4, 7, 4, 7, 1, 7, 8, 0, 7, 4, 1, 9, 3, 2,
        4, 3, 7, 4, 2, 4, 7, 9])
tensor([1, 0, 6, 0, 4, 0, 4, 2, 1, 8, 7, 5, 8, 6, 2, 4, 8, 1, 8, 3, 6, 2, 0, 1,
        4, 7, 5, 6, 9, 9, 5, 4])
tensor([5, 3, 1, 2, 5, 9, 3, 4, 9, 6, 9, 3, 4, 3, 0, 8, 6, 0, 2, 8, 2, 3, 1, 8,
        1, 3, 6, 0, 7, 0, 1, 7])
tensor([7, 0, 6, 1, 2, 0, 3, 8, 8, 9, 4, 5, 1, 4, 7, 9, 8, 0, 0, 1, 6, 5, 7, 2,
        5, 2, 6, 7, 4, 2, 8, 3])
tensor([4, 6, 2, 3, 7, 1, 3, 1, 9, 8, 1, 5, 6, 7, 8, 6, 0, 3, 2, 7, 1, 7, 1, 2,
        0, 6, 2,

tensor([9, 0, 7, 2, 9, 6, 5, 7, 7, 7, 4, 5, 7, 6, 8, 6, 8, 3, 9, 1, 0, 0, 2, 8,
        9, 0, 2, 0, 5, 5, 9, 8])
tensor([7, 3, 5, 2, 3, 7, 4, 8, 2, 3, 9, 8, 4, 2, 1, 9, 5, 2, 9, 9, 3, 5, 4, 7,
        2, 8, 9, 2, 4, 8, 4, 0])
tensor([7, 1, 3, 9, 2, 2, 2, 4, 7, 6, 0, 0, 4, 5, 4, 2, 1, 8, 2, 9, 9, 3, 4, 1,
        8, 9, 6, 0, 9, 9, 5, 1])
tensor([3, 3, 9, 9, 7, 8, 5, 1, 7, 2, 7, 4, 6, 6, 3, 1, 7, 8, 0, 1, 6, 9, 4, 3,
        0, 1, 1, 3, 4, 6, 3, 8])
tensor([5, 9, 5, 8, 5, 2, 8, 5, 1, 2, 1, 2, 2, 9, 6, 9, 9, 8, 7, 6, 1, 6, 1, 9,
        5, 1, 3, 4, 9, 6, 1, 1])
tensor([0, 8, 1, 3, 5, 7, 6, 2, 0, 8, 8, 9, 7, 1, 3, 7, 2, 1, 5, 2, 3, 5, 5, 9,
        3, 7, 4, 3, 9, 8, 4, 1])
tensor([2, 2, 0, 7, 9, 7, 5, 6, 2, 4, 9, 8, 8, 6, 7, 0, 7, 7, 3, 6, 3, 7, 6, 6,
        8, 5, 1, 4, 6, 0, 9, 9])
tensor([8, 3, 4, 4, 1, 5, 4, 5, 2, 2, 6, 9, 2, 8, 5, 1, 3, 1, 6, 9, 4, 1, 9, 0,
        3, 0, 1, 4, 4, 1, 7, 6])
tensor([5, 9, 2, 7, 6, 1, 3, 4, 1, 4, 3, 4, 4, 6, 6, 2, 8, 3, 8, 1, 2, 5, 9, 1,
        7, 7, 1,

tensor([4, 7, 5, 6, 3, 9, 8, 4, 2, 7, 2, 2, 4, 8, 7, 4, 1, 5, 3, 4, 9, 0, 9, 0,
        1, 1, 6, 4, 7, 1, 7, 7])
tensor([3, 0, 1, 2, 8, 4, 9, 5, 6, 9, 1, 6, 4, 5, 8, 1, 2, 8, 4, 2, 0, 5, 6, 8,
        5, 8, 4, 0, 4, 9, 3, 1])
tensor([2, 0, 2, 9, 5, 7, 0, 7, 7, 9, 4, 0, 6, 7, 7, 3, 0, 6, 4, 0, 3, 1, 0, 0,
        6, 5, 2, 0, 2, 3, 7, 2])
tensor([9, 1, 9, 8, 8, 0, 4, 8, 8, 0, 4, 5, 3, 1, 2, 4, 0, 1, 3, 0, 6, 2, 8, 5,
        3, 1, 6, 4, 8, 3, 7, 3])
tensor([0, 6, 6, 9, 8, 7, 6, 8, 0, 7, 1, 1, 8, 4, 9, 6, 5, 1, 4, 9, 2, 3, 6, 4,
        2, 1, 5, 6, 9, 6, 8, 9])
tensor([3, 3, 5, 5, 0, 2, 8, 7, 0, 3, 7, 8, 7, 4, 9, 5, 6, 9, 4, 0, 2, 7, 7, 8,
        5, 3, 9, 3, 6, 1, 1, 1])
tensor([3, 2, 4, 7, 1, 2, 9, 5, 9, 8, 8, 4, 0, 6, 9, 8, 1, 1, 7, 2, 2, 8, 6, 8,
        9, 7, 1, 3, 7, 7, 4, 8])
tensor([2, 4, 0, 9, 5, 0, 2, 0, 4, 9, 6, 7, 7, 5, 8, 3, 1, 9, 1, 3, 6, 7, 6, 9,
        1, 9, 2, 7, 4, 0, 3, 4])
tensor([9, 0, 2, 0, 8, 7, 1, 9, 7, 7, 8, 6, 9, 0, 6, 4, 9, 6, 0, 8, 9, 2, 7, 4,
        6, 0, 1,

tensor([3, 3, 1, 9, 7, 8, 0, 8, 5, 7, 2, 3, 1, 5, 1, 4, 7, 0, 4, 6, 5, 1, 5, 2,
        8, 1, 6, 0, 9, 5, 5, 2])
tensor([4, 1, 2, 8, 8, 1, 9, 0, 5, 1, 8, 4, 3, 2, 5, 9, 0, 3, 6, 5, 5, 1, 3, 6,
        5, 2, 1, 1, 8, 8, 2, 2])
tensor([2, 5, 6, 3, 5, 4, 6, 6, 8, 9, 4, 3, 6, 2, 9, 5, 0, 7, 0, 4, 3, 7, 6, 6,
        3, 6, 5, 1, 0, 8, 1, 4])
tensor([7, 0, 6, 5, 2, 2, 6, 3, 6, 0, 2, 2, 9, 6, 9, 0, 8, 4, 0, 1, 0, 8, 1, 5,
        5, 0, 2, 0, 3, 8, 3, 0])
tensor([7, 4, 3, 1, 5, 7, 8, 2, 7, 2, 6, 1, 8, 2, 5, 5, 2, 1, 5, 6, 0, 2, 1, 4,
        8, 0, 5, 6, 4, 6, 7, 8])
tensor([9, 6, 7, 0, 1, 6, 7, 6, 4, 7, 1, 0, 9, 9, 8, 8, 5, 8, 0, 5, 2, 3, 6, 3,
        1, 6, 4, 9, 1, 5, 1, 0])
tensor([8, 6, 6, 2, 4, 6, 7, 6, 4, 3, 9, 2, 6, 8, 6, 0, 9, 7, 6, 2, 1, 7, 4, 4,
        4, 3, 3, 4, 6, 2, 6, 1])
tensor([8, 6, 0, 3, 2, 8, 5, 5, 7, 7, 9, 0, 8, 2, 0, 2, 7, 9, 7, 4, 9, 6, 6, 6,
        7, 5, 5, 6, 0, 0, 9, 9])
tensor([4, 9, 3, 3, 1, 8, 1, 8, 6, 4, 8, 5, 5, 4, 9, 8, 5, 3, 8, 5, 9, 8, 6, 0,
        6, 5, 6,

tensor([1, 8, 4, 4, 1, 4, 4, 6, 9, 6, 3, 6, 6, 7, 6, 1, 1, 2, 8, 2, 1, 1, 6, 7,
        9, 0, 4, 8, 7, 1, 6, 4])
tensor([8, 8, 2, 4, 3, 2, 4, 8, 8, 1, 8, 2, 4, 7, 2, 1, 0, 6, 7, 4, 3, 4, 0, 8,
        4, 1, 9, 6, 5, 5, 1, 1])
tensor([9, 1, 8, 2, 8, 2, 7, 6, 5, 3, 6, 8, 5, 6, 6, 6, 0, 1, 5, 1, 8, 7, 8, 5,
        0, 0, 2, 6, 0, 4, 9, 3])
tensor([7, 8, 4, 3, 8, 8, 7, 8, 2, 4, 6, 0, 5, 3, 8, 8, 4, 9, 6, 9, 1, 8, 8, 8,
        5, 5, 9, 1, 1, 5, 5, 8])
tensor([4, 2, 9, 4, 4, 1, 7, 7, 5, 8, 3, 6, 5, 3, 3, 8, 7, 3, 9, 7, 5, 1, 2, 2,
        4, 0, 3, 3, 9, 5, 7, 1])
tensor([2, 5, 5, 6, 4, 6, 7, 1, 3, 3, 4, 5, 2, 7, 0, 1, 8, 0, 0, 3, 1, 5, 5, 1,
        2, 8, 9, 8, 5, 9, 3, 2])
tensor([2, 7, 0, 2, 7, 1, 5, 3, 7, 6, 0, 5, 7, 5, 4, 6, 0, 7, 8, 5, 4, 5, 9, 9,
        3, 2, 6, 4, 9, 2, 9, 4])
tensor([7, 2, 1, 5, 4, 0, 0, 8, 9, 2, 2, 8, 0, 3, 3, 3, 1, 3, 0, 1, 1, 3, 8, 7,
        9, 3, 2, 0, 1, 1, 8, 5])
tensor([9, 4, 9, 9, 3, 5, 7, 3, 6, 6, 7, 7, 4, 5, 9, 1, 7, 0, 9, 0, 8, 7, 3, 6,
        5, 4, 8,

tensor([3, 3, 1, 6, 9, 4, 5, 1, 2, 7, 2, 7, 8, 1, 3, 7, 6, 0, 3, 7, 4, 2, 9, 9,
        7, 1, 2, 0, 5, 0, 0, 6])
tensor([6, 3, 9, 0, 3, 1, 4, 1, 1, 8, 1, 0, 2, 8, 1, 5, 8, 2, 1, 1, 4, 8, 3, 0,
        0, 4, 7, 0, 4, 2, 6, 1])
tensor([5, 6, 2, 3, 2, 6, 7, 8, 8, 7, 9, 7, 6, 0, 7, 2, 1, 9, 6, 9, 1, 6, 0, 3,
        5, 3, 2, 4, 3, 8, 4, 7])
tensor([9, 9, 7, 0, 9, 5, 7, 3, 7, 5, 1, 9, 8, 9, 1, 4, 2, 3, 1, 4, 2, 0, 2, 8,
        9, 2, 0, 8, 7, 6, 6, 3])
tensor([3, 3, 3, 8, 4, 9, 8, 1, 9, 6, 4, 9, 2, 3, 7, 0, 1, 4, 2, 8, 3, 1, 2, 8,
        1, 7, 4, 0, 2, 3, 8, 3])
tensor([3, 5, 6, 4, 9, 5, 4, 7, 3, 2, 1, 3, 4, 2, 4, 2, 8, 7, 8, 8, 7, 1, 7, 9,
        7, 9, 5, 2, 3, 3, 7, 4])
tensor([9, 1, 7, 0, 6, 7, 0, 6, 4, 9, 3, 8, 6, 2, 3, 1, 1, 4, 0, 8, 7, 9, 4, 6,
        7, 4, 5, 8, 3, 6, 6, 4])
tensor([7, 8, 6, 2, 7, 2, 1, 3, 2, 6, 8, 9, 9, 4, 7, 9, 6, 4, 0, 1, 7, 3, 4, 5,
        4, 8, 4, 9, 1, 5, 1, 7])
tensor([6, 9, 2, 6, 6, 1, 4, 4, 4, 4, 2, 5, 5, 0, 2, 9, 6, 5, 6, 6, 9, 7, 5, 8,
        3, 3, 4,

KeyboardInterrupt: 

### 4.1.2 Debrief: Simple model with simple training loop
At the end of the training loop, our model performs pretty well - should be around 80-90% accuracy most of the time. This is definitely better than random chance, so our model seems to have learned something about the dataset and can make good predictions. But it could be better! Before we look into improving this, there is something else that needs fixing... 

### 4.1.3 Training device
Something you may have noticed so far is that the training loop runs quite slowly. 5 epochs is not a very long time at all in the machine learning world and it still takes a while to complete. This because we've been asking the CPU to do all the tensor calculations needed to update the weights. This is a bad idea because GPUs are much more efficient at processing large amounts of data in parallel. You should always use a GPU to train machine learning models if one is available. Pytorch makes it very easy to detect GPU availability and transfer code you've written for a CPU to GPU:

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # CUDA is a low-level toolkit that provides libraries for interacting with NVIDIA graphics cards
print(f"Using device: {DEVICE}") # This should print out `Using device: cuda` if Pytorch detects a GPU on your system 


# We also need to modify our training loop to accept the device as an argument, and transfer input tensors to the GPU device
def train_model_gpu(model, epochs, train_dl, optimiser):
    msg = ""
    for epoch in range(epochs):
        total_steps = len(train_dl)
        correct = 0
        total = 0

        model.train()
        for batch_num, (image_batch, label_batch) in enumerate(train_dl):
            batch_sz = len(image_batch)
            
            # Transferring image and label tensors to GPU #
            image_batch = image_batch.to(DEVICE)
            label_batch = label_batch.to(DEVICE)
            ###############################################
            
            output = model(image_batch)
            losses = nn.CrossEntropyLoss()(output, label_batch)
                        
            optimiser.zero_grad()
            losses.backward()
            optimiser.step()  
            
            preds = torch.argmax(output, dim=1)
            correct += int(torch.eq(preds, label_batch).sum())
            total += batch_sz
            minibatch_accuracy = 100 * correct / total

            #### Fancy printing stuff, you can ignore this! ######
            if (batch_num + 1) % 5 == 0:
                print(" " * len(msg), end='\r')
                msg = f'Train epoch[{epoch+1}/{epochs}], MiniBatch[{batch_num + 1}/{total_steps}], Loss: {losses.item():.5f}, Acc: {minibatch_accuracy:.5f}'
                print (msg, end='\r' if epoch < epochs else "\n",flush=True)
            #### Fancy printing stuff, you can ignore this! ######

In [None]:
# Finally, we need to transfer our model to the device as well, and can begin training
network = get_simple_linear_net()
optim = SGD(network.parameters(), lr=learning_rate)
network = network.to(DEVICE)
train_model_gpu(network, epochs, train_dl, optim)

# You should see a speedup in training speed!