# ECE 176: Fine-Grained Classification Using a CB-ViT Model

The focus of our final project will be reimplmenenting the work of Shuo Zhu, Xukang Zhang, Yu Wang, Zhongyang Wang, and Jiahao Sun. The main result of this paper is the introduction of a CB-ViT model. This model combines the local feature extraction of Convolutional networks with the broad feature extraction of Vision Transformers. A version of their reseach paper can be found here: https://ietresearch.onlinelibrary.wiley.com/doi/full/10.1049/ipr2.13295. 

In [1]:
# Imports - reused from assignment 5

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np

import torch.nn.functional as F  # useful stateless functions

In [2]:
# Data setup - reused from assignment 5

NUM_TRAIN = 49000
batch_size= 64

# The torchvision.transforms package provides tools for preprocessing data
# and for performing data augmentation; here we set up a transform to
# preprocess the data by subtracting the mean RGB value and dividing by the
# standard deviation of each RGB value; we've hardcoded the mean and std.

#===========================================================================#
# You should try changing the transform for the training data to include    #
# data augmentation such as RandomCrop and HorizontalFlip                    #
# when running the final part of the notebook where you have to achieve     #
# as high accuracy as possible on CIFAR-100.                                #
# Of course you will have to re-run this block for the effect to take place #
#===========================================================================#
train_transform = transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
            ])

# We set up a Dataset object for each split (train / val / test); Datasets load
# training examples one at a time, so we wrap each Dataset in a DataLoader which
# iterates through the Dataset and forms minibatches. We divide the CIFAR-100
# training set into train and val sets by passing a Sampler object to the
# DataLoader telling how it should sample from the underlying Dataset.
cifar100_train = dset.CIFAR100('./datasets/cifar100', train=True, download=True,
                             transform=train_transform)
