**Import necessary modules**
nn and optim for neural network construction and optimizers


In [1]:
import torch, cv2, os, random, json, time, sys
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import _LRScheduler
import math

In [2]:
# This is a test for the imported torch library

# Assume input feature map with dimensions [batch_size, channels, height, width]
input_feature_map = torch.randn(1, 5, 5)  # Example input

# Define global average pooling operation
global_avg_pool = nn.AdaptiveAvgPool2d(1)
global_max_pool = nn.AdaptiveMaxPool2d(1)
convolution = nn.Conv2d(3, 5, 3)

# Apply global average pooling to input
map = global_max_pool(input_feature_map)

print(map.shape)
print(input_feature_map)
print(map)

torch.Size([1, 1, 1])
tensor([[[-1.1697, -1.3042, -0.6255, -1.5804, -0.5656],
         [ 1.3725, -0.7085, -0.7689, -1.0101,  0.6127],
         [-1.5991,  1.1792, -0.1324, -1.1519,  0.8122],
         [ 0.0517, -0.5503,  1.0421, -2.4580,  1.1521],
         [-0.0437,  1.1453,  0.2751, -1.1008, -0.3836]]])
tensor([[[1.3725]]])


# Model Architecture
* Individual Components:
    * CBAM Module
    * Channel Attention
    * Spatial Attention
    * DCNN (Deep Convolutional Neural Network)

Convolutional Block Attention Module:
This allows the neural network to focus on specific aspects of the image and improves
the representation of interests. If, for example, the input feature map tensor dimensions
are 6 x 127 x 127, the output will have the same dimensions. The CBAM module works on each feature map and enhances certain aspects of each feature map.

# Proposed Architecture for this task:
**Channel Attention**


**Spatial Attention**


**CBAM (Convolutional Block Attention Module)**


**CNN Feature Extracion**
Input (512 x 512) ->

6 x (Conv2d -> CBAM -> PReLU -> Conv2d -> CBAM -> PReLU -> MaxPool) ->

Flatten the feature maps to serve as inputs to DNN (5880 inputs) ->
DNN (output 4)

**DNN Classification**
L1 (5880) ->
ReLU ->
L2 (1024) ->
ReLU ->
L3 (512) ->
ReLU ->
L4 (256) ->
ReLU ->
L5 (128) ->
ReLU ->
L6 (64) ->
ReLU ->
L7 (32) ->
ReLU ->
L8 (4) ->
Softmax ->

Output: Classification probabilities (4 x 1)

# Notes
There are two architectures for this task. Both of them are almost the same but one does not use the attention mechanism. This provides a baseline to determine if the attention mechanism is effective.

In [3]:
class channel_attention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=4):
        super(channel_attention, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # We can use Conv2d instead of FCLs to simplify the operation and avoid having to flatten the layers.
        # The operation is essentially the same as in the CBAM paper but applied in a different way.
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, math.ceil(in_channels / reduction_ratio), 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(math.ceil(in_channels / reduction_ratio), in_channels, 1, bias=False)
        )

    def forward(self, x):
        # As the network is trained, the channels (feature maps) that should be paid more attention to become more pronounced.
        # Example avg_o and max_o ==> input: [batch_size, 6, 512, 512] -> 2 x [batch_size, 6, 1, 1] -> 2 x [batch_size, 2, 1, 1] -> 2 x [batch_size, 6, 1, 1]
        avg_o = self.fc(self.avg_pool(x))
        max_o = self.fc(self.max_pool(x))
        # Here just add the two channel attentions and put it through a sigmoid function.
        # This will give the attention score for each channel.
        out = torch.sigmoid(avg_o + max_o)
        return out


class spatial_attention(nn.Module):
    def __init__(self, kernel_size=7):
        super(spatial_attention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)

    def forward(self, x):

        # Compress across the channel dimension by getting the average and max of all values across that dimension.
        # input: (batch_size, #channels, height, width) -> output: (batch_size, 1, height, width)
        avg_map = torch.mean(x, dim=1, keepdim=True)
        max_map, thr = torch.max(x, dim=1, keepdim=True)

        # Concat the two maps.
        # input: 2 x (batch_size, 1, height, width) -> output: (batch_size, 2, height, width)
        x = torch.cat([avg_map, max_map], dim=1)

        x = self.conv(x)
        out = torch.sigmoid(x)
        return out

class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=4,sa_kernel_size=7):
        super(CBAM, self).__init__()
        self.channel = channel_attention(in_channels, reduction_ratio)
        self.spatial = spatial_attention(sa_kernel_size)
    def forward(self, x):
        x = x * self.channel(x)
        x = x * self.spatial(x)
        return x

