In [27]:
import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
import json

In [63]:
class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, use1x1conv=False, stride=1):
        """
        use1x1conv: used to change channels and shape
        """
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, padding=1, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        if use1x1conv:
            self.conv3 = nn.Conv2d(
                in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv3 = None

    def forward(self, X):
        # print(type(self.conv1), type(self.bn1), type(self.relu))
        Y = self.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        # print(Y.shape, X.shape)
        return self.relu(Y + X)


def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    blks = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            # half the size of figure
            blks.append(Residual(in_channels, out_channels, use1x1conv=True, stride=2))
        else:
            blks.append(Residual(out_channels, out_channels))
    return blks


class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        b1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        b2 = nn.Sequential(*resnet_block(in_channels=3, out_channels=3, 
                                         num_residuals=2, first_block=True))
        
        b3 = nn.Sequential(*resnet_block(in_channels=3, out_channels=24, num_residuals=2))

        b4 = nn.Sequential(*resnet_block(in_channels=24, out_channels=64, num_residuals=2))
        
        b5 = nn.Sequential(*resnet_block(64, 128, 2))
        
        self.net = nn.Sequential(b1, b2, b3, b4, b5, nn.AdaptiveAvgPool2d((
            1, 1)), nn.Flatten(), nn.Linear(128, 100))

    def forward(self, x):
        return self.net(x)

In [64]:
resetnet = ResNet()
x = torch.randn((1, 3, 32, 32))

for layer in resetnet.net:
    x = layer(x)
    print(layer.__class__.__name__, 'output shape:\t', x.shape)

Sequential output shape:	 torch.Size([1, 3, 16, 16])
Sequential output shape:	 torch.Size([1, 3, 16, 16])
Sequential output shape:	 torch.Size([1, 24, 8, 8])
Sequential output shape:	 torch.Size([1, 64, 4, 4])
Sequential output shape:	 torch.Size([1, 128, 2, 2])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 128, 1, 1])
Flatten output shape:	 torch.Size([1, 128])
Linear output shape:	 torch.Size([1, 100])


## datasets CIFAR100

In [65]:
root_path = "./data"
train_dataset = datasets.CIFAR100(root=root_path, 
                  download=True, 
                  train=True, 
                  transform=transforms.ToTensor())

test_dataset = datasets.CIFAR100(root=root_path,
                  download=True, 
                  train=False, 
                  transform=transforms.ToTensor())
classes = json.load(open("./datasets/cifar-100-labels.json"))

Files already downloaded and verified
Files already downloaded and verified


In [66]:
def init_weight(m):
    if isinstance(m, (nn.Linear)):
        nn.init.normal_(m.weight)
        nn.init.zeros_(m.bias)
resetnet.apply(init_weight)

ResNet(
  (net): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Residual(
        (conv1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): Residual(
        (conv1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv2): Conv2d(3, 3, ker

In [67]:
from utils import ProgressMeter, AverageMeter

In [68]:
for name, param in resetnet.named_parameters():
    print(name, param.size())

net.0.0.weight torch.Size([3, 3, 3, 3])
net.0.0.bias torch.Size([3])
net.0.1.weight torch.Size([3])
net.0.1.bias torch.Size([3])
net.1.0.conv1.weight torch.Size([3, 3, 3, 3])
net.1.0.conv1.bias torch.Size([3])
net.1.0.bn1.weight torch.Size([3])
net.1.0.bn1.bias torch.Size([3])
net.1.0.conv2.weight torch.Size([3, 3, 3, 3])
net.1.0.conv2.bias torch.Size([3])
net.1.0.bn2.weight torch.Size([3])
net.1.0.bn2.bias torch.Size([3])
net.1.1.conv1.weight torch.Size([3, 3, 3, 3])
net.1.1.conv1.bias torch.Size([3])
net.1.1.bn1.weight torch.Size([3])
net.1.1.bn1.bias torch.Size([3])
net.1.1.conv2.weight torch.Size([3, 3, 3, 3])
net.1.1.conv2.bias torch.Size([3])
net.1.1.bn2.weight torch.Size([3])
net.1.1.bn2.bias torch.Size([3])
net.2.0.conv1.weight torch.Size([24, 3, 3, 3])
net.2.0.conv1.bias torch.Size([24])
net.2.0.bn1.weight torch.Size([24])
net.2.0.bn1.bias torch.Size([24])
net.2.0.conv2.weight torch.Size([24, 24, 3, 3])
net.2.0.conv2.bias torch.Size([24])
net.2.0.bn2.weight torch.Size([24])
ne

In [77]:
num_epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(resetnet.parameters(), lr=0.1)
batch_size = 32

train_dataloader = DataLoader(train_dataset, batch_size)
def train():
    for epoch in range(num_epochs):
        end = time.time()
        
        batch_time = AverageMeter("Batch time:", ':.2f')
        acc = AverageMeter("Acc:", ":.4f")
        losses = AverageMeter('Loss:', ':.4f')
        
        progress = ProgressMeter(
            len(train_dataloader), 
            [batch_time, acc, losses],
            prefix="Epoch: [{}]".format(epoch + 1)
        )
        
        for i, (X, y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            pred_y = resetnet(X)
            loss = criterion(pred_y, y).sum()
            loss.backward()
            optimizer.step()

            # print(loss.item())
            losses.update(loss.item())


            acc.update((pred_y.argmax(dim=-1) == y).float().mean().item())
            batch_time.update(time.time() - end)
            
            progress.display(i + 1)
 
train()

Epoch: [1][   1/1563]	Batch time: 0.06 (0.06)	Acc: 0.1875 (0.1875)	Loss: 2.9773 (2.9773)
Epoch: [1][   2/1563]	Batch time: 0.10 (0.08)	Acc: 0.2812 (0.2344)	Loss: 3.0829 (3.0301)
Epoch: [1][   3/1563]	Batch time: 0.15 (0.10)	Acc: 0.3125 (0.2604)	Loss: 2.9358 (2.9987)
Epoch: [1][   4/1563]	Batch time: 0.18 (0.12)	Acc: 0.3750 (0.2891)	Loss: 2.9589 (2.9887)
Epoch: [1][   5/1563]	Batch time: 0.23 (0.14)	Acc: 0.2188 (0.2750)	Loss: 2.8916 (2.9693)
Epoch: [1][   6/1563]	Batch time: 0.29 (0.17)	Acc: 0.3125 (0.2812)	Loss: 2.8307 (2.9462)
Epoch: [1][   7/1563]	Batch time: 0.33 (0.19)	Acc: 0.3750 (0.2946)	Loss: 2.7736 (2.9215)
Epoch: [1][   8/1563]	Batch time: 0.37 (0.21)	Acc: 0.2812 (0.2930)	Loss: 2.9224 (2.9217)
Epoch: [1][   9/1563]	Batch time: 0.42 (0.24)	Acc: 0.4688 (0.3125)	Loss: 2.5060 (2.8755)
Epoch: [1][  10/1563]	Batch time: 0.46 (0.26)	Acc: 0.2188 (0.3031)	Loss: 3.0836 (2.8963)
Epoch: [1][  11/1563]	Batch time: 0.50 (0.28)	Acc: 0.4688 (0.3182)	Loss: 2.1385 (2.8274)
Epoch: [1][  12/1563]

In [1]:
torch.save(resetnet.state_dict(), "checkpoint.pth.tar")

NameError: name 'torch' is not defined