In [1]:
# header files needed
import numpy as np
import torch
import torch.nn as nn
import torchvision

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [None]:
# define transforms
train_transforms = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(30),
                                       torchvision.transforms.Resize((224, 224)),
                                       torchvision.transforms.RandomHorizontalFlip(),
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [None]:
# get data
train_data = torchvision.datasets.ImageFolder("/content/drive/My Drive/train_images/", transform=train_transforms)
val_data = torchvision.datasets.ImageFolder("/content/drive/My Drive/val_images/", transform=train_transforms)
print(len(train_data))
print(len(val_data))

In [None]:
# data loaders
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=32, shuffle=True, num_workers=16, pin_memory=True)
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=32, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
# model class
class VGG16_CBAM(torch.nn.Module):

  # init function
  def __init__(self, model, num_classes=2):
    super().__init__()

    # pool layer
    self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))

    # spatial attention
    self.spatial_attention = torch.nn.Sequential(
        torch.nn.Conv2d(2, 1, kernel_size=7, padding=3, stride=1),
        torch.nn.BatchNorm2d(1),
        torch.nn.Sigmoid()
    )

    # channel attention
    self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))
    self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))
    self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))
    self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))
    self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))
    self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))
    self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))
    self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))
    self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))
    self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))

    # features
    self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])
    self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])
    self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])
    self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])
    self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])
    self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])
    self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])
    self.features_8 = torch.nn.Sequential(*list(model.features.children())[24:27])
    self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])
    self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])
    self.features_11 = torch.nn.Sequential(*list(model.features.children())[34:37])
    self.features_12 = torch.nn.Sequential(*list(model.features.children())[37:40])
    self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])

    self.avgpool = nn.AdaptiveAvgPool2d(7)

    # classifier
    self.classifier = torch.nn.Sequential(
        torch.nn.Linear(25088, 4096),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(),
        torch.nn.Linear(4096, 4096),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(),
        torch.nn.Linear(4096, 2)
    )


  # forward
  def forward(self, x):
    x = self.features_1(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_1(x) + self.avg_pool_1(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale

    x = self.features_2(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_1(x) + self.avg_pool_1(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale
    x = self.pool(x)

    x = self.features_3(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_2(x) + self.avg_pool_2(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale

    x = self.features_4(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_2(x) + self.avg_pool_2(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale
    x = self.pool(x)

    x = self.features_5(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_3(x) + self.avg_pool_3(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale

    x = self.features_6(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_3(x) + self.avg_pool_3(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale

    x = self.features_7(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_3(x) + self.avg_pool_3(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale
    x = self.pool(x)

    x = self.features_8(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_4(x) + self.avg_pool_4(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale

    x = self.features_9(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_4(x) + self.avg_pool_4(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale

    x = self.features_10(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_4(x) + self.avg_pool_4(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale
    x = self.pool(x)

    x = self.features_11(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_5(x) + self.avg_pool_5(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale

    x = self.features_12(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_5(x) + self.avg_pool_5(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale

    x = self.features_13(x)
    scale = torch.nn.functional.sigmoid(self.max_pool_5(x) + self.avg_pool_5(x)).expand_as(x)
    x = x * scale
    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    scale = self.spatial_attention(scale)
    x = x * scale
    x = self.pool(x)

    x = self.avgpool(x)
    x = x.view(x.shape[0], -1)
    x = self.classifier(x)
    return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pretrained_model = torchvision.models.vgg16_bn(pretrained=True)
model = VGG16_CBAM(pretrained_model, 2)
model.to(device)
print(model)

In [None]:
# loss
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# optimizer to be used
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)

In [None]:
train_losses = []
train_acc = []
val_losses = []
val_acc = []
best_metric = -1
best_metric_epoch = -1

# train and validate
for epoch in range(0, 30):
    
    # train
    model.train()
    training_loss = 0.0
    total = 0
    correct = 0
    for i, (input, target) in enumerate(train_loader):
        
        input = input.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        training_loss = training_loss + loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
    training_loss = training_loss / float(len(train_loader))
    training_accuracy = str(100.0 * (float(correct) / float(total)))
    train_losses.append(training_loss)
    train_acc.append(training_accuracy)
    
    # validate
    model.eval()
    valid_loss = 0.0
    total = 0
    correct = 0
    for i, (input, target) in enumerate(val_loader):
        
        with torch.no_grad():
            input = input.to(device)
            target = target.to(device)

            output = model(input)
            loss = criterion(output, target)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
        valid_loss = valid_loss + loss.item()
    valid_loss = valid_loss / float(len(val_loader))
    valid_accuracy = str(100.0 * (float(correct) / float(total)))
    val_losses.append(valid_loss)
    val_acc.append(valid_accuracy)


    # store best model
    if(float(valid_accuracy) > best_metric and epoch >= 10):
      best_metric = float(valid_accuracy)
      best_metric_epoch = epoch
      torch.save(model.state_dict(), "best_model.pth")
    
    print()
    print("Epoch" + str(epoch) + ":")
    print("Training Accuracy: " + str(training_accuracy) + "    Validation Accuracy: " + str(valid_accuracy))
    print("Training Loss: " + str(training_loss) + "    Validation Loss: " + str(valid_loss))
    print()

In [None]:
import matplotlib.pyplot as plt

e = []
for index in range(0, 30):
    e.append(index)
plt.plot(e, train_losses)
plt.show()