#Trash

In [None]:
!pip install nibabel 

In [None]:
class Config:

    def __init__(self):
        ################################################################
        # Definitions required for CNN graph
        ################################################################
        # Filter size at different depth level of CNN in order
        self.fs = 3
        # Interpolation type for upsampling layers in decoder
        self.interp_val = 1  # 0 - bilinear interpolation; 1- nearest neighbour interpolation
        ################################################################

        ################################################################
        # data dimensions, num of classes and resolution
        ################################################################
        # Name of dataset
        self.dataset_name = 'acdc'
        # Image Dimensions
        self.img_size_x = 192
        self.img_size_y = 192
        # Images dimensions in one-dimensional array
        self.img_size_flat = self.img_size_x * self.img_size_y
        # Number of colour channels for the images: 1 channel for gray-scale.
        self.num_channels = 1
        # Number of label classes : # 0-background, 1-rv, 2-myo, 3-lv
        self.num_classes = 4
        # Image dimensions in x and y directions
        self.size = (self.img_size_x, self.img_size_y)
        # target image resolution
        self.target_resolution = (1.36719, 1.36719)
        # label class name
        self.class_name = 'rv'
        # class_name='lv'
        ################################################################
        # data paths
        ################################################################
        # validation_update_step to save values
        self.val_step_update = 50
        # base directory of the code
        self.base_dir = '/usr/bmicnas01/data-biwi-01/krishnch/projects/self_tr/contrastive_lr/git_may24/'
        self.srt_dir = '/usr/bmicnas01/data-biwi-01/krishnch/projects/self_tr/contrastive_lr/git_may24/'

        # Path to data in original dimensions in default resolution
        self.data_path_tr = '/content/drive/MyDrive/Data/ACDC/training_ACDC/patient00'

        # Path to data in cropped dimensions in target resolution (saved apriori)
        self.data_path_tr_cropped = '/usr/bmicnas01/data-biwi-01/krishnch/datasets/heart_acdc/acdc_bias_corr_cropped/patient'
        ################################################################

        ################################################################
        # training hyper-parameters
        ################################################################
        # learning rate for segmentation net
        self.lr = 0.001
        # pre-training batch size
        self.mtask_bs = 20
        # batch_size for fine-tuning on segmentation task
        self.batch_size_ft = 10
        # foreground structures names to segment
        self.struct_name = ['rv', 'myo', 'lv']


