## Imports

In [1]:
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math
from IPython.display import clear_output
import numpy as np
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [47]:
def rseed(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

In [48]:
use_cuda = torch.cuda.is_available()
device = torch.device(f"cuda:0" if use_cuda else "cpu") # I have made sure that device is "cuda:0"
# grad_clipping = 0.0
grad_clipping = 10.0

In [49]:
input_size = 1  # channel size
transform = [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]

# transform.append(torchvision.transforms.Lambda(torch.flatten))

mnist_train = torchvision.datasets.MNIST(
    "/tmp/data",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose(transform),
)
mnist_test = torchvision.datasets.MNIST(
    "/tmp/data",
    train=False,
    download=True,
    transform=torchvision.transforms.Compose(transform),
)
# input_size = mnist_train.data.shape[1] * mnist_train.data.shape[2]

In [50]:
# batch_size = 64
batch_size = 256

train_loader = torch.utils.data.DataLoader(
    mnist_train, batch_size=batch_size, shuffle=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
    mnist_test, batch_size=batch_size, shuffle=True, num_workers=8
)

In [51]:
output_size = len(mnist_train.classes)

In [52]:
# set torch no grad
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x764b90e4c690>

In [53]:
from typing import List, Tuple, Dict, Any, Union, Optional

In [54]:
from torch import Tensor

In [55]:
class NeuralNet(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_sizes: List[int],
        output_size: int,
        activation_function: Optional[torch.nn.Module] = None,
    ):
        """Standard Fully-Connnected layers.

        Args:
            input_size (int): input size of the model.
            hidden_sizes (List[int]): a list of hidden sizes.
            output_size (int): The number of output classes.
            activation_function (Optional[torch.nn.Module], optional): the activation function for the hidden layers.
                Defaults to None.

        """
        super(NeuralNet, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.hidden_sizes.insert(0, input_size)
        for i in range(len(hidden_sizes) - 1):
            setattr(self, f"fc{i}", nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
            setattr(self, f"act{i}", activation_function or nn.ReLU())
        self.out = nn.Linear(hidden_sizes[-1], output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i in range(len(self.hidden_sizes) - 1):
            x = getattr(self, f"act{i}")(getattr(self, f"fc{i}")(x))
        x = self.out(x)
        return x

class ConvNet(nn.Module):
    def __init__(self, input_size: int = 1, output_size: int = 10):
        """Standard Convolutional Network layers for the MNIST dataset.

        Args:
            input_size (int): input size of the model.
            output_size (int): The number of output classes.

        """
        super(ConvNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(input_size, 64, kernel_size=3, padding=1)
        self.conv2 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv4 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.fc1 = torch.nn.Linear(3136, 1024)
        self.fc2 = torch.nn.Linear(1024, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [56]:
import math
import torch.func as fc

def exponential_lr_decay(step: int, k: float):
    return math.e ** (-step * k)

In [57]:
def _xent(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """Compute cross-entropy loss.

    Args:
        x (torch.Tensor): Output of the model.
        t (torch.Tensor): Targets.

    Returns:
        torch.Tensor: Cross-entropy loss.
    """
    return F.cross_entropy(x, t)

In [58]:
from typing import ValuesView, KeysView

In [59]:
def functional_xent(
    params: ValuesView,
    buffers: Dict[str, Tensor],
    names: KeysView,
    model: torch.nn.Module,
    x: torch.Tensor,
    t: torch.Tensor,
) -> torch.Tensor:
    """Functional cross-entropy loss. Given a pytorch model it computes the cross-entropy loss
    in a functional way.

    Args:
        params: Model parameters.
        buffers: Buffers of the model.
        names: Names of the parameters.
        model: A pytorch model.
        x (torch.Tensor): Input tensor for the PyTorch model.
        t (torch.Tensor): Targets.

    Returns:
        torch.Tensor: Cross-entropy loss.
    """
    y = fc.functional_call(model, ({k: v for k, v in zip(names, params)}, buffers), (x,))
    return _xent(y, t)

In [60]:
import copy

In [61]:
def plot(log):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=log['train_losses'], mode="lines", name="Train loss"))
    fig.add_trace(go.Scatter(y=log['test_losses'], mode="lines", name="Test loss"))
    fig.update_layout(title="", xaxis_title="Epoch", yaxis_title="Cross-Entropy Loss")

    # Update layout
    fig.update_layout(autosize=False, width=1250, height=350,
                      plot_bgcolor='rgba(255, 255, 255, 1)',
                      legend=dict(font=dict(size=15, color="black")),
                      legend_title=dict(font=dict(size=20, color="blue")))

    # horizontal legend at the bottom
    fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))

    # add grid
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')

    # width and height of fig
    fig.update_layout(width=700, height=400)

    return fig

In [62]:
input_size, output_size = 1, 10

In [63]:
import time
from IPython.display import clear_output

from functools import partial

# refresh variable
# model = NeuralNet(input_size, [1024, 1024], output_size, torch.nn.ReLU())
model = ConvNet(input_size, output_size)
model.to(device)
model.float()
model.train()

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=2e-4,
    # lr=1e-2,
    nesterov=False,
    momentum=0.0,
    weight_decay=0.0,
)

optimizer.zero_grad(set_to_none=True)

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, 
    lr_lambda=lambda step: exponential_lr_decay(step, 1e-4),
)

# total_epochs = 200
total_epochs = 50
steps = 0
t_total = 0.0
named_buffers = dict(model.named_buffers())
named_params = dict(model.named_parameters())
names = named_params.keys()
params = named_params.values()

base_model = copy.deepcopy(model)
base_model.to("meta")

log = {
    'train_losses': [],
    'test_losses': [],
    'test_accuracies': [],
}

In [64]:
for epoch in range(total_epochs):
    t0 = time.perf_counter()
    train_loss = 0.0
    for batch in train_loader:
        steps += 1
        images, labels = batch

        # sample perturbation (tangent) vectors for every parameter of the model
        v_params = tuple([torch.randn_like(p) for p in params])
        f = partial(
            functional_xent,
            model=base_model,
            names=names,
            buffers=named_buffers,
            x=images.to(device),
            t=labels.to(device),
        )

        # Forward AD
        loss, jvp = fc.jvp(f, (tuple(params),), (v_params,))
        train_loss += loss.item()

        # Setting gradients
        for v, p in zip(v_params, params):
            p.grad = v * jvp

        # Clip gradients
        if grad_clipping > 0:
            torch.nn.utils.clip_grad.clip_grad_norm_(
                parameters=params, max_norm=grad_clipping, error_if_nonfinite=True
            )

        # Optimizer step
        optimizer.step()

        # Lr scaling
        scheduler.step()

        # Zero out grads
        optimizer.zero_grad(set_to_none=True)

    t1 = time.perf_counter()
    t_total += t1 - t0
    # get the test loss
    test_loss = 0.0
    for batch in test_loader:
        images, labels = batch
        y = model(images.to(device))
        test_loss += _xent(y, labels.to(device)).item()
    total_batches = len(test_loader)
    log['test_losses'].append(test_loss / total_batches)
    log['train_losses'].append(train_loss / len(train_loader))
    # print(f"Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item():.4f}, Time (s): {t1 - t0:.4f}")
    plot(log).show()
    clear_output(wait=True)

print(f"Mean time: {t_total / total_epochs:.4f}")

Mean time: 1.6034


In [65]:
plot(log).show()

In [44]:
# test
acc = 0
for batch in test_loader:
    images, labels = batch
    out = fc.functional_call(base_model, (named_params, named_buffers), (images.to(device),))
    pred = F.softmax(out, dim=-1).argmax(dim=-1)
    acc += (pred == labels.to(device)).sum()
print(f"Test accuracy: {(acc / len(mnist_test)).item():.4f}")

Test accuracy: 0.1041
