# ECE 176: Fine-Grained Classification Using a CB-ViT Model

The focus of our final project will be reimplmenenting the work of Shuo Zhu, Xukang Zhang, Yu Wang, Zhongyang Wang, and Jiahao Sun. The main result of this paper is the introduction of a CB-ViT model. This model combines the local feature extraction of Convolutional networks with the broad feature extraction of Vision Transformers. A version of their reseach paper can be found here: https://ietresearch.onlinelibrary.wiley.com/doi/full/10.1049/ipr2.13295. 

In [1]:
# Imports - reused from assignment 5

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np

import torch.nn.functional as F  # useful stateless functions

In [2]:
# Data setup - reused from assignment 5

NUM_TRAIN = 49000
batch_size= 64

# The torchvision.transforms package provides tools for preprocessing data
# and for performing data augmentation; here we set up a transform to
# preprocess the data by subtracting the mean RGB value and dividing by the
# standard deviation of each RGB value; we've hardcoded the mean and std.

#===========================================================================#
# You should try changing the transform for the training data to include    #
# data augmentation such as RandomCrop and HorizontalFlip                    #
# when running the final part of the notebook where you have to achieve     #
# as high accuracy as possible on CIFAR-100.                                #
# Of course you will have to re-run this block for the effect to take place #
#===========================================================================#
train_transform = transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
            ])

# We set up a Dataset object for each split (train / val / test); Datasets load
# training examples one at a time, so we wrap each Dataset in a DataLoader which
# iterates through the Dataset and forms minibatches. We divide the CIFAR-100
# training set into train and val sets by passing a Sampler object to the
# DataLoader telling how it should sample from the underlying Dataset.
cifar100_train = dset.CIFAR100('./datasets/cifar100', train=True, download=True,
                             transform=train_transform)
