<a href="https://colab.research.google.com/github/anirudh-chakravarthy/MRSCAtt/blob/main/MRSCAtt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MRSCAtt: A Spatio-Channel Attention-Guided Network for Mars Rover Image Classification (CVPRW 2021)

In [None]:
from google.colab import drive
drive.mount("/content/gdrive", force_remount=True)

Mounted at /content/gdrive


In [None]:
%cd gdrive/MyDrive/MARS_ROVER_IMAGE_CLASSIFICATION/
DATA_DIR = 'DATA/msl-images/'
CKPT_DIR = 'checkpoints-Roshan/'
RESULT_FILE = 'Results-Roshan'

In [None]:
import os
import os.path as osp
import cv2
import pandas as pd
import time
import json
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

from collections import OrderedDict
from collections import namedtuple
from itertools import product

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import DataParallel
from torchvision import models, transforms, utils
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from IPython.display import display, clear_output
import pdb

# Specifications of the dataset. (256x193 images) (25 different categories)
IMG_WIDTH = 256
IMG_HEIGHT = 193
NUM_CATEGORIES = 25
BATCH_SIZE = 64
NUM_EPOCHS = 20
LR = 1e-4
RANDOM_SEED = 42

# seeding for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True

## Data loading

In [None]:
class MarsRoverDataset(Dataset) :
    def __init__(self) :
        self.root_dir = None
        self.transform = None
        self.images = None
        self.labels = None

    def __len__(self) :
        return len(self.images)
    
    def __getitem__(self, idx) :
        path = osp.join(self.root_dir, self.images[idx])
        sample = {'image': cv2.imread(path), 'class': self.labels[idx]}

        if self.transform : 
            sample["image"] = self.transform(sample["image"])
            sample["class"] = torch.tensor(int(sample["class"]), dtype=torch.long)
        return sample
    
    def read_txt(self, split):
        dataset = osp.join(DATA_DIR, split + '-calibrated-shuffled.txt')
        imgs = []
        labels = []

        with open(dataset, 'r') as f:
            for line in f.readlines():
                line = line.replace('\n', '')
                line = line.split(' ')
                imgs.append(line[0])
                labels.append(line[1])

        return imgs, labels

class trainDataset(MarsRoverDataset) :
    def __init__(self, root_dir, transform=None) :
        super(trainDataset, self).__init__()
        self.root_dir = root_dir
        self.images, self.labels = self.read_txt('train')
        self.transform = transform

class validDataset(MarsRoverDataset) :
    def __init__(self, root_dir, transform=None) :
        super(validDataset, self).__init__()
        self.root_dir = root_dir
        self.images, self.labels = self.read_txt('val')
        self.transform = transform

class testDataset(MarsRoverDataset) :
    def __init__(self, root_dir, transform=None) :
        super(testDataset, self).__init__()
        self.root_dir = root_dir
        self.images, self.labels = self.read_txt('test')
        self.transform = transform

In [None]:
train_transforms = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), 
                                       transforms.RandomHorizontalFlip(p=0.5),
                                       transforms.RandomRotation(72)
])
test_transforms = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), 
])

train_set = trainDataset(root_dir=DATA_DIR, transform=test_transforms)
valid_set = validDataset(root_dir=DATA_DIR, transform=test_transforms)
test_set = testDataset(root_dir=DATA_DIR, transform=test_transforms)

## MRSCAtt network

