##### Import require dependencies

In [9]:
import numpy as np
import torch
import gc
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

# Device configuration
device = torch.device("mps" if torch.mps.is_available() else "cpu")
device

device(type='mps')

##### Load CIFAR10 dataset

In [10]:
def data_loader(root, batch_size, random_seed=5, valid_size=0.1, shuffle=True, test=False):
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    # transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        normalize,
    ])

    if test:
        dataset = datasets.CIFAR10(
            root=root,
            train=False,
            download=True,
            transform=transform,
        )

        data_loader = torch.utils.data.DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=shuffle
        )

        return data_loader
    
    # load the dataset
    train_dataset = datasets.CIFAR10(
        root=root, train=True,
        download=True,
        transform=transform,
    )

    valid_dataset = datasets.CIFAR10(
        root=root,
        train=True,
        download=True, 
        transform=transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler)

    return (train_loader, valid_loader)

train_loader, valid_loader = data_loader(
    root='./datasets',
    batch_size=64
)

test_loader = data_loader(
    root='./datasets',
    batch_size=64,
    test=True
)

train_loader.dataset.data.shape

(50000, 32, 32, 3)

##### Build ResidualBlock

In [14]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, down_sample = None):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(out_channels)
        )

        self.down_sample = down_sample
        self.relu = nn.ReLU()
        self.out_channels = out_channels

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.down_sample:
            residual = self.down_sample(x)

        out += residual
        out = self.relu(out)
        return out

##### Build Model ResNet

In [21]:
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes = 10):
        super(ResNet, self).__init__()

        self.in_planes = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.max_pool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride = 1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
        self.avg_pool = nn.AvgPool2d(7, stride=1)

        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        down_sample = None
        if stride != 1 or self.in_planes != planes:
            down_sample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes),
            )

        layers = []
        layers.append(block(self.in_planes, planes, stride, down_sample))
        self.in_planes = planes
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.max_pool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

##### Training

In [25]:

num_classes = 10
num_epochs = 20
batch_size = 16
learning_rate = 0.01
path = "models/torch-resnet"

model = ResNet(ResidualBlock, [3, 4, 6, 3]).to(device)

#Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.001, momentum = 0.9)  

#Train the model
total_step = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)

        #Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(epoch+1, num_epochs, i+1, total_step, loss.item()))

        #Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()

    print ("Epoch [{}/{}], Loss: {:.4f}".format(epoch+1, num_epochs, loss.item()))

torch.save(model.state_dict(), path)

Epoch [1/20], Step [1/704], Loss: 2.4780
Epoch [1/20], Step [2/704], Loss: 2.3366
Epoch [1/20], Step [3/704], Loss: 2.5150
Epoch [1/20], Step [4/704], Loss: 2.4239
Epoch [1/20], Step [5/704], Loss: 2.5685
Epoch [1/20], Step [6/704], Loss: 2.4030
Epoch [1/20], Step [7/704], Loss: 2.2277
Epoch [1/20], Step [8/704], Loss: 2.4448
Epoch [1/20], Step [9/704], Loss: 2.4341
Epoch [1/20], Step [10/704], Loss: 2.0827
Epoch [1/20], Step [11/704], Loss: 2.3374
Epoch [1/20], Step [12/704], Loss: 2.4648
Epoch [1/20], Step [13/704], Loss: 2.2556
Epoch [1/20], Step [14/704], Loss: 2.3061
Epoch [1/20], Step [15/704], Loss: 2.6321
Epoch [1/20], Step [16/704], Loss: 2.4192
Epoch [1/20], Step [17/704], Loss: 2.3231
Epoch [1/20], Step [18/704], Loss: 2.2205
Epoch [1/20], Step [19/704], Loss: 2.3871
Epoch [1/20], Step [20/704], Loss: 2.4605
Epoch [1/20], Step [21/704], Loss: 2.5355
Epoch [1/20], Step [22/704], Loss: 2.4179
Epoch [1/20], Step [23/704], Loss: 2.3324
Epoch [1/20], Step [24/704], Loss: 2.0531
E

KeyboardInterrupt: 

In [None]:
# Validation
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in valid_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        del images, labels, outputs

    print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))