class CNN_Attention(nn.Module):
    def __init__(self):
        super(CNN_Attention, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3),
            nn.BatchNorm2d(6),
            CBAM(6),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=12, kernel_size=3),
            nn.BatchNorm2d(12),
            CBAM(12),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3),
            nn.BatchNorm2d(24),
            CBAM(24),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(in_channels=24, out_channels=48, kernel_size=3),
            nn.BatchNorm2d(48),
            CBAM(48),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(in_channels=48, out_channels=96, kernel_size=3),
            nn.BatchNorm2d(96),
            CBAM(96),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block6 = nn.Sequential(
            nn.Conv2d(in_channels=96, out_channels=192, kernel_size=3),
            nn.BatchNorm2d(192),
            CBAM(192),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        return x

# CNN without CBAM
class CNN_NoAttention(nn.Module):
    def __init__(self):
        super(CNN_NoAttention, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=12, kernel_size=3),
            nn.BatchNorm2d(12),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3),
            nn.BatchNorm2d(24),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(in_channels=24, out_channels=48, kernel_size=3),
            nn.BatchNorm2d(48),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(in_channels=48, out_channels=96, kernel_size=3),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.block6 = nn.Sequential(
            nn.Conv2d(in_channels=96, out_channels=192, kernel_size=3),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        return x


class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.classify = nn.Sequential(
            nn.Linear(18816, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 4)
        )
    def forward(self, x):
        return self.classify(x)

class BrainTumorClassifier_Attention(nn.Module):
    def __init__(self):
        super(BrainTumorClassifier_Attention, self).__init__()
        self.feature_extraction = CNN_Attention()
        self.classification = DNN()
    def forward(self, x):
        features = self.feature_extraction(x)
        flattened_features = features.view(features.size(0), -1)
        classification = self.classification(flattened_features)
        return classification

class BrainTumorClassifier_NoAttention(nn.Module):
    def __init__(self):
        super(BrainTumorClassifier_NoAttention, self).__init__()
        self.feature_extraction = CNN_NoAttention()
        self.classification = DNN()
    def forward(self, x):
        features = self.feature_extraction(x)
        flattened_features = features.view(features.size(0), -1)
        classification = self.classification(flattened_features)
        return classification


In [4]:
attention_select = input("Use attention for classification? (Y|N)")
model = None
if (attention_select == "Y"):
    model = BrainTumorClassifier_Attention()
else:
    model = BrainTumorClassifier_NoAttention()
model.to(torch.float32)

device = torch.device("cpu")
if torch.cuda.is_available:
    torch.cuda.empty_cache()
    device = torch.device("cuda")

model = model.to(device)

dvc = next(model.parameters()).device
print("Model is on device:", dvc)

total_params = sum(p.numel() for p in model.parameters())

print(f"Total number of parameters: {total_params}")

num_parameters = 6736274
size_of_float32 = 4  # 4 bytes for float32
total_memory_bytes = num_parameters * size_of_float32
total_memory_mb = (total_memory_bytes / (2**20))

print(f"Total estimated memory usage: {total_memory_mb:.2f} MB")

Model is on device: cuda:0
Total number of parameters: 41565168
Total estimated memory usage: 25.70 MB


In [5]:
class ImageDataset(Dataset):
    def __init__(self, IO_pairs):
        self.IO_pairs = IO_pairs

    def __len__(self):
        return len(self.IO_pairs)

    def __getitem__(self, index):
        # Get the image
        image_class_name = self.IO_pairs[index][0]
        image_tensor_list = self.IO_pairs[index][1]
        classification_target = self.IO_pairs[index][2]

        return image_class_name, image_tensor_list, classification_target

In [None]:
num_epochs = 20

pt_path = "./content/train/train_datasets/train_dataset1.pt"
dataset = torch.load(pt_path)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
model.train()

class_loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.03)
#scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2, eta_min=0.00001)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

correction_threshold = 0.4
max_repetitions = 10

supervise = input("Enable constant supervision? (Y|N)")
superv_bool = False
if (supervise == "Y"):
    superv_bool = True
else:
    superv_bool = False

sum_class_loss = 0
total_backwards = 0
previous_avg_loss = 0

