In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Set device variable

In [2]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#print(device)

# Prepare Dataset objects

In [3]:
data_path = './'

cifar10 = datasets.CIFAR10(
    data_path, train=True, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

cifar10_val = datasets.CIFAR10(
    data_path, train=False, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

n_out = 10

# Prepare Dataloader objects

In [4]:
#kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available()else {}

train_loader = torch.utils.data.DataLoader(cifar10, batch_size=64, shuffle=True)#,  **kwargs)

val_loader = torch.utils.data.DataLoader(cifar10_val, batch_size=64,shuffle=False)#, **kwargs)

# Build a CNN Model via Subclassing nn.Module and functional API

**The functional API**

**torch.nn.functional** provides many functions that work like the modules we find in nn. But instead of working on the input arguments and stored parameters like the module counterparts, they take inputs and parameters as arguments to the function call.

For example, the functional counterpart of nn.Linear is nn.functional.linear, which is a function that has signature linear(input, weight, bias=None). The weight and bias parameters are arguments to the function.

In our CNN model below, it makes sense to keep using nn modules for nn.Linear and nn.Conv2d so that Net will be able to manage their Parameters during training. However, we can safely switch to the functional counterparts of pooling and activation, since they have no parameters.

Thus, the functional way also sheds light on what the nn.Module API is all about: a Module is a container for state in the forms of Parameters and submodules combined with the instructions to do a forward. 

In [5]:
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(8 * 8 * 8, 32)
        self.fc2 = nn.Linear(32, n_out)
        
    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)
        out = out.view(-1, 8 * 8 * 8)
        out = torch.tanh(self.fc1(out))
        out = self.fc2(out)
        return out

model = Net()

# model.to(device)

# Start Training

In [6]:
learning_rate = 1e-2

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

loss_fn = nn.CrossEntropyLoss()

n_epochs = 100

for epoch in range(n_epochs):
    model.train(True)
    for imgs, labels in train_loader:
        #imgs, labels = imgs.to(device), labels.to(device)
        
        outputs = model(imgs)   # important:  nn.Conv2d expects a B × C × H × W shaped tensor as input
        train_loss = loss_fn(outputs, labels)
  
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
    
    model.eval()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            #imgs, labels = imgs.to(device), labels.to(device)
            
            outputs = model(imgs)
            val_loss = loss_fn(outputs, labels)
            
            _, predicted = torch.max(outputs, dim=1)
            total += labels.shape[0]
            correct += int((predicted == labels).sum())
    print("Epoch: %d, train_loss: %f, val_loss: %f, val_accuracy: %f" % (epoch, float(train_loss), float(val_loss), (correct / total)))

Epoch: 0, train_loss: 1.715013, val_loss: 1.681804, val_accuracy: 0.332600
Epoch: 1, train_loss: 1.604180, val_loss: 1.598014, val_accuracy: 0.403100
Epoch: 2, train_loss: 1.313389, val_loss: 1.407520, val_accuracy: 0.446200
Epoch: 3, train_loss: 1.309000, val_loss: 1.356401, val_accuracy: 0.470300
Epoch: 4, train_loss: 1.649784, val_loss: 1.309236, val_accuracy: 0.504700
Epoch: 5, train_loss: 1.032436, val_loss: 1.222861, val_accuracy: 0.517400


KeyboardInterrupt: 