# ResNet for CIFAR-10

We take the inspiration for this code from the github link `pytorch-cifar`. It provides the basic block implementation of a Residual Network architecture. In our project, we declare a ResNet12 architecture, with four layers and `[2, 1, 1, 1]` blocks in each layer.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, dropout_rate=0.25):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.dropout = nn.Dropout(dropout_rate)  # Add dropout
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)  # Apply dropout after first activation
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, dropout_rate=0.5):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, dropout_rate=dropout_rate)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, dropout_rate=dropout_rate)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, dropout_rate=dropout_rate)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, dropout_rate=dropout_rate)
        self.dropout = nn.Dropout(dropout_rate)  # Add dropout before the final layer
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride, dropout_rate):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, dropout_rate))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)  # Apply dropout before the final fully connected layer
        out = self.linear(out)
        return out

def ResNet12():
    return ResNet(BasicBlock, [2, 1, 1, 1])


Next, we define a few helper functions to import data and format training progress. Some are also taken from the `python-cifar` utility functions.

In [2]:
import os
import sys
import time
import math
import zipfile
import pickle
import torch.nn as nn
import torch.nn.init as init


term_width = 50
last_time = time.time()
begin_time = last_time

def progress_bar(current, total, msg=None):
    bar_length = 60
    progress = current / total

    block = int(round(bar_length * progress))
    # Use '>' as an arrow to indicate current progress position
    # Note: The arrow is positioned at the end of the filled portion, except when progress is 0
    arrow = ">" if block < bar_length else ""
    text = "\rProgress: [{0}{1}{2}] {3:.2f}% ({4}/{5}) {6}".format(
        "=" * (block - 1 if block > 0 else 0), arrow,
        "-" * (bar_length - block), progress * 100, current, total,
        msg if msg else "")
    sys.stdout.write(text)
    sys.stdout.flush()
    if current == total:
        sys.stdout.write('\n')

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f


def unpickle(filename):
    zf = zipfile.ZipFile(filename, "r")
    data = pickle.loads(zf.open("cifar_test_nolabels.pkl").read())
    zf.close()
    return data

Finally, we instanciate our model and count the number of parameters. We also import the CIFAR-10 dataset.

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.init as init
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import argparse
import sys
import time
import math
from PIL import Image

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

class CIFARLikeDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.data = []
        self.labels = []
        self.transform = transform

        # Load all batch files
        for i in range(1, 6):
            file_path = os.path.join(folder_path, f'data_batch_{i}')
            with open(file_path, 'rb') as file:
                batch = pickle.load(file, encoding='bytes')
                self.data.append(batch[b'data'])
                self.labels.extend(batch[b'labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)  # reshape and transpose to HWC format

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.labels[idx]
        
        # Convert numpy array to PIL Image
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label
    
custom_data_path = './data'  # Update this path
cifar_like_train = CIFARLikeDataset(folder_path=custom_data_path, transform=transform_train)


trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
combined_set = torch.utils.data.ConcatDataset([trainset, testset, cifar_like_train])
print(len(trainset), len(combined_set))
trainloader = torch.utils.data.DataLoader(
    combined_set, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = ResNet12()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(net)
count_params = count_parameters(net)
print("Count params: ", count_params)
assert count_params < 5000000, "Parameters must be lower than 5 million"

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
50000 110000
==> Building model..
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.5, inplace=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

# Training
For training, we use `CrossEntropyLoss` to meausre how the model is doing. This is a suitable loss for an image classification task. Next, we use the `Adam` optimizer to train the parameters. The learning rate will be updated using cosine annealing schedule.

In [4]:
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

# Loss
criterion = nn.CrossEntropyLoss()
# Training Optimizer
optimizer = optim.Adam(net.parameters())
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# History
train_loss_hist = []
test_loss_hist = []

# Training
def train(epoch, dataloader):
    global best_acc
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx+1, len(dataloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    # save the history
    train_loss_hist.append(train_loss)
    # Save checkpoint.
    acc = 100.*correct/total
    if acc >= best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc


def test(epoch, dataloader):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            progress_bar(batch_idx+1, len(dataloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                          % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # save the history
    test_loss_hist.append(test_loss)

Do 25 epochs of training.

In [None]:
# Do 25 epoch
for epoch in range(start_epoch, start_epoch+500):
    train(epoch, trainloader)
    test(epoch, testloader)
    scheduler.step()


Epoch: 0
Saving..

Epoch: 1
Saving..

Epoch: 2
Saving..

Epoch: 3
Saving..

Epoch: 4
Saving..

Epoch: 5
Saving..

Epoch: 6
Saving..

Epoch: 7
Saving..

Epoch: 8
Saving..

Epoch: 9
Saving..

Epoch: 10
Saving..

Epoch: 11
Saving..

Epoch: 12
Saving..

Epoch: 13
Saving..

Epoch: 14
Saving..

Epoch: 15
Saving..

Epoch: 16
Saving..

Epoch: 17
Saving..

Epoch: 18
Saving..

Epoch: 19
Saving..

Epoch: 20
Saving..

Epoch: 21
Saving..

Epoch: 22
Saving..

Epoch: 23
Saving..

Epoch: 24
Saving..

Epoch: 25
Saving..

Epoch: 26
Saving..

Epoch: 27
Saving..

Epoch: 28
Saving..

Epoch: 29


Saving..

Epoch: 30
Saving..

Epoch: 31
Saving..

Epoch: 32
Saving..

Epoch: 33
Saving..

Epoch: 34
Saving..

Epoch: 35
Saving..

Epoch: 36
Saving..

Epoch: 37
Saving..

Epoch: 38
Saving..

Epoch: 39
Saving..

Epoch: 40
Saving..

Epoch: 41
Saving..

Epoch: 42
Saving..

Epoch: 43
Saving..

Epoch: 44
Saving..

Epoch: 45

Epoch: 46
Saving..

Epoch: 47

Epoch: 48
Saving..

Epoch: 49
Saving..

Epoch: 50
Saving..

Epoch: 51

Epoch: 52
Saving..

Epoch: 53

Epoch: 54
Saving..

Epoch: 55
Saving..

Epoch: 56
Saving..

Epoch: 57
Saving..

Epoch: 58


Saving..

Epoch: 59
Saving..

Epoch: 60

Epoch: 61

Epoch: 62
Saving..

Epoch: 63
Saving..

Epoch: 64
Saving..

Epoch: 65
Saving..

Epoch: 66
Saving..

Epoch: 67
Saving..

Epoch: 68
Saving..

Epoch: 69

Epoch: 70
Saving..

Epoch: 71

Epoch: 72
Saving..

Epoch: 73
Saving..

Epoch: 74
Saving..

Epoch: 75
Saving..

Epoch: 76
Saving..

Epoch: 77
Saving..

Epoch: 78

Epoch: 79

Epoch: 80
Saving..

Epoch: 81
Saving..

Epoch: 82

Epoch: 83

Epoch: 84
Saving..

Epoch: 85

Epoch: 86
Saving..

Epoch: 87



Epoch: 88
Saving..

Epoch: 89

Epoch: 90
Saving..

Epoch: 91
Saving..

Epoch: 92
Saving..

Epoch: 93
Saving..

Epoch: 94

Epoch: 95
Saving..

Epoch: 96
Saving..

Epoch: 97

Epoch: 98
Saving..

Epoch: 99

Epoch: 100

Epoch: 101
Saving..

Epoch: 102

Epoch: 103

Epoch: 104
Saving..

Epoch: 105
Saving..

Epoch: 106

Epoch: 107

Epoch: 108

Epoch: 109
Saving..

Epoch: 110
Saving..

Epoch: 111

Epoch: 112

Epoch: 113
Saving..

Epoch: 114
Saving..

Epoch: 115

Epoch: 116


Saving..

Epoch: 117
Saving..

Epoch: 118
Saving..

Epoch: 119
Saving..

Epoch: 120
Saving..

Epoch: 121

Epoch: 122

Epoch: 123

Epoch: 124

Epoch: 125
Saving..

Epoch: 126

Epoch: 127

Epoch: 128
Saving..

Epoch: 129

Epoch: 130

Epoch: 131
Saving..

Epoch: 132

Epoch: 133
Saving..

Epoch: 134

Epoch: 135

Epoch: 136

Epoch: 137
Saving..

Epoch: 138
Saving..

Epoch: 139
Saving..

Epoch: 140

Epoch: 141
Saving..

Epoch: 142
Saving..

Epoch: 143

Epoch: 144

Epoch: 145


Saving..

Epoch: 146
Saving..

Epoch: 147

Epoch: 148

Epoch: 149

Epoch: 150
Saving..

Epoch: 151

Epoch: 152

Epoch: 153
Saving..

Epoch: 154

Epoch: 155

Epoch: 156

Epoch: 157
Saving..

Epoch: 158

Epoch: 159

Epoch: 160
Saving..

Epoch: 161

Epoch: 162
Saving..

Epoch: 163

Epoch: 164
Saving..

Epoch: 165

Epoch: 166
Saving..

Epoch: 167
Saving..

Epoch: 168

Epoch: 169
Saving..

Epoch: 170

Epoch: 171

Epoch: 172

Epoch: 173

Epoch: 174



Epoch: 175
Saving..

Epoch: 176

Epoch: 177
Saving..

Epoch: 178

Epoch: 179

Epoch: 180
Saving..

Epoch: 181

Epoch: 182

Epoch: 183

Epoch: 184

Epoch: 185

Epoch: 186

Epoch: 187

Epoch: 188

Epoch: 189

Epoch: 190

Epoch: 191

Epoch: 192
Saving..

Epoch: 193

Epoch: 194

Epoch: 195

Epoch: 196

Epoch: 197

Epoch: 198

Epoch: 199

Epoch: 200

Epoch: 201

Epoch: 202

Epoch: 203



Epoch: 204

Epoch: 205
Saving..

Epoch: 206

Epoch: 207

Epoch: 208

Epoch: 209

Epoch: 210

Epoch: 211

Epoch: 212

Epoch: 213

Epoch: 214

Epoch: 215

Epoch: 216

Epoch: 217

Epoch: 218

Epoch: 219

Epoch: 220

Epoch: 221

Epoch: 222

Epoch: 223
Saving..

Epoch: 224

Epoch: 225
Saving..

Epoch: 226

Epoch: 227

Epoch: 228

Epoch: 229

Epoch: 230

Epoch: 231

Epoch: 232



Epoch: 233

Epoch: 234

Epoch: 235

Epoch: 236

Epoch: 237

Epoch: 238

Epoch: 239

Epoch: 240

Epoch: 241

Epoch: 242

Epoch: 243

Epoch: 244

Epoch: 245

Epoch: 246

Epoch: 247

Epoch: 248

Epoch: 249

Epoch: 250

Epoch: 251

Epoch: 252

Epoch: 253

Epoch: 254

Epoch: 255

Epoch: 256

Epoch: 257

Epoch: 258

Epoch: 259

Epoch: 260

Epoch: 261

Epoch: 262



Epoch: 263

Epoch: 264

Epoch: 265

Epoch: 266

Epoch: 267

Epoch: 268

Epoch: 269

Epoch: 270

Epoch: 271

Epoch: 272

Epoch: 273

Epoch: 274

Epoch: 275

Epoch: 276

Epoch: 277

Epoch: 278

Epoch: 279

Epoch: 280

Epoch: 281

Epoch: 282

Epoch: 283

Epoch: 284

Epoch: 285

Epoch: 286

Epoch: 287

Epoch: 288

Epoch: 289

Epoch: 290

Epoch: 291



Epoch: 292

Epoch: 293

Epoch: 294

Epoch: 295

Epoch: 296

Epoch: 297

Epoch: 298

Epoch: 299

Epoch: 300

Epoch: 301

Epoch: 302

Epoch: 303

Epoch: 304

Epoch: 305

Epoch: 306

Epoch: 307

Epoch: 308

Epoch: 309

Epoch: 310

Epoch: 311

Epoch: 312

Epoch: 313

Epoch: 314

Epoch: 315

Epoch: 316

Epoch: 317

Epoch: 318

Epoch: 319

Epoch: 320

Epoch: 321



Epoch: 322

Epoch: 323

Epoch: 324

Epoch: 325

Epoch: 326

Epoch: 327

Epoch: 328

Epoch: 329

Epoch: 330

Epoch: 331

Epoch: 332

Epoch: 333

Epoch: 334

Epoch: 335

Epoch: 336

Epoch: 337

Epoch: 338

Epoch: 339

Epoch: 340

We can observe the loss over time.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(10,6))
plt.plot(train_loss_hist, '-', linewidth=3, label='Train Error')
plt.plot(test_loss_hist, '-', linewidth=3, label='Test Error')
plt.legend()
plt.grid(True)
plt.show()

We also can see the prediction performance on the test data

In [None]:
from sklearn.metrics import confusion_matrix

labels = []
with torch.no_grad():
    for batch_idx, (input_test,label_test) in enumerate(testloader):
        input_test = input_test.to(device)
        outputs = net(input_test.float())
        _, predicted = outputs.max(1)
        labels.append(predicted)

# flatten the labels
pred = np.array([])
for batch in labels:
    pred = np.append(pred, batch.tolist())

# predict performance
cf_matrix = confusion_matrix(testset.targets, pred)
cf_matrix

It looks like all labels have around 90% accuracy score. We see that label 2 ('bird') has the lowest accuracy of 88.7%. Let's look at some wrong prediction examples:

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

print_testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True)

fig, ax = plt.subplots(1, 3, figsize=(15, 10))
count = 0

for i in range(0, len(pred)):
    if print_testset.targets[i] != pred[i]:
        # print the image
        im = print_testset[i][0]
        ax[count].imshow(im)
        ax[count].set_title("Act:%s Pred:%s" % (classes[print_testset.targets[i]], classes[int(pred[i])]))
        count += 1

    if count == 3:
        break

plt.show()

For our best result, we trained the model using the combined dataset for 100 epochs with the RTX8000 GPU is NYU's HPC cluster.

In [None]:
# Do 100 epoch
for epoch in range(start_epoch, start_epoch+100):
    train(epoch, trainloader)
    test(epoch, testloader)
    scheduler.step()

# Prediction
For Kaggle's data prediction, we will import the trained model and start predicting.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.init as init

import torchvision
import torchvision.transforms as transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Model
print('==> Building model..')
net = ResNet12()
net = net.to(device)
net = torch.nn.DataParallel(net)
if device == 'cuda':
    cudnn.benchmark = True

# load state
saved_model = torch.load('./checkpoint/ckpt.pth', map_location=torch.device(device))
net.load_state_dict(saved_model['net'])
net.eval()

Preprocess the raw data and predict in batches of a thousand.

In [None]:
import numpy as np
from PIL import Image

cifar_test_nolabels = unpickle("../data/cifar_test_nolabels.pkl.zip")
competition_data = cifar_test_nolabels[b'data']
y_id = cifar_test_nolabels[b'ids']

competitionset = []
for d in competition_data:
    # transpose flatten CIFAR image to RGB
    d = d.reshape(3, 32, 32).transpose(1,2,0)
    im = Image.fromarray(d, mode='RGB')
    im = transform_test(im)
    competitionset.append(im)

competitionloader = torch.utils.data.DataLoader(
    competitionset, batch_size=1000, shuffle=False)

labels = []

def competition_test():
    net.eval()
    
    with torch.no_grad():
        for batch_idx, input_test in enumerate(competitionloader):
            input_test = input_test.to(device)
            outputs = net(input_test.float())
            _, predicted = outputs.max(1)
            labels.append(predicted)
            print("Batch %d done" % (batch_idx + 1))

            
competition_test()

Finally, we write the prediction labels into a .csv file.

In [None]:
import pandas as pd

# flatten the labels
pred = np.array([])
for batch in labels:
    pred = np.append(pred, batch.tolist())


competition_data_ids = unpickle("../data/cifar_test_nolabels.pkl.zip")[b'ids']
    
# create dataframe to save prediction
df = pd.DataFrame()
df['ID'] = competition_data_ids
df['Label'] = pred
df['Label'] = df['Label'].astype(int)

df.to_csv('out.csv', index=False)