Code based on https://github.com/pytorch/examples/blob/master/mnist/main.py

In this exercise, we are using high-level abstractions from torch.nn like nn.Linear.
Note: during the next lab session we will go one level deeper and implement more things
with bare hands.

Tasks:

 1. Read the code.

 2. Check that the given implementation reaches 95% test accuracy for architecture input-128-128-10 after few epochs.

 3. Add the option to use SGD with momentum instead of ADAM.

 4. Experiment with different learning rates. Use the provided TrainingVisualizer
 to plot the learning curves and gradient-to-weight ratios. Compare visualizations
 for different learning rates for both ADAM and SGD with momentum.

 5. Parameterize the constructor by a list of sizes of hidden layers of the MLP.
 Note that this requires creating a list of layers as an attribute of the Net class,
 and one can't use a standard Python list containing nn.Modules (why?).
 Check torch.nn.ModuleList.

If you run this notebook locally then you may need to install some packages.
It may be achieved by adding the following code cell to the notebook and running it:
```
!pip install torch torchvision plotly ipywidgets
```
This notebook can also utilize Colab GPU. However, remember to kill your GPU session after classes as otherwise, you may use all your free GPU time for this week.

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import plotly.graph_objects as go
import sys

if "google.colab" in sys.modules:
    from google.colab import output

    output.enable_custom_widget_manager()

In [7]:
# @title Visualize gradients


class TrainingVisualizer:
    def __init__(self, log_interval: int = 10):
        self.log_interval = log_interval
        self.train_loss_fig = self.init_line_plot(
            title="Training loss", xaxis_title="Step"
        )
        self.grad_to_weight_fig = self.init_line_plot(
            title="Gradient standard deviation to weight standard deviation ratio at 1st layer",
            xaxis_title="Step",
            yaxis_title="Gradient to weight ratio (log scale)",
            yaxis_type="log",
        )
        self.test_acc_fig = self.init_line_plot(
            title="Test accuracy", x=[], xaxis_title="Epoch", mode="lines+markers"
        )

        # Parameters related to current tracked model and its training
        self.first_linear_layer = None
        self.lr = None
        self.trace_idx = -1

    def init_line_plot(
        self,
        title: str,
        x=None,
        xaxis_title: str = None,
        yaxis_title: str = None,
        yaxis_type: str = "linear",
        mode: str = "lines",
    ):
        fig = go.Figure()
        fig.update_layout(
            title=title,
            title_x=0.5,
            xaxis_title=xaxis_title,
            yaxis_title=yaxis_title,
            height=400,
            width=1500,
            margin=dict(b=10, t=60),
        )
        fig.update_yaxes(type=yaxis_type)
        # We cannot add new traces dynamically because Colab has a problem with widgets
        # from plotly (traces added dynamically are rendered twice).
        # As an ugly workaround we create a lot of empty traces and update them later
        # with actual data. Empty traces are not plotted.
        for _ in range(25):
            fig.add_trace(go.Scatter(x=x, y=[], showlegend=True, mode=mode))

        fig_widget = go.FigureWidget(fig)
        display(fig_widget)
        return fig_widget

    def track_model(
        self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, lr: float
    ):
        """
        Start tracking training metrics for a new model.
        """

        for field in model.__dict__["_modules"].values():
            if isinstance(field, nn.Linear):
                self.first_linear_layer = field
                break
            elif isinstance(field, nn.ModuleList):
                self.first_linear_layer = field[0]
                break

        self.lr = lr
        self.trace_idx += 1

        optim_name = type(optimizer).__name__
        self.train_loss_fig.data[self.trace_idx].name = f"{optim_name}, {lr}"
        self.grad_to_weight_fig.data[self.trace_idx].name = f"{optim_name}, {lr}"
        self.test_acc_fig.data[self.trace_idx].name = f"{optim_name}, {lr}"

    def plot_gradients_and_loss(self, batch_idx: int, loss: float):
        if batch_idx % self.log_interval == 0:
            self.train_loss_fig.data[self.trace_idx].y += (loss,)

            layer = self.first_linear_layer
            grad_to_weight_ratio = (
                self.lr * layer.weight.grad.std() / layer.weight.std()
            ).item()

            self.grad_to_weight_fig.data[self.trace_idx].y += (grad_to_weight_ratio,)

    def plot_accuracy(self, epoch: int, accuracy: float):
        self.test_acc_fig.data[self.trace_idx].x += (epoch,)
        self.test_acc_fig.data[self.trace_idx].y += (accuracy,)

In [8]:
from functools import reduce

In [None]:
class Net(nn.Module):
    def __init__(self, sizes=None):
        super(Net, self).__init__()
        # After flattening an image of size 28x28 we have 784 inputs
        if sizes == None:
            fc1 = nn.Linear(784, 128)
            fc2 = nn.Linear(128, 128)
            fc3 = nn.Linear(128, 10)
            self.fcs = nn.ModuleList([fc1, fc2, fc3])
        else:
            #tuple: (last, arr)
            def make_list(tuple, size):
                last_size, arr = tuple
                
                if last_size != -1:
                    arr.append(nn.Linear(last_size, size))

                return (size, arr)

            self.fcs = reduce(make_list, sizes, (-1, nn.ModuleList([])) )[1]
        print(self.fcs)

    def forward(self, x: torch.Tensor):
        x = torch.flatten(x, 1)

        for i, module in enumerate(self.fcs):
            x = module(x)
            if i != len(self.fcs):
                x = F.relu(x)

        output = F.log_softmax(x, dim=1)
        return output