loader_train = DataLoader(cifar100_train, batch_size=batch_size, num_workers=2,
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar100_val = dset.CIFAR100('./datasets/cifar100', train=True, download=True,
                           transform=transform)
loader_val = DataLoader(cifar100_val, batch_size=batch_size, num_workers=2, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar100_test = dset.CIFAR100('./datasets/cifar100', train=False, download=True, 
                            transform=transform)
loader_test = DataLoader(cifar100_test, batch_size=batch_size, num_workers=2)

In [3]:
# Dtype and device selection - reused from assignment 5

USE_GPU = True
num_class = 100
dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cuda


In [4]:
# Flatten function - reused from assignment 5

def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

# We need to wrap `flatten` function in a module in order to stack it
# in nn.Sequential
class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)

In [5]:
# Random weight function - reused from assignment 5

def random_weight(shape):
    """
    Create random Tensors for weights; setting requires_grad=True means that we
    want to compute gradients for these Tensors during the backward pass.
    We use Kaiming normalization: sqrt(2 / fan_in)
    """
    if len(shape) == 2:  # FC weight
        fan_in = shape[0]
    else:
        fan_in = np.prod(shape[1:]) # conv weight [out_channel, in_channel, kH, kW]
    # randn is standard normal distribution generator. 
    w = torch.randn(shape, device=device, dtype=dtype) * np.sqrt(2. / fan_in)
    w.requires_grad = True
    return w

def zero_weight(shape):
    return torch.zeros(shape, device=device, dtype=dtype, requires_grad=True)

# create a weight of shape [3 x 5]
# you should see the type `torch.cuda.FloatTensor` if you use GPU. 
# Otherwise it should be `torch.FloatTensor`
random_weight((3, 5))

tensor([[-0.2513, -0.3043,  2.0349,  0.3886,  0.8291],
        [-1.4269, -0.4210, -0.6799, -0.0029,  1.1006],
        [-0.5245, -0.6563, -0.3006, -0.2547, -0.3857]], device='cuda:0',
       requires_grad=True)

In [6]:
# Check accuracy function - reused from assignment 5

def check_accuracy_part34(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return acc

In [7]:
# Training function - reused from assignment 5

def train_part34(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: The accuracy of the model
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if (t + 1) % print_every == 0:
                print('Epoch %d, Iteration %d, loss = %.4f' % (e, t + 1, loss.item()))
                check_accuracy_part34(loader_val, model)
                print()
    return check_accuracy_part34(loader_val, model)

In [8]:
# Sample class - reused from assignment 5

class Network(nn.Module):
    def __init__(self, in_channel, channel_1, channel_2, num_classes):
        super().__init__()
        self.conv_1 = nn.Conv2d(in_channel, channel_1, (5,5), padding=2)
        nn.init.kaiming_normal_(self.conv_1.weight)
        self.conv_2 = nn.Conv2d(channel_1, channel_2, (3,3), padding=1)
        nn.init.kaiming_normal_(self.conv_2.weight)
        self.fc1 = nn.Linear(65536, num_classes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        scores = None
        scores = self.fc1(flatten(F.relu(self.conv_2(F.relu(self.conv_1(x))))))
        return scores

In [9]:
# Training and validating of sample class - reused from assignment 5

"""learning_rate = 1e-3
channel_1 = 32
channel_2 = 64

model = None
optimizer = None

model = Network(3, channel_1, channel_2, 100)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

print_every = 100
train_part34(model, optimizer, epochs=1)""";

In [14]:
# ResNet 50

class conv_miniblock(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return
    
class conv_block(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return

class ResNet50(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return

In [10]:
# "CNN Modules"

class FEM(nn.Module):
    def __init__(self, in_channel, in_height, in_width):
        super().__init__()
        self.conv3_1 = nn.Conv2d(in_channel, in_channel, (3,3), padding=1)
        nn.init.kaiming_normal_(self.conv3_1.weight)
        self.batch31 = nn.BatchNorm2d(in_channel)
        self.silu31 = nn.SiLU()
        self.conv5 = nn.Conv2d(in_channel, in_channel, (5,5), padding=2)
        nn.init.kaiming_normal_(self.conv5.weight)
        self.batch5 = nn.BatchNorm2d(in_channel)
        self.silu5 = nn.SiLU()
        self.max = nn.MaxPool2d(kernel_size=(2,2))
        self.avg = nn.AvgPool2d(kernel_size=(2,2))
        self.conv3_2 = nn.Conv2d(in_channel, in_channel, (3,3), padding=1)
        nn.init.kaiming_normal_(self.conv3_2.weight)
        self.flatten = Flatten()
        fc_size = int(in_channel*(2*np.floor(in_height/2))*(np.floor(in_width/2)))
        self.fc = nn.Linear(fc_size, in_width)
        self.soft = nn.Softmax(dim=1)

    def forward(self, x):
        R1 = self.silu31(self.batch31(self.conv3_1(x)))
        R2 = self.silu5(self.batch5(self.conv5(x)))
        Rm = R1 + R2
        Rn = torch.cat((self.max(Rm), self.avg(Rm)), dim=2)
        Rp = self.conv3_2(Rn)
        M = self.fc(self.flatten(Rp))
        gamma = self.soft(M)
        Rx = torch.mul(R1, gamma) + torch.mul(R2, 1-gamma)
        return Rx
    
class CNN_Block(nn.Module):
    def __init__(self, in_channel, hidden_channel, in_height, in_width):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channel, hidden_channel, (1,1), padding=0)
        nn.init.kaiming_normal_(self.conv1.weight)
        self.batch1 = nn.BatchNorm2d(hidden_channel)
        self.silu1 = nn.SiLU()
        self.conv3 = nn.Conv2d(hidden_channel, hidden_channel, (3,3), padding=1)
        nn.init.kaiming_normal_(self.conv3.weight)
        self.batch3 = nn.BatchNorm2d(hidden_channel)
        self.silu3 = nn.SiLU()
        self.FEM = FEM(hidden_channel, in_height, in_width)

    def forward(self, x):
        output = self.FEM(self.silu3(self.batch3(self.conv3(self.silu1(self.batch1(self.conv1(x)))))))
        return output

In [16]:
# SFE 

class SFE(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        self.conv3_1 = nn.Conv2d(in_channel, in_channel, (3,3), stride=2, padding=2)
        nn.init.kaiming_normal_(self.conv3_1.weight)
        self.batch3_1 = nn.BatchNorm2d(in_channel)
        self.relu3_1 = nn.ReLU()
        self.conv3_2 = nn.Conv2d(in_channel, in_channel, (3,3), stride=2, padding=2)
        nn.init.kaiming_normal_(self.conv3_2.weight)
        self.batch3_2 = nn.BatchNorm2d(in_channel)
        self.relu3_2 = nn.ReLU()
        self.conv3_3 = nn.Conv2d(in_channel, in_channel, (3,3), stride=2, padding=2)
        nn.init.kaiming_normal_(self.conv3_3.weight)
        self.batch3_3 = nn.BatchNorm2d(in_channel)
        self.relu3_3 = nn.ReLU()
        self.conv3_4 = nn.Conv2d(in_channel, in_channel, (3,3), stride=2, padding=2)
        nn.init.kaiming_normal_(self.conv3_4.weight)
        self.batch3_4 = nn.BatchNorm2d(in_channel)
        self.relu3_4 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channel, in_channel, (1,1), stride=1, padding=2)
        nn.init.kaiming_normal_(self.conv1.weight)
        self.batch1 = nn.BatchNorm2d(in_channel)
        self.relu1 = nn.ReLU()

    def forward(self, x):
        output = self.relu1(self.batch1(self.conv1(self.relu3_4(self.batch3_4(self.conv3_4(self.relu3_3(self.batch3_3(self.conv3_3(self.relu3_2(self.batch3_2(self.conv3_2(self.relu3_1(self.batch3_1(self.conv3_1(x)))))))))))))))
        return output

In [11]:
# Test FEM
x = torch.ones((1,3,128,128), dtype=dtype) # NxCxHxW
model = FEM(3, 128, 128)
output = model(x)
print(output.size())

torch.Size([1, 3, 128, 128])


In [13]:
# Test CNN_Block
x = torch.ones((1,32,255,123), dtype=dtype) # NxCxHxW
model = CNN_Block(32, 3, 255, 123)
output = model(x)
print(output.size())

torch.Size([1, 3, 255, 123])


In [19]:
# Test SFE
x = torch.ones((1,3,128,128), dtype=dtype) # NxCxHxW
model = SFE(3)
output = model(x)
print(output.size())

torch.Size([1, 3, 14, 14])
