In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

In [19]:
# Load and normalizde the data

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batchSize = 5
validSize = 0.2 # use 20% of train set as validation

trainValidSet = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testSet = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainSet, validSet = torch.utils.data.random_split(trainValidSet, [int(len(trainValidSet)*(1-validSize)), int(len(trainValidSet)*validSize)])

trainLoader = torch.utils.data.DataLoader(trainSet, batch_size=batchSize, shuffle=True)
validLoader = torch.utils.data.DataLoader(validSet, batch_size=batchSize, shuffle=True)
testLoader = torch.utils.data.DataLoader(testSet, batch_size=batchSize, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
to_onehot = nn.Embedding(len(classes), len(classes))
to_onehot.weight.data = torch.eye(len(classes))

Files already downloaded and verified
Files already downloaded and verified


In [15]:
len(trainLoader), len(validLoader), len(testLoader), next(iter(testLoader))[0][0].shape

(8000, 2000, 2000, torch.Size([3, 32, 32]))

In [17]:
# Define the network class

class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = self.pool(F.leaky_relu(self.conv1(x)))
        x = self.pool(F.leaky_relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.leaky_relu(self.fc1(x))
        x = self.fc2(x)
        return x



In [20]:
learningRate = 0.0001

network = ConvNet()

optimizer = torch.optim.SGD(network.parameters(), lr=learningRate)
lossFunction = nn.CrossEntropyLoss()


In [21]:
epochs = 10

trainLossList = []
validLossList = []

network.train()
for epoch in range(epochs):

    ### TRAINING ###
    trainLoss = 0
    for batch_nr, (images, labels) in enumerate(trainLoader):
        
        # Onehot label and reshape img
        # labels = to_onehot(labels)
        # images = images.view(-1,32*32)

        # Predict
        predictions = network(images)

        # Get loss and backpropogate
        loss = lossFunction(predictions, labels)
        loss.backward() 

        # Optimize parameters (weights and biases) and remove gradients after
        optimizer.step() 
        optimizer.zero_grad()

        # Save loss for whole epoch
        trainLoss += loss.item()

    trainLoss /= len(trainLoader)
    trainLossList.append(trainLoss)

    ### VALIDATION ###
    validLoss = 0
    for batch_nr, (images, labels) in enumerate(validLoader):
        # Onehot label and reshape img
        # labels = to_onehot(labels)
        # images = images.view(-1,28*28)

        # Predict
        predictions = network(images)

        # Get loss
        loss = lossFunction(predictions, labels)

        # Save loss for whole epoch
        validLoss += loss.item()

    validLoss /= len(validLoader)
    validLossList.append(validLoss)

    # Print reuslt of epoch
    print(f'Epoch [{epoch+1}/{epochs}] \t Training Loss: {trainLoss} \t Validation Loss: {validLoss}')

KeyboardInterrupt: 