# ECE 176: Fine-Grained Classification Using a CB-ViT Model

The focus of our final project will be reimplmenenting the work of Shuo Zhu, Xukang Zhang, Yu Wang, Zhongyang Wang, and Jiahao Sun. The main result of this paper is the introduction of a CB-ViT model. This model combines the local feature extraction of Convolutional networks with the broad feature extraction of Vision Transformers. A version of their reseach paper can be found here: https://ietresearch.onlinelibrary.wiley.com/doi/full/10.1049/ipr2.13295. 

In [2]:
# Imports - reused from assignment 5

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np

import torch.nn.functional as F  # useful stateless functions

from einops.layers.torch import Rearrange
from einops import repeat

import matplotlib.pyplot as plt

In [3]:
# Data setup - similar to assignment 5

train_transform = transform = T.Compose([
    T.ToTensor(),
    T.Resize((448, 448))
    ])

# Oxford 102 dataset for train, val, and test.
flower_train = dset.Flowers102("./datasets/flowers", split='test', download=True, transform=train_transform) # The 'test' split has 6k+ images
flower_val   = dset.Flowers102("./datasets/flowers", split='val', download=True, transform=transform)
flower_test  = dset.Flowers102("./datasets/flowers", split='train', download=True, transform=transform) # The 'train' split has 1k images

NUM_TRAIN = 6000
NUM_VAL = 1000
NUM_TEST = 1000

batch_size = 8

loader_train = DataLoader(
    flower_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
)

loader_val = DataLoader(
    flower_val,
    batch_size=batch_size,
    num_workers=2,
    shuffle=True
)

loader_test = DataLoader(
    flower_test,
    batch_size=batch_size,
    num_workers=2,
    shuffle=True
)

In [4]:
# Dtype and device selection - reused from assignment 5

USE_GPU = True
num_class = 100
dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cuda


In [5]:
# Flatten function - reused from assignment 5

def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

# We need to wrap `flatten` function in a module in order to stack it
# in nn.Sequential
class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)

In [6]:
# Random weight function - reused from assignment 5

def random_weight(shape):
    """
    Create random Tensors for weights; setting requires_grad=True means that we
    want to compute gradients for these Tensors during the backward pass.
    We use Kaiming normalization: sqrt(2 / fan_in)
    """
    if len(shape) == 2:  # FC weight
        fan_in = shape[0]
    else:
        fan_in = np.prod(shape[1:]) # conv weight [out_channel, in_channel, kH, kW]
    # randn is standard normal distribution generator. 
    w = torch.randn(shape, device=device, dtype=dtype) * np.sqrt(2. / fan_in)
    w.requires_grad = True
    return w

def zero_weight(shape):
    return torch.zeros(shape, device=device, dtype=dtype, requires_grad=True)

# create a weight of shape [3 x 5]
# you should see the type `torch.cuda.FloatTensor` if you use GPU. 
# Otherwise it should be `torch.FloatTensor`
random_weight((3, 5))

tensor([[-0.5847, -0.5212,  1.3914,  0.6470, -0.1718],
        [ 1.4946, -0.1659, -0.2562, -0.8890, -0.7908],
        [ 0.6552,  0.4206,  0.1405, -0.7177, -0.7956]], device='cuda:0',
       requires_grad=True)

In [7]:
# Check accuracy function - reused from assignment 5

def check_accuracy_part34(loader, model, print_flag=True):
    """if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')"""   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        if(print_flag==True):
            print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return acc

In [8]:
# Training function - reused from assignment 5

