# Imports

## w0d2_solutions

In [2]:
pip install fancy_einsum einops wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting fancy_einsum
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 447 kB/s 
[?25hCollecting wandb
  Downloading wandb-0.13.5-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 25.7 MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.29-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 63.6 MB/s 
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.10.1-py2.py3-none-any.whl (166 kB)
[K     |████████████████████████████████| 166 kB 60.0 MB/s 
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp37-cp37m-manylinux_2_5_x86_

In [3]:
 # Section 1

from fancy_einsum import einsum
from typing import Union, Optional, Callable
import numpy as np

def einsum_trace(mat):
    return einsum("i i", mat)

def einsum_mv(mat, vec):
    return einsum("i j, j -> i", mat, vec)

def einsum_mm(mat1, mat2):
    return einsum("i j, j k -> i k", mat1, mat2)

def einsum_inner(vec1, vec2):
    return einsum("i, i", vec1, vec2)

def einsum_outer(vec1, vec2):
    return einsum("i, j -> i j", vec1, vec2)



# Section 2

import torch as t
from collections import namedtuple
TestCase = namedtuple("TestCase", ["output", "size", "stride"])

test_cases = [
    TestCase(
        output=t.tensor([0, 1, 2, 3]), 
        size=(4,), 
        stride=(1,)
    ),
    # Explanation: the output is a 1D vector of length 4 (hence size=(4,))
    # and each time you move one element along in this output vector, you also want to move
    # one element along the `test_input_a` tensor

    TestCase(
        output=t.tensor([0, 1, 2, 3, 4]), 
        size=(5,), 
        stride=(1,)
    ),
    # Explanation: the tensor is held in a contiguous memory block. When you get to the end
    # of one row, a single stride jumps to the start of the next row

    TestCase(
        output=t.tensor([0, 5, 10, 15]), 
        size=(4,), 
        stride=(5,)
    ),
    # Explanation: this is same as previous case, only now you're moving in colspace (i.e. skipping
    # 5 elements) each time you move one element across the output tensor.
    # So stride is 5 rather than 1

    TestCase(
        output=t.tensor([[0, 1, 2], [5, 6, 7]]), 
        size=(2, 3), 
        stride=(5, 1)),
    # Explanation: consider the output tensor. As you move one element along a row, you want to jump
    # one element in the `test_input_a` (since you're just going to the next row). As you move
    # one element along a column, you want to jump to the next column, i.e. a stride of 5.

    TestCase(
        output=t.tensor(
            [[0, 1, 2], 
             [10, 11, 12]]
        ), 
        size=(2, 3), 
        stride=(10, 1)),

    TestCase(
        output=t.tensor(
            [[0, 0, 0], 
             [11, 11, 11]]
        ), 
        size=(2, 3),
        stride=(11, 0)),

    TestCase(
        output=t.tensor(
            [0, 6, 12, 18]
        ), 
        size=(4,), 
        stride=(6,)),

    TestCase(
        output=t.tensor(
            [[[0, 1, 2]], [[9, 10, 11]]]
        ), 
        size=(2, 1, 3), 
        stride=(9, 0, 1)),
    # Note here that the middle element of `stride` doesn't actually matter, since you never
    # jump in this dimension. You could change it and the test result would still be the same

    TestCase(
        output=t.tensor(
            [
                [
                    [[0, 1], [2, 3]], 
                    [[4, 5], [6, 7]]
                ], 
                [
                    [[12, 13], [14, 15]], 
                    [[16, 17], [18, 19]]
                ]
            ]
        ),
        size=(2, 2, 2, 2),
        stride=(12, 4, 2, 1),
    ),
]




def as_strided_trace(mat: t.Tensor) -> t.Tensor:
    
    stride = mat.stride()
    
    assert len(stride) == 2, f"matrix should have size 2"
    assert mat.size(0) == mat.size(1), "matrix should be square"
    
    return mat.as_strided((mat.size(0),), (sum(stride),)).sum()

def as_strided_mv(mat: t.Tensor, vec: t.Tensor) -> t.Tensor:
    
    sizeM = mat.shape
    sizeV = vec.shape
    
    strideM = mat.stride()
    strideV = vec.stride()
    
    assert len(sizeM) == 2, f"mat1 should have size 2"
    assert sizeM[1] == sizeV[0], f"mat{list(sizeM)}, vec{list(sizeV)} not compatible for multiplication"
    
    vec_expanded = vec.as_strided(mat.shape, (0, strideV[0]))
    
    product_expanded = mat * vec_expanded
    
    return product_expanded.sum(dim=1)

def as_strided_mm(matA: t.Tensor, matB: t.Tensor) -> t.Tensor:
    
    assert len(matA.shape) == 2, f"mat1 should have size 2"
    assert len(matB.shape) == 2, f"mat2 should have size 2"
    assert matA.shape[1] == matB.shape[0], f"mat1{list(matA.shape)}, mat2{list(matB.shape)} not compatible for multiplication"
    
    # Get the matrix strides, and matrix dims
    sA0, sA1 = matA.stride()
    dA0, dA1 = matA.shape
    sB0, sB1 = matB.stride()
    dB0, dB1 = matB.shape
    
    expanded_size = (dA0, dA1, dB1)
    
    matA_expanded_stride = (sA0, sA1, 0)
    matA_expanded = matA.as_strided(expanded_size, matA_expanded_stride)
    
    matB_expanded_stride = (0, sB0, sB1)
    matB_expanded = matB.as_strided(expanded_size, matB_expanded_stride)
    
    product_expanded = matA_expanded * matB_expanded
    
    return product_expanded.sum(dim=1)





def conv1d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    """Like torch's conv1d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, width)
    weights: shape (out_channels, in_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    """
    
    batch, in_channels, width = x.shape
    out_channels, in_channels_2, kernel_width = weights.shape
    assert in_channels == in_channels_2, "in_channels for x and weights don't match up"
    output_width = width - kernel_width + 1
    
    xsB, xsI, xsWi = x.stride()
    wsO, wsI, wsW = weights.stride()
    
    x_new_shape = (batch, in_channels, output_width, kernel_width)
    x_new_stride = (xsB, xsI, xsWi, xsWi)
    # Common error: xsWi is always 1, so if you put 1 here you won't spot your mistake until you try this with conv2d!
    x_strided = x.as_strided(size=x_new_shape, stride=x_new_stride)
    
    return einsum(
        "batch in_channels output_width kernel_width, out_channels in_channels kernel_width -> batch out_channels output_width", 
        x_strided, weights
    )

def conv2d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    """Like torch's conv2d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, height, width)
    weights: shape (out_channels, in_channels, kernel_height, kernel_width)

    Returns: shape (batch, out_channels, output_height, output_width)
    """
    
    batch, in_channels, height, width = x.shape
    out_channels, in_channels_2, kernel_height, kernel_width = weights.shape
    assert in_channels == in_channels_2, "in_channels for x and weights don't match up"
    output_width = width - kernel_width + 1
    output_height = height - kernel_height + 1
    
    xsB, xsIC, xsH, xsW = x.stride() # B for batch, IC for input channels, H for height, W for width
    wsOC, wsIC, wsH, wsW = weights.stride()
    
    x_new_shape = (batch, in_channels, output_height, output_width, kernel_height, kernel_width)
    x_new_stride = (xsB, xsIC, xsH, xsW, xsH, xsW)
    
    x_strided = x.as_strided(size=x_new_shape, stride=x_new_stride)
    
    return einsum(
        "batch in_channels output_height output_width kernel_height kernel_width, \
out_channels in_channels kernel_height kernel_width \
-> batch out_channels output_height output_width",
        x_strided, weights
    )

def pad1d(x: t.Tensor, left: int, right: int, pad_value: float) -> t.Tensor:
    """Return a new tensor with padding applied to the edges.

    x: shape (batch, in_channels, width), dtype float32

    Return: shape (batch, in_channels, left + right + width)
    """
    B, C, W = x.shape
    output = x.new_full(size=(B, C, left + W + right), fill_value=pad_value)
    output[..., left : left + W] = x
    # Note - you can't use `left:-right`, because `right` could be zero.
    return output
    


def pad2d(x: t.Tensor, left: int, right: int, top: int, bottom: int, pad_value: float) -> t.Tensor:
    """Return a new tensor with padding applied to the edges.

    x: shape (batch, in_channels, height, width), dtype float32

    Return: shape (batch, in_channels, top + height + bottom, left + width + right)
    """
    B, C, H, W = x.shape
    output = x.new_full(size=(B, C, top + H + bottom, left + W + right), fill_value=pad_value)
    output[..., top : top + H, left : left + W] = x
    return output

def conv1d(x, weights, stride: int = 1, padding: int = 0) -> t.Tensor:
    """Like torch's conv1d using bias=False.

    x: shape (batch, in_channels, width)
    weights: shape (out_channels, in_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    """
    
    x_padded = pad1d(x, left=padding, right=padding, pad_value=0)
    
    batch, in_channels, width = x_padded.shape
    out_channels, in_channels_2, kernel_width = weights.shape
    assert in_channels == in_channels_2, "in_channels for x and weights don't match up"
    output_width = 1 + (width - kernel_width) // stride
    # note, we assume padding is zero in the formula here, because we're working with input which has already been padded
    
    xsB, xsI, xsWi = x_padded.stride()
    wsO, wsI, wsW = weights.stride()
    
    x_new_shape = (batch, in_channels, output_width, kernel_width)
    x_new_stride = (xsB, xsI, xsWi * stride, xsWi)
    # Explanation for line above:
    #     we need to multiply the stride corresponding to the `output_width` dimension
    #     because this is the dimension that we're sliding the kernel along
    x_strided = x_padded.as_strided(size=x_new_shape, stride=x_new_stride)
    
    return einsum("B IC OW wW, OC IC wW -> B OC OW", x_strided, weights)

IntOrPair = Union[int, tuple]
Pair = tuple

def force_pair(v: IntOrPair) -> Pair:
    """Convert v to a pair of int, if it isn't already."""
    if isinstance(v, tuple):
        if len(v) != 2:
            raise ValueError(v)
        return (int(v[0]), int(v[1]))
    elif isinstance(v, int):
        return (v, v)
    raise ValueError(v)

def conv2d(x, weights, stride: IntOrPair = 1, padding: IntOrPair = 0) -> t.Tensor:
    """Like torch's conv2d using bias=False

    x: shape (batch, in_channels, height, width)
    weights: shape (out_channels, in_channels, kernel_height, kernel_width)


    Returns: shape (batch, out_channels, output_height, output_width)
    """

    stride_h, stride_w = force_pair(stride)
    padding_h, padding_w = force_pair(padding)
    
    x_padded = pad2d(x, left=padding_w, right=padding_w, top=padding_h, bottom=padding_h, pad_value=0)
    
    batch, in_channels, height, width = x_padded.shape
    out_channels, in_channels_2, kernel_height, kernel_width = weights.shape
    assert in_channels == in_channels_2, "in_channels for x and weights don't match up"
    output_width = 1 + (width - kernel_width) // stride_w
    output_height = 1 + (height - kernel_height) // stride_h
    
    xsB, xsIC, xsH, xsW = x_padded.stride() # B for batch, IC for input channels, H for height, W for width
    wsOC, wsIC, wsH, wsW = weights.stride()
    
    x_new_shape = (batch, in_channels, output_height, output_width, kernel_height, kernel_width)
    x_new_stride = (xsB, xsIC, xsH * stride_h, xsW * stride_w, xsH, xsW)
    
    x_strided = x_padded.as_strided(size=x_new_shape, stride=x_new_stride)
    
    return einsum("B IC OH OW wH wW, OC IC wH wW -> B OC OH OW", x_strided, weights)


def maxpool2d(x: t.Tensor, kernel_size: IntOrPair, stride: Optional[IntOrPair] = None, padding: IntOrPair = 0
) -> t.Tensor:
    """Like PyTorch's maxpool2d.

    x: shape (batch, channels, height, width)
    stride: if None, should be equal to the kernel size

    Return: (batch, channels, output_height, output_width)
    """

    if stride is None:
        stride = kernel_size
    stride_height, stride_width = force_pair(stride)
    padding_height, padding_width = force_pair(padding)
    kernel_height, kernel_width = force_pair(kernel_size)
    
    x_padded = pad2d(x, left=padding_width, right=padding_width, top=padding_height, bottom=padding_height, pad_value=-t.inf)
    
    batch, channels, height, width = x_padded.shape
    output_width = 1 + (width - kernel_width) // stride_width
    output_height = 1 + (height - kernel_height) // stride_height
    
    xsB, xsC, xsH, xsW = x_padded.stride()
    
    x_new_shape = (batch, channels, output_height, output_width, kernel_height, kernel_width)
    x_new_stride = (xsB, xsC, xsH * stride_height, xsW * stride_width, xsH, xsW)
    
    x_strided = x_padded.as_strided(size=x_new_shape, stride=x_new_stride)
    
    output = t.amax(x_strided, dim=(-1, -2))
    return output


# =============== PART 4 ===============

from torch import nn

class MaxPool2d(nn.Module):
    def __init__(self, kernel_size: IntOrPair, stride: Optional[IntOrPair] = None, padding: IntOrPair = 1):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

    def forward(self, x: t.Tensor) -> t.Tensor:
        """Call the functional version of maxpool2d."""
        return maxpool2d(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)

    def extra_repr(self) -> str:
        """Add additional information to the string representation of this class."""
        return ", ".join([f"{key}={getattr(self, key)}" for key in ["kernel_size", "stride", "padding"]])


class ReLU(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        return t.maximum(x, t.tensor(0.0))


import functools
class Flatten(nn.Module):
    def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
        super().__init__()
        self.start_dim = start_dim
        self.end_dim = end_dim

    def forward(self, input: t.Tensor) -> t.Tensor:
        """Flatten out dimensions from start_dim to end_dim, inclusive of both.
        """
        shape = input.shape
        
        start_dim = self.start_dim
        end_dim = self.end_dim if self.end_dim >= 0 else len(shape) + self.end_dim
        
        shape_left = shape[:start_dim]
        shape_middle = functools.reduce(lambda x, y: x*y, shape[start_dim : end_dim+1])
        shape_right = shape[end_dim+1:]
        
        new_shape = shape_left + (shape_middle,) + shape_right
        
        return t.reshape(input, new_shape)

    def extra_repr(self) -> str:
        return ", ".join([f"{key}={getattr(self, key)}" for key in ["start_dim", "end_dim"]])

class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias=True):
        """A simple linear (technically, affine) transformation.

        The fields should be named `weight` and `bias` for compatibility with PyTorch.
        If `bias` is False, set `self.bias` to None.
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        
        sf = 1 / np.sqrt(in_features)
        
        weight = sf * (2 * t.rand(out_features, in_features) - 1)
        self.weight = nn.Parameter(weight)
        
        if bias:
            bias = sf * (2 * t.rand(out_features,) - 1)
            self.bias = nn.Parameter(bias)
        else:
            self.bias = None

    def forward(self, x: t.Tensor) -> t.Tensor:
        """
        x: shape (*, in_features)
        Return: shape (*, out_features)
        """
        x = einsum("... in_features, out_features in_features -> ... out_features", x, self.weight)
        if self.bias is not None: x += self.bias
        return x

    def extra_repr(self) -> str:
        # note, we need to use `self.bias is not None`, because `self.bias` is either a tensor or None, not bool
        return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"


class Conv2d(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, kernel_size: IntOrPair, stride: IntOrPair = 1, padding: IntOrPair = 0
    ):
        """
        Same as torch.nn.Conv2d with bias=False.

        Name your weight field `self.weight` for compatibility with the PyTorch version.
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        kernel_height, kernel_width = force_pair(kernel_size)
        sf = 1 / np.sqrt(in_channels * kernel_width * kernel_height)
        weight = sf * (2 * t.rand(out_channels, in_channels, kernel_height, kernel_width) - 1)
        self.weight = nn.Parameter(weight)

    def forward(self, x: t.Tensor) -> t.Tensor:
        """Apply the functional conv2d you wrote earlier."""
        return conv2d(x, self.weight, self.stride, self.padding)

    def extra_repr(self) -> str:
        keys = ["in_channels", "out_channels", "kernel_size", "stride", "padding"]
        return ", ".join([f"{key}={getattr(self, key)}" for key in keys])

## w0d3_solutions

In [4]:
from einops import rearrange
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
import PIL
from torch import nn

MAIN = False

class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu1 = ReLU()
        self.maxpool1 = MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        self.conv2 = Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.relu2 = ReLU()
        self.maxpool2 = MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        self.flatten = Flatten()
        self.fc1 = Linear(in_features=3136, out_features=128)
        self.fc2 = Linear(in_features=128, out_features=10)
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        x = self.maxpool1(self.relu1(self.conv1(x)))
        x = self.maxpool2(self.relu2(self.conv2(x)))
        x = self.fc2(self.fc1(self.flatten(x)))
        return x

# model = ConvNet()
# print(model)

class Sequential(nn.Module):
    def __init__(self, *modules: nn.Module):
        super().__init__()
        for i, mod in enumerate(modules):
            self.add_module(str(i), mod)

    def forward(self, x: t.Tensor) -> t.Tensor:
        """Chain each module together, with the output from one feeding into the next one."""
        for mod in self._modules.values():
            if mod is not None: x = mod(x)
        return x

class BatchNorm2d(nn.Module):
    running_mean: t.Tensor         # shape: (num_features,)
    running_var: t.Tensor          # shape: (num_features,)
    num_batches_tracked: t.Tensor  # shape: ()

    def __init__(self, num_features: int, eps=1e-05, momentum=0.1):
        '''Like nn.BatchNorm2d with track_running_stats=True and affine=True.

        Name the learnable affine parameters `weight` and `bias` in that order.
        '''
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        self.weight = nn.Parameter(t.ones(num_features))  # type: ignore
        self.bias = nn.Parameter(t.zeros(num_features))  # type: ignore
        
        self.register_buffer("running_mean", t.zeros(num_features))
        self.register_buffer("running_var", t.ones(num_features))
        self.register_buffer("num_batches_tracked", t.tensor(0))

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Normalize each channel.

        Compute the variance using `torch.var(x, unbiased=False)`
        Hint: you may also find it helpful to use the argument `keepdim`.

        x: shape (batch, channels, height, width)
        Return: shape (batch, channels, height, width)
        '''
        
        # Calculating mean and var over all dims except for the channel dim
        if self.training:
            # Using keepdim=True so we don't have to worry about broadasting them with x at the end
            mean = t.mean(x, dim=(0, 2, 3), keepdim=True)
            var = t.var(x, dim=(0, 2, 3), unbiased=False, keepdim=True)
            # Updating running mean and variance, in line with PyTorch documentation
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.squeeze()
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze()
            self.num_batches_tracked += 1
        else:
            mean = rearrange(self.running_mean, "channels -> 1 channels 1 1")
            var = rearrange(self.running_var, "channels -> 1 channels 1 1")
        
        # Rearranging these so they can be broadcasted (although there are other ways you could do this)
        weight = rearrange(self.weight, "channels -> 1 channels 1 1")
        bias = rearrange(self.bias, "channels -> 1 channels 1 1")
        
        return ((x - mean) / t.sqrt(var + self.eps)) * weight + bias

    def extra_repr(self) -> str:
        return ", ".join([f"{key}={getattr(self, key)}" for key in ["num_features", "eps", "momentum"]])


class AveragePool(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        """
        x: shape (batch, channels, height, width)
        Return: shape (batch, channels)
        """
        return t.mean(x, dim=(2, 3))


class ResidualBlock(nn.Module):
    def __init__(self, in_feats: int, out_feats: int, first_stride=1):
        """A single residual block with optional downsampling.

        For compatibility with the pretrained model, declare the left side branch first using a `Sequential`.

        If first_stride is > 1, this means the optional (conv + bn) should be present on the right branch. Declare it second using another `Sequential`.
        """
        super().__init__()
        
        self.left = Sequential(
            Conv2d(in_feats, out_feats, kernel_size=3, stride=first_stride, padding=1),
            BatchNorm2d(out_feats),
            ReLU(),
            Conv2d(out_feats, out_feats, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(out_feats)
        )
        
        if first_stride > 1:
            self.right = Sequential(
                Conv2d(in_feats, out_feats, kernel_size=1, stride=first_stride),
                BatchNorm2d(out_feats)
            )
        else:
            self.right = nn.Identity()  # type: ignore
            
        self.relu = ReLU()

    def forward(self, x: t.Tensor) -> t.Tensor:
        """Compute the forward pass.

        x: shape (batch, in_feats, height, width)

        Return: shape (batch, out_feats, height / stride, width / stride)

        If no downsampling block is present, the addition should just add the left branch's output to the input.
        """
        x_left = self.left(x)
        x_right = self.right(x)
        return self.relu(x_left + x_right)


class BlockGroup(nn.Module):
    def __init__(self, n_blocks: int, in_feats: int, out_feats: int, first_stride=1):
        """An n_blocks-long sequence of ResidualBlock where only the first block uses the provided stride."""
        super().__init__()
        
        blocks = [ResidualBlock(in_feats, out_feats, first_stride)] + [
            ResidualBlock(out_feats, out_feats) for n in range(n_blocks - 1)
        ]
        self.blocks = nn.Sequential(*blocks)  # type: ignore
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        """Compute the forward pass.
        x: shape (batch, in_feats, height, width)

        Return: shape (batch, out_feats, height / first_stride, width / first_stride)
        """
        return self.blocks(x)



class ResNet34(nn.Module):
    def __init__(
        self,
        n_blocks_per_group=[3, 4, 6, 3],
        out_features_per_group=[64, 128, 256, 512],
        first_strides_per_group=[1, 2, 2, 2],
        n_classes=1000,
    ):
        super().__init__()
        in_feats0 = 64

        self.in_layers = Sequential(
            Conv2d(3, in_feats0, kernel_size=7, stride=2, padding=3),
            BatchNorm2d(in_feats0),
            ReLU(),
            MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        all_in_feats = [in_feats0] + out_features_per_group[:-1]
        self.residual_layers = Sequential(
            *(
                BlockGroup(*args)
                for args in zip(
                    n_blocks_per_group,
                    all_in_feats,
                    out_features_per_group,
                    first_strides_per_group,
                )
            )
        )
        # Alternative that uses `add_module`, in a way which makes the layer names line up:
        # for idx, (n_blocks, in_feats, out_feats, first_stride) in enumerate(zip(
        #     n_blocks_per_group, all_in_feats, out_features_per_group, strides_per_group
        # )):
        #     self.add_module(f"layer{idx+1}", BlockGroup(n_blocks, in_feats, out_feats, first_stride))

        self.out_layers = Sequential(
            AveragePool(),
            Flatten(),
            Linear(out_features_per_group[-1], n_classes),
        )

    def forward(self, x: t.Tensor) -> t.Tensor:
        """
        x: shape (batch, channels, height, width)
        Return: shape (batch, n_classes)
        """
        x = self.in_layers(x)
        x = self.residual_layers(x)
        x = self.out_layers(x)
        return x


if MAIN:
    # ImageNet transforms:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


def prepare_data(images: list) -> t.Tensor:  # type: ignore
    """
    Return: shape (batch=len(images), num_channels=3, height=224, width=224)
    """
    x = t.stack([transform(img) for img in images], dim=0)  # type: ignore
    return x




# ================================= ConvNet training & testing =================================

if MAIN:
    epochs = 3
    loss_fn = nn.CrossEntropyLoss()
    batch_size = 128

    MODEL_FILENAME = "./w1d2_convnet_mnist.pt"
    device = "cuda" if t.cuda.is_available() else "cpu"

    trainset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
    testset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
    testloader = DataLoader(testset, batch_size=64, shuffle=True)

def train_convnet(trainloader: DataLoader, testloader: DataLoader, epochs: int, loss_fn: Callable) -> tuple:
    """
    Defines a ConvNet using our previous code, and trains it on the data in trainloader.
    
    Returns tuple of (loss_list, accuracy_list), where accuracy_list contains the fraction of accurate classifications on the test set, at the end of each epoch.
    """
    
    model = ConvNet().to(device).train()
    optimizer = t.optim.Adam(model.parameters())
    loss_list = []
    accuracy_list = []
    
    for epoch in tqdm_notebook(range(epochs)):
        
        for (x, y) in tqdm_notebook(trainloader, leave=False):
            
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            y_hat = model(x)
            loss = loss_fn(y_hat, y)
            loss.backward()
            optimizer.step()
            
            loss_list.append(loss.item())
        
        with t.inference_mode():
            
            accuracy = 0
            total = 0
            
            for (x, y) in testloader:

                x = x.to(device)
                y = y.to(device)

                y_hat = model(x)
                y_predictions = y_hat.argmax(1)
                accuracy += (y_predictions == y).sum().item()
                total += y.size(0)

            accuracy_list.append(accuracy/total)
            
        print(f"Epoch {epoch+1}/{epochs}, train loss is {loss:.6f}, accuracy is {accuracy}/{total}")  # type: ignore
    
    print(f"Saving model to: {MODEL_FILENAME}")
    t.save(model, MODEL_FILENAME)
    return loss_list, accuracy_list

if MAIN:
    loss_list, accuracy_list = train_convnet(trainloader, testloader, epochs, loss_fn)

# Actual work

In [5]:
# %%
import torch as t
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from fancy_einsum import einsum
from typing import Union, Optional, Callable
import numpy as np
from einops import rearrange
from tqdm.notebook import tqdm_notebook
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import time
import wandb
import os
# import utils
# from w0d3_solutions import ResNet34

MAIN = (__name__ == '__main__')
device = "cuda" if t.cuda.is_available() else "cpu"

os.environ['WANDB_NOTEBOOK_NAME'] = 'arena_w2d1'
wandb.login()



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [6]:
if MAIN:
    cifar_mean = [0.485, 0.456, 0.406]
    cifar_std = [0.229, 0.224, 0.225]

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar_mean, std=cifar_std)
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # utils.show_cifar_images(trainset, rows=3, cols=5)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## previous iterations

In [7]:
# def train(trainset, testset, epochs: int, loss_fn: Callable, batch_size: int, lr: float) -> None:

#     config_dict = {
#         "batch_size": batch_size,
#         "epochs": epochs,
#         "lr": lr,
#     }
#     wandb.init(project="w2d1_resnet", config=config_dict)

#     model = ResNet34().to(device).train()
#     optimizer = t.optim.Adam(model.parameters(), lr=lr)

#     examples_seen = 0
#     start_time = time.time()

#     trainloader = DataLoader(trainset, shuffle=True, batch_size=batch_size)
#     testloader = DataLoader(testset, shuffle=True, batch_size=batch_size)

#     wandb.watch(model, criterion=loss_fn, log="all", log_freq=10, log_graph=True)

#     for epoch in range(epochs):

#         progress_bar = tqdm_notebook(trainloader)

#         for (x, y) in progress_bar:

#             x = x.to(device)
#             y = y.to(device)

#             optimizer.zero_grad()
#             y_hat = model(x)
#             loss = loss_fn(y_hat, y)
#             loss.backward()
#             optimizer.step()

#             progress_bar.set_description(f"Epoch = {epoch}, Loss = {loss.item():.4f}")

#             examples_seen += len(x)
#             wandb.log({"train_loss": loss, "elapsed": time.time() - start_time}, step=examples_seen)

#         with t.inference_mode():

#             accuracy = 0
#             total = 0

#             for (x, y) in testloader:

#                 x = x.to(device)
#                 y = y.to(device)

#                 y_hat = model(x)
#                 y_predictions = y_hat.argmax(1)
#                 accuracy += (y_predictions == y).sum().item()
#                 total += y.size(0)

#             wandb.log({"test_accuracy": accuracy/total}, step=examples_seen)

#     filename = f"{wandb.run.dir}/model_state_dict.pt"
#     print(f"Saving model to: {filename}")
#     t.save(model.state_dict(), filename)
#     wandb.save(filename)
#     wandb.finish()


In [10]:
# if MAIN:
#   epochs = 1
#   loss_fn = nn.CrossEntropyLoss()
#   batch_size = 256
#   lr = 0.0025
#   train(trainset, testset, epochs, loss_fn, batch_size, lr)

VBox(children=(Label(value='83.307 MB of 83.307 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
elapsed,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▁
train_loss,█▄▃▃▂▃▃▃▂▂▂▂▂▂▂▂▃▂▂▂▂▂▂▁▁▂▁▁▂▂▂▂▁▁▁▁▁▁▁▃

0,1
elapsed,46.63053
test_accuracy,0.4486
train_loss,2.31925


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


  0%|          | 0/196 [00:00<?, ?it/s]



Saving model to: /content/wandb/run-20221114_032021-31e8fxtu/files/model_state_dict.pt


## current iteration

In [12]:
## Training loop with wandb sweep

def train() -> None:

    wandb.init()

    epochs = wandb.config.epochs
    batch_size = wandb.config.batch_size
    lr = wandb.config.lr

    model = ResNet34().to(device).train()
    optimizer = t.optim.Adam(model.parameters(), lr=lr)

    examples_seen = 0
    start_time = time.time()

    trainloader = DataLoader(trainset, shuffle=True, batch_size=batch_size)
    testloader = DataLoader(testset, shuffle=True, batch_size=batch_size)

    wandb.watch(model, criterion=loss_fn, log="all", log_freq=10, log_graph=True)

    for epoch in range(epochs):

        progress_bar = tqdm_notebook(trainloader)

        for (x, y) in progress_bar:

            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            y_hat = model(x)
            loss = loss_fn(y_hat, y)
            loss.backward()
            optimizer.step()

            progress_bar.set_description(f"Epoch = {epoch}, Loss = {loss.item():.4f}")

            examples_seen += len(x)
            wandb.log({"train_loss": loss, "elapsed": time.time() - start_time}, step=examples_seen)


        with t.inference_mode():

            accuracy = 0
            total = 0

            for (x, y) in testloader:

                x = x.to(device)
                y = y.to(device)

                y_hat = model(x)
                y_predictions = y_hat.argmax(1)
                accuracy += (y_predictions == y).sum().item()
                total += y.size(0)

            wandb.log({"test_accuracy": accuracy/total}, step=examples_seen)

        print(f"Epoch {epoch+1}/{epochs}, train loss is {loss:.6f}, accuracy is {accuracy}/{total}")

    filename = f"{wandb.run.dir}/model_state_dict.pt"
    print(f"Saving model to: {filename}")
    t.save(model.state_dict(), filename)
    wandb.save(filename)


In [13]:
sweep_config = {
    'method': 'random',
    'name': 'w2d1_resnet_sweep_2',
    'metric': {'name': 'test_accuracy', 'goal': 'maximize'},
    'parameters': 
    {
        'batch_size': {'values': [64, 128, 256]},
        'epochs': {'min': 1, 'max': 3},
        'lr': {'max': 0.1, 'min': 0.0001, 'distribution': 'log_uniform_values'}
     }
}

sweep_id = wandb.sweep(sweep=sweep_config, project='w2d1_resnet')

wandb.agent(sweep_id=sweep_id, function=train, count=2)
wandb.finish()

Create sweep with ID: enceyqij
Sweep URL: https://wandb.ai/abhatt349/w2d1_resnet/sweeps/enceyqij


[34m[1mwandb[0m: Agent Starting Run: ueeqolrr with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	lr: 0.004230750891130828


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1/3, train loss is 2.252360, accuracy is 4206/10000


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 2/3, train loss is 1.274308, accuracy is 5457/10000


  0%|          | 0/782 [00:00<?, ?it/s]



Epoch 3/3, train loss is 0.874775, accuracy is 6150/10000
Saving model to: /content/wandb/run-20221114_041740-ueeqolrr/files/model_state_dict.pt


0,1
elapsed,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▁▆█
train_loss,▇▆█▆▅▇▅▆▆▄▃▆▃▄▅▄▃▅▃▃▄▄▄▃▃▃▄▃▂▃▃▃▃▁▁▃▁▂▂▂

0,1
elapsed,419.07201
test_accuracy,0.615
train_loss,0.87477


[34m[1mwandb[0m: Agent Starting Run: ciz5v2us with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	lr: 0.010532929832841502


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1/2, train loss is 1.583317, accuracy is 4557/10000


  0%|          | 0/782 [00:00<?, ?it/s]



Epoch 2/2, train loss is 0.841183, accuracy is 6034/10000
Saving model to: /content/wandb/run-20221114_042505-ciz5v2us/files/model_state_dict.pt


0,1
elapsed,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▁█
train_loss,█▇▆▆▅▅▅▅▄▅▅▃▄▄▄▄▃▃▃▄▂▃▃▃▄▃▂▂▂▂▃▂▂▂▂▃▁▁▁▃

0,1
elapsed,281.36186
test_accuracy,0.6034
train_loss,0.84118