In [None]:
class ChannelSpatialBlock(nn.Module):
  """Contains the implementation of Convolutional Block Attention Module (CBAM).
     As described in https://arxiv.org/abs/1807.06521.
  """

  def __init__(self, in_channels=2048, kernel_size=7, ratio=8):
    super(ChannelSpatialBlock, self).__init__()

    self.in_channels = in_channels
    self.ratio = ratio

    self.dropout = nn.Dropout(p=0.2)

    self.avg_fc1 = nn.Linear(in_channels, in_channels // self.ratio)
    self.avg_fc2 = nn.Linear(in_channels // self.ratio, in_channels)
    self.max_fc1 = nn.Linear(in_channels, in_channels // self.ratio)
    self.max_fc2 = nn.Linear(in_channels // self.ratio, in_channels)
    self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=3)

    self.avg_bn1 = nn.BatchNorm1d(in_channels//self.ratio, affine=True)
    self.avg_bn2 = nn.BatchNorm1d(in_channels, affine=True)
    self.max_bn1 = nn.BatchNorm1d(in_channels//self.ratio, affine=True)
    self.max_bn2 = nn.BatchNorm1d(in_channels, affine=True)
    self.bn = nn.BatchNorm2d(1, affine=True)

    self.init_weights()

  def init_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
          nn.init.kaiming_normal_(
              m.weight, mode='fan_out', nonlinearity='relu')
          nn.init.constant_(m.bias, 0)
  
  def channel_attention(self, x):
    avg_pool = torch.mean(x, dim=(-1, -2), keepdim=False) # (N, C)
    avg_pool = self.avg_fc1(avg_pool) # (N, C/r)
    avg_pool = self.avg_bn1(avg_pool) # (N, C/r)
    avg_pool = self.dropout(avg_pool)
    avg_pool = self.avg_fc2(avg_pool) # (N, C)
    avg_pool = self.avg_bn2(avg_pool) # (N, C)
    avg_pool = self.dropout(avg_pool)

    max_pool, _ = torch.max(x, dim=-1, keepdim=False) # (N, C, H)
    max_pool, _ = torch.max(max_pool, dim=-1, keepdim=False) # (N, C)
    max_pool = self.max_fc1(max_pool) # (N, C/r)
    max_pool = self.max_bn1(max_pool) # (N, C/r)
    max_pool = self.dropout(max_pool)
    max_pool = self.max_fc2(max_pool) # (N, C)
    max_pool = self.max_bn2(max_pool) # (N, C)
    max_pool = self.dropout(max_pool)

    scale = F.sigmoid(avg_pool + max_pool)[:, :, None, None] # (N, C, 1, 1)
    return scale, x * scale

  def spatial_attention(self, x):
    avg_pool = torch.mean(x, dim=1, keepdim=True) # (N, 1, H, W)
    max_pool, _ = torch.max(x, dim=1, keepdim=True) # (N, 1, H, W)

    concat = torch.cat([avg_pool, max_pool], dim=1) # (N, 2, H, W)
    concat = self.conv(concat) # (N, 1, H, W)
    concat = self.bn(concat)
    concat = F.sigmoid(concat) # (N, 1, H, W)

    return concat, x * concat

  def forward(self, x):
    scale, attention_feature = self.channel_attention(x)
    concat, attention_feature = self.spatial_attention(attention_feature)
    return attention_feature

In [None]:
class MRSCAtt(nn.Module):

  def __init__(self, num_classes=25):
    super(MRSCAtt, self).__init__()
    resnet = models.resnet50(pretrained=True)

    self.conv1 = resnet.conv1
    self.bn1 = resnet.bn1
    self.relu = resnet.relu  # 1/2, 64
    self.maxpool = resnet.maxpool
    self.avgpool = resnet.avgpool

    self.res2 = resnet.layer1 # 1/4, 256
    self.res3 = resnet.layer2 # 1/8, 512
    self.res4 = resnet.layer3 # 1/16, 1024
    self.res5 = resnet.layer4 # 1/32, 2048
    self.fc = nn.Linear(2048, num_classes)

    # spatial channel attention block
    self.csb = ChannelSpatialBlock(in_channels=2048)

    # dropout
    # self.dropout = nn.Dropout(p=0.2)

    # freeze backbone parameters
    for m in [self.conv1, self.bn1, self.res2, self.res3, self.res4]:
      for param in m.parameters():
        param.requires_grad = False

    self.register_buffer(
        'mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
    self.register_buffer(
        'std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))

  def forward(self, x):
    x = (x - Variable(self.mean)) / Variable(self.std)

    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x) # 1/2, 64
    x = self.maxpool(x) # 1/4, 64
    r2 = self.res2(x)   # 1/4, 64
    r3 = self.res3(r2) # 1/8, 128
    r4 = self.res4(r3) # 1/16, 256
    r5 = self.res5(r4) # 1/32, 512

    r5 = self.csb(r5)

    r5 = self.avgpool(r5)
    r5 = torch.flatten(r5, 1)
    # r5 = self.dropout(r5)
    r5 = self.fc(r5)
    
    return r5, r4, r3, r2

In [None]:
# estimate number of parameters
model = MRSCAtt(NUM_CATEGORIES)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_params)

## Training

In [None]:
class RunBuilder():
    @staticmethod
    def get_runs(params):
        Run = namedtuple('Run', params.keys())
        runs = []
        for v in product(*params.values()):
            runs.append(Run(*v))

        return runs

class RunManager() :
    def __init__(self):
        self.epoch_count = 0
        self.epoch_loss = 0
        self.train_epoch_num_correct = 0
        self.epoch_start_time = None
        self.valid_epoch_loss = 0
        self.valid_epoch_num_correct = 0
        self.test_epoch_num_correct = 0

        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None

        self.network = None
        self.train_loader = None
        self.valid_loader = None
        self.test_loader = None

    def begin_run(self, run, network, train_loader, valid_loader, test_loader, start_epoch=0):
        self.run_start_time = time.time()

        self.run_params = run
        self.run_count += 1

        self.network = network
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

    def end_run(self):
        self.epoch_count = 0
    
    def begin_epoch(self):
        self.epoch_start_time = time.time()
        self.epoch_count += 1
        self.epoch_loss = 0
        self.valid_epoch_loss = 0
        self.train_epoch_num_correct = 0
        self.valid_epoch_num_correct = 0
        self.test_epoch_num_correct = 0
        self.network.train()
    
    def end_epoch(self):
        epoch_duration = time.time() - self.epoch_start_time
        run_duration = time.time() - self.run_start_time

        train_loss = self.epoch_loss / len(self.train_loader.dataset)
        valid_loss = self.valid_epoch_loss / len(self.valid_loader.dataset)
        train_accuracy = self.train_epoch_num_correct / len(self.train_loader.dataset)
        valid_accuracy = self.valid_epoch_num_correct / len(self.valid_loader.dataset)
        test_accuracy = self.test_epoch_num_correct / len(self.test_loader.dataset)

        results = OrderedDict()
        results["run"] = self.run_count
        results["epoch"] = self.epoch_count
        results['train_loss'] = train_loss
        results['valid_loss'] = valid_loss
        results["train_accuracy"] = train_accuracy * 100
        results["valid_accuracy"] = valid_accuracy * 100
        results["test_accuracy"] = test_accuracy * 100
        results['epoch_duration'] = epoch_duration
        results['run_duration'] = run_duration
        for k,v in self.run_params._asdict().items(): results[k] = v
        
        self.run_data.append(results)
        df = pd.DataFrame.from_dict(self.run_data, orient='columns')

        clear_output(wait=True)
        display(df)

    def track_loss(self, loss):
        self.epoch_loss += loss.item()
    
    def track_valid_loss(self, loss):
        self.valid_epoch_loss += loss.item()

    def track_num_correct(self, train_preds, train_labels):
        self.train_epoch_num_correct += self._get_num_correct(train_preds, train_labels)
      
    def track_valid_stats(self, valid_preds, valid_labels):
        self.valid_epoch_num_correct += self._get_num_correct(valid_preds, valid_labels)

    def track_test_stats(self, test_preds, test_labels):
        self.test_epoch_num_correct += self._get_num_correct(test_preds, test_labels)

    @torch.no_grad()
    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()

    def saveCheckpoints(self, out_dir, loss) :
        if not osp.isdir(out_dir):
          os.makedirs(out_dir)
        ckpt = osp.join(out_dir, 'epoch_' + str(self.epoch_count) + '.pth')
        torch.save({
            'epoch': self.epoch_count,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, ckpt)

    def save(self, fileName):
        pd.DataFrame.from_dict(self.run_data, orient='columns').to_csv(f'{fileName}.csv')
        with open(f'{fileName}.json', 'w', encoding='utf-8') as f:
            json.dump(self.run_data, f, ensure_ascii=False, indent=4)

In [None]:
params = OrderedDict(
    lr = [LR],
    batch_size = [BATCH_SIZE],
    epochs = [NUM_EPOCHS]
)

m = RunManager()

for run in RunBuilder.get_runs(params):
    model = MRSCAtt(NUM_CATEGORIES)
    model = DataParallel(
        model.to(torch.cuda.current_device()), 
        device_ids=[torch.cuda.current_device()]
        )
    loss_fn = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), run[0])

    train_loader = DataLoader(train_set, batch_size=run[1])
    valid_loader = DataLoader(valid_set, batch_size=run[1])
    test_loader = DataLoader(test_set, batch_size=run[1])
    m.begin_run(run, model, train_loader, valid_loader, test_loader)

    for epoch in tqdm(range(run[2])):
        m.begin_epoch()
        for batch in train_loader:
            images, labels = batch['image'], batch['class']
            train_pred, r4, r3, r2 = model(images.cuda())
            train_loss = loss_fn(train_pred, labels.cuda())

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            m.track_loss(train_loss)
            m.track_num_correct(train_pred, labels.cuda())
        model.eval()

        # validation accuracy stats
        for valid_batch in valid_loader:
          valid_images, valid_labels = valid_batch['image'], valid_batch['class']
          valid_pred, _, _, _ = model(valid_images.cuda())
          valid_loss = loss_fn(valid_pred, valid_labels.cuda())

          m.track_valid_loss(valid_loss)
          m.track_valid_stats(valid_pred, valid_labels.cuda())

        # test accuracy stats
        for test_batch in test_loader:
          test_images, test_labels = test_batch['image'], test_batch['class']
          test_pred, _, _, _ = model(test_images.cuda())
          m.track_test_stats(test_pred, test_labels.cuda())

        m.saveCheckpoints(CKPT_DIR, train_loss)        
        m.end_epoch()

    m.end_run()
m.save('Results-Roshan-R50-SCAtt-loss-curves')

In [None]:
train_losses = [0.020033, 0.004966, 0.002251, 0.001156, 0.000543, 0.000284, 0.000173, 0.000131,
                0.000062, 0.000043, 0.000035, 0.000029, 0.000025, 0.000021, 0.000019]
valid_losses = [0.034559, 0.018036, 0.016703, 0.013259, 0.013580, 0.012898, 0.014908, 0.013318,
                0.013530, 0.013288, 0.013193, 0.013178, 0.013189, 0.013137, 0.013228]
plt.plot(range(1,len(train_losses)+1), train_losses, label='Training loss')
plt.plot(range(1, len(valid_losses)+1), valid_losses, label='Validation loss')
plt.xlabel('No. of epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

## Evaluation

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
sota_result_file = "Results-Roshan-R50-SCAtt-50e-res5freeze-aug-0.2newdropout-newbn.csv"
sota_epoch = "epoch_13.pth" # 1-indexed
sota_checkpoints_folder = "checkpoints-Roshan10"
sota_model_path = osp.join(sota_checkpoints_folder, sota_epoch)

def track_test_stats(test_preds, test_labels):
        return get_num_correct(test_preds, test_labels)

def get_num_correct(preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()

model = MRSCAtt(NUM_CATEGORIES)
model = DataParallel(
    model.to(torch.cuda.current_device()), 
    device_ids=[torch.cuda.current_device()]
    )
model.eval()

# loading old state
checkpoint = torch.load(sota_model_path)
model.load_state_dict(checkpoint['model_state_dict'])

# test loader
test_set = testDataset(root_dir=DATA_DIR, transform=test_transforms)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)

# inference
num_correct = 0
predlist=torch.zeros(0,dtype=torch.long, device='cpu')
lbllist=torch.zeros(0,dtype=torch.long, device='cpu')
with torch.no_grad() :
    for test_batch in test_loader:
        test_images, test_labels = test_batch['image'], test_batch['class']
        test_pred, _, _, _ = model(test_images.cuda())    
        num_correct += track_test_stats(test_pred, test_labels.cuda())
        
        predlist=torch.cat([predlist,test_pred.argmax(dim=1).cpu()])
        lbllist=torch.cat([lbllist,test_labels.cpu()])

# confusion matrix
conf_mat=confusion_matrix(lbllist.numpy(), predlist.numpy())

# accuracy
test_accuracy = 100 * num_correct / len(test_loader.dataset)
class_accuracy_values = 100*conf_mat.diagonal()/conf_mat.sum(1)
class_accuracy = {f'Class {k}':v for k,v in zip(lbllist.unique().numpy(),class_accuracy_values)}
print(f"Class Accuracy: {class_accuracy}")
print(f"Test Accuracy: {test_accuracy}")