In [1]:
# ---------------------------------------------------------------------------- #
# An implementation of https://arxiv.org/pdf/1512.03385.pdf                    #
# See section 4.2 for the model architecture on CIFAR-10                       #
# Some part of the code was referenced from below                              #
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py   #
# ---------------------------------------------------------------------------- #

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os

In [2]:
# flags
color = 1

In [3]:
ResNet18_color = "ResNet18_color/"
ResNet18_gray = "ResNet18_gray/"
PIXELS_DIR = f"{ResNet18_color}pixel_data/"
PIXELS_DIR_GRAY = f"{ResNet18_gray}pixel_data_grayscale/"

In [4]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Hyper-parameters
num_epochs = 300
learning_rate = 0.001

# Image preprocessing modules
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

cuda


In [5]:
# CIFAR-10 dataset
if color:
    train_dataset = torchvision.datasets.ImageFolder(root=f'./{PIXELS_DIR}train',
                                                 transform=transform)

    test_dataset = torchvision.datasets.ImageFolder(root=f'./{PIXELS_DIR}test',
                                                transform=transform)
else:
    train_dataset = torchvision.datasets.ImageFolder(root=f'./{PIXELS_DIR_GRAY}train',
                                                 transform=transform)

    test_dataset = torchvision.datasets.ImageFolder(root=f'./{PIXELS_DIR_GRAY}test',
                                                transform=transform)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=128, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=128, 
                                          shuffle=False)

In [6]:
# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=1, bias=False)

# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# ResNet
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[1], 2)
        self.layer3 = self.make_layer(block, 64, layers[2], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [7]:
model = ResNet(ResidualBlock, [2, 2, 2]).to(device)

In [8]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [9]:
# For updating learning rate
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [10]:
# Train the model
def train(epoch):
    global curr_lr
    train_loss = 0
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = 100. * correct / total
#         acc = accuracy(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        loss_ = train_loss/(i+1)

        if (i+1) % 100 == 0:
            print (f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss_:.4f}, Accuracy: {acc:.4f}")

    print (f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss_:.4f}, Accuracy: {acc:.4f}")
    logs.write(f"{epoch+1}, {loss_}, {acc}")

    # Decay learning rate
    if (epoch+1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)

In [11]:
# Test the model
def test(epoch):
    global best_losses
    model.eval()
    with torch.no_grad():
        test_loss = 0
        correct = 0
        total = 0
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            acc = 100. * correct / total

            test_loss += loss.item()
            loss_ = test_loss/(i+1)

        print(f"Loss: {loss_:.4f}, Accuracy: {acc:.4f}")
        logs.write(f" {loss_}, {acc}\n")

    
    # Save checkpoint and replace old best model if current model is better
    if loss < best_losses:
        best_losses = loss
        torch.save(model.state_dict(), '{}/model-epoch-{}-losses-{:.3f}.pth'.format(checkpoints,epoch+1,loss))

In [12]:
# accuracy = MulticlassAccuracy(num_classes=10).to(device)
if color:
    checkpoints = f'{ResNet18_color}checkpoints_color'
    logs_file = f'{ResNet18_color}logs_color.txt'
else:
    checkpoints = f'{ResNet18_gray}checkpoints_gray'
    logs_file = f'{ResNet18_gray}logs_gray.txt'
    
os.makedirs(checkpoints, exist_ok=True)
best_losses = torch.tensor(1e10)
total_step = len(train_loader)
curr_lr = learning_rate
logs = open(logs_file,'w')
logs.write("Epoch, Loss, Accuracy, Val_loss, Val_accuracy\n")
logs.close()

In [13]:
# for epoch in range(num_epochs):
for epoch in range(num_epochs):
    logs = open(logs_file,'a')
    train(epoch)
    test(epoch)
    logs.close()

Epoch [1/300], Step [100/391], Loss: 1.7985, Accuracy: 32.3984
Epoch [1/300], Step [200/391], Loss: 1.6530, Accuracy: 38.3945
Epoch [1/300], Step [300/391], Loss: 1.5499, Accuracy: 42.5859
Epoch [1/300], Step [391/391], Loss: 1.4777, Accuracy: 45.4240
Loss: 1.3529, Accuracy: 50.0200
Epoch [2/300], Step [100/391], Loss: 2.2475, Accuracy: 15.9609
Epoch [2/300], Step [200/391], Loss: 2.1061, Accuracy: 18.6523
Epoch [2/300], Step [300/391], Loss: 2.0406, Accuracy: 20.2865
Epoch [2/300], Step [391/391], Loss: 1.9711, Accuracy: 23.2320
Loss: 1.6809, Accuracy: 36.9700
Epoch [3/300], Step [100/391], Loss: 1.6073, Accuracy: 39.1797
Epoch [3/300], Step [200/391], Loss: 1.5247, Accuracy: 42.8906
Epoch [3/300], Step [300/391], Loss: 1.4785, Accuracy: 44.9661
Epoch [3/300], Step [391/391], Loss: 1.4398, Accuracy: 46.5080
Loss: 1.2959, Accuracy: 52.3000
Epoch [4/300], Step [100/391], Loss: 1.2851, Accuracy: 53.7500
Epoch [4/300], Step [200/391], Loss: 1.2498, Accuracy: 54.8047
Epoch [4/300], Step [3

Epoch [29/300], Step [391/391], Loss: 0.3516, Accuracy: 87.4560
Loss: 0.4871, Accuracy: 83.5600
Epoch [30/300], Step [100/391], Loss: 0.3408, Accuracy: 88.0391
Epoch [30/300], Step [200/391], Loss: 0.3417, Accuracy: 88.1172
Epoch [30/300], Step [300/391], Loss: 0.3454, Accuracy: 87.8073
Epoch [30/300], Step [391/391], Loss: 0.3457, Accuracy: 87.8020
Loss: 0.4975, Accuracy: 82.7600
Epoch [31/300], Step [100/391], Loss: 0.3497, Accuracy: 87.8984
Epoch [31/300], Step [200/391], Loss: 0.3402, Accuracy: 88.0820
Epoch [31/300], Step [300/391], Loss: 0.3409, Accuracy: 88.0000
Epoch [31/300], Step [391/391], Loss: 0.3436, Accuracy: 87.8320
Loss: 0.5071, Accuracy: 83.1400
Epoch [32/300], Step [100/391], Loss: 0.3174, Accuracy: 88.5312
Epoch [32/300], Step [200/391], Loss: 0.3199, Accuracy: 88.4375
Epoch [32/300], Step [300/391], Loss: 0.3308, Accuracy: 88.2474
Epoch [32/300], Step [391/391], Loss: 0.3339, Accuracy: 88.1040
Loss: 0.4900, Accuracy: 83.5900
Epoch [33/300], Step [100/391], Loss: 0.

Epoch [58/300], Step [200/391], Loss: 0.2227, Accuracy: 92.1758
Epoch [58/300], Step [300/391], Loss: 0.2253, Accuracy: 91.9974
Epoch [58/300], Step [391/391], Loss: 0.2250, Accuracy: 92.0440
Loss: 0.4913, Accuracy: 85.1500
Epoch [59/300], Step [100/391], Loss: 0.2180, Accuracy: 92.1250
Epoch [59/300], Step [200/391], Loss: 0.2258, Accuracy: 91.9688
Epoch [59/300], Step [300/391], Loss: 0.2232, Accuracy: 92.1042
Epoch [59/300], Step [391/391], Loss: 0.2221, Accuracy: 92.1500
Loss: 0.4979, Accuracy: 84.8500
Epoch [60/300], Step [100/391], Loss: 0.2226, Accuracy: 92.2656
Epoch [60/300], Step [200/391], Loss: 0.2224, Accuracy: 92.0938
Epoch [60/300], Step [300/391], Loss: 0.2230, Accuracy: 92.0312
Epoch [60/300], Step [391/391], Loss: 0.2225, Accuracy: 92.0160
Loss: 0.5020, Accuracy: 84.8800
Epoch [61/300], Step [100/391], Loss: 0.2026, Accuracy: 92.9141
Epoch [61/300], Step [200/391], Loss: 0.2067, Accuracy: 92.6992
Epoch [61/300], Step [300/391], Loss: 0.2050, Accuracy: 92.7422
Epoch [6

Loss: 0.4931, Accuracy: 85.7000
Epoch [87/300], Step [100/391], Loss: 0.1885, Accuracy: 93.1641
Epoch [87/300], Step [200/391], Loss: 0.1835, Accuracy: 93.4297
Epoch [87/300], Step [300/391], Loss: 0.1871, Accuracy: 93.3464
Epoch [87/300], Step [391/391], Loss: 0.1867, Accuracy: 93.3400
Loss: 0.5034, Accuracy: 85.4300
Epoch [88/300], Step [100/391], Loss: 0.1818, Accuracy: 93.5078
Epoch [88/300], Step [200/391], Loss: 0.1826, Accuracy: 93.4727
Epoch [88/300], Step [300/391], Loss: 0.1826, Accuracy: 93.4427
Epoch [88/300], Step [391/391], Loss: 0.1841, Accuracy: 93.3740
Loss: 0.5035, Accuracy: 85.2800
Epoch [89/300], Step [100/391], Loss: 0.1839, Accuracy: 93.4219
Epoch [89/300], Step [200/391], Loss: 0.1866, Accuracy: 93.4414
Epoch [89/300], Step [300/391], Loss: 0.1841, Accuracy: 93.4688
Epoch [89/300], Step [391/391], Loss: 0.1861, Accuracy: 93.4100
Loss: 0.5212, Accuracy: 84.9700
Epoch [90/300], Step [100/391], Loss: 0.1845, Accuracy: 93.5469
Epoch [90/300], Step [200/391], Loss: 0.

Epoch [115/300], Step [200/391], Loss: 0.1736, Accuracy: 93.7852
Epoch [115/300], Step [300/391], Loss: 0.1772, Accuracy: 93.6589
Epoch [115/300], Step [391/391], Loss: 0.1793, Accuracy: 93.5520
Loss: 0.5018, Accuracy: 85.1800
Epoch [116/300], Step [100/391], Loss: 0.1789, Accuracy: 93.7266
Epoch [116/300], Step [200/391], Loss: 0.1792, Accuracy: 93.6406
Epoch [116/300], Step [300/391], Loss: 0.1812, Accuracy: 93.5052
Epoch [116/300], Step [391/391], Loss: 0.1816, Accuracy: 93.5420
Loss: 0.5075, Accuracy: 85.3600
Epoch [117/300], Step [100/391], Loss: 0.1852, Accuracy: 93.3438
Epoch [117/300], Step [200/391], Loss: 0.1797, Accuracy: 93.4570
Epoch [117/300], Step [300/391], Loss: 0.1804, Accuracy: 93.6198
Epoch [117/300], Step [391/391], Loss: 0.1810, Accuracy: 93.5080
Loss: 0.5081, Accuracy: 85.1200
Epoch [118/300], Step [100/391], Loss: 0.1861, Accuracy: 93.4609
Epoch [118/300], Step [200/391], Loss: 0.1807, Accuracy: 93.6172
Epoch [118/300], Step [300/391], Loss: 0.1801, Accuracy: 93

Epoch [143/300], Step [300/391], Loss: 0.1786, Accuracy: 93.5651
Epoch [143/300], Step [391/391], Loss: 0.1785, Accuracy: 93.6080
Loss: 0.4950, Accuracy: 85.3900
Epoch [144/300], Step [100/391], Loss: 0.1739, Accuracy: 93.6094
Epoch [144/300], Step [200/391], Loss: 0.1754, Accuracy: 93.6133
Epoch [144/300], Step [300/391], Loss: 0.1775, Accuracy: 93.5807
Epoch [144/300], Step [391/391], Loss: 0.1778, Accuracy: 93.6020
Loss: 0.4991, Accuracy: 85.6300
Epoch [145/300], Step [100/391], Loss: 0.1833, Accuracy: 93.3438
Epoch [145/300], Step [200/391], Loss: 0.1800, Accuracy: 93.4688
Epoch [145/300], Step [300/391], Loss: 0.1778, Accuracy: 93.5677
Epoch [145/300], Step [391/391], Loss: 0.1763, Accuracy: 93.6100
Loss: 0.5080, Accuracy: 85.2600
Epoch [146/300], Step [100/391], Loss: 0.1841, Accuracy: 93.5391
Epoch [146/300], Step [200/391], Loss: 0.1796, Accuracy: 93.6484
Epoch [146/300], Step [300/391], Loss: 0.1794, Accuracy: 93.6068
Epoch [146/300], Step [391/391], Loss: 0.1790, Accuracy: 93

Epoch [171/300], Step [391/391], Loss: 0.1800, Accuracy: 93.5900
Loss: 0.5119, Accuracy: 85.4500
Epoch [172/300], Step [100/391], Loss: 0.1790, Accuracy: 93.6562
Epoch [172/300], Step [200/391], Loss: 0.1844, Accuracy: 93.4297
Epoch [172/300], Step [300/391], Loss: 0.1807, Accuracy: 93.5573
Epoch [172/300], Step [391/391], Loss: 0.1811, Accuracy: 93.5120
Loss: 0.5150, Accuracy: 85.1600
Epoch [173/300], Step [100/391], Loss: 0.1808, Accuracy: 93.2812
Epoch [173/300], Step [200/391], Loss: 0.1795, Accuracy: 93.5234
Epoch [173/300], Step [300/391], Loss: 0.1797, Accuracy: 93.5599
Epoch [173/300], Step [391/391], Loss: 0.1812, Accuracy: 93.5160
Loss: 0.4992, Accuracy: 85.4800
Epoch [174/300], Step [100/391], Loss: 0.1805, Accuracy: 93.2656
Epoch [174/300], Step [200/391], Loss: 0.1818, Accuracy: 93.4414
Epoch [174/300], Step [300/391], Loss: 0.1810, Accuracy: 93.4870
Epoch [174/300], Step [391/391], Loss: 0.1797, Accuracy: 93.5060
Loss: 0.4975, Accuracy: 85.5300
Epoch [175/300], Step [100/

Loss: 0.5113, Accuracy: 85.5600
Epoch [200/300], Step [100/391], Loss: 0.1665, Accuracy: 93.9688
Epoch [200/300], Step [200/391], Loss: 0.1713, Accuracy: 93.8320
Epoch [200/300], Step [300/391], Loss: 0.1777, Accuracy: 93.6406
Epoch [200/300], Step [391/391], Loss: 0.1783, Accuracy: 93.6300
Loss: 0.5043, Accuracy: 85.4900
Epoch [201/300], Step [100/391], Loss: 0.1719, Accuracy: 93.6562
Epoch [201/300], Step [200/391], Loss: 0.1782, Accuracy: 93.4336
Epoch [201/300], Step [300/391], Loss: 0.1757, Accuracy: 93.6406
Epoch [201/300], Step [391/391], Loss: 0.1752, Accuracy: 93.7080
Loss: 0.5019, Accuracy: 85.3000
Epoch [202/300], Step [100/391], Loss: 0.1785, Accuracy: 93.3594
Epoch [202/300], Step [200/391], Loss: 0.1796, Accuracy: 93.5156
Epoch [202/300], Step [300/391], Loss: 0.1785, Accuracy: 93.5729
Epoch [202/300], Step [391/391], Loss: 0.1781, Accuracy: 93.5840
Loss: 0.5013, Accuracy: 85.7200
Epoch [203/300], Step [100/391], Loss: 0.1824, Accuracy: 93.5938
Epoch [203/300], Step [200/

Epoch [228/300], Step [100/391], Loss: 0.1687, Accuracy: 93.8438
Epoch [228/300], Step [200/391], Loss: 0.1735, Accuracy: 93.7461
Epoch [228/300], Step [300/391], Loss: 0.1738, Accuracy: 93.7943
Epoch [228/300], Step [391/391], Loss: 0.1760, Accuracy: 93.7000
Loss: 0.5015, Accuracy: 85.7400
Epoch [229/300], Step [100/391], Loss: 0.1793, Accuracy: 93.5000
Epoch [229/300], Step [200/391], Loss: 0.1769, Accuracy: 93.6094
Epoch [229/300], Step [300/391], Loss: 0.1767, Accuracy: 93.6068
Epoch [229/300], Step [391/391], Loss: 0.1784, Accuracy: 93.5940
Loss: 0.4977, Accuracy: 85.3500
Epoch [230/300], Step [100/391], Loss: 0.1752, Accuracy: 93.7031
Epoch [230/300], Step [200/391], Loss: 0.1752, Accuracy: 93.7617
Epoch [230/300], Step [300/391], Loss: 0.1776, Accuracy: 93.6536
Epoch [230/300], Step [391/391], Loss: 0.1785, Accuracy: 93.5920
Loss: 0.5069, Accuracy: 85.2600
Epoch [231/300], Step [100/391], Loss: 0.1727, Accuracy: 93.9375
Epoch [231/300], Step [200/391], Loss: 0.1723, Accuracy: 93

Epoch [256/300], Step [200/391], Loss: 0.1795, Accuracy: 93.5781
Epoch [256/300], Step [300/391], Loss: 0.1806, Accuracy: 93.5417
Epoch [256/300], Step [391/391], Loss: 0.1786, Accuracy: 93.6480
Loss: 0.5158, Accuracy: 85.4400
Epoch [257/300], Step [100/391], Loss: 0.1785, Accuracy: 93.7188
Epoch [257/300], Step [200/391], Loss: 0.1808, Accuracy: 93.6133
Epoch [257/300], Step [300/391], Loss: 0.1806, Accuracy: 93.5677
Epoch [257/300], Step [391/391], Loss: 0.1804, Accuracy: 93.5700
Loss: 0.5031, Accuracy: 85.2600
Epoch [258/300], Step [100/391], Loss: 0.1756, Accuracy: 93.8906
Epoch [258/300], Step [200/391], Loss: 0.1792, Accuracy: 93.6172
Epoch [258/300], Step [300/391], Loss: 0.1769, Accuracy: 93.7135
Epoch [258/300], Step [391/391], Loss: 0.1775, Accuracy: 93.6900
Loss: 0.5060, Accuracy: 85.5600
Epoch [259/300], Step [100/391], Loss: 0.1725, Accuracy: 93.9922
Epoch [259/300], Step [200/391], Loss: 0.1766, Accuracy: 93.7344
Epoch [259/300], Step [300/391], Loss: 0.1778, Accuracy: 93

Epoch [284/300], Step [300/391], Loss: 0.1764, Accuracy: 93.7344
Epoch [284/300], Step [391/391], Loss: 0.1790, Accuracy: 93.6620
Loss: 0.5076, Accuracy: 85.1400
Epoch [285/300], Step [100/391], Loss: 0.1830, Accuracy: 93.2891
Epoch [285/300], Step [200/391], Loss: 0.1779, Accuracy: 93.6016
Epoch [285/300], Step [300/391], Loss: 0.1791, Accuracy: 93.5469
Epoch [285/300], Step [391/391], Loss: 0.1804, Accuracy: 93.4860
Loss: 0.5076, Accuracy: 85.5000
Epoch [286/300], Step [100/391], Loss: 0.1763, Accuracy: 93.6484
Epoch [286/300], Step [200/391], Loss: 0.1772, Accuracy: 93.6719
Epoch [286/300], Step [300/391], Loss: 0.1782, Accuracy: 93.6094
Epoch [286/300], Step [391/391], Loss: 0.1794, Accuracy: 93.5980
Loss: 0.5120, Accuracy: 85.2400
Epoch [287/300], Step [100/391], Loss: 0.1805, Accuracy: 93.7734
Epoch [287/300], Step [200/391], Loss: 0.1818, Accuracy: 93.5469
Epoch [287/300], Step [300/391], Loss: 0.1803, Accuracy: 93.5495
Epoch [287/300], Step [391/391], Loss: 0.1814, Accuracy: 93

In [14]:
# Save the model checkpoint
if color:
    torch.save(model.state_dict(), f'{ResNet18_color}resnet_color.ckpt')
else:
    torch.save(model.state_dict(), f'{ResNet18_gray}resnet_gray.ckpt')