loader_train = DataLoader(cifar100_train, batch_size=batch_size, num_workers=2,
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar100_val = dset.CIFAR100('./datasets/cifar100', train=True, download=True,
                           transform=transform)
loader_val = DataLoader(cifar100_val, batch_size=batch_size, num_workers=2, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar100_test = dset.CIFAR100('./datasets/cifar100', train=False, download=True, 
                            transform=transform)
loader_test = DataLoader(cifar100_test, batch_size=batch_size, num_workers=2)

In [3]:
# Dtype and device selection - reused from assignment 5

USE_GPU = True
num_class = 100
dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cuda


In [4]:
# Random weight function - reused from assignment 5

def random_weight(shape):
    """
    Create random Tensors for weights; setting requires_grad=True means that we
    want to compute gradients for these Tensors during the backward pass.
    We use Kaiming normalization: sqrt(2 / fan_in)
    """
    if len(shape) == 2:  # FC weight
        fan_in = shape[0]
    else:
        fan_in = np.prod(shape[1:]) # conv weight [out_channel, in_channel, kH, kW]
    # randn is standard normal distribution generator. 
    w = torch.randn(shape, device=device, dtype=dtype) * np.sqrt(2. / fan_in)
    w.requires_grad = True
    return w

def zero_weight(shape):
    return torch.zeros(shape, device=device, dtype=dtype, requires_grad=True)

# create a weight of shape [3 x 5]
# you should see the type `torch.cuda.FloatTensor` if you use GPU. 
# Otherwise it should be `torch.FloatTensor`
random_weight((3, 5))

tensor([[-1.5944,  1.1914,  0.0858,  0.2109,  1.9863],
        [-0.3466,  0.7288,  0.9983, -0.2790, -0.5860],
        [ 0.9807,  0.4504, -0.5694,  0.4216,  0.5454]], device='cuda:0',
       requires_grad=True)

In [5]:
# Check accuracy function - reused from assignment 5

def check_accuracy_part34(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return acc

In [6]:
# Trainig function - reused from assignment 5

def train_part34(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: The accuracy of the model
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if (t + 1) % print_every == 0:
                print('Epoch %d, Iteration %d, loss = %.4f' % (e, t + 1, loss.item()))
                check_accuracy_part34(loader_val, model)
                print()
    return check_accuracy_part34(loader_val, model)

In [7]:
# Flatten function - reused from assignment 5

def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

# We need to wrap `flatten` function in a module in order to stack it
# in nn.Sequential
class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)

In [8]:
########################################################################
# TODO: Implement the forward function for the Resnet specified        #
# above. HINT: You might need to create a helper class to              # 
# define a Resnet block and then use that block here to create         #
# the resnet layers i.e. conv2_x, conv3_x, conv4_x and conv5_x         #
########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

class block(nn.Module):
    def __init__(self, in_channel, out_channel, stride_1, batch_flag):
        super().__init__()
        self.conv_1 = nn.Conv2d(in_channel, out_channel, (3,3), stride=stride_1, padding=1)
        nn.init.kaiming_normal_(self.conv_1.weight)
        self.conv_2 = nn.Conv2d(out_channel, out_channel, (3,3), padding=1)
        nn.init.kaiming_normal_(self.conv_2.weight)
        self.batch1 = nn.BatchNorm2d(out_channel)
        self.batch2 = nn.BatchNorm2d(out_channel)
        self.batch_flag = batch_flag
    def forward(self, x):
        if (self.batch_flag):
            output = self.batch2(self.conv_2(F.relu(self.batch1(self.conv_1(x)))))
        else: 
            output = self.conv_2(F.relu(self.conv_1(x)))
        return output

class ResNet(nn.Module):
    def __init__(self, batch_flag):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, (7,7), stride=2, padding=3)
        nn.init.kaiming_normal_(self.conv1.weight)
        self.batch = nn.BatchNorm2d(64)
        self.max = nn.MaxPool2d((3,3), stride=2, padding=1)
        self.conv2_x = block(64, 64, 1, batch_flag)
        self.skip2 = nn.Conv2d(64, 64, (1,1), padding=0)
        self.conv3_x = block(64, 128, 1, batch_flag)
        self.skip3 = nn.Conv2d(64, 128, (1,1), padding=0)
        self.conv4_x = block(128, 256, 1, batch_flag)
        self.skip4 = nn.Conv2d(128, 256, (1,1), padding=0)
        self.conv5_x = block(256, 512, 2, batch_flag)
        self.skip5 = nn.Conv2d(256, 512, (1,1), stride=2, padding=0)
        self.avg = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512*1*1, 100)
        nn.init.kaiming_normal_(self.fc.weight)
        self.batch_flag = batch_flag        
        self.relu = nn.ReLU()
    def forward(self, x):    # Added 'x'
        if (self.batch_flag):
            conv2x_input = self.max(self.relu(self.batch(self.conv1(x))))
            conv3x_input = self.relu(self.relu(self.conv2_x(conv2x_input)) + self.skip2(conv2x_input))
            conv4x_input = self.relu(self.relu(self.conv3_x(conv3x_input)) + self.skip3(conv3x_input))
            conv5x_input = self.relu(self.relu(self.conv4_x(conv4x_input)) + self.skip4(conv4x_input))
            avg_input = self.relu(self.relu(self.conv5_x(conv5x_input)) + self.skip5(conv5x_input))
            fc_input = self.avg(avg_input)
        else:
            conv2x_input = self.max(self.relu(self.conv1(x)))
            conv3x_input = self.relu(self.relu(self.conv2_x(conv2x_input)) + self.skip2(conv2x_input))
            conv4x_input = self.relu(self.relu(self.conv3_x(conv3x_input)) + self.skip3(conv3x_input))
            conv5x_input = self.relu(self.relu(self.conv4_x(conv4x_input)) + self.skip4(conv4x_input))
            avg_input = self.relu(self.relu(self.conv5_x(conv5x_input)) + self.skip5(conv5x_input))
            fc_input = self.avg(avg_input)
        fc_input = flatten(fc_input)
        output = self.fc(fc_input)
        return output
    
########################################################################
#                             END OF YOUR CODE                         #
########################################################################

In [9]:
learning_rate = 1e-3

model = None
optimizer = None

################################################################################
# TODO: Instantiate and train Resnet-10.                                       #
################################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

model = ResNet(False)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
################################################################################
#                                 END OF YOUR CODE                             
################################################################################

print_every = 700
train_part34(model, optimizer, epochs=10)
print_every = 100

Epoch 0, Iteration 700, loss = 3.6109
Checking accuracy on validation set
Got 151 / 1000 correct (15.10)

Epoch 1, Iteration 700, loss = 2.9985
Checking accuracy on validation set
Got 251 / 1000 correct (25.10)

Epoch 2, Iteration 700, loss = 2.4002
Checking accuracy on validation set
Got 342 / 1000 correct (34.20)

Epoch 3, Iteration 700, loss = 2.0542
Checking accuracy on validation set
Got 372 / 1000 correct (37.20)

Epoch 4, Iteration 700, loss = 1.8887
Checking accuracy on validation set
Got 383 / 1000 correct (38.30)

Epoch 5, Iteration 700, loss = 2.0410
Checking accuracy on validation set
Got 410 / 1000 correct (41.00)

Epoch 6, Iteration 700, loss = 1.2317
Checking accuracy on validation set
Got 445 / 1000 correct (44.50)

Epoch 7, Iteration 700, loss = 1.1610
Checking accuracy on validation set
Got 410 / 1000 correct (41.00)

Epoch 8, Iteration 700, loss = 0.8167
Checking accuracy on validation set
Got 409 / 1000 correct (40.90)

Epoch 9, Iteration 700, loss = 0.5293
Checking