class dataloaderObj:

    # define functions to load data from acdc/prostate/mmwhs dataset
    def __init__(self, cfg):
        # print('dataloaders init')
        self.data_path_tr = cfg.data_path_tr
        self.data_path_tr_cropped = cfg.data_path_tr_cropped
        self.target_resolution = cfg.target_resolution
        self.dataset_name = cfg.dataset_name
        self.size = cfg.size
        self.num_classes = cfg.num_classes
        self.one_label = 0

    def normalize_minmax_data(self, image_data, min_val=1, max_val=99):
        """
        # 3D MRI scan is normalized to range between 0 and 1 using min-max normalization.
        Here, the minimum and maximum values are used as 1st and 99th percentiles respectively from the 3D MRI scan.
        We expect the outliers to be away from the range of [0,1].
        input params :
            image_data : 3D MRI scan to be normalized using min-max normalization
            min_val : minimum value percentile
            max_val : maximum value percentile
        returns:
            final_image_data : Normalized 3D MRI scan obtained via min-max normalization.
        """
        min_val_1p = np.percentile(image_data, min_val)
        max_val_99p = np.percentile(image_data, max_val)
        final_image_data = np.zeros(
            (image_data.shape[0], image_data.shape[1], image_data.shape[2]), dtype=np.float64)
        # min-max norm on total 3D volume
        final_image_data = (image_data-min_val_1p)/(max_val_99p-min_val_1p)
        return final_image_data

    def load_acdc_imgs(self, study_id_list, ret_affine=0, label_present=1):
        """
        #Load ACDC image and its label with pixel dimensions
        input params :
            study_id_list: subject id number of the image to be loaded
            ret_affine: to enable returning of affine transformation matrix of the loaded image
            label_present : to enable loading of 3D mask if the label is present or not (0 is used for unlabeled images)
        returns :
            image_data_test_sys : normalized 3D image
            label_data_test_sys : 3D label mask of the image
            pixel_size : pixel dimensions of the loaded image
            affine_tst : affine transformation matrix of the loaded image
        """
        for study_id in study_id_list:
            # print("study_id",study_id)
            path_files = str(self.data_path_tr)+str(study_id)+'/'
            # print(path_files)
            systole_lstfiles = []  # create an empty list
            for dirName, subdirList, fileList in os.walk(path_files):
                fileList.sort()
                # print(dirName,subdirList,fileList)
                for filename in fileList:
                    # print(filename)
                    if "_frame01" in filename.lower():
                        systole_lstfiles.append(
                            os.path.join(dirName, filename))
                    elif "_frame04" in filename.lower():
                        systole_lstfiles.append(
                            os.path.join(dirName, filename))
        print(systole_lstfiles)
        # Load the 3D image
        image_data_test_load = nib.load(systole_lstfiles[0])
        image_data_test_sys = image_data_test_load.get_data()
        pixel_size = image_data_test_load.header['pixdim'][1:4]
        affine_tst = image_data_test_load.affine

        # Normalize input data
        image_data_test_sys = self.normalize_minmax_data(image_data_test_sys)

        if (label_present == 1):
            # Load the segmentation mask
            label_data_test_load = nib.load(systole_lstfiles[1])
            label_data_test_sys = label_data_test_load.get_data()

        if (label_present == 0):
            if (ret_affine == 0):
                return image_data_test_sys, pixel_size
            else:
                return image_data_test_sys, pixel_size, affine_tst
        else:
            if (ret_affine == 0):
                return image_data_test_sys, label_data_test_sys, pixel_size
            else:
                return image_data_test_sys, label_data_test_sys, pixel_size, affine_tst

    def load_mmwhs_imgs(self, study_id_list, ret_affine=0, label_present=1):
        """
        #Load MMWHS image and its label with pixel dimensions
        input params :
            study_id_list: subject id number of the image to be loaded
            ret_affine: to enable returning of affine transformation matrix of the loaded image
            label_present : to enable loading of 3D mask if the label is present or not (0 is used for unlabeled images)
        returns :
            image_data_test_sys : normalized 3D image
            label_data_test_sys : 3D label mask of the image
            pixel_size : pixel dimensions of the loaded image
            affine_tst : affine transformation matrix of the loaded image
        """
        for study_id in study_id_list:
            img_path = str(self.data_path_tr)+str(study_id)+'/img.nii.gz'
            seg_path = str(self.data_path_tr)+str(study_id)+'/seg.nii.gz'

        # Load the 3D image
        image_data_test_load = nib.load(img_path)
        image_data_test_sys = image_data_test_load.get_data()
        pixel_size = image_data_test_load.header['pixdim'][1:4]
        affine_tst = image_data_test_load.affine

        # Normalize input data
        image_data_test_sys = self.normalize_minmax_data(image_data_test_sys)
        if (label_present == 1):
            # Load the segmentation mask
            label_data_test_load = nib.load(seg_path)
            label_data_test_sys = label_data_test_load.get_data()

        if (label_present == 0):
            if (ret_affine == 0):
                return image_data_test_sys, pixel_size
            else:
                return image_data_test_sys, pixel_size, affine_tst
        else:
            if (ret_affine == 0):
                return image_data_test_sys, label_data_test_sys, pixel_size
            else:
                return image_data_test_sys, label_data_test_sys, pixel_size, affine_tst

    def load_prostate_imgs_md(self, study_id_list, ret_affine=0, label_present=1):
        """
        #Load Prostate MD image and its label with pixel dimensions
        input params :
            study_id_list: subject id number of the image to be loaded
            ret_affine: to enable returning of affine transformation matrix of the loaded image
            label_present : to enable loading of 3D mask if the label is present or not (0 is used for unlabeled images)
        returns :
            image_data_test_sys : normalized 3D image
            label_data_test_sys : 3D label mask of the image
            pixel_size : pixel dimensions of the loaded image
            affine_tst : affine transformation matrix of the loaded image
        """

        # Load Prostate data images and its labels with pixel dimensions
        print('PZ Decathlon')
        for study_id in study_id_list:
            img_path = str(self.data_path_tr)+str(study_id)+'/img.nii.gz'
            seg_path = str(self.data_path_tr)+str(study_id)+'/mask.nii.gz'

        # Load the 3D image
        image_data_test_load = nib.load(img_path)
        image_data_test_sys = image_data_test_load.get_data()
        pixel_size = image_data_test_load.header['pixdim'][1:4]
        affine_tst = image_data_test_load.affine
        image_data_test_sys = image_data_test_sys[:, :, :, 0]

        # Normalize input data
        image_data_test_sys = self.normalize_minmax_data(image_data_test_sys)

        if (label_present == 1):
            # Load the segmentation mask
            label_data_test_load = nib.load(seg_path)
            label_data_test_sys = label_data_test_load.get_data()

        if (label_present == 0):
            if (ret_affine == 0):
                return image_data_test_sys, pixel_size
            else:
                return image_data_test_sys, pixel_size, affine_tst
        else:
            if (ret_affine == 0):
                return image_data_test_sys, label_data_test_sys, pixel_size
            else:
                return image_data_test_sys, label_data_test_sys, pixel_size, affine_tst

    def crop_or_pad_slice_to_size_1hot(self, img_slice, nx, ny):
        """
        To crop the input 2D slice for the chosen dimensions in 1-hot encoding format
        input params :
            image_slice : 2D slice to be cropped (in 1-hot encoding format)
            nx : dimension in x
            ny : dimension in y
        returns:
            slice_cropped : cropped 2D slice
        """

        slice_cropped = np.zeros((nx, ny, self.num_classes))
        x, y, _ = img_slice.shape

        x_s = (x - nx) // 2
        y_s = (y - ny) // 2
        x_c = (nx - x) // 2
        y_c = (ny - y) // 2

        if x > nx and y > ny:
            slice_cropped = img_slice[x_s:x_s + nx, y_s:y_s + ny]
        else:
            slice_cropped = np.zeros((nx, ny, self.num_classes))
            if x <= nx and y > ny:
                slice_cropped[x_c:x_c + x, :] = img_slice[:, y_s:y_s + ny]
            elif x > nx and y <= ny:
                slice_cropped[:, y_c:y_c + y] = img_slice[x_s:x_s + nx, :]
            else:
                slice_cropped[x_c:x_c + x, y_c:y_c + y] = img_slice[:, :]

        return slice_cropped

    def crop_or_pad_slice_to_size(self, img_slice, nx, ny):
        """
        To crop the input 2D slice for the chosen dimensions
        input params :
            image_slice : 2D slice to be cropped
            nx : dimension in x
            ny : dimension in y
        returns:
            slice_cropped : cropped 2D slice
        """
        slice_cropped = np.zeros((nx, ny))
        x, y = img_slice.shape

        x_s = (x - nx) // 2
        y_s = (y - ny) // 2
        x_c = (nx - x) // 2
        y_c = (ny - y) // 2

        if x > nx and y > ny:
            slice_cropped = img_slice[x_s:x_s + nx, y_s:y_s + ny]
        else:
            slice_cropped = np.zeros((nx, ny))
            if x <= nx and y > ny:
                slice_cropped[x_c:x_c + x, :] = img_slice[:, y_s:y_s + ny]
            elif x > nx and y <= ny:
                slice_cropped[:, y_c:y_c + y] = img_slice[x_s:x_s + nx, :]
            else:
                slice_cropped[x_c:x_c + x, y_c:y_c + y] = img_slice[:, :]

        return slice_cropped

    def preprocess_data(self, img, mask, pixel_size, label_present=1):
        """
        To preprocess the input 3D volume into chosen target resolution and crop them into dimensions specified in the init_*dataset_name*.py file
        input params :
            img : input 3D image volume to be processed
            mask : corresponding 3D segmentation mask to be processed
            pixel_size : the native pixel size of the input image
            label_present : to indicate if the image has labels provided or not (used for unlabeled images)
        returns:
            cropped_img : processed and cropped 3D image
            cropped_mask : processed and cropped 3D segmentation mask
        """
        nx, ny = self.size

        # scale vector to rescale to the target resolution
        scale_vector = [pixel_size[0] / self.target_resolution[0],
                        pixel_size[1] / self.target_resolution[1]]

        for slice_no in range(img.shape[2]):

            slice_img = np.squeeze(img[:, :, slice_no])
            slice_rescaled = transform.rescale(slice_img,
                                               scale_vector,
                                               order=1,
                                               preserve_range=True,
                                               mode='constant')
            if (label_present == 1):
                slice_mask = np.squeeze(mask[:, :, slice_no])
                mask_rescaled = transform.rescale(slice_mask,
                                                  scale_vector,
                                                  order=0,
                                                  preserve_range=True,
                                                  mode='constant')

            slice_cropped = self.crop_or_pad_slice_to_size(
                slice_rescaled, nx, ny)
            if (label_present == 1):
                mask_cropped = self.crop_or_pad_slice_to_size(
                    mask_rescaled, nx, ny)

            if (slice_no == 0):
                cropped_img = np.reshape(slice_cropped, (nx, ny, 1))
                if (label_present == 1):
                    cropped_mask = np.reshape(mask_cropped, (nx, ny, 1))
            else:
                slice_cropped_tmp = np.reshape(slice_cropped, (nx, ny, 1))
                cropped_img = np.concatenate(
                    (cropped_img, slice_cropped_tmp), axis=2)
                if (label_present == 1):
                    mask_cropped_tmp = np.reshape(mask_cropped, (nx, ny, 1))
                    cropped_mask = np.concatenate(
                        (cropped_mask, mask_cropped_tmp), axis=2)

        if (label_present == 1):
            return cropped_img, cropped_mask
        else:
            return cropped_img

    # def load_acdc_cropped_img_labels(self, train_ids_list,label_present=1):
    #    """
    #    # Load the already created and stored a-priori ACDC image and its labels that are pre-processed: normalized and cropped to chosen dimensions
    #    input params :
    #        train_ids_list : patient ids of the image and label pairs to be loaded
    #        label_present : to indicate if the image has labels provided or not (0 is used for unlabeled images)
    #    returns:
    #        img_cat : stack of 3D images of all the patient id nos.
    #        mask_cat : corresponding stack of 3D segmentation masks of all the patient id nos.
    #    """
    #
    #    count=0
    #    for study_id in train_ids_list:
    #        #print("study_id",study_id)
    #        img_fname = str(self.data_path_tr_cropped)+str(study_id)+'/img_cropped.npy'
    #        img_tmp=np.load(img_fname)
    #        if(label_present==1):
    #            mask_fname = str(self.data_path_tr_cropped)+str(study_id)+'/mask_cropped.npy'
    #            mask_tmp=np.load(mask_fname)
    #
    #        if(count==0):
    #            img_cat=img_tmp
    #            if(label_present==1):
    #                mask_cat=mask_tmp
    #            count=1
    #        else:
    #            img_cat=np.concatenate((img_cat,img_tmp),axis=2)
    #            if(label_present==1):
    #                mask_cat=np.concatenate((mask_cat,mask_tmp),axis=2)
    #    if(label_present==1):
    #        return img_cat,mask_cat
    #    else:
    #        return img_cat

    def load_cropped_img_labels(self, train_ids_list, label_present=1):
        """
        # Load the already created and stored a-priori acdc/prostate/mmwhs image and its labels that are pre-processed: normalized and cropped to chosen dimensions
        input params :
            train_ids_list : patient ids of the image and label pairs to be loaded
            label_present : to indicate if the image has labels provided or not (used for unlabeled images)
        returns:
            img_cat : stack of 3D images of all the patient id nos.
            mask_cat : corresponding stack of 3D segmentation masks of all the patient id nos.
        """
        count = 0
        for study_id in train_ids_list:

            # Load the 3D image
            img_fname = str(self.data_path_tr_cropped) + \
                str(study_id)+'/img_cropped.nii.gz'
            img_tmp_load = nib.load(img_fname)
            img_tmp = img_tmp_load.get_data()

            # load the mask if label is present
            if (label_present == 1):
                # Load the segmentation mask
                mask_fname = str(self.data_path_tr_cropped) + \
                    str(study_id)+'/mask_cropped.nii.gz'
                mask_tmp_load = nib.load(mask_fname)
                mask_tmp = mask_tmp_load.get_data()

            if (count == 0):
                img_cat = img_tmp
                if (label_present == 1):
                    mask_cat = mask_tmp
                count = 1
            else:
                img_cat = np.concatenate((img_cat, img_tmp), axis=2)
                if (label_present == 1):
                    mask_cat = np.concatenate((mask_cat, mask_tmp), axis=2)

        if (label_present == 1):
            return img_cat, mask_cat
        else:
            return img_cat