def train(
    model: torch.nn.Module,
    device: torch.device,
    train_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    log_interval: int,
    visualizer: TrainingVisualizer,
    verbose: bool = False,
):
    # Setting the model to train mode makes model.training True.
    # This can alter the behavior of modules like Dropout.
    # This also applies to sub-modules if they are properly set
    # Consider print(model.fc2.training)
    # and then
    # model.eval() or model.training()
    # print(model.fc2.training)
    model.train()
    assert model.training
    for batch_idx, (data, target) in enumerate(train_loader):
        # PyTorch will return an error if the model and data are on different devices.
        # The two most popular devices are cpu and cuda.
        # You can also use .to(torch_tyoe) to change type.
        # For example consider data = data.to(torch.float32)
        # For more, see https://pytorch.org/docs/stable/tensors.html
        data, target = data.to(device), target.to(device)
        # Let p be a parameter of our model.
        # Initially p.grad should be None.
        # After we call loss.backward()
        # p.grad() should be populated with (d loss)/(d p)
        # optimizer.step() applies the gradient and updates optimizer stats.
        # Whereas optimizer.zero_grad() clears the p.grad attribute.
        # If we do not call  optimizer.zero_grad() then
        # we will accumulate gradients from previous
        # calls of loss.backward()
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        visualizer.plot_gradients_and_loss(batch_idx, loss.item())
        optimizer.step()
        if batch_idx % log_interval == 0:
            if verbose:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )


def test(
    model: torch.nn.Module,
    device: torch.device,
    test_loader: torch.utils.data.DataLoader,
    epoch: int,
    visualizer: TrainingVisualizer,
    verbose: bool = False,
):
    # This makes model.training False
    # What alters the behaviour of (sub)modules like Dropout
    model.eval()
    assert not model.training
    test_loss = 0
    correct = 0
    # This disables gradient calculation.
    # By default torch starts tracking calculations (what uses memory)
    # for later backpropagation when it comes across a variable
    # with requires_grad set to True.
    # torch.no_grad prevents this
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # this sum up batch loss, however usually we prefer a mean reduction
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    if verbose:
        print(
            "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
                test_loss,
                correct,
                len(test_loader.dataset),
                100.0 * correct / len(test_loader.dataset),
            )
        )
    visualizer.plot_accuracy(epoch, 100.0 * correct / len(test_loader.dataset))

In [42]:
# training uses more memory than test due to gradient computation
# therefore we can set test_batch_size to be larger
batch_size = 256
test_batch_size = 1000
epochs = 5
# old lr :1e-2
lr = 3e-2
seed = 1
log_interval = 10
# to check whether a cuda device is available
use_cuda = torch.cuda.is_available()

In [43]:
# for reproductibility
torch.manual_seed(seed)
# if we have cuda capable device (ex, nvidia GPU) then we use it
# otherwise we use cpu, there are also other devices like mps,
# but their support for operations may be limited
device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {"batch_size": batch_size}
test_kwargs = {"batch_size": test_batch_size}
if use_cuda:
    cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [44]:
# data loader preparation
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [45]:
visualizer = TrainingVisualizer(log_interval=log_interval)

FigureWidget({
    'data': [{'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '6586816f-7a78-48fa-bd8e-50c18563218b',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '4270e61e-c265-4a15-9d1c-e3fdd729fd19',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '8c6f38b4-763c-426c-8f79-6c6ac1afa10a',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '07d4fb27-5fe8-4526-9980-01fed5f2a64b',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '1f30a61e-a127-4bf1-9153-490371d3ffd8',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
     

FigureWidget({
    'data': [{'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'be4aa67b-bb63-42b7-8497-897de89130b9',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '8c71afd0-3508-494c-aa74-10a546b991f8',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '1a8592aa-5aa7-4965-92be-5cb8e635445c',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'f7538b22-a1cc-4826-9bcc-5fac05194d76',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '6aed7b52-af2f-4765-99e0-6580ecf4c0a8',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
     

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': '4c5cf411-21e1-40c8-8e1e-a61d1b867e4e',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': '74a19fb7-0a43-4011-8af5-c7acf4dcf832',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': '7ff5c2ff-e960-46d2-a78e-0546e25f4c77',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': '08dc1312-1efe-4ddf-a49d-a9ec287c9e3c',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': '2

In [47]:
# Note that there is a difference between taking tensor to device and model to device.
# Consider writing model.to(torch.device("cpu")) and some_tesnor.to(torch.device("cpu"))
# What seems to be happening in-place?
model = Net([784, 128, 128, 10]).to(device)
# Look at the output of list(model.parameters()) as you add new parameters to the model
# What happens when the parameters are inside the standard list instead of torch.nn.ModuleList?

# Old optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

# SGD with momentum
#optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
visualizer.track_model(model, optimizer, lr)

print_details = False  # change if you want additional output

for epoch in range(1, epochs + 1):
    train(
        model,
        device,
        train_loader,
        optimizer,
        epoch,
        log_interval,
        visualizer,
        verbose=print_details,
    )
    test(model, device, test_loader, epoch, visualizer, verbose=print_details)

# investigate the results, what is the shape of test accuracy curve? Try to explain it.