In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np
import time

########################################
# You can define whatever classes if needed
########################################

class IdentityResNet(nn.Module):
    
    # __init__ takes 4 parameters
    # nblk_stage1: number of blocks in stage 1, nblk_stage2.. similar
    def __init__(self, nblk_stage1, nblk_stage2, nblk_stage3, nblk_stage4):
        super(IdentityResNet, self).__init__()
    ########################################
    # Implement the network
    # You can declare whatever variables
    ########################################
        self.nblk_stage1 = nblk_stage1
        self.nblk_stage2 = nblk_stage2
        self.nblk_stage3 = nblk_stage3
        self.nblk_stage4 = nblk_stage4

        self.conv0 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
        
        
        self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.BN64 = nn.BatchNorm2d(num_features=64)

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.BN128 = nn.BatchNorm2d(num_features=128)
        
        
        
        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.BN256 = nn.BatchNorm2d(num_features=256)

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.BN512 = nn.BatchNorm2d(num_features=512)

        self.shortcut2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, stride=2)
        self.shortcut3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, stride=2)
        self.shortcut4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, stride=2)

        
        self.stage1=nn.Sequential(self.BN64,
                                  torch.nn.ReLU(),
                                  self.conv1,
                                  self.BN64,
                                  self.conv1,
                                  torch.nn.ReLU(),
                                  self.conv1
                                  )
        
        self.stage2_2=nn.Sequential(self.BN128,
                                    torch.nn.ReLU(),
                                    self.conv2_2,
                                    self.BN128,
                                    torch.nn.ReLU(),
                                    self.conv2_2
            
                                    )
        self.stage3_2=nn.Sequential(self.BN256,
                                    torch.nn.ReLU(),
                                    self.conv3_2,
                                    self.BN256,
                                    torch.nn.ReLU(),
                                    self.conv3_2
            
                                    )
        self.stage4_2=nn.Sequential(self.BN512,
                                    torch.nn.ReLU(),
                                    self.conv4_2,
                                    self.BN512,
                                    torch.nn.ReLU(),
                                    self.conv4_2
            
                                    )


        self.linear = nn.Linear(512, 10)

    ########################################
    # You can define whatever methods
    ########################################
    
    def forward(self, x):
        ########################################
        # Implement the network
        # You can declare or define whatever variables or methods
        ########################################
        out=self.conv0(x)
        for i in range(self.nblk_stage1):
          out=self.stage1(out)+out

        out=self.BN64(out)
        out=F.relu(out)
        SC2=out
        SC2=self.shortcut2(SC2)
        out=self.conv2_1(out)
        out=self.BN128(out)
        out=F.relu(out)
        out=self.conv2_2(out)
        out=SC2+out

        for i in range(self.nblk_stage2-1):
          out=self.stage2_2(out)+out

        out=self.BN128(out)
        out=F.relu(out)
        SC3=out
        SC3=self.shortcut3(SC3)
        out=self.conv3_1(out)
        out=self.BN256(out)
        out=F.relu(out)
        out=self.conv3_2(out)
        out=SC3+out

        for i in range(self.nblk_stage3-1):
          out=self.stage3_2(out)+out  

        out=self.BN256(out)
        out=F.relu(out)
        SC4=out
        SC4=self.shortcut4(SC4)
        out=self.conv4_1(out)
        out=self.BN512(out)
        out=F.relu(out)
        out=self.conv4_2(out)
        out=SC4+out

        for i in range(self.nblk_stage4-1):
          out=self.stage4_2(out)+out 

        avg_pool = F.avg_pool2d(out, kernel_size=4, stride=4)
        avg_pool = avg_pool.view(-1, 512)
        out = self.linear(avg_pool)


        return out

########################################
# Q1. set device
# First, check availability of GPU.
# If available, set dev to "cuda:0";
# otherwise set dev to "cpu"
########################################
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('current device: ', dev)


########################################
# data preparation: CIFAR10
########################################

########################################
# Q2. set batch size
# set batch size for training data
########################################
batch_size = 4

# preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load training data
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

# load test data
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


# define network
net = IdentityResNet(nblk_stage1=2, nblk_stage2=2, nblk_stage3=2, nblk_stage4=2)

########################################
# Q3. load model to GPU
# Complete below to load model to GPU
########################################
net = net.to(dev)

# set loss function
criterion = nn.CrossEntropyLoss()

########################################
# Q4. optimizer
# Complete below to use SGD with momentum (alpha= 0.9)
# set proper learning rate
########################################
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# start training
t_start = time.time()

for epoch in range(5):  # loop over the dataset multiple times
  
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(dev), data[1].to(dev)
             
        ########################################
        # Q5. make sure gradients are zero!
        # zero the parameter gradients
        ########################################
        optimizer.zero_grad()
        
        ########################################
        # Q6. perform forward pass
        ########################################
        
        outputs = net(inputs)
        
        # set loss
        loss = criterion(outputs, labels)
        
        ########################################
        # Q7. perform backprop
        ########################################
        
        loss.backward()
        
        ########################################
        # Q8. take a SGD step
        ########################################
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
            t_end = time.time()
            print('elapsed:', t_end-t_start, ' sec')  
            t_start = t_end

print('Finished Training')


# now testing
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

########################################
# Q9. complete below
# when testing, computation is done without building graphs
########################################
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(dev), data[1].to(dev)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

# per-class accuracy
for i in range(10):
    print('Accuracy of %5s' %(classes[i]), ': ',
          100 * class_correct[i] / class_total[i],'%')

# overall accuracy
print('Overall Accurracy: ', (sum(class_correct)/sum(class_total))*100, '%')



current device:  cuda:0
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
[1,  2000] loss: 1.969
elapsed: 84.62864780426025  sec
[1,  4000] loss: 1.685
elapsed: 83.99553275108337  sec
[1,  6000] loss: 1.512
elapsed: 84.29839849472046  sec
[1,  8000] loss: 1.358
elapsed: 84.00432109832764  sec
[1, 10000] loss: 1.248
elapsed: 83.51493620872498  sec
[1, 12000] loss: 1.135
elapsed: 83.7755355834961  sec
[2,  2000] loss: 1.019
elapsed: 104.59989285469055  sec
[2,  4000] loss: 0.999
elapsed: 83.7048728466034  sec
[2,  6000] loss: 0.950
elapsed: 83.42711448669434  sec
[2,  8000] loss: 0.921
elapsed: 83.40395951271057  sec
[2, 10000] loss: 0.874
elapsed: 83.3845283985138  sec
[2, 12000] loss: 0.866
elapsed: 83.2923994064331  sec
[3,  2000] loss: 0.740
elapsed: 104.26908755302429  sec
[3,  4000] loss: 0.748
elapsed: 83.65341591835022  sec
[3,  6000] loss: 0.742
elapsed: 83.43548274040222  sec
[3,  8000] loss: 0.715
elapsed: 83.46729564666748  sec
[3, 10000] loss: 0.696
elapsed: 8