Implementing a CNN with residual connections (https://arxiv.org/pdf/1512.03385) and training it on MNIST. The training settings are vanilla.

In [32]:
import torch
from torchvision import datasets, transforms
import wandb

Downloading the MNIST dataset. Doing it this way using torchvision makes it such that we don't have to manually create a Dataset class; this does that for us.

In [39]:
def numerical_to_distribution(label):
  return torch.Tensor([0 if i != label else 1 for i in range(10)])

"""print(numerical_to_distribution(0))
print(numerical_to_distribution(4))
print(numerical_to_distribution(8))
print(numerical_to_distribution(5))
print(numerical_to_distribution(9))"""

mnist_train = datasets.MNIST(root = "./train",
                             train =  True,
                             transform = transforms.functional.pil_to_tensor,
                             target_transform = numerical_to_distribution,
                             download = True)

mnist_test = datasets.MNIST(root = "./test",
                             train =  False,
                             transform = transforms.functional.pil_to_tensor,
                             target_transform = numerical_to_distribution,
                             download = True)

# mnist_train.__getitem__(0)

Splitting the overall MNIST training set to train and validation sets.
80/20.

In [59]:
generator = torch.Generator().manual_seed(11)
train, val = torch.utils.data.random_split(mnist_train, [0.8, 0.2], generator = generator)

The model. I am adding residual connections after every 2 layers in the generic VGG-11 ("A" in Table 1 of this paper: https://arxiv.org/pdf/1409.1556). VGG-11 may not be the deepest network to see the fruits of residual connections, but for learning how to implement residual connections, that shouldn't matter.

In [66]:
class ResConnCNN(torch.nn.Module):

  def __init__(self):
    super(ResConnCNN, self).__init__()

    self.conv1 = torch.nn.Conv2d(1, 64, 3, padding = "same")
    self.conv2 = torch.nn.Conv2d(64, 128, 3, padding = "same")
    self.conv3 = torch.nn.Conv2d(128, 256, 3, padding = "same")
    self.conv4 = torch.nn.Conv2d(256, 256, 3, padding = "same")
    self.conv5 = torch.nn.Conv2d(256, 512, 3, padding = "same")
    self.conv6 = torch.nn.Conv2d(512, 512, 3, padding = "same")
    self.conv7 = torch.nn.Conv2d(512, 512, 3, padding = "same")
    self.conv8 = torch.nn.Conv2d(512, 512, 3, padding = "same")

    self.fc1 = torch.nn.Linear(512, 4096)
    self.fc2 = torch.nn.Linear(4096, 4096)
    self.fc3 = torch.nn.Linear(4096, 10)

    self.maxPool = torch.nn.MaxPool2d(2, ceil_mode=True)
    self.activation = torch.nn.ReLU()
 
  def forward(self, x):
    x_1 = self.maxPool(self.activation(self.conv2(self.maxPool(self.activation(self.conv1(x))))))
    # adding extra empty channels to adjust dims. Otherwise, when adding with the residue below,
    # which will have more channels, error will occur. This solution is directly from the ResNet paper.
    # You can also do conv to x_1 as an alternative solution; read the paper for details.
    x_1_dim_adjusted = torch.cat((x_1, torch.zeros(x_1.shape).to(torch.device("cuda"))), 1)

    x_2 = self.maxPool(self.activation(self.conv4(self.activation(self.conv3(x_1)))) + x_1_dim_adjusted)
    x_2_dim_adjusted = torch.cat((x_2, torch.zeros(x_2.shape).to(torch.device("cuda"))), 1)

    x_3 = self.maxPool(self.activation(self.conv6(self.activation(self.conv5(x_2)))) + x_2_dim_adjusted)

    # the number of channels don't change in the output of x_4, thus no need to adjust dims for this.
    x_4 = self.maxPool(self.activation(self.conv8(self.activation(self.conv7(x_3)))) + x_3)

    # permuting x_4 to comply with the specifications of "Linear". Basically, this permute makes it
    # such that you have only one row filled with all our values, instead of separating individual
    # values in different channels. The latter configuration, to be fair, is "unnatural".
    output = self.fc3(self.activation(self.fc2(self.activation(self.fc1(torch.permute(x_4, (0, 2, 3, 1)))))))

    return output

# mod = ResConnCNN()
# print(mod(mnist_train.__getitem__(0)[0].float()).shape)
# mnist_train.__getitem__(0)[1].shape

Training staging area.

In [67]:
model = ResConnCNN().to(torch.device("cuda"))
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.001, momentum = 0.99)

batch_size = 256

train_loader = torch.utils.data.DataLoader(
    dataset = train,
    batch_size = batch_size,
    shuffle = True
)

val_loader = torch.utils.data.DataLoader(
    dataset = val,
)

test_loader = torch.utils.data.DataLoader(
    dataset = mnist_test,
)

Wandb init.

In [None]:
wandb.init(
    project = "CNN with Residual Connections",
    config = {
        "dataset": "MNIST",
        "epochs": 20,
    }
)

Training code.

In [69]:
for epoch in range(20):
  model.train()

  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(torch.device("cuda")), target.to(torch.device("cuda"))
    optimizer.zero_grad()

    # torch requires the input to cross entropy to be (minibatch, C), where C = number of classes.
    output = torch.squeeze(model(data.float())) 
      
    loss = loss_fn(output, target)

    loss.backward()
    optimizer.step()

    if batch_idx % 20 == 0:
        wandb.log({"Loss": loss})

  acc = test(val_loader)
  wandb.log({"Accuracy": acc})
    
wandb.finish()

0,1
Accuracy,▁▁▂▄▇▇▇█████████████
Loss,█▅▅▅▄▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Accuracy,98.70833
Loss,0.00783


<img src = "run.png" width="400" height="200">

Testing code.

In [47]:
def test(loader):
    model.eval()

    total_data = 0
    total_accurate_pred = 0

    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(torch.device("cuda")), target.to(torch.device("cuda"))
            
            softmax = torch.nn.Softmax(dim = 0)
            output = (softmax(torch.squeeze(model(data.float()))) > 0.5).float()
            target = torch.squeeze(target)
            
            total_accurate_pred += torch.equal(output, target)
            total_data += 1
            
    return (total_accurate_pred / total_data) * 100

Running the test set on the final model. 

In [70]:
print(test(test_loader))

98.86


The network, at least, works. Can probably make it perform better with hyperparameter tuning, but that is not the point of this. 