def train_part34(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: The accuracy of the model
    """
    iterations = []
    accuracies = []
    itera = 1
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            torch.autograd.set_detect_anomaly(True)
            scores = model(x)
            loss = F.cross_entropy(scores, y)

            optimizer.zero_grad()

            loss.backward()

            optimizer.step()

            if (t + 1) % print_every == 0:
                print('Epoch %d, Iteration %d, loss = %.4f' % (e, t + 1, loss.item()))
                acc = check_accuracy_part34(loader_val, model)
                print()
                iterations.append(itera)
                accuracies.append(acc)
            itera = itera + 1
        plt.plot(iterations, accuracies)
        plt.title("Validation Accuracy")
        plt.xlabel("Iteration")
        plt.ylabel("Accuracy")
    return check_accuracy_part34(loader_val, model)

In [9]:
# ResNet 50

import torch
import torch.nn as nn
import torch.nn.functional as F

class IdentityBlock(nn.Module):
    def __init__(self, in_channels, filters, kernel_size):
        super(IdentityBlock, self).__init__()
        F1, F2, F3 = filters

        self.conv1 = nn.Conv2d(in_channels, F1, kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(F1)

        self.conv2 = nn.Conv2d(F1, F2, kernel_size=kernel_size, stride=1,
                               padding=kernel_size // 2, bias=False)
        self.bn2 = nn.BatchNorm2d(F2)

        self.conv3 = nn.Conv2d(F2, F3, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(F3)

        self.relu = nn.ReLU()   #self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        shortcut = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += shortcut
        out = self.relu(out)
        return out

class ConvolutionalBlock(nn.Module):
    def __init__(self, in_channels, filters, kernel_size, stride):
        super(ConvolutionalBlock, self).__init__()
        F1, F2, F3 = filters

        self.conv1 = nn.Conv2d(in_channels, F1, kernel_size=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(F1)

        self.conv2 = nn.Conv2d(F1, F2, kernel_size=kernel_size, stride=1,
                               padding=kernel_size // 2, bias=False)
        self.bn2 = nn.BatchNorm2d(F2)

        self.conv3 = nn.Conv2d(F2, F3, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(F3)

        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channels, F3, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(F3)
        )

        self.relu = nn.ReLU() #self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        shortcut = self.shortcut(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += shortcut
        out = self.relu(out)
        return out

class ResNet50(nn.Module):
    def __init__(self, num_classes=102):
        super(ResNet50, self).__init__()
        # Initial convolution and max-pooling
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()   #self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Stage 2
        self.layer2 = nn.Sequential(
            ConvolutionalBlock(64, filters=(64, 64, 256), kernel_size=3, stride=1),
            IdentityBlock(256, filters=(64, 64, 256), kernel_size=3),
            IdentityBlock(256, filters=(64, 64, 256), kernel_size=3)
        )

        # Stage 3
        self.layer3 = nn.Sequential(
            ConvolutionalBlock(256, filters=(128, 128, 512), kernel_size=3, stride=2),
            IdentityBlock(512, filters=(128, 128, 512), kernel_size=3),
            IdentityBlock(512, filters=(128, 128, 512), kernel_size=3),
            IdentityBlock(512, filters=(128, 128, 512), kernel_size=3)
        )

        # Stage 4
        self.layer4 = nn.Sequential(
            ConvolutionalBlock(512, filters=(256, 256, 1024), kernel_size=3, stride=2),
            IdentityBlock(1024, filters=(256, 256, 1024), kernel_size=3),
            IdentityBlock(1024, filters=(256, 256, 1024), kernel_size=3),
            IdentityBlock(1024, filters=(256, 256, 1024), kernel_size=3),
            IdentityBlock(1024, filters=(256, 256, 1024), kernel_size=3),
            IdentityBlock(1024, filters=(256, 256, 1024), kernel_size=3)
        )

        # Stage 5
        self.layer5 = nn.Sequential(
            ConvolutionalBlock(1024, filters=(512, 512, 2048), kernel_size=3, stride=2),
            IdentityBlock(2048, filters=(512, 512, 2048), kernel_size=3),
            IdentityBlock(2048, filters=(512, 512, 2048), kernel_size=3)
        )

        # Average pooling and fully connected layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        # x shape: (batch_size, 3, H, W)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x) # Given an input size of (1,3,448,448), the size of x after this line is [1, 64, 112, 112]

        x = self.layer2(x)  # Given an input size of (1,3,448,448), the size of x after this line is [1, 256, 112, 112]
        x = self.layer3(x)  # Given an input size of (1,3,448,448), the size of x after this line is [1, 512, 56, 56]
        x = self.layer4(x)  # Given an input size of (1,3,448,448), the size of x after this line is [1, 1024, 28, 28]
        x = self.layer5(x)  # Given an input size of (1,3,448,448), the size of x after this line is [1, 2048, 14, 14]

        #x = self.avgpool(x)
        #x = torch.flatten(x, 1)
        #x = self.fc(x)"""
        return x

In [10]:
# "CNN Modules"

class FEM(nn.Module):
    def __init__(self, in_channel, in_height, in_width):
        super().__init__()
        self.channel_half = int(0.5*(in_channel))
        self.conv3_1 = nn.Conv2d(in_channel, self.channel_half, (3,3), padding=1)
        nn.init.kaiming_normal_(self.conv3_1.weight)
        self.batch31 = nn.BatchNorm2d(self.channel_half)
        self.silu31 = nn.SiLU()
        self.conv5 = nn.Conv2d(in_channel, self.channel_half, (5,5), padding=2)
        nn.init.kaiming_normal_(self.conv5.weight)
        self.batch5 = nn.BatchNorm2d(self.channel_half)
        self.silu5 = nn.SiLU()
        self.max = nn.MaxPool2d(kernel_size=(2,2))
        self.avg = nn.AvgPool2d(kernel_size=(2,2))
        self.conv3_2 = nn.Conv2d(self.channel_half, self.channel_half, (3,3), padding=1)
        nn.init.kaiming_normal_(self.conv3_2.weight)
        self.flatten = Flatten()
        self.fc_size = int(self.channel_half*(2*np.floor(in_height/2))*(np.floor(in_width/2)))
        self.fc = nn.Linear(self.fc_size, in_width)
        self.soft = nn.Softmax(dim=1)
        self.in_channel = in_channel
        self.in_height = in_height
        self.in_width = in_width
    def forward(self, x):
        R1 = self.silu31(self.batch31(self.conv3_1(x)))
        R2 = self.silu5(self.batch5(self.conv5(x)))
        Rm = R1 + R2
        Rn = torch.cat((self.max(Rm), self.avg(Rm)), dim=2)
        Rp = self.conv3_2(Rn)
        M = self.fc(self.flatten(Rp))
        gamma = self.soft(M)
        batch_size = x.shape[0]
        gamma_expand = gamma.view(batch_size,1,1,self.in_width)
        Rx_1 = torch.mul(R1, gamma_expand)
        Rx_2 = torch.mul(R2, 1-gamma_expand)
        Rx = Rx_1 + Rx_2
        return Rx
    
class CNN_Block(nn.Module):
    def __init__(self, in_channel, in_height, in_width):
        super().__init__()
        self.channel_half = int(in_channel/2)
        self.channel_quarter = int(self.channel_half/2)
        self.conv1 = nn.Conv2d(in_channel, self.channel_half, (1,1), padding=0)
        nn.init.kaiming_normal_(self.conv1.weight)
        self.batch1 = nn.BatchNorm2d(self.channel_half)
        self.silu1 = nn.SiLU()
        self.conv3 = nn.Conv2d(self.channel_half, self.channel_quarter, (3,3), padding=1)
        nn.init.kaiming_normal_(self.conv3.weight)
        self.batch3 = nn.BatchNorm2d(self.channel_quarter)
        self.silu3 = nn.SiLU()
        self.FEM = FEM(self.channel_quarter, in_height, in_width)
    def forward(self, x):
        output = self.FEM(self.silu3(self.batch3(self.conv3(self.silu1(self.batch1(self.conv1(x)))))))
        return output

In [11]:
# SFE 

class SFE(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        self.conv3_1 = nn.Conv2d(in_channel, in_channel, (3,3), stride=2, padding=2)
        nn.init.kaiming_normal_(self.conv3_1.weight)
        self.batch3_1 = nn.BatchNorm2d(in_channel)
        self.relu3_1 = nn.ReLU()
        self.conv3_2 = nn.Conv2d(in_channel, in_channel, (3,3), stride=2, padding=2)
        nn.init.kaiming_normal_(self.conv3_2.weight)
        self.batch3_2 = nn.BatchNorm2d(in_channel)
        self.relu3_2 = nn.ReLU()
        self.conv3_3 = nn.Conv2d(in_channel, in_channel, (3,3), stride=2, padding=2)
        nn.init.kaiming_normal_(self.conv3_3.weight)
        self.batch3_3 = nn.BatchNorm2d(in_channel)
        self.relu3_3 = nn.ReLU()
        self.conv3_4 = nn.Conv2d(in_channel, in_channel, (3,3), stride=2, padding=2)
        nn.init.kaiming_normal_(self.conv3_4.weight)
        self.batch3_4 = nn.BatchNorm2d(in_channel)
        self.relu3_4 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channel, in_channel, (1,1), stride=1, padding=3) # Change padding to 3 so that 448x448 input becomes 36x36 output
        nn.init.kaiming_normal_(self.conv1.weight)
        self.batch1 = nn.BatchNorm2d(in_channel)
        self.relu1 = nn.ReLU()
    def forward(self, x):
        output = self.relu1(self.batch1(self.conv1(self.relu3_4(self.batch3_4(self.conv3_4(self.relu3_3(self.batch3_3(self.conv3_3(self.relu3_2(self.batch3_2(self.conv3_2(self.relu3_1(self.batch3_1(self.conv3_1(x)))))))))))))))
        return output

In [12]:
# "ViT Modules"

class PatchEmbedding(nn.Module):   # Obtained from tutorial: https://youtu.be/j3VNqtJUoz0?si=iZLnmtbygLGLQl9K
    def __init__(self, in_channels=3, patch_size=12, emb_size=48):
        super().__init__()
        self.projection = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )
    def forward(self, x):
        output = self.projection(x)
        return output