In [None]:
cfg = Config()
loader = dataloaderObj(cfg)
a,b,c = loader.load_acdc_imgs([1])
print(a.shape)
print(b.shape)
#for i in range(10):
    #cv2_imshow(a[:,:,i]*256)
    #cv2_imshow(b[:,:,i]*256)

a,b = loader.preprocess_data(a,b,c)

In [None]:
print(a.shape, b.shape)

for i in range(10):
    cv2_imshow(a[:,:,i]*256)
    cv2_imshow(b[:,:,i]*256)

#print(c)

In [None]:
path = '/content/drive/MyDrive/Data/ACDC/training_ACDC/patient001/patient001_frame01_gt.nii.gz'

img = nib.load(path).get_data()
img = np.array(img, dtype=float)/1.0
print(img.shape)

#x = transform.rescale(img[:,:,4], [192, 192], order=1, preserve_range=True, mode = 'constant')
x = img[:,:,5]
x = cv2.resize(x, (192,192),cv2.INTER_AREA)
cv2_imshow(x*255.0)

# Main Code

In [None]:
#!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
#import segmentation_models_pytorch as smp

In [None]:
import torch
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np
from tqdm import tqdm_notebook as tqdm
import cv2
#from sklearn import metrics
#from sklearn.metrics import jaccard_score as js
from sklearn.metrics import roc_curve as rc
from sklearn.metrics import auc as auc
#from PIL import Image, ImageOps
#from torch.autograd import Variable as v
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
%matplotlib inline
#import pandas as pd
import os
from google.colab.patches import cv2_imshow
import pickle
import nibabel as nib
import pathlib
from skimage import transform
from sklearn.manifold import TSNE
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
def loader(img_path, mask_path, noise=0):
    img = nib.load(img_path).get_data()
    mask = nib.load(mask_path).get_data()
    #print(img.shape, mask.shape)
    img, mask = np.array(img), np.array(mask)
    I, M = [], []
    for i in range(0,4):
        im, ms = img[:,:,i][30:186,20:196], mask[:,:,i][30:186,20:196]
        im, ms = cv2.resize(im, (128,128),cv2.INTER_AREA), cv2.resize(ms, (128,128),cv2.INTER_AREA)
        I.append(im)
        M.append(ms)
    img = np.array(I, dtype=float)/255.0
    mask = np.array(M)
    img = img + noise*np.random.rand(img.shape[0], img.shape[1], img.shape[2])
    #mask = mask[np.newaxis,:,:]
    
    mask[mask >= 0.5] = 1
    mask[mask < 0.5] = 0
    
    return img, mask