contin = "Y"
for epoch in range(num_epochs):

    for i, j in enumerate(dataloader):
        
        image_class_name = j[0]
        image_tensor = j[1]
        image_class_tensor = j[2]

        image_tensor = image_tensor.to(torch.float32)
        image_class_tensor = image_class_tensor.to(torch.float32)

        image_tensor = image_tensor.to(device)
        image_class_tensor = image_class_tensor.to(device)
            
        optimizer.zero_grad()
        prediction = model(image_tensor)
        classification_loss = class_loss_function(prediction, image_class_tensor)
        classification_loss.backward()
        optimizer.step()
        sum_class_loss += round(classification_loss.item(), 5)
        total_backwards += 1

        print(f"Epoch [{epoch+1}/{num_epochs}], Progress: [{i+1}/{len(dataloader)}], Class_Loss: {round(classification_loss.item(), 5)}, Learning Rate: {scheduler.get_last_lr()[0]}, Previous Loss: {previous_avg_loss}")

        # Reduce memory to device time overhead
        if (round(classification_loss.item(), 5) > previous_avg_loss or round(classification_loss.item(), 5) > 0.35) and epoch + 1 >= 2:
            for m in range(min(math.ceil(epoch), max_repetitions)):
                optimizer.zero_grad()
                prediction = model(image_tensor)
                classification_loss = class_loss_function(prediction, image_class_tensor)
                classification_loss.backward()
                optimizer.step()
                sum_class_loss += round(classification_loss.item(), 5)
                total_backwards += 1

                print(f"Epoch [{epoch+1}/{num_epochs}], Progress: [{i+1}/{len(dataloader)}, T{m}], Class_Loss: {round(classification_loss.item(), 5)}, Learning Rate: {scheduler.get_last_lr()[0]}, Previous Loss: {previous_avg_loss}")
    
    previous_avg_loss = round(sum_class_loss / total_backwards, 5)
    sum_class_loss = 0
    total_backwards = 0
    
    if ((epoch + 1) % 4 == 0):
        if (superv_bool):
            contin = input("Continue for 4 more epochs? (Y|N)")
            if (contin == "N"):
                break
            else:
                pass
        scheduler.step()
    
    if (contin == "N"):
        break

In [None]:
if not os.path.exists("./content/models"):
    os.mkdir("./content/models")
mytext = input("Enter the model name: ")
torch.save(model, f"./content/models/{mytext}.pt")

In [6]:
eval_select = input("Which model do you want to evaluate?")
attention_select = input("Attention? (Y|N)")

if (attention_select == "Y"):
    loaded_model: BrainTumorClassifier_Attention = torch.load(f"./content/models/{eval_select}.pt")
else:
    loaded_model: BrainTumorClassifier_NoAttention = torch.load(f"./content/models/{eval_select}.pt")

loaded_model.to(torch.float32)
loaded_model.to(device)
loaded_model.eval()

test_set = torch.load("./content/test/test_datasets/test_dataset1.pt")
train_set = torch.load("./content/train/train_datasets/train_dataset1.pt")

test_correct = 0
test_tested = 0

for i in range(test_set.__len__()):

    current_eval = test_set.__getitem__(i)
    inputTensor = current_eval[1]
    inputTensor = inputTensor.to(torch.float32)
    inputTensor = inputTensor.to(device)
    inputTensor = inputTensor.unsqueeze(0)
    
    target = current_eval[2]
    predicted_output = loaded_model(inputTensor)
    if (predicted_output.argmax() == target.argmax()):
        test_correct += 1
    test_tested += 1

    print("Testing Eval: " + str(test_correct / test_tested))

train_correct = 0
train_tested = 0

for i in range(train_set.__len__()):

    current_eval = train_set.__getitem__(i)
    inputTensor = current_eval[1]
    inputTensor = inputTensor.to(torch.float32)
    inputTensor = inputTensor.to(device)
    inputTensor = inputTensor.unsqueeze(0)
    
    target = current_eval[2]
    predicted_output = loaded_model(inputTensor)
    if (predicted_output.argmax() == target.argmax()):
        train_correct += 1
    train_tested += 1

    print("Training Eval: " + str(train_correct / train_tested))


print("\n\nOverall Results: ")
print("Testing Eval: " + str(test_correct / test_tested))
print("Training Eval: " + str(train_correct / train_tested))

Testing Eval: 1.0
Testing Eval: 1.0
Testing Eval: 0.6666666666666666
Testing Eval: 0.5
Testing Eval: 0.6
Testing Eval: 0.6666666666666666
Testing Eval: 0.7142857142857143
Testing Eval: 0.75
Testing Eval: 0.6666666666666666
Testing Eval: 0.7
Testing Eval: 0.6363636363636364
Testing Eval: 0.5833333333333334
Testing Eval: 0.6153846153846154
Testing Eval: 0.6428571428571429
Testing Eval: 0.6666666666666666
Testing Eval: 0.625
Testing Eval: 0.6470588235294118
Testing Eval: 0.6666666666666666
Testing Eval: 0.631578947368421
Testing Eval: 0.6
Testing Eval: 0.5714285714285714
Testing Eval: 0.5909090909090909
Testing Eval: 0.6086956521739131
Testing Eval: 0.625
Testing Eval: 0.64
Testing Eval: 0.6153846153846154
Testing Eval: 0.5925925925925926
Testing Eval: 0.6071428571428571
Testing Eval: 0.6206896551724138
Testing Eval: 0.6
Testing Eval: 0.5806451612903226
Testing Eval: 0.5625
Testing Eval: 0.5757575757575758
Testing Eval: 0.5882352941176471
Testing Eval: 0.6
Testing Eval: 0.5833333333333334