class Attention(nn.Module): # Obtained from tutorial: https://youtu.be/j3VNqtJUoz0?si=iZLnmtbygLGLQl9K
    def __init__(self, dim, n_heads, dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, dropout=dropout)
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        attention_out, attention_out_weights = self.attention(q, k, v)
        return attention_out
    
class PreNorm(nn.Module):   # Obtained from tutorial: https://youtu.be/j3VNqtJUoz0?si=iZLnmtbygLGLQl9K
    def __init__(self, dim, function):
        super().__init__()
        self.layer = nn.LayerNorm(dim)
        self.function = function
    def forward(self, x, **kwargs):
        return self.function(self.layer(x), **kwargs)
    
class FeedForward(nn.Sequential):   # Obtained from tutorial: https://youtu.be/j3VNqtJUoz0?si=iZLnmtbygLGLQl9K
    def __init__(self, dim, hidden_dim, dropout=0):
        super().__init__(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

class ResidualAdd(nn.Module): # Obtained from tutorial: https://youtu.be/j3VNqtJUoz0?si=iZLnmtbygLGLQl9K
    def __init__(self, function):
        super().__init__()
        self.function = function
    def forward(self, x, **kwargs):
        res = x
        output = self.function(x, **kwargs)
        output += res
        return output
    
class ViT(nn.Module): # Obtained from tutorial: https://youtu.be/j3VNqtJUoz0?si=iZLnmtbygLGLQl9K
    def __init__(self, ch=3, img_size=36, patch_size=12, emb_dim=432,   #Not sure about emb_dim since each patch is 3x12x12 (CxHxW) = 432 but that is big
                n_layers=12, out_dim=102, dropout=0.1, heads=4):
        super(ViT, self).__init__()
        # Attributes
        self.channels = ch
        self.height = img_size
        self.width = img_size
        self.patch_size = patch_size
        self.n_layers = n_layers
        # Patching
        self.patch_embedding = PatchEmbedding(in_channels=ch, patch_size=patch_size, emb_size=emb_dim)
        # Learnable params
        num_patches = (img_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_patches + 1, emb_dim))
        self.cls_token = nn.Parameter(torch.rand(1, 1, emb_dim))
        # Transformer Encoder
        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            transformer_block = nn.Sequential(
                ResidualAdd(PreNorm(emb_dim, Attention(emb_dim, n_heads = heads, dropout = dropout))),
                ResidualAdd(PreNorm(emb_dim, FeedForward(emb_dim, emb_dim, dropout = dropout))))
            self.layers.append(transformer_block)
        # Classification head
        self.head = nn.Sequential(nn.LayerNorm(emb_dim), nn.Linear(emb_dim, out_dim))
    def forward(self, img, cnn_in):
        # Get patch embedding vectors
        x = self.patch_embedding(img)
        b, n, _ = x.shape
        # Add cls token to inputs
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = x + cnn_in
        # Transformer layers
        for i in range(self.n_layers):
            x = self.layers[i](x)
        # Output based on classification token
        return self.head(x[:, 0, :])

In [13]:
# CB_ViT Model

class CB_ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.ResNet = ResNet50()    # For (1,3,448,448) input, expecting [1, 2048, 14, 14] output
        self.CNN_modules = CNN_Block(in_channel=2048, in_height=14, in_width=14)
        self.SFE = SFE(in_channel=3)
        self.conv1_down = nn.Conv2d(in_channels=256, out_channels=10, kernel_size=(1,1))
        self.adapt = nn.AdaptiveAvgPool2d(output_size=(432, 1))
        self.layer = nn.LayerNorm(432)  # em_dim = 432
        self.ViT = ViT(ch=3, img_size=36, patch_size=12, emb_dim=432, n_layers=12, out_dim=102, dropout=0.1, heads=4)
    def forward(self, x):
        # CNN branch
        resnet50_out = self.ResNet(x)
        cnn_modules_out = self.CNN_modules(resnet50_out)    # For [1, 2048, 14, 14] ResNet output, expecting [1, 256, 14, 14] output, needs to be converted to [1, 10, 432]
        # Down converter
        down_out = self.conv1_down(cnn_modules_out) # For [1, 256, 14, 14] input, expecting [1, 10, 14, 14] output
        down_adapt = self.adapt(down_out)   # For [1, 10, 14, 14] input, obtain [1, 10, 432, 1] output
        down_resize = down_adapt.view(down_adapt.shape[0], down_adapt.shape[1], -1)
        down_layer = self.layer(down_resize)
        # ViT branch    
        sfe_out = self.SFE(x)
        output = self.ViT(img=sfe_out, cnn_in=down_layer)
        return output

In [14]:
# ResNet_ViT Model

class Res_ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.ResNet = ResNet50()
        self.y = torch.zeros((8,5,1024), device=device, dtype=dtype) # Batch_size, num_patches + 1, emb_dim
        self.ViT = ViT(ch=2048, img_size=14, patch_size=7, emb_dim=1024, n_layers=12, out_dim=102, dropout=0.1, heads=4)
    def forward(self, x):
        resnet_out = self.ResNet(x)
        output = self.ViT(img=resnet_out, cnn_in=self.y)
        return output

In [15]:
# Test Res_ViT
"""x = torch.ones((8,3,448,448), dtype=dtype) # NxCxHxW
model = Res_ViT()
output = model(x)
print(output.size())""";

In [None]:
# Training Res_ViT

learning_rate = 0.001
momentum = 0.9
weight_decay = 0.00005
epochs = 30#0

model = CB_ViT()
optimizer = optim.SGD(params=model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

print_every = 100
train_part34(model, optimizer, epochs=epochs)

Epoch 0, Iteration 100, loss = 4.4619
Got 10 / 1020 correct (0.98)