def read_dataset(root_path):
    images = []
    labels = []

    for image_name in sorted(os.listdir(root_path)):
        image_path = os.path.join(root_path, image_name) #.split('.')[0] + '.jpg')
        if os.path.isfile(image_path+'/'+image_name+'_frame01.nii.gz'):
            images.append(image_path+'/'+image_name+'_frame01.nii.gz')
            labels.append(image_path+'/'+image_name+'_frame01_gt.nii.gz')
        else:
            images.append(image_path+'/'+image_name+'_frame04.nii.gz')
            labels.append(image_path+'/'+image_name+'_frame04_gt.nii.gz')
        #images.append(image_path+'/'+image_name+'_frame01.nii.gz')
        #labels.append(image_path+'/'+image_name+'_frame01_gt.nii.gz')
    return images, labels

class Dataset(Dataset):

    def __init__(self, root_path, noise=0):
        self.root = root_path
        self.images, self.labels = read_dataset(self.root)
        self.noise = noise
        print('Num Images:', len(self.images), 'Num Labes:', len(self.labels))
        print('images: ', self.images)
        print('labels: ', self.labels)

    def __getitem__(self, index):
        #print(self.images[index], self.labels[index])
        img, mask = loader(self.images[index], self.labels[index], self.noise)
        img = torch.tensor(img, dtype = torch.float32)
        mask = torch.tensor(mask)
        return img, mask

    def __len__(self):
        assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
        return len(self.images)

In [None]:
img, mask = loader('/content/drive/MyDrive/Data/ACDC/training_ACDC/patient056/patient056_frame01.nii.gz', 
                   '/content/drive/MyDrive/Data/ACDC/training_ACDC/patient056/patient056_frame01_gt.nii.gz')
print(img.shape,mask.shape)
img, mask = img[:,30:186,20:196], mask[:,30:186,20:196]
img, mask = cv2.resize(img[1], (128,128),cv2.INTER_AREA), cv2.resize(mask[1], (128,128),cv2.INTER_AREA)
print(img.shape,mask.shape)
cv2_imshow(img * 255.0)
cv2_imshow(mask * 255.0)

In [None]:
#U-Net Accessories
class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        
        
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=True)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=False)
        

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=False),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class BatchNorm(nn.Module):
  def init(self, out_channels):
    super(BatchNorm, self).init()
    #self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3)
    self.bn = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU()

  def forward(self, x):
    #x = self.conv(x)
    x = self.bn(x)
    x = self.relu(x)
    return x



In [None]:
class UNet(nn.Module):
    def __init__(self, bilinear=False):
        super(UNet, self).__init__()
        self.bilinear = bilinear
        factor = 2 if bilinear else 1

        self.act = nn.Sigmoid()

        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // factor)

        self.up4 = Up(1024, 512 // factor, bilinear)
        self.up3 = Up(512, 256 // factor, bilinear)
        self.up2 = Up(256, 128, bilinear) 
        self.up1 = Up(128, 64, bilinear)

        self.inc = DoubleConv(3, 64)
        self.outconv = OutConv(64, 3)

    def forward(self, x, factor=0.001, visual = False):


        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        
        xb = self.down4(x4)

        y4 = self.up4(xb, x4)
        y3 = self.up3(y4, x3)
        y2 = self.up2(y3, x2)
        y1 = self.up1(y2, x1)
        
        out = self.outconv(y1)
        out = self.act(out)
        return out

class Encoder_normie(nn.Module):
    def __init__(self, bilinear=False):
        super(Encoder_normie, self).__init__()
        self.bilinear = bilinear
        factor = 2 if bilinear else 1
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // factor)
        self.inc = DoubleConv(1, 64)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        xb = self.down4(x4)
        return xb, x4, x3, x2, x1

class Block(nn.Module):
    def __init__(self, in_channels=64):
        super(Block, self).__init__()
        self.conv1 = BasicConv2d(in_channels, in_channels, 3,1,1)
        self.conv2 = BasicConv2d(in_channels, in_channels, 3,1,1)
    
    def forward(self, x):
        xn = self.conv1(x)
        xn = self.conv2(xn)
        x = x + xn
        return x

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        l1 = []
        l2 = []
        l3 = []
        l4 = []
        self.inc = BasicConv2d(1, 64, 3, 1, 1)
        self.out = BasicConv2d(1024, 1024, 1)
        self.mp = nn.MaxPool2d(2)
        for i in range(3):
            l1.append(Block(64))
        self.lm1 = nn.Conv2d(64,128,1)
        for i in range(4):
            l2.append(Block(128))
        self.lm2 = nn.Conv2d(128,256,1)
        for i in range(6):
            l3.append(Block(256))
        self.lm3 = nn.Conv2d(256,512,1)
        for i in range(3):
            l4.append(Block(512))
        self.lm4 = nn.Conv2d(512,1024,1)
        self.layer1 = nn.Sequential(*l1)
        self.layer2 = nn.Sequential(*l2)
        self.layer3 = nn.Sequential(*l3)
        self.layer4 = nn.Sequential(*l4)
    
    def forward(self, x):
        x = self.inc(x)
        x = self.layer1(x)
        x = self.lm1(x)
        x = self.mp(x)

        x = self.layer2(x)
        x = self.lm2(x)
        x = self.mp(x)
        
        x = self.layer3(x)
        x = self.lm3(x)
        x = self.mp(x)
        
        x = self.layer4(x)
        x = self.lm4(x)
        x = self.mp(x)
        
        x = self.out(x)
        return x

class Decoder(nn.Module):
    def __init__(self, bilinear=False):
        super(Decoder, self).__init__()
        self.bilinear = bilinear
        factor = 2 if bilinear else 1
        self.up4 = Up(1024, 512 // factor, bilinear)
        self.up3 = Up(512, 256 // factor, bilinear)
        self.up2 = Up(256, 128, bilinear) 
        self.up1 = Up(128, 64, bilinear)
        self.act = nn.Sigmoid()
        self.outconv = OutConv(64, 1)
    
    def forward(self, xb, x4, x3, x2, x1):
        y4 = self.up4(xb, x4)
        y3 = self.up3(y4, x3)
        y2 = self.up2(y3, x2)
        y1 = self.up1(y2, x1)
        
        out = self.outconv(y1)
        out = self.act(out)
        return out

class Grep(nn.Module):
    def __init__(self):
        super(Grep, self).__init__()
        self.flat = nn.Flatten()
        self.l1 = nn.Linear(65536, 3200)
        self.ln1 = nn.LayerNorm(3200)
        self.l2 = nn.Linear(3200, 128)
        self.ln2 = nn.LayerNorm(128)
    def forward(self, x):
        x = self.flat(x)
        x = self.ln1(self.l1(x))
        x = self.ln2(self.l2(x))
        return F.softmax(x, dim=1)



In [None]:
pytorch_total_params_enc = sum(p.numel() for p in Encoder().parameters() if p.requires_grad)
print('Parameters: ', pytorch_total_params_enc)

In [None]:
root_path = '/content/drive/MyDrive/Data/ACDC/training_ACDC'
input_size = (3,256,256) #for kaggle 448
batch_size = 1
learning_rate = 0.000001
epochs = 500

INITAL_EPOCH_LOSS = 10000
NUM_EARLY_STOP = 20
NUM_UPDATE_LR = 5

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
def convert(A, Ttnp = True):
    if Ttnp == True: return A.detach().cpu().numpy()
    else: return torch.tensor(A).to(device)

class Visualize():
    def __init__(self):
        super(Visualize, self).__init__()
        self.lst = []
    
    def register(self, x):
        self.lst.append(x)
    
    def make_numpy(self, vector_length=128):
        self.dataset = np.array(self.lst, dtype=float)
        self.shape = self.dataset.shape
        self.dimension = vector_length
    
    def make_representation(self):
        emb = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(self.dataset)
        self.embeddings = emb
    
    def plot(self):
        plt.scatter(self.embeddings[:,0], self.embeddings[:,1], s=5)
        plt.show()

class contrastive_loss(nn.Module):
    def __init__(self):
        super(contrastive_loss, self).__init__()

    def mag(self, x):
        x = x**2
        s = x.sum()
        s = s**(1/2)
        return s
    
    def cosine_similarity(self, x, y):
        S = (x*y).sum()
        S = S/(self.mag(x) * self.mag(y))
        return S
    
    def forward(self, pos, neg, t=1):
        N, D = torch.zeros([1]), torch.zeros([1])#.to(device)
        p = len(pos)
        for i in range(1,p):
            cos = self.cosine_similarity(pos[i], pos[0]).cpu()
            N += torch.exp(cos/t)
        #print(self.N)
        n = len(neg)
        for i in range(n):
            cos = self.cosine_similarity(pos[0], neg[i]).cpu()
            D += torch.exp(cos/t)
        #print(self.D)
        loss = - torch.log(N/(N+D))
        #N, D = convert(N), convert(D)
        return loss


In [None]:
def perturb(x):
    x = convert(x)[0,0,10:118,10:118]
    x = cv2.resize(x, (128,128),cv2.INTER_AREA)
    x = convert(x, False)
    x = x.unsqueeze(0).unsqueeze(0)
    return x

def adversarial(x, xp, Enc, G, loss):
    x_att = x
    x_att.requires_grad_()
    x_a = G(Enc.forward(x_att))
    adg = loss.cosine_similarity(x_a,xp)
    adg.backward(retain_graph=True)
    sal, _ = torch.max(x_att.grad.data.abs(), dim=1)
    sal = (sal - sal.min())/sal.max()
    # Visual Inspection
    #print('Similarity: ',adg)
    
    noise = torch.rand(x.shape).to(device)
    #print('Saliency Map and Noise Map: ')
    #cv2_imshow(sal[0,:,:].detach().cpu().numpy()*255.0)
    #cv2_imshow(noise[0,0,:,:].detach().cpu().numpy()*255.0)
    noise = noise*sal*0.7
    xa = (x + noise)
    #print('Adversarial Image: ')
    #cv2_imshow(xa[0,0,:,:].detach().cpu().numpy()*255.0)
    adg = 0
    return xa


class MyFrame():
    def __init__(self, encoder, learning_rate, device, evalmode=False):
        self.Enc = encoder().to(device)
        self.G = Grep().to(device)
        self.optimizer = torch.optim.Adam(params=list(self.Enc.parameters()) + list(self.G.parameters()), lr=learning_rate, weight_decay=0.0001)
        self.loss = contrastive_loss().to(device)
        self.lr = learning_rate

    
    def set_input(self, img_batch, mask_batch=None):
        self.img = img_batch
        self.mask = mask_batch
        
    def optimize(self):
        self.optimizer.zero_grad()
        c = self.img.shape[1]
        b = self.img.shape[0]
        L = 0
        preds = []
        for i in range(c):
            pos = []
            neg = []
            im1 = self.img
            x = im1[0,i,:,:]
            im1 = im1[0,i,:,:]
            
            x = x.unsqueeze(0).unsqueeze(0)
            #print('Image: ')
            #cv2_imshow(x[0,0,:,:].detach().cpu().numpy()*255.0)
            xp = perturb(x) #Perturbation
            #print('Perturbed Image: ')
            #cv2_imshow(xp[0,0,:,:].detach().cpu().numpy()*255.0)
            xo = self.G(self.Enc.forward(x))
            xp = self.G(self.Enc.forward(xp))
            #xa = adversarial(x, xp, self.Enc, self.G, self.loss)
            #xa = self.G(self.Enc.forward(xa))
            pos = [xo, xp]
            preds.append(xo)
            
            for j in range(c):
                if j==i: continue
                n = self.img[0,j,:,:]
                n = n.unsqueeze(0).unsqueeze(0)
                #print('Image: ')
                #cv2_imshow(n[0,0,:,:].detach().cpu().numpy()*255.0)
                no = self.G(self.Enc.forward(n))
                neg.append(no)
                np = perturb(n)
                #print('Perturbed Image: ')
                #cv2_imshow(np[0,0,:,:].detach().cpu().numpy()*255.0)
                np = self.G(self.Enc.forward(np))
                neg.append(np)
                
                #na = adversarial(n, np, self.Enc, self.G, self.loss)
                #na = self.G(self.Enc.forward(na))
                #neg.append(na)
                

            L += self.loss.forward(pos,neg,0.1)
        L.backward()
        self.optimizer.step()
        return L.item(), preds
        
    def save(self, path):
        #torch.save(self.Enc.state_dict(), path + '/' + 'pre_enc_resnet34_cropped' + '.pth')
        #torch.save(self.G.state_dict(), path + '/' + 'pre_G_resnet34_cropped' + '.pth')
        torch.save(self.Enc.state_dict(), path + '/' + 'pre_enc_resnet34_cropped_KC' + '.pth')
        torch.save(self.G.state_dict(), path + '/' + 'pre_G_resnet34_cropped_KC' + '.pth')

    def load(self, path):
        #self.Enc.load_state_dict(torch.load(path + '/' + 'pre_enc_resnet34_cropped' + '.pth'))
        #self.G.load_state_dict(torch.load(path + '/' + 'pre_G_resnet34_cropped' + '.pth'))
        self.Enc.load_state_dict(torch.load(path + '/' + 'pre_enc_resnet34_cropped_KC' + '.pth'))
        self.G.load_state_dict(torch.load(path + '/' + 'pre_G_resnet34_cropped_KC' + '.pth'))

    def update_lr(self, new_lr, factor=False):

        if factor:
            new_lr = self.lr / new_lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr

        print ('update learning rate: %f -> %f' % (self.lr, new_lr))
        print ('update learning rate: %f -> %f' % (self.lr, new_lr))
        self.lr = new_lr

class proposed_loss(nn.Module):
    def __init__(self, batch=True):
        super(proposed_loss, self).__init__()
        self.batch = batch
        self.mae_loss = torch.nn.L1Loss()
        self.bce_loss = torch.nn.BCELoss()

    def soft_dice_coeff(self, y_true, y_pred):
        smooth = 0.0  # may change
        if self.batch:
            i = torch.sum(y_true)
            j = torch.sum(y_pred)
            intersection = torch.sum(y_true * y_pred)
        else:
            i = y_true.sum(1).sum(1).sum(1)
            j = y_pred.sum(1).sum(1).sum(1)
            intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
        score = (2. * intersection + smooth) / (i + j + smooth)
        # score = (intersection + smooth) / (i + j - intersection + smooth)#iou
        return score.mean()

    def soft_dice_loss(self, y_true, y_pred):
        loss = 1 - self.soft_dice_coeff(y_true, y_pred)
        return loss

    def iou_loss(self, inputs, targets):
        smooth = 0.0
        #inputs = inputs.view(-1)
        #targets = targets.view(-1)
        
        intersection = (inputs * targets).sum(1).sum(1).sum(1)
        total = (inputs + targets).sum(1).sum(1).sum(1)
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return (1 - IoU.mean())

    def forward(self, y_true, y_pred):
        a = self.mae_loss(y_pred, y_true)
        b = self.soft_dice_loss(y_true, y_pred)
        c = self.bce_loss(y_pred, y_true)
        d = self.iou_loss(y_pred, y_true)
        loss = 0.15*a + 0.4*b  + 0.15*c + 0.3*d
        return loss




In [None]:
dataset = Dataset(root_path)
dataset.images = dataset.images[0:10]
dataset.labels = dataset.labels[0:10]
train_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True)

In [None]:
#Insert Root Locations
model_root = '/content/drive/MyDrive/Data/Models'
logfile_root = '/content/drive/MyDrive/Var Logs'
varfile_root = '/content/drive/MyDrive/Var Logs/Variables'

#ONLY CHANGE THE NAME
model_path = model_root
logfile_path = logfile_root + '/' + 'cont_pre_resnet34_cropped_KC' + '.txt'
varfile_tr = varfile_root + '/' + 'cont_pre_resnet34_cropped_KC' + '_tr.pkl'

In [None]:
def trainer(epoch, epochs, train_loader, solver, logfile=None):
    tsne_plotter = Visualize()
    keep_training = True
    no_optim = 0
    train_epoch_best_loss = INITAL_EPOCH_LOSS
    prev_loss = 1
    print('Epoch {}/{}'.format(epoch, epochs))
    train_epoch_loss = 0
    train_epoch_ssim = 0
    length = len(train_loader)
    iterator = tqdm(enumerate(train_loader), total=length, leave=False, desc=f'Epoch {epoch}/{epochs}')
    for index, (img, mask) in iterator :

        img = img.to(device)
        mask = mask.to(device)
        solver.set_input(img, mask)
        train_loss, preds = solver.optimize()
        for p in preds:
            p = p[0].detach().cpu().numpy()
            tsne_plotter.register(p)
        
        train_epoch_loss += train_loss
    
    
    train_epoch_loss = train_epoch_loss/len(dataset)

    print('train_loss:', train_epoch_loss)
    print('Learning rate: ', solver.lr)
    tsne_plotter.make_numpy()
    tsne_plotter.make_representation()
    tsne_plotter.plot()

    if logfile != None:
        logfile.write('Epoch: '+str(epoch)+'/'+str(epochs)+'\n')
        logfile.write('train_loss: '+str(train_epoch_loss)+'\n')
        logfile.write('Learning rate: '+str(solver.lr)+'\n')
        logfile.write('------------------------------------------------------------------')
        logfile.write('------------------------------------------------------------------')

    return train_epoch_loss, keep_training

def train(init=True):
    solver = MyFrame(Encoder, learning_rate, device)
    EP, Ls, BL = [], [], []
    if init==False: 
        solver.load(model_path)
        v = open(varfile_tr, 'rb')
        EP, Ls, Lr = pickle.load(v)
    else:
        logfile = open(logfile_path, 'w')
        logfile.write('\n')
        logfile.close()
    
    start_ep = 0
    if len(EP)!=0: start_ep = EP[-1]
    tbl = 100000000000
    if len(BL)!=0: tbl = BL[-1]
    num=0
    for epoch in range(start_ep+1, epochs + 1):
        
        logfile = open(logfile_path, 'a')

        l,k = trainer(epoch, epochs, train_loader, solver, logfile)
        if l<=tbl: 
            tbl = l
            solver.save(model_path)
            print('Saving model...Updating best loss')
        else: print('Not saving model...Loss: ', l, ', Best Loss: ', tbl)
        EP += [epoch]
        Ls += [l]
        BL += [tbl]
        tr_comp = [EP, Ls, BL]
        var = open(varfile_tr, 'wb')
        pickle.dump(tr_comp, var)
        print('----------------------------------------------------------------------------------------')
        print('----------------------------------------------------------------------------------------')
        print('----------------------------------------------------------------------------------------')
        #torch.cuda.empty_cache()
        if k: continue
        else: break

#KC Training

In [None]:
train(True)

#Best Training

In [None]:
train(False)

In [None]:
train(True)

#Visuals

In [None]:
solver1 = MyFrame(Encoder, learning_rate, device)
solver2 = MyFrame(Encoder, learning_rate, device)

solver1.Enc.load_state_dict(torch.load(model_path + '/' + 'pre_enc_resnet34_cropped' + '.pth'))
solver1.G.load_state_dict(torch.load(model_path + '/' + 'pre_G_resnet34_cropped' + '.pth'))

solver2.Enc.load_state_dict(torch.load(model_path + '/' + 'pre_enc_resnet34_cropped_KC' + '.pth'))
solver2.G.load_state_dict(torch.load(model_path + '/' + 'pre_G_resnet34_cropped_KC' + '.pth'))

In [None]:
dataset = Dataset(root_path)
dataset.images = dataset.images[0:52]
dataset.labels = dataset.labels[0:52]
train_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True)

In [None]:
class Visualize_proper():
    def __init__(self):
        super(Visualize_proper, self).__init__()
        self.lst = []
        self.lab = []
    
    def register(self, x, chan):
        self.lst.append(x)
        self.lab.append(chan)
    
    def make_numpy(self, vector_length=128):
        self.dataset = np.array(self.lst, dtype=float)
        self.shape = self.dataset.shape
        self.dimension = vector_length
    
    def make_representation(self):
        emb = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(self.dataset)
        self.embeddings = emb
    
    def plot(self):
        plt.xlim(-15,15)
        plt.ylim(-15,15)
        for i in range(self.embeddings.shape[0]):
            if self.lab[i] == 0: plt.scatter(self.embeddings[i,0], self.embeddings[i,1], color='red', s=5)
            elif self.lab[i] == 1: plt.scatter(self.embeddings[i,0], self.embeddings[i,1], color='blue', s=5)
            elif self.lab[i] == 2: plt.scatter(self.embeddings[i,0], self.embeddings[i,1], color='green', s=5)
            else: plt.scatter(self.embeddings[i,0], self.embeddings[i,1], color='black', s=5)
            
        plt.show()

In [None]:
tsne_plotter1 = Visualize_proper()
tsne_plotter2 = Visualize_proper()

length = len(train_loader)
iterator = tqdm(enumerate(train_loader), total=length, leave=False)
for index, (img, mask) in iterator :

    img = img.to(device)
    for i in range(img.shape[1]):
        im = img[0,i,:,:]
        im = im.unsqueeze(0).unsqueeze(0)
        ip = perturb(im)
        p1 = solver1.G.forward(solver1.Enc.forward(im))
        p2 = solver2.G.forward(solver2.Enc.forward(im))
        tsne_plotter1.register(p1[0].detach().cpu().numpy(), i)
        tsne_plotter2.register(p2[0].detach().cpu().numpy(), i)
        #p1 = solver1.G.forward(solver1.Enc.forward(ip))
        #p2 = solver2.G.forward(solver2.Enc.forward(ip))
        #tsne_plotter1.register(p1[0].detach().cpu().numpy(), i)
        #tsne_plotter2.register(p2[0].detach().cpu().numpy(), i)
print('Ours')
tsne_plotter1.make_numpy()
tsne_plotter1.make_representation()
tsne_plotter1.plot()
print('KC')
tsne_plotter2.make_numpy()
tsne_plotter2.make_representation()
tsne_plotter2.plot()

In [None]:
dataset = Dataset(root_path, 0.6)
dataset.images = dataset.images[0:52]
dataset.labels = dataset.labels[0:52]
train_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True)

In [None]:
tsne_plotter1 = Visualize_proper()
tsne_plotter2 = Visualize_proper()

length = len(train_loader)
iterator = tqdm(enumerate(train_loader), total=length, leave=False)
for index, (img, mask) in iterator :

    img = img.to(device)
    for i in range(img.shape[1]):
        im = img[0,i,:,:]
        im = im.unsqueeze(0).unsqueeze(0)
        ip = perturb(im)
        p1 = solver1.G.forward(solver1.Enc.forward(im))
        p2 = solver2.G.forward(solver2.Enc.forward(im))
        tsne_plotter1.register(p1[0].detach().cpu().numpy(), i)
        tsne_plotter2.register(p2[0].detach().cpu().numpy(), i)
        #p1 = solver1.G.forward(solver1.Enc.forward(ip))
        #p2 = solver2.G.forward(solver2.Enc.forward(ip))
        #tsne_plotter1.register(p1[0].detach().cpu().numpy(), i)
        #tsne_plotter2.register(p2[0].detach().cpu().numpy(), i)
print('Ours')
tsne_plotter1.make_numpy()
tsne_plotter1.make_representation()
tsne_plotter1.plot()
print('KC')
tsne_plotter2.make_numpy()
tsne_plotter2.make_representation()
tsne_plotter2.plot()

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.enc = Encoder()
        self.G = Grep()
    def forward(self, x):
        x = self.enc(x)
        x = self.G(x)
        return x
#print(model)
model1 = Model()
print(model1.enc.layer4)

In [None]:
model1 = Model()

model1.enc.load_state_dict(torch.load(model_path + '/' + 'pre_enc_resnet34_cropped' + '.pth'))
model1.G.load_state_dict(torch.load(model_path + '/' + 'pre_G_resnet34_cropped' + '.pth'))


In [None]:
!git clone https://github.com/jacobgil/pytorch-grad-cam.git

In [None]:
%cd pytorch-grad-cam

In [None]:
!pip install ttach

In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
from torchvision.models.segmentation import deeplabv3_resnet50
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import requests
import torchvision
from PIL import Image
from torchvision import transforms
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
import cv2

l=[]
image_url, mask = loader('/content/drive/MyDrive/Data/ACDC/training_ACDC/patient056/patient056_frame01.nii.gz', 
                   '/content/drive/MyDrive/Data/ACDC/training_ACDC/patient056/patient056_frame01_gt.nii.gz')
image = np.array(image_url)[0:1]
rgb_img = np.float32(image) / 255
print(rgb_img.shape)
print(mask.shape)
#rgb_img = torch.tensor(rgb_img)
#to_pil_image = transforms.ToPILImage()
#rgb_img = to_pil_image(rgb_img)
input_tensor = torch.tensor(rgb_img)
#input_tensor = preprocess_image(rgb_img,mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
# Taken from the torchvision tutorial
# https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html
model = Model()
if torch.cuda.is_available():
    model = model.cuda()
    input_tensor = input_tensor.unsqueeze(0).cuda()

output = model(input_tensor)
print(output.shape)

In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
from torchvision.models.segmentation import deeplabv3_resnet50
import torch
import torch.functional as F
import numpy as np
import requests
import torchvision
from PIL import Image
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from sklearn import preprocessing

image_url = "https://farm1.staticflickr.com/6/9606553_ccc7518589_z.jpg"
image = np.array(Image.open(requests.get(image_url, stream=True).raw))

rgb_img = np.float32(image) / 255
rgb_img = rgb_img + 0.4*np.random.rand(rgb_img.shape[0], rgb_img.shape[1], rgb_img.shape[2])
rgb_img = rgb_img/np.amax(rgb_img)
rgb_img = np.float32(rgb_img)
input_tensor = preprocess_image(rgb_img,
                                mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])

# Taken from the torchvision tutorial
# https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html
model = deeplabv3_resnet50(pretrained=True, progress=False)
model = model.eval()
if torch.cuda.is_available():
    model = model.cuda()
    input_tensor = input_tensor.cuda()

output = model(input_tensor)
print(output)

In [None]:
class SegmentationModelOutputWrapper(torch.nn.Module):
    def __init__(self, model): 
        super(SegmentationModelOutputWrapper, self).__init__()
        self.model = model
        
    def forward(self, x):
        return self.model(x)["out"]
    
model = SegmentationModelOutputWrapper(model)
output = model(input_tensor)
print(output.shape)

In [None]:
normalized_masks = torch.nn.functional.softmax(output, dim=1).cpu()
sem_classes = [
    '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}

car_category = sem_class_to_idx["car"]
car_mask = normalized_masks[0, :, :, :].argmax(axis=0).detach().cpu().numpy()
car_mask_uint8 = 255 * np.uint8(car_mask == car_category)
car_mask_float = np.float32(car_mask == car_category)

both_images = np.hstack((image, np.repeat(car_mask_uint8[:, :, None], 3, axis=-1)))
print(both_images.shape)
Image.fromarray(both_images)

In [None]:
from pytorch_grad_cam import GradCAM

class SemanticSegmentationTarget:
    def __init__(self, category, mask):
        self.category = category
        self.mask = torch.from_numpy(mask)
        if torch.cuda.is_available():
            self.mask = self.mask.cuda()
        
    def __call__(self, model_output):
        print(model_output[self.category, :, : ].shape)
        print((model_output[self.category, :, : ] * self.mask).sum())
        return (model_output[self.category, :, : ] * self.mask).sum()

    
target_layers = [model.model.backbone.layer4]
print(target_layers)
targets = [SemanticSegmentationTarget(car_category, car_mask_float)]
with GradCAM(model=model,
             target_layers=target_layers,
             use_cuda=torch.cuda.is_available()) as cam:
    grayscale_cam = cam(input_tensor=input_tensor,
                        targets=targets)[0, :]
    print(input_tensor.shape, targets[0].shape)
    cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

Image.fromarray(cam_image)

In [None]:
from pytorch_grad_cam import GradCAM
class SemanticSegmentationTarget:
    def __init__(self, input_tensor1, mask):
        self.input_tensor1 = input_tensor1
        self.mask = torch.from_numpy(mask)
        if torch.cuda.is_available():
            self.mask = self.mask.cuda()
        
    def __call__(self, model_output):
        print((model_output[self.input_tensor1]* self.mask).sum())
        return (model_output[self.input_tensor1]* self.mask).sum()

target_layers = [model.enc.layer4]
print(target_layers)
targets = [SemanticSegmentationTarget(input_tensor, mask)]
with GradCAM(model=model,
             target_layers=target_layers,
             use_cuda=torch.cuda.is_available()) as cam:
    grayscale_cam = cam(input_tensor=input_tensor,targets=targets)[0, :]
    cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
Image.fromarray(cam_image)