In [37]:
import torch
from torch import nn, optim, tensor
from typing import List, Optional

from datasets import load_dataset
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F

from torch.utils.data import DataLoader

In [2]:
class ProjectionLayer(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        return self.bn(self.proj(x))

In [3]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, ks =3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=3//2)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=3//2)
        
        if stride != 1 or in_channels != out_channels:
            self.proj = ProjectionLayer(in_channels, out_channels, stride)
        else:
            self.proj = nn.Identity()
        
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.act2 = nn.ReLU()
        
    def forward(self, x):
        projection = self.proj(x)
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(projection + self.bn2(self.conv2(x)))
        return x
        

In [4]:
data = torch.randn(1, 3, 128, 128)
resnet = ResNetBlock(3, 100, 2)
resnet(data).shape

torch.Size([1, 100, 64, 64])

In [5]:
class BottleneckResNet(nn.Module):
    def __init__(self, in_channels, bottleneck_channels, out_channels, stride):
        super().__init__()
        if in_channels != out_channels or stride !=1:
            self.proj = ProjectionLayer(in_channels, out_channels, stride)
        else:
            self.proj = nn.Identity()
        
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(bottleneck_channels),
            nn.ReLU(),
            nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(bottleneck_channels),
            nn.ReLU(),
            nn.Conv2d(bottleneck_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
        )
        self.act = nn.ReLU()
        
    def forward(self, x):
        return self.act(self.proj(x) + self.net(x))

In [6]:
resnet = BottleneckResNet(33, 223, 23, 1)
data = torch.randn(1, 33, 128, 128)
resnet(data).shape

torch.Size([1, 23, 128, 128])

In [60]:
class ResNet(nn.Module):
    def __init__(self, n_blocks,
                n_channels: Optional[List[int]], 
                bottlenecks: Optional[List[int]], 
                img_channels=3,
                first_kernel_size=7,
                out_channels=1000
                ):
        
        super().__init__()
        
        modules = [
            nn.Conv2d(img_channels, n_channels[0], first_kernel_size, 2, first_kernel_size//2),
            nn.BatchNorm2d(n_channels[0]),
        ]
        
        assert len(n_blocks) == len(n_channels)
        if bottlenecks is None:
            for pos, (block_len, channel_size) in enumerate(zip(n_blocks, n_channels)):
                for i in range(block_len - 1):
                    modules.append(ResNetBlock(channel_size, channel_size, 1))
                if pos != len(n_blocks)-1:
                    modules.append(ResNetBlock(channel_size, n_channels[pos+1], 2))
    
        else:
            for pos, (block_len, bottleneck_size, channel_size) in enumerate(zip(n_blocks, bottlenecks, n_channels)):
                for i in range(block_len - 1):
                    modules.append(BottleneckResNet(channel_size, bottleneck_size, channel_size, 1))
                if pos != len(n_blocks)-1:
                    modules.append(BottleneckResNet(channel_size, bottleneck_size, n_channels[pos+1], 2))
        
        self.net = nn.Sequential(*modules)
        self.fc = nn.Linear(n_channels[-1], out_channels)
        
        
    def forward(self, x):
        x = self.net(x)
        x = x.view(x.shape[0], x.shape[1], -1)
        x = x.mean(-1)
        x = self.fc(x)
        
        return x

In [61]:
resnet = ResNet([2, 2], [64, 128], None)

In [62]:
data = torch.randn(2, 3, 128, 128)
resnet(data).shape

torch.Size([2, 1000])

In [63]:
ds = load_dataset("ylecun/mnist")

In [64]:
x, y = ds['train'].features

In [65]:
dsd_trn = ds['train']
dsd_val = ds['test']

tfs = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(32)
])
    

def collate_fn(b):
    xb = torch.stack([tfs(o[x]) for o in b])
    yb = torch.stack([tensor(o[y]) for o in b])
    return xb, yb

In [70]:
dl_trn = DataLoader(dsd_trn, batch_size=300, collate_fn=collate_fn)

In [71]:
next(iter(dl_trn))

(tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         ...,
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ..

In [90]:
n_epochs = 2
tmax = n_epochs * len(dl_trn)

lr= 1
model = ResNet([4, 6], [64, 128], [64, 64], img_channels=1, out_channels=10, first_kernel_size=3)
opt = optim.AdamW(model.parameters(), lr)
schedo = CosineAnnealingLR(opt, tmax, )

for n in range(n_epochs):
    for i, (xb, yb) in enumerate(dl_trn):
        out = model(xb)
        loss = F.cross_entropy(out, yb)
    
        loss.backward()

        with torch.no_grad():
            opt.step()
            opt.zero_grad()
            schedo.step()
            
            if i % (len(dl_trn)/10) == 0 or i == len(dl_trn) -1:
                print(loss)

tensor(2.4311, grad_fn=<NllLossBackward0>)
tensor(2.3520, grad_fn=<NllLossBackward0>)
tensor(1.7827, grad_fn=<NllLossBackward0>)
tensor(1.4508, grad_fn=<NllLossBackward0>)
tensor(1.0181, grad_fn=<NllLossBackward0>)
tensor(0.6496, grad_fn=<NllLossBackward0>)
tensor(0.3520, grad_fn=<NllLossBackward0>)
tensor(0.3655, grad_fn=<NllLossBackward0>)
tensor(0.2827, grad_fn=<NllLossBackward0>)
tensor(0.1909, grad_fn=<NllLossBackward0>)
tensor(0.4274, grad_fn=<NllLossBackward0>)
tensor(0.2304, grad_fn=<NllLossBackward0>)
tensor(0.1500, grad_fn=<NllLossBackward0>)
tensor(0.0874, grad_fn=<NllLossBackward0>)
tensor(0.0732, grad_fn=<NllLossBackward0>)
tensor(0.0549, grad_fn=<NllLossBackward0>)
tensor(0.0501, grad_fn=<NllLossBackward0>)
tensor(0.0825, grad_fn=<NllLossBackward0>)
tensor(0.0762, grad_fn=<NllLossBackward0>)
tensor(0.0403, grad_fn=<NllLossBackward0>)
tensor(0.0308, grad_fn=<NllLossBackward0>)
tensor(0.1525, grad_fn=<NllLossBackward0>)


In [83]:
dl_val = DataLoader(dsd_val, batch_size=300, collate_fn=collate_fn)

In [91]:
model.eval()
accuracy = []
with torch.no_grad():
    for i, (xb, yb) in enumerate(dl_val):
        out = model(xb)
        out.softmax(-1)
        accuracy.append((out.argmax(-1) == yb).float().mean())

In [92]:
print(tensor(accuracy).mean())

tensor(0.9893)
