## ResNet on CIFAR 10

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)

        residual = x
        if self.downsample is not None:
            residual = self.downsample(residual)

        return F.relu(out + residual)


class ResNet(nn.Module):
    def __init__(self, num_blocks, num_classes=10):
        super(ResNet, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16)
        )

        self.layer2 = [ResBlock(16, 16)]
        for i in range(num_blocks[0] - 1):
            self.layer2.append(ResBlock(16, 16))
        self.layer2 = nn.ModuleList(self.layer2)

        self.layer3 = [ResBlock(16, 32, 2)]
        for i in range(num_blocks[1] - 1):
            self.layer3.append(ResBlock(32, 32))
        self.layer3 = nn.ModuleList(self.layer3)

        self.layer4 = [ResBlock(32, 64, 2)]
        for i in range(num_blocks[2] - 1):
            self.layer4.append(ResBlock(64, 64))
        self.layer4 = nn.ModuleList(self.layer4)

        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.layer1(x)
        out = self.relu(out)

        for i in range(len(self.layer2)):
            out = self.layer2[i](out)

        for i in range(len(self.layer3)):
            out = self.layer3[i](out)

        for i in range(len(self.layer4)):
            out = self.layer4[i](out)

        out = self.avgpool(out)
        out = nn.Flatten()(out)
        out = self.fc(out)

        return out


def resnet20():
    return ResNet([3, 3, 3])


def resnet32():
    return ResNet([5, 5, 5])


def resnet44():
    return ResNet([7, 7, 7])


def resnet56():
    return ResNet([9, 9, 9])



import torch
from torchvision.datasets import CIFAR10, CIFAR100
import torchvision.transforms as T
from torch.utils.data import DataLoader
import time
# from resnet import resnet56


train_transform = T.Compose([
    T.RandomCrop(size=32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

try:
    train_ds = CIFAR10(root='./', train=True, transform=train_transform,
                        download=False)
except:
    train_ds = CIFAR10(root='./', train=True, transform=train_transform,
                        download=True)

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)

try:
    val_ds = CIFAR10(root='./', train=False, transform=val_transform,
                        download=False)
except:
    val_ds = CIFAR10(root='./', train=False, transform=val_transform,
                        download=True)

val_dl = DataLoader(val_ds, batch_size=128)

print(f"Total {len(train_ds)} Training Data")
print(f"Total {len(val_ds)} Validation Data")



model = resnet56()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)

lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)

criterion = torch.nn.CrossEntropyLoss()

epochs = 150

best_acc = 0.
total_time = 0

for epoch in range(epochs):
    tick = time.time()
    model.train()
    epoch_loss = 0.
    if epoch == 99 or epoch == 124:
        lr_scheduler.step()
    for data in train_dl:
        optimizer.zero_grad()
        img, label = data[0].to(device), data[1].to(device)

        pred = model(img)

        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"\nEpoch {epoch + 1:4d} Train Loss: {epoch_loss:.6f}")

    model.eval()
    correct = 0.
    with torch.no_grad():
        for data in val_dl:
            img, label = data[0].to(device), data[1].to(device)

            pred = model(img)
            pred = torch.argmax(pred.data, 1)
            correct += (pred == label).sum().item()

        print(f"Epoch {epoch + 1:4d} Validation Accuracy: {100 * correct / len(val_ds)}%")

    if best_acc < correct:
        print(f"New Best Accuracy")
        best_acc = correct
        torch.save({
            'model_state_dict': model.state_dict()
        }, 'model.pt')
    tock = time.time()
    total_time += tock - tick
    print(f"Total Time for Epoch {epoch + 1:4d}: {tock - tick:.6f}")

print(f"Total Time: {total_time:.6f}")







