In [None]:
from tqdm import tqdm
import cv2
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import csv

In [None]:
def mask_addition(base_img, mask):
    feature_imp_map = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    feature_imp_map = np.float32(feature_imp_map) / 255
    cam = feature_imp_map + np.float32(base_img)
    cam = cam / np.max(cam)
    super_imposed = np.uint8(255 * cam)
    return super_imposed

In [None]:
def cam_data_generator(image, vgg_map, resnet_map, densenet_map, c):
    image = image.cpu().numpy()
    image = np.squeeze(np.transpose(image[0], (1, 2, 0)))
    image = image * np.array((0.229, 0.224, 0.225)) + \
        np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    grad_models = {
        'X-Ray Image': image,
        'VGG-16': mask_addition(image, vgg_map),
        'ResNet-18': mask_addition(image, resnet_map),
        'DenseNet-121': mask_addition(image, densenet_map)
    }

    plt.style.use('seaborn-notebook')
    fig = plt.figure(figsize=(20, 4))
    for i, (name, img) in enumerate(grad_models.items()):
        ax = fig.add_subplot(1, 4, i+1, xticks=[], yticks=[])
        if i:
            img = img[:, :, ::-1]
        ax.imshow(img)
        ax.set_xlabel(name, fontweight='bold')

    fig.suptitle(
        'grad_cl_ac_map Comparison',
        fontweight='bold', fontsize=18
    )
    save_path = c + 'grad_cam.png'
    plt.tight_layout()
    fig.savefig('grad_cam/' + save_path)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dirs = {
    'train': 'C:/Users/Hemanth/Desktop/Pre-trained with GradCAM/train',
    'val': 'C:/Users/Hemanth/Desktop/Pre-trained with GradCAM/val',
    'test': 'C:/Users/Hemanth/Desktop/Pre-trained with GradCAM/test'
}
###
transform = {
    'train': transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]),
    'val': transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]),
    'test': transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
}

In [None]:
def worker_funtion(id):
    seed_current = torch.initial_seed()
    seed_base = seed_current - id
    seed_seq = np.random.SeedSequence([id, seed_base])
    np.random.seed(seed_seq.generate_state(4))

In [None]:
def tot_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()


def tot_preds(model, loader):
    model.eval()
    with torch.no_grad():
        final_preds = torch.tensor([], device=device)
        for batch in loader:
            images = batch[0].to(device)
            preds = model(images)
            final_preds = torch.cat((final_preds, preds), dim=0)

    return final_preds

In [None]:
def fit_model(epochs, model, criterion, optimizer, train_dl, valid_dl):
    model_name = type(model).__name__.lower()
    val_loss_min = np.Inf
    train_img_count = 6416
    val_img_count = 600
    fields = [
        'epoch', 'running_loss', 'running_acc', 'val_loss', 'val_acc'
    ]
    rows = []

    for epoch in range(epochs):
        running_loss =  0
        corr_pred_train = 0
        train_loop = tqdm(train_dl)

        model.train()
        for batch in train_loop:
            images, labels = batch[0].to(device), batch[1].to(device)
            preds = model(images)
            loss = criterion(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            corr_pred_train += tot_correct(preds, labels)

            train_loop.set_description(f'Epoch [{epoch+1:2d}/{epochs}]')
            train_loop.set_postfix(
                loss=loss.item(), acc=corr_pred_train/train_img_count
            )
        running_loss = running_loss/train_img_count
        running_acc = corr_pred_train/train_img_count

        model.eval()
        with torch.no_grad():
            val_loss = 0
            corr_pred_val = 0
            for batch in valid_dl:
                images, labels = batch[0].to(device), batch[1].to(device)
                preds = model(images)
                loss = criterion(preds, labels)
                val_loss += loss.item() * labels.size(0)
                corr_pred_val += tot_correct(preds, labels)

            val_loss = val_loss/val_img_count
            val_acc = corr_pred_val/val_img_count

            rows.append([epoch, running_loss, running_acc, val_loss, val_acc])

            train_loop.write(
                f'\nTrain loss in this Epoch: {running_loss:.6f}', end='\t')
            train_loop.write(f'Validation loss this Epoch: {val_loss:.6f}\n')


            if val_loss <= val_loss_min:
                train_loop.write('\t\tSaving best model\n')
                torch.save(
                    model.state_dict(),
                    f'{model_name}.pth'
                )
                val_loss_min = val_loss

    with open(f'{model_name}.csv', 'w') as csv_file:
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(fields)
        csv_writer.writerows(rows)

In [None]:
def vgg_loader(pretrained=False, out_features=None, path=None):
    model = torchvision.models.vgg16(pretrained=pretrained)
    if out_features is not None:
        model.classifier = torch.nn.Sequential(
            *list(model.classifier.children())[:-1],
            torch.nn.Linear(in_features=4096, out_features=out_features)
        )
    if path is not None:
        model.load_state_dict(torch.load(path, map_location=device))

    return model.to(device)


def resnet_loader(pretrained=False, out_features=None, path=None):
    model = torchvision.models.resnet18(pretrained=pretrained)
    if out_features is not None:
        model.fc = torch.nn.Linear(in_features=512, out_features=out_features)
    if path is not None:
        model.load_state_dict(torch.load(path, map_location=device))

    return model.to(device)


def densenet_loader(pretrained=False, out_features=None, path=None):
    model = torchvision.models.densenet121(pretrained=pretrained)
    if out_features is not None:
        model.classifier = torch.nn.Linear(
            in_features=1024, out_features=out_features
        )
    if path is not None:
        model.load_state_dict(torch.load(path, map_location=device))

    return model.to(device)

In [None]:
dataset_train = datasets.ImageFolder(root=dirs['train'], transform=transform['train'])
val_set = datasets.ImageFolder(root=dirs['val'], transform=transform['val'])

class_tot_count = torch.as_tensor(dataset_train.targets).bincount()
weight = 1 / class_tot_count
samples_weight = weight[dataset_train.targets]
sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)

In [None]:
train_dl = DataLoader(dataset_train, batch_size=40, sampler=sampler, num_workers=0, worker_init_fn=worker_funtion)
valid_dl = DataLoader(val_set, batch_size=40)

epochs = 25
criterion = nn.CrossEntropyLoss()

# To Train Remove Comments. Approx training time 6 hours.

In [None]:
# vgg16 = vgg_loader(pretrained=True, out_features=2)
# fit_model(
#     epochs=epochs,
#     model=vgg16,
#     criterion=criterion,
#     optimizer=optim.Adam(vgg16.parameters(), lr=3e-5),
#     train_dl=train_dl,
#     valid_dl=valid_dl
# )

# resnet18 = resnet_loader(pretrained=True, out_features=2)
# fit_model(
#     epochs=epochs,
#     model=resnet18,
#     criterion=criterion,
#     optimizer=optim.Adam(resnet18.parameters(), lr=3e-5),
#     train_dl=train_dl,
#     valid_dl=valid_dl
# )

# densenet121 = densenet_loader(pretrained=True, out_features=2)
# fit_model(
#     epochs=epochs,
#     model=densenet121,
#     criterion=criterion,
#     optimizer=optim.Adam(densenet121.parameters(), lr=3e-5),
#     train_dl=train_dl,
#     valid_dl=valid_dl
# )

In [None]:
resnet18 = resnet_loader(out_features=2, path='resnet.pth')
vgg16 = vgg_loader(out_features=2, path='vgg.pth')
densenet121 = densenet_loader(out_features=2, path='densenet.pth')

In [None]:
dataset_train = datasets.ImageFolder(root=dirs['train'], transform=transform['test'])
dataset_test = datasets.ImageFolder(root=dirs['test'], transform=transform['test'])
train_dl = DataLoader(dataset_train, batch_size=128)
test_dl = DataLoader(dataset_test, batch_size=120)

# Generates GradCAM for all Test Images. Could be used as a dataset.

In [None]:
class grad_cl_ac_map:
    def __init__(self, model, target_layer):
        self.model = model.eval()
        self.fmaps = []
        self.grads = []

        target_layer.register_forward_hook(self.fl_fmap)
        target_layer.register_backward_hook(self.grad_vals)

    def fl_fmap(self, module, input, output):
        self.fmaps.append(output)

    def grad_vals(self, module, grad_input, grad_output):
        self.grads.append(grad_output[0])

    def get_cam_weights(self, grads):
        return np.mean(grads, axis=(1, 2))

    def __call__(self, image, label=None):
        preds = self.model(image)
        self.model.zero_grad()

        if label is None:
            label = preds.argmax(dim=1).item()

        preds[:, label].backward()

        fmaps = self.fmaps[-1].cpu().data.numpy()[0, :]
        grads = self.grads[-1].cpu().data.numpy()[0, :]

        cam_weights = self.get_cam_weights(grads)
        cam = np.zeros(fmaps.shape[1:], dtype=np.float32)

        for i, w in enumerate(cam_weights):
            cam += w * fmaps[i]

        cam = np.maximum(cam, 0)
        cam = cv2.resize(cam, image.shape[-2:][::-1])
        cam = cam - np.min(cam)
        cam = cam / np.max(cam)
        return cam

for i in range(0,600):
    image, label = dataset_test[i]
    image = image.unsqueeze(dim=0).to(device)

    gcam = grad_cl_ac_map(model=vgg16, target_layer=vgg16.features[-1])
    vgg_map = gcam(image, label)

    gcam = grad_cl_ac_map(model=resnet18, target_layer=resnet18.layer4[-1])
    resnet_map = gcam(image, label)

    gcam = grad_cl_ac_map(model=densenet121, target_layer=densenet121.features[-1])
    densenet_map = gcam(image, label)


    if label == 0:
        c = "COVID"
    else:
        c = "Normal"

    c = c + "_" + str(i) + "_"

    cam_data_generator(image, vgg_map, resnet_map, densenet_map, c)