In [None]:
import torch
import torch.nn.functional as F
import time
import logging
import copy
from sklearn.metrics import roc_auc_score, confusion_matrix

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class GradCAM:
    def __init__(self, model, candidate_layers=None):
        self.model = model
        self.model.eval()
        self.device = next(model.parameters()).device
        self.handlers = []
        self.fmap_pool = {}
        self.grad_pool = {}
        self.candidate_layers = candidate_layers  # list

        def save_fmaps(key):
            def forward_hook(module, input, output):
                self.fmap_pool[key] = output.detach()

            return forward_hook

        def save_grads(key):
            def backward_hook(module, grad_in, grad_out):
                self.grad_pool[key] = grad_out[0].detach()

            return backward_hook

        for name, module in self.model.named_modules():
            if self.candidate_layers is None or name in self.candidate_layers:
                self.handlers.append(module.register_forward_hook(save_fmaps(name)))
                self.handlers.append(module.register_backward_hook(save_grads(name)))

    def _encode_one_hot(self, ids):
        one_hot = torch.zeros_like(self.nll).to(self.device)
        one_hot.scatter_(1, ids, 1.0)
        return one_hot

    def forward(self, image1, image2, image3):
        self.image_shape = image1.shape[2:]
        self.nll = self.model(image1, image2, image3)
        self.prob = F.softmax(self.nll, dim=1)

    def backward(self, ids):
        """
        Class-specific backpropagation
        """
        one_hot = self._encode_one_hot(ids)
        self.model.zero_grad()
        self.nll.backward(gradient=one_hot, retain_graph=True)

    def remove_hook(self):
        """
        Remove all the forward/backward hook functions
        """
        for handle in self.handlers:
            handle.remove()

    def __call__(self, image1, image2, image3, ids=None):
        """
        Generate Grad-CAM
        Args:
            image (torch.tensor): shape => (1, 3, H, W)
            ids (int): class id to generate cam
        Return:
            mask: heatmap of the same spatial dimension with image
            logit: model output
        """
        # If not specificing, find the class index with the highest score.
        if ids == None:
            ids = torch.argmax(self.prob[0]).unsqueeze(dim=0)

        self.forward(image1, image2, image3)
        self.backward(ids)

        # Get hooked gradients.
        grads_val = self.grad_pool['bn2']
        fmap = self.fmap_pool['bn2']
        weights = F.adaptive_avg_pool2d(grads_val, 1)

        gcam = torch.mul(fmap, weights).sum(dim=1, keepdim=True)
        gcam = F.relu(gcam)
        gcam = F.interpolate(
            gcam, self.image_shape, mode='bilinear', align_corners=False)
        
        # normalize
        B, C, H, W = gcam.shape
        gcam = gcam.view(B, -1)
        gcam -= gcam.min(dim=1, keepdim=True)[0]
        gcam /= gcam.max(dim=1, keepdim=True)[0]
        gcam = gcam.view(B, C, H, W)

        return gcam, ids
    
# train function
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_auc = 0.0
    best_epoch = 0
    best_loss = 0.0
    best_cm = None
    best_fpr = None
    best_tpr = None
    best_thresholds = None
    best_gcam = None
    best_gcam_id = None

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-'*10)
        logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
        logging.info('-'*10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            if phase == 'train':
                scheduler.step()
                model.train() # set model to training mode
            else:
                model.eval() # set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            running_auc = 0.0
            running_cm = None
            running_fpr = None
            running_tpr = None
            running_thresholds = None
            running_gcam = None
            running_gcam_id = None

            # iterate over data
            for inputs1, inputs2, inputs3, labels in dataloaders[phase]:
                inputs1 = inputs1.to(device)
                inputs2 = inputs2.to(device)
                inputs3 = inputs3.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase=='train'):
                    gcam, ids = grad_cam(inputs1, inputs2, inputs3)
                    outputs = model(inputs1, inputs2, inputs3)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # statistics
                running_loss += loss.item() * inputs1.size(0)
                running_corrects += torch.sum(preds == labels.data)
                running_auc += roc_auc_score(labels.data)
                running_cm += confusion_matrix(labels.data, preds)
                running_gcam += gcam
                running_gcam_id += ids