# CNN with Channel-wise Attention

- Modify a standard CNN by inserting channel-wise attention modules at suitable
locations.

#### Setup
Import libraries, and load the dataset. The function to load the dataset has been provided in `dataset_wrapper` and the data is stored in the `./data` local directory.


In [2]:
from dataset_wrapper import get_pet_datasets
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from question1 import acc_string

In [3]:
train_dataset, val_dataset, test_dataset = get_pet_datasets(img_width=128, img_height=128,root_path='./data' )
print(f"Loaded data, train = {len(train_dataset)}, test = {len(test_dataset)}")

Loaded data, train = 5719, test = 716


In [4]:
torch.cuda.is_available()

True

In [5]:
compute_device = torch.device('cuda:0')

Define the Data-loaders

In [6]:
# load the datasets
batch_size = 64
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

testing_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

### Defining the Neural Network
This is based off a baseline CNN architecture with five layers.

The implementation is inspired by this source: https://medium.com/@simonyihunie/a-comprehensive-guide-to-attention-mechanisms-in-cnns-from-intuition-to-implementation-7a40df01a118

In [7]:
class ChannelwiseAttention(nn.Module):
    def __init__(self, channel, reduction=4):
        super().__init__()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fullyConnected = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.pooling(x).view(b, c)
        y = self.fullyConnected(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [37]:
class ChannelWiseCNN(nn.Module):
    def __init__(self, classes=4, in_channels=3, reduction=4):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(num_features=16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # l1 output dim = 128 - 5 + 1 + 2*2 / 2
        #               = 64

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # l2 output dim = 64 - 5 + 1 + 2*2 / 2
        #               = 32

        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=32,
                      out_channels=64,
                      kernel_size=5,
                      stride=1,
                      padding=2),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # l3 output dim = 32 - 5 + 1 + 2*2 / 2
        #               = 16

        self.ChannelAttention = ChannelwiseAttention(channel=64, reduction=reduction)

        self.global_pool = nn.AdaptiveAvgPool2d(1)

        self.skip = nn.Sequential(
                nn.Conv2d(in_channels, 64, kernel_size=1),
                nn.MaxPool2d(kernel_size=2, stride=2),  # 128 → 64
                nn.MaxPool2d(kernel_size=2, stride=2),  # 64 → 32
                nn.MaxPool2d(kernel_size=2, stride=2),  # 32 → 16
                nn.ReLU()
        )

        self.fc = nn.Linear(in_features=64, out_features=classes)

        self.relu = nn.ReLU()

    def forward(self, x):
        # the paper has a skip connection around the whole block
        original_input = x

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.ChannelAttention(x)

        residual = self.relu(self.skip(original_input))

        x = x + residual

        x = self.global_pool(x)
        x = x.view(x.size(0), -1)

        # pass through the fully connected layer for classification
        x = self.fc(x)

        return x


### Testing

In [38]:
def do_training(model, experiment_name, criterion, optimizer, num_epochs=20, patience=5):
    writer = SummaryWriter('runs/' + experiment_name)

    min_validation_loss = None
    best_model_state = None  # store the best model here. Re-instate this if early stopping is triggered
    wait = 0

    # make sure the model is starting with new weights
    for layer in model.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

    steps = len(training_dataloader)
    for epoch in range(num_epochs):  # epoch iteration loop
        model.train()
        train_loss_epoch_total = 0
        batches_count = 0

        for i, (images, labels) in enumerate(training_dataloader):

            if i == 0:
                writer.add_graph(model, images.to(compute_device))
            images = images.to(compute_device)
            labels = labels.to(compute_device)

            # forwards
            outputs = model(images)
            loss = criterion(outputs, labels)

            # backpropogation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss_epoch_total += loss.item()
            batches_count += 1

        train_loss = train_loss_epoch_total / batches_count
        writer.add_scalar('Loss/train', train_loss, epoch + 1)

        # validation accuracy
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            val_loss_epoch_total = 0
            val_batches_count = 0

            for images, labels in validation_dataloader:
                images = images.to(compute_device)
                labels = labels.to(compute_device)
                outputs = model(images)
                val_loss = criterion(outputs, labels)

                val_loss_epoch_total += val_loss.item()
                val_batches_count += 1

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        val_loss = val_loss_epoch_total / val_batches_count

        writer.add_scalar('Accuracy/validation', val_acc, epoch)
        writer.add_scalar('Loss/validation', val_loss, epoch)

        if min_validation_loss is None or val_loss < min_validation_loss:
            min_validation_loss = val_loss
            best_model_state = model.state_dict()  # save the best weights
            wait = 0
        else:
            wait += 1

        if wait >= patience:
            break  # exit early if there has been no improvement in validation loss

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("Best model weights restored.")

    writer.close()

    return model


def do_testing(model, dataloader):
    model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in dataloader:
            images = images.to(compute_device)
            labels = labels.to(compute_device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            accuracy = 100 * correct / total

        return accuracy, 'Accuracy of the model on the provided images: {} %'.format(accuracy)

In [47]:
channelwise_cnn = ChannelWiseCNN().to(compute_device)

In [48]:
weights = [600/7149, 1771/7149, 2590/7149, 2188/7149]
class_weights = torch.FloatTensor(weights).cuda()
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(channelwise_cnn.parameters(), lr=0.01)

In [49]:
channelwise_cnn_trained = do_training(channelwise_cnn, experiment_name='channelwise_cnn', criterion=criterion, optimizer=optimizer, num_epochs=100, patience=10)

Best model weights restored.


In [50]:
acc, acc_string = do_testing(channelwise_cnn_trained, testing_dataloader)

In [51]:
print(acc_string)

Accuracy of the model on the provided images: 66.34078212290503 %
