# Why Residual Connections?

With the advent of the Transformer architecture ([Attention is All You Need](https://arxiv.org/abs/1706.03762)) both the original implementation and modern day implementations use residual connections to improve performance and training speed.

The hitch is that many ML researchers today don't know where this idea came from or erroneously attribute it to the transformer architecture. My goal in this notebook is to explain what residual connections are, when they should be used, and broadly why they work.

# What are Residual Connections?

![alt text](transformer_red.png "Transformer Architecture")

# MNIST Model Implemented Fast with fastai

## Define the data loaders

In [1]:
from fastai.vision.all import * # Usually imports everything needed for computer vision tasks in fastai
import os
import torchvision

assert os.path.exists('/mnt/2tb-drive/data'), "Data directory does not exist. Please check the path."
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
trainmnist = torchvision.datasets.MNIST('/mnt/2tb-drive/data/MNIST', train=True, download=True, transform=transform)
trainloader = DataLoader(dataset=trainmnist, batch_size=32, shuffle=True, num_workers=16) # fastai's DataLoader
testmnist = torchvision.datasets.MNIST('/mnt/2tb-drive/data/MNIST', train=False, download=True, transform=transform)
testloader = DataLoader(dataset=testmnist, batch_size=32, shuffle=False, num_workers=16) # fastai's DataLoader

trainloader, testloader

(<fastai.data.load.DataLoader at 0x7f45304a5a90>,
 <fastai.data.load.DataLoader at 0x7f4530473250>)

## Residual Networks are Just Fine for MNIST

In [2]:
dls = DataLoaders(trainloader, testloader)
dls.c = 10  # Set number of classes for MNIST
# https://docs.fast.ai/vision.learner.html#vision_learner
l = vision_learner(dls, resnet18, pretrained=False, loss_func=F.cross_entropy, metrics=accuracy, n_in=1)
l.fit(4, 1e-3)  # Train for 4 epochs with a learning rate of 0.01

epoch,train_loss,valid_loss,accuracy,time
0,0.117826,0.110482,0.9684,03:03
1,0.08414,0.040583,0.987,03:06
2,0.062278,0.047078,0.9865,03:07
3,0.048858,0.03145,0.9903,03:11
