In [4]:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from tqdm.notebook import tqdm
from torch.nn import functional as F
import algos

device = 'cuda'

In [5]:
class Residual(nn.Module):
    def __init__(self, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.LazyConv2d(num_channels, 3, strides, 1)

        self.conv2 = nn.LazyConv2d(num_channels, 3, 1, 1)
        if use_1x1conv:
            self.conv3 = nn.LazyConv2d(num_channels, 1, strides)
        else:
            self.conv3 = None
        self.bn1 = nn.LazyBatchNorm2d()
        self.bn2 = nn.LazyBatchNorm2d()


    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

In [9]:
class ResNet(nn.Module):
    def b1(self):
        return nn.Sequential(
            nn.LazyConv2d(64, 7, 2, 3),
            nn.LazyBatchNorm2d(), nn.ReLU(),
            nn.MaxPool2d(3,2,1)
        )

    def block(self, num_residuals, num_channels, first_block=False):
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(Residual(num_channels, use_1x1conv=True, strides=2))
            else:
                blk.append(Residual(num_channels))
        return nn.Sequential(*blk)

    def __init__(self, arch, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(self.b1())
        for i, b in enumerate(arch):
            self.net.add_module(f'b{i+2}', self.block(*b, first_block=(i==0)))
        self.net.add_module('last', nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(),
            nn.LazyLinear(num_classes)
        ))
        self.net.apply(algos.init_cnn)

    def apply_init(self, inputs, init=None):
        self.forward(*inputs)
        if init is not None:
            self.net.apply(init)

In [12]:
class ResNet18(ResNet):
    def __init__(self, num_classes=10):
        super().__init__(((2, 64), (2, 128), (2, 256), (2, 512)),
                         num_classes)
    def forward(self, x):
        return self.net(x)

In [None]:
train_loader, _ = algos.load_mnist()
model = algos.fit(ResNet18(), train_loader)



  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [1/10], Average Loss: 0.3845


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [2/10], Average Loss: 0.2484


  0%|          | 0/469 [00:00<?, ?it/s]