In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import copy
import pandas as pd 
import torch.nn.init as init

In [None]:
!git clone https://github.com/Ashish-Arya-CS/Coursera-Content.git

Cloning into 'Coursera-Content'...
remote: Enumerating objects: 3049, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 3049 (delta 1), reused 6 (delta 1), pack-reused 3043[K
Receiving objects: 100% (3049/3049), 79.25 MiB | 20.72 MiB/s, done.
Resolving deltas: 100% (1/1), done.
Checking out files: 100% (3287/3287), done.


In [None]:
TRAIN_ROOT = '/content/Coursera-Content/Brain-MRI/Training/'
TEST_ROOT = '/content/Coursera-Content/Brain-MRI/Testing/'

train_dataset = torchvision.datasets.ImageFolder(root=TRAIN_ROOT)
test_dataset = torchvision.datasets.ImageFolder(root=TEST_ROOT)

In [None]:
device = torch.device("cuda:0")

In [None]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.vgg16 = models.vgg16(pretrained=True) 

        # Replace output layer according to our problem
        in_feats = self.vgg16.classifier[6].in_features 
        self.vgg16.classifier[6] = nn.Linear(in_feats, 4)

    def forward(self, x):
        x = self.vgg16(x)
        return x
device = torch.device("cuda:0")

model = CNNModel()
model.to(device)
model

# %% Prepare data for pretrained model
train_dataset = torchvision.datasets.ImageFolder(
        root=TRAIN_ROOT,
        transform=transforms.Compose([
                      transforms.Resize((255,255)),
                      transforms.ToTensor()
        ])
)

test_dataset = torchvision.datasets.ImageFolder(
        root=TEST_ROOT,
        transform=transforms.Compose([
                      transforms.Resize((255,255)),
                      transforms.ToTensor()
        ])
)

#train_dataset[0][0].permute(1,2,0)

# %% Create data loaders
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True
)

# %% Train
cross_entropy_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
epochs = 10

# Iterate x epochs over the train data
for epoch in range(epochs):  
    for i, batch in enumerate(train_loader, 0):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        # Labels are automatically one-hot-encoded
        loss = cross_entropy_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        print(loss)

# %% Inspect predictions for first batch
import pandas as pd
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.numpy()
outputs = model(inputs).max(1).indices.detach().cpu().numpy()
comparison = pd.DataFrame()
print("Batch accuracy: ", (labels==outputs).sum()/len(labels))
comparison["labels"] = labels

comparison["outputs"] = outputs
comparison

# %% Layerwise relevance propagation for VGG16
# For other CNN architectures this code might become more complex
# Source: https://git.tu-berlin.de/gmontavon/lrp-tutorial
# http://iphome.hhi.de/samek/pdf/MonXAI19.pdf

def new_layer(layer, g):
    """Clone a layer and pass its parameters through the function g."""
    layer = copy.deepcopy(layer)
    try: layer.weight = torch.nn.Parameter(g(layer.weight))
    except AttributeError: pass
    try: layer.bias = torch.nn.Parameter(g(layer.bias))
    except AttributeError: pass
    return layer

def dense_to_conv(layers):
    """ Converts a dense layer to a conv layer """
    newlayers = []
    for i,layer in enumerate(layers):
        if isinstance(layer, nn.Linear):
            newlayer = None
            if i == 0:
                m, n = 512, layer.weight.shape[0]
                newlayer = nn.Conv2d(m,n,7)
                newlayer.weight = nn.Parameter(layer.weight.reshape(n,m,7,7))
            else:
                m,n = layer.weight.shape[1],layer.weight.shape[0]
                newlayer = nn.Conv2d(m,n,1)
                newlayer.weight = nn.Parameter(layer.weight.reshape(n,m,1,1))
            newlayer.bias = nn.Parameter(layer.bias)
            newlayers += [newlayer]
        else:
            newlayers += [layer]
    return newlayers

def get_linear_layer_indices(model):
    offset = len(model.vgg16._modules['features']) + 1
    indices = []
    for i, layer in enumerate(model.vgg16._modules['classifier']): 
        if isinstance(layer, nn.Linear): 
            indices.append(i)
    indices = [offset + val for val in indices]
    return indices

def apply_lrp_on_vgg16(model, image):
    image = torch.unsqueeze(image, 0)
    # >>> Step 1: Extract layers
    layers = list(model.vgg16._modules['features']) \
                + [model.vgg16._modules['avgpool']] \
                + dense_to_conv(list(model.vgg16._modules['classifier']))
    linear_layer_indices = get_linear_layer_indices(model)
    # >>> Step 2: Propagate image through layers and store activations
    n_layers = len(layers)
    activations = [image] + [None] * n_layers # list of activations
    
    for layer in range(n_layers):
        if layer in linear_layer_indices:
            if layer == 32:
                activations[layer] = activations[layer].reshape((1, 512, 7, 7))
        activation = layers[layer].forward(activations[layer])
        if isinstance(layers[layer], torch.nn.modules.pooling.AdaptiveAvgPool2d):
            activation = torch.flatten(activation, start_dim=1)
        activations[layer+1] = activation

    # >>> Step 3: Replace last layer with one-hot-encoding
    output_activation = activations[-1].detach().cpu().numpy()
    max_activation = output_activation.max()
    one_hot_output = [val if val == max_activation else 0 
                        for val in output_activation[0]]
    print(one_hot_output)
    activations[-1] = torch.FloatTensor([one_hot_output]).to(device)

    # >>> Step 4: Backpropagate relevance scores
    relevances = [None] * n_layers + [activations[-1]]
    # Iterate over the layers in reverse order
    for layer in range(0, n_layers)[::-1]:
        current = layers[layer]
        # Treat max pooling layers as avg pooling
        if isinstance(current, torch.nn.MaxPool2d):
            layers[layer] = torch.nn.AvgPool2d(2)
            current = layers[layer]
        if isinstance(current, torch.nn.Conv2d) or \
           isinstance(current, torch.nn.AvgPool2d) or\
           isinstance(current, torch.nn.Linear):
            activations[layer] = activations[layer].data.requires_grad_(True)
            
            # Apply variants of LRP depending on the depth
            # see: https://link.springer.com/chapter/10.1007%2F978-3-030-28954-6_10
            # Lower layers, LRP-gamma >> Favor positive contributions (activations)
            if layer <= 16:       rho = lambda p: p + 0.25*p.clamp(min=0); incr = lambda z: z+1e-9
            # Middle layers, LRP-epsilon >> Remove some noise / Only most salient factors survive
            if 17 <= layer <= 30: rho = lambda p: p;                       incr = lambda z: z+1e-9+0.25*((z**2).mean()**.5).data
            # Upper Layers, LRP-0 >> Basic rule
            if layer >= 31:       rho = lambda p: p;                       incr = lambda z: z+1e-9
            
            # Transform weights of layer and execute forward pass
            z = incr(new_layer(layers[layer],rho).forward(activations[layer]))
            # Element-wise division between relevance of the next layer and z
            s = (relevances[layer+1]/z).data                                     
            # Calculate the gradient and multiply it by the activation
            (z * s).sum().backward(); 
            c = activations[layer].grad       
            # Assign new relevance values           
            relevances[layer] = (activations[layer]*c).data                          
        else:
            relevances[layer] = relevances[layer+1]

    # >>> Potential Step 5: Apply different propagation rule for pixels
    return relevances[0]

# %%
# Calculate relevances for first image in this test batch


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

tensor(1.4258, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.4514, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3317, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3765, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3213, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3992, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3652, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2668, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3037, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2868, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2960, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2643, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3329, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.1776, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.1685, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2132, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2115, device='cuda:0', grad_fn=

In [None]:
class ResUnit(nn.Module):

    def __init__(self, inplanes, outplanes, stride=1):

        """
        Residual Unit
        """
        super(ResUnit, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, int(outplanes/4), kernel_size=1, bias=True)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv2 = nn.Conv2d(int(outplanes/4), int(outplanes/4), kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(int(outplanes/4))
        self.conv3 = nn.Conv2d(int(outplanes/4), outplanes, kernel_size=1, bias=True)
        self.bn3 = nn.BatchNorm2d(int(outplanes/4))
        self.relu = nn.ReLU(inplace=True)
        self.inplanes = inplanes
        self.outplanes = outplanes
        self.stride = stride
        self.make_downsample = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=1,stride=stride, bias=False)

    def forward(self, x):

        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        if (self.inplanes != self.outplanes) or (self.stride !=1 ):
            residual = self.make_downsample(residual)

        out += residual

        return out


In [None]:
class AttentionModule_stage1(nn.Module):

    def __init__(self, inplanes, outplanes, size1, size2, size3):
        """
        max_pooling layers are used in mask branch size with input
        """
        super(AttentionModule_stage1, self).__init__()
        self.size1 = size1
        self.size2 = size2
        self.size3 = size3
        self.Resblock1 = ResUnit(inplanes, outplanes) #first residual block

        self.trunkbranch = nn.Sequential(ResUnit(inplanes, outplanes),
                                         ResUnit(inplanes, outplanes))

        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.Resblock2 = ResUnit(inplanes, outplanes)
        self.skip1 = ResUnit(inplanes, outplanes)

        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.Resblock3 = ResUnit(inplanes, outplanes)
        self.skip2 = ResUnit(inplanes, outplanes)

        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.Resblock4 = nn.Sequential(ResUnit(inplanes, outplanes),
                                       ResUnit(inplanes, outplanes))


        #self.upsample3 = nn.UpsamplingBilinear2d(size=size3)
        self.Resblock5 = ResUnit(inplanes, outplanes)

        #self.upsample2 = nn.UpsamplingBilinear2d(size=size2)
        self.Resblock6 = ResUnit(inplanes, outplanes)

        #self.upsample1 = nn.UpsamplingBilinear2d(size=size1)

        self.output_block = nn.Sequential(nn.BatchNorm2d(outplanes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(outplanes, outplanes, kernel_size=1, stride=1, bias=False),
                                    nn.BatchNorm2d(outplanes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(outplanes, outplanes, kernel_size=1, stride=1, bias=False),
                                    nn.Sigmoid())

        self.last_block =  ResUnit(inplanes, outplanes)

    def forward(self, x):
        # The Number of pre-processing Residual Units Before
        # Splitting into trunk branch and mask branch is 1.
        # 48*48
        x = self.Resblock1(x)

        # The output of trunk branch
        out_trunk = self.trunkbranch(x)

        #soft Mask Branch
        #The Number of Residual Units between adjacent pooling layer is 1.
        pool1 = self.maxpool1(x) # (48,48) -> (24,24)
        out_softmask1 = self.Resblock2(pool1)
        skip_connection1 = self.skip1(pool1) #"skip_connection"

        pool2 = self.maxpool2(out_softmask1) #(24,24) -> (12,12)
        out_softmask2 = self.Resblock3(pool2)
        skip_connection2 = self.skip2(pool2)

        pool3 = self.maxpool3(out_softmask2) #(12,12) -> (6,6)
        out_softmask3 = self.Resblock4(pool3)

        out_interp3 = nn.functional.interpolate(out_softmask3, size=self.size3, mode= 'bilinear', align_corners=True) #(6,6)->(12,12)
        out = out_interp3 + skip_connection2
        out_softmask4 = self.Resblock5(out)

        out_interp2 = nn.functional.interpolate(out_softmask4, size=self.size2, mode= 'bilinear', align_corners=True) #(12,12)->(24,24)
        out = out_interp2 + skip_connection1
        out_softmask5 = self.Resblock6(out)

        out_interp1 = nn.functional.interpolate(out_softmask5, size=self.size1, mode= 'bilinear', align_corners=True) #(24,24)->(48,48)
        out_softmask6 = self.output_block(out_interp1)

        out = (1 + out_softmask6) * out_trunk
        last_out = self.last_block(out)

        return last_out

class AttentionModule_stage2(nn.Module):

    def __init__(self, inplanes, outplanes, size1, size2):
        """
        max_pooling layers are used in mask branch size with input
        """
        super(AttentionModule_stage2, self).__init__()
        self.size1 = size1
        self.size2 = size2
        self.Resblock1 = ResUnit(inplanes, outplanes) #first residual block

        self.trunkbranch = nn.Sequential(ResUnit(inplanes, outplanes),
                                         ResUnit(inplanes, outplanes))

        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # (24,24) -> (12,12)
        self.Resblock2 = ResUnit(inplanes, outplanes)
        self.skip1 = ResUnit(inplanes, outplanes)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # (12,12) -> (6,6)
        self.Resblock3 = nn.Sequential(ResUnit(inplanes, outplanes),
                                       ResUnit(inplanes, outplanes))
        #self.upsample2 = nn.UpsamplingBilinear2d(size2)
        self.Resblock4 = ResUnit(inplanes, outplanes)
        #self.upsample1 = nn.UpsamplingBilinear2d(size1)

        self.output_block = nn.Sequential(nn.BatchNorm2d(outplanes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(outplanes, outplanes, kernel_size=1, stride=1, bias=False),
                                    nn.BatchNorm2d(outplanes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(outplanes, outplanes, kernel_size=1, stride=1, bias=False),
                                    nn.Sigmoid())

        self.last_block =  ResUnit(inplanes, outplanes)

    def forward(self, x):
        # The Number of pre-processing Residual Units Before
        # Splitting into trunk branch and mask branch is 1.
        # 48*48
        x = self.Resblock1(x)

        # The output of trunk branch
        out_trunk = self.trunkbranch(x)

        #soft Mask Branch
        #The Number of Residual Units between adjacent pooling layer is 1.
        pool1 = self.maxpool1(x) # (24,24) -> (12,12)
        out_softmask1 = self.Resblock2(pool1)
        skip_connection1 = self.skip1(pool1) #"skip_connection"

        pool2 = self.maxpool2(out_softmask1) #(12,12) -> (6,6)
        out_softmask2 = self.Resblock3(pool2)

        out_interp2 = nn.functional.interpolate(out_softmask2, size=self.size2, mode= 'bilinear', align_corners=True) #(6,6) ->(12,12)
        out = out_interp2 + skip_connection1
        out_softmask3 = self.Resblock4(out)
        out_interp1 = nn.functional.interpolate(out_softmask3, size=self.size1, mode= 'bilinear', align_corners=True) #(24,24)
        out_softmask4 = self.output_block(out_interp1)

        out = (1 + out_softmask4) * out_trunk
        last_out = self.last_block(out)

        return last_out

class AttentionModule_stage3(nn.Module):

    def __init__(self, inplanes, outplanes, size1):
        """
        max_pooling layers are used in mask branch size with input
        """
        super(AttentionModule_stage3, self).__init__()
        self.size1 = size1
        self.Resblock1 = ResUnit(inplanes, outplanes) #first residual block

        self.trunkbranch = nn.Sequential(ResUnit(inplanes, outplanes),
                                         ResUnit(inplanes, outplanes))

        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # (12,12) -> (6,6)
        self.Resblock2 = nn.Sequential(ResUnit(inplanes, outplanes),
                                       ResUnit(inplanes, outplanes))
        #self.upsample1 = nn.UpsamplingBilinear2d(size1)

        self.output_block = nn.Sequential(nn.BatchNorm2d(outplanes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(outplanes, outplanes, kernel_size=1, stride=1, bias=False),
                                    nn.BatchNorm2d(outplanes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(outplanes, outplanes, kernel_size=1, stride=1, bias=False),
                                    nn.Sigmoid())

        self.last_block =  ResUnit(inplanes, outplanes)

    def forward(self, x):
        # The Number of pre-processing Residual Units Before
        # Splitting into trunk branch and mask branch is 1.
        x = self.Resblock1(x)

        # The output of trunk branch
        out_trunk = self.trunkbranch(x)

        #soft Mask Branch
        #The Number of Residual Units between adjacent pooling layer is 1.
        pool1 = self.maxpool1(x) # (12,12) -> (6,6)
        out_softmask1 = self.Resblock2(pool1)
        out_interp1 = nn.functional.interpolate(out_softmask1, size=self.size1, mode= 'bilinear', align_corners=True) #(6,6) ->(12,12)
        out_softmask2 = self.output_block(out_interp1)
        out = (1 + out_softmask2) * out_trunk
        last_out = self.last_block(out)

        return last_out

In [None]:
train_dataset = torchvision.datasets.ImageFolder(
        root=TRAIN_ROOT,
        transform=transforms.Compose([
                      transforms.Resize((96,96)),
                      transforms.ToTensor()
        ])
)

test_dataset = torchvision.datasets.ImageFolder(
        root=TEST_ROOT,
        transform=transforms.Compose([
                      transforms.Resize((96,96)),
                      transforms.ToTensor()
        ])
)

#train_dataset[0][0].permute(1,2,0)

# %% Create data loaders
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True
)

In [None]:
class ResidualAttentionNetwork(nn.Module):

    def __init__(self):
        super(ResidualAttentionNetwork, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=16,
                                             kernel_size=3, stride=1, padding=1, bias=False),
                                   nn.BatchNorm2d(16),
                                   nn.ReLU(inplace=True))
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.Resblock1 = ResUnit(16, 64)
        self.attention_module1 = AttentionModule_stage1(64, 64, size1=(48,48), size2=(24,24), size3=(12,12)) #(48,48)
        self.attention_module2 = AttentionModule_stage1(64, 64, size1=(48,48), size2=(24,24), size3=(12,12)) 
        self.Resblock2 = ResUnit(64, 128, 2) #(24, 24)
        self.attention_module3 = AttentionModule_stage2(128, 128, size1=(24,24), size2=(12,12))
        self.attention_module4 = AttentionModule_stage2(128, 128, size1=(24,24), size2=(12,12))
        self.Resblock3 = ResUnit(128, 256, 2)
        self.attention_module5 = AttentionModule_stage3(256, 256, size1=(12,12))
        self.attention_module6 = AttentionModule_stage3(256, 256, size1=(12,12))
        self.Resblock4 = nn.Sequential(ResUnit(256, 512, 2),
                                       ResUnit(512, 512),
                                       ResUnit(512, 512))
        self.Avergepool = nn.Sequential(
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=6, stride=1)
            )

        self.fc1 = nn.Linear(512, 4)
        #self.pred = nn.Sigmoid()

    def forward(self, x):

        x = self.conv1(x) # (96,96)
        #print(x.shape)
        x = self.maxpool1(x) #
        x = self.Resblock1(x)
        x = self.attention_module1(x)
        x = self.attention_module2(x)
        x = self.Resblock2(x)
        x = self.attention_module3(x)
        x = self.attention_module4(x)
        x = self.Resblock3(x)
        x = self.attention_module5(x)
        x = self.attention_module6(x)
        x = self.Resblock4(x)
        x = self.Avergepool(x)
        #print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        #print(x.shape)
        #x = self.pred(x)
        #x = x.squeeze()

        return x

In [None]:
def weight_init(m):
    if isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    if isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    if isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)

In [None]:
model = ResidualAttentionNetwork()
model.cuda()
model.apply(weight_init)
cross_entropy_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
epochs = 10

In [None]:
# Iterate x epochs over the train data
for epoch in range(epochs):  
    for i, batch in enumerate(train_loader, 0):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        # Labels are automatically one-hot-encoded
        loss = cross_entropy_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        print(epoch,i,loss)

0 0 tensor(2.3488, device='cuda:0', grad_fn=<NllLossBackward0>)
0 1 tensor(1.6703, device='cuda:0', grad_fn=<NllLossBackward0>)
0 2 tensor(1.9890, device='cuda:0', grad_fn=<NllLossBackward0>)
0 3 tensor(1.6983, device='cuda:0', grad_fn=<NllLossBackward0>)
0 4 tensor(1.2592, device='cuda:0', grad_fn=<NllLossBackward0>)
0 5 tensor(1.3517, device='cuda:0', grad_fn=<NllLossBackward0>)
0 6 tensor(1.0038, device='cuda:0', grad_fn=<NllLossBackward0>)
0 7 tensor(0.9811, device='cuda:0', grad_fn=<NllLossBackward0>)
0 8 tensor(1.0988, device='cuda:0', grad_fn=<NllLossBackward0>)
0 9 tensor(1.1568, device='cuda:0', grad_fn=<NllLossBackward0>)
0 10 tensor(1.0287, device='cuda:0', grad_fn=<NllLossBackward0>)
0 11 tensor(1.2046, device='cuda:0', grad_fn=<NllLossBackward0>)
0 12 tensor(1.1762, device='cuda:0', grad_fn=<NllLossBackward0>)
0 13 tensor(0.9319, device='cuda:0', grad_fn=<NllLossBackward0>)
0 14 tensor(0.7525, device='cuda:0', grad_fn=<NllLossBackward0>)
0 15 tensor(0.9290, device='cuda:0'

In [None]:
class GradCam:
    def __init__(self, model):
        self.model = model.eval()
        self.feature = None
        self.gradient = None

    def save_gradient(self, grad):
        self.gradient = grad

    def __call__(self, x):
        image_size = (x.size(-1), x.size(-2))
        feature_maps = []
        
        for i in range(x.size(0)):
            img = x[i].data.cpu().numpy()
            img = img - np.min(img)
            if np.max(img) != 0:
                img = img / np.max(img)

            feature = x[i].unsqueeze(0)
            
            for name, module in self.model.named_children():
                if name == 'fc1':
                    feature = feature.view(feature.size(0), -1)
                feature = module(feature)
                if name == 'attention_module6':
                    feature.register_hook(self.save_gradient)
                    self.feature = feature
                    
            classes = F.sigmoid(feature)
            one_hot, _ = classes.max(dim=-1)
            self.model.zero_grad()
            classes.backward()

            weight = self.gradient.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
            
            mask = F.relu((weight * self.feature).sum(dim=1)).squeeze(0)
            mask = cv2.resize(mask.data.cpu().numpy(), image_size)
            mask = mask - np.min(mask)
            
            if np.max(mask) != 0:
                mask = mask / np.max(mask)
                
            feature_map = np.float32(cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET))
            cam = feature_map + np.float32((np.uint8(img.transpose((1, 2, 0)) * 255)))
            cam = cam - np.min(cam)
            
            if np.max(cam) != 0:
                cam = cam / np.max(cam)
                
            feature_maps.append(transforms.ToTensor()(cv2.cvtColor(np.uint8(255 * cam), cv2.COLOR_BGR2RGB)))
            
        feature_maps = torch.stack(feature_maps)
        
        return feature_maps