In [None]:
import argparse
import cv2
import numpy as np
import torch
from torchvision import models
from pytorch_grad_cam import GradCAM,ScoreCAM,GradCAMPlusPlus,AblationCAM,XGradCAM,EigenCAM,EigenGradCAM,LayerCAM,FullGrad,GradCAMElementWise
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image,deprocess_image,preprocess_image,preprocess_grayimage
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os,sys
import copy
from torch.utils.data.dataset import Dataset
from scipy.io import savemat
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
torch.cuda.is_available()

In [None]:
from __future__ import print_function, division
import argparse
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import shutil
import math
import fnmatch
import nets
from collections import OrderedDict

from collections import defaultdict
from sklearn.model_selection import train_test_split
import random
from datetime import datetime
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score
import seaborn as sns
from pytorch_model_summary import summary
from torch.nn.utils import weight_norm
from matplotlib import pyplot as plt

In [None]:
class CAE_2d_classify(nn.Module):
    def __init__(self, input_shape=[400,400,1], latent_channels=512, num_clusters=2, filters=[32, 64, 128, 256], dropout=0.1):
        super(CAE_2d_classify, self).__init__()
        self.pretrained = False
        self.latent_channels = latent_channels
        self.input_shape = input_shape
        self.filters = filters
        self.num_clusters = num_clusters
        
        self.encoder2d = nn.Sequential(OrderedDict([
            ('conv2d1_1', nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=False)),
          ('relu2d1_1', nn.ReLU()),
          ('bn2d1_1', nn.BatchNorm2d(filters[0])),
            ('conv2d2_1', nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=False)),
          ('relu2d2_1', nn.ReLU()),
          ('bn2d2_1', nn.BatchNorm2d(filters[1])),
            ('conv2d3_1', nn.Conv2d(filters[1], filters[2], 5, stride=2, padding=2, bias=False)),
          ('relu2d3_1', nn.ReLU()),
          ('bn2d3_1', nn.BatchNorm2d(filters[2])),
            ('conv2d4_1', nn.Conv2d(filters[2], filters[3], 3, stride=2, padding=0, bias=False)),
          ('relu2d4_1', nn.ReLU()),
        ]))
        lin_features_len = ((input_shape[0] // 2 // 2 // 2 - 1) // 2) * ((input_shape[1] // 2 // 2 // 2 - 1) // 2) * \
                           filters[3]
        self.embedding2d = nn.Linear(lin_features_len, latent_channels, bias=False)
        self.fc1 = nn.Linear(latent_channels, 1024)
        self.bn_fc1 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(1024, num_clusters)
        self.relu = nn.ReLU()

    def forward(self, x):
        # tansmission
        x = self.encoder2d(x)
        x = x.view(x.size(0), -1)
        x = self.embedding2d(x)
        x_out = x
        
        x = x
        x = self.relu(x)
        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        
        pred = nn.functional.log_softmax(x,dim=1)
        
        return pred

In [None]:
from skimage import io
import skimage.transform
import scipy
import scipy.io as sio
from torch.utils.data.dataset import Dataset

class MyDataset_mat(Dataset):
    def __init__(self, tm_paths, labels, transforms_tm=None, transforms_fl=None):
        self.paths_tm = tm_paths
        self.transforms_tm = transforms_tm
        self.transforms_fl = transforms_fl
        self.labels = labels
        
    def __getitem__(self, index, load_fl=False):
        data_tm = sio.loadmat(self.paths_tm[index])['data']
#         data_tm = np.transpose(np.array(cv2.cvtColor(data_tm,cv2.COLOR_GRAY2RGB)), (2,0,1))
        
        x_tm = torch.from_numpy(data_tm.astype(np.float32)/65535)
        x_tm = x_tm.unsqueeze(dim = 0)
        
        label = torch.from_numpy(np.asarray(self.labels[index]))
        label = label.to(torch.int64)
        if self.transforms_tm:
            x_tm = self.transforms_tm(x_tm)
        if load_fl:
            name = self.paths_tm[index]
            fldir = name[:-23]
            base_name = name[-19:-10]
            flname = fldir+'FL/'+base_name+'_2DFL.mat'
            data_fl = sio.loadmat(flname)['data']
            x_fl = torch.from_numpy(data_fl.astype(np.float32)/65535)
            x_fl = x_fl.unsqueeze(dim = 0)
            if self.transforms_fl:
                x_fl = self.transforms_fl(x_fl)
        if load_fl:
            return x_tm, x_fl, label
        else:
            return x_tm, label
    
    def __len__(self):
        return len(self.paths_tm)
    
#Visualize image stacks
from mpl_toolkits.axes_grid1 import make_axes_locatable


def DisplayImage(img,LimMin,LimMax,title='2D Projection'):  
    if img.shape[2]>1:
        for depth in range (0,10):
            plt.subplot(2,5,depth+1)
            plt.imshow(img[:,:,depth*8],vmin=LimMin, vmax=LimMax, cmap ='jet')
            # plt.colorbar()
            plt.title('z = '+ str(depth))
            if depth==9:
                plt.subplot(2,5,depth+1)
                plt.title('z = '+ str(depth))
                ax = plt.gca()
                fig4 = plt.imshow(img[:,:,depth*8],vmin=LimMin, vmax=LimMax, cmap ='jet')

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(fig4, cax=cax)
        plt.show()
    else:
        fig4 = plt.imshow(np.squeeze(img),vmin=LimMin, vmax=LimMax, cmap ='gray')
        plt.title(title)
        ax = plt.gca()
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(fig4, cax=cax)
        plt.show()
def DisplayandSaveImage(img,LimMin,LimMax,title='2D Projection',save_name='2D Projection.jpg'):  
    if img.shape[2]>1:
        for depth in range (0,10):
            plt.subplot(2,5,depth+1)
            plt.imshow(img[:,:,depth*8],vmin=LimMin, vmax=LimMax, cmap ='jet')
            # plt.colorbar()
            plt.title('z = '+ str(depth))
            if depth==9:
                plt.subplot(2,5,depth+1)
                plt.title('z = '+ str(depth))
                ax = plt.gca()
                fig4 = plt.imshow(img[:,:,depth*8],vmin=LimMin, vmax=LimMax, cmap ='jet')

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(fig4, cax=cax)
        plt.show()
    else:
        fig4 = plt.imshow(np.squeeze(img),vmin=LimMin, vmax=LimMax, cmap ='gray')
        plt.title(title)
        ax = plt.gca()
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(fig4, cax=cax)
        plt.savefig(save_name, dpi=200)
        plt.show()
        

In [None]:
import pandas as pd

csv_path = './reports/PT_GroundTruthLabel&ClusterLabel.csv'
csv_data = pd.read_csv(csv_path)

sscpath = csv_data.name
tmpath = csv_data.tmname
cluster_label = csv_data.cluster
truth_label = csv_data.label
print(cluster_label[0])
print('Type of label: ', type(cluster_label[0]))
    
transforms_tm = transforms.Resize([400, 400])
total_datasets = MyDataset_mat(tmpath, cluster_label, transforms_tm=transforms_tm)
print('Total data size: ', len(total_datasets))

In [None]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--use-cuda', action='store_true', default=True,
                        help='Use NVIDIA GPU acceleration')
    parser.add_argument(
        '--image-path',
        type=str,
        default='./examples/both.png',
        help='Input image path')
    parser.add_argument('--aug_smooth', action='store_true', default=False,
                        help='Apply test time augmentation to smooth the CAM')
    parser.add_argument('--eigen_smooth', action='store_true', default=False,
        help='Reduce noise by taking the first principle componenet of cam_weights*activations')
    parser.add_argument('--method', type=str, default='ablationcam',
                        choices=['gradcam', 'gradcam++', 'scorecam', 'xgradcam', 'ablationcam', 'eigencam', 'eigengradcam', 'layercam', 'fullgrad'],
                        help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam/eigencam/eigengradcam/layercam')

    args, unknown = parser.parse_known_args()
    args.use_cuda = args.use_cuda and torch.cuda.is_available()
    if args.use_cuda:
        print('Using GPU for acceleration')
    else:
        print('Using CPU for computation')

    return args

In [None]:
def Display_cam(input_img, cam_img, cam_mask, fl_mask, save_name=None):
    plt.figure(figsize=(20, 4))
    plt.subplot(1,4,1)
    plt.imshow(input_img,vmin=input_img.min(), vmax=input_img.max(), cmap ='gray')
    plt.title('Input Image')
    plt.subplot(1,4,2)
    plt.imshow(cam_img,vmin=cam_img.min(), vmax=cam_img.max(), cmap ='gray')
    plt.title('CAM Image')
    plt.subplot(1,4,3)
    plt.imshow(cam_mask,vmin=cam_mask.min(), vmax=cam_mask.max(), cmap ='gray')
    plt.title('CAM Mask')
    plt.subplot(1,4,4)
    plt.imshow(fl_mask,vmin=fl_mask.min(), vmax=fl_mask.max(), cmap ='gray')
    plt.title('FL Mask')
    if save_name is not None:
        plt.savefig(save_name, dpi=300)
    plt.show()

In [None]:
args = get_args()
args.method = 'gradcam'
methods = {"gradcam": GradCAM,
     "scorecam": ScoreCAM,
     "gradcam++": GradCAMPlusPlus,
     "ablationcam": AblationCAM,
     "xgradcam": XGradCAM,
     "eigencam": EigenCAM,
     "eigengradcam": EigenGradCAM,
     "layercam": LayerCAM,
     "fullgrad": FullGrad,
     "gradcamelementwise": GradCAMElementWise}

basedir = 'nets/ckpt_CAE_2d_classify_001/'
scratch_model = torch.load(basedir+'Classify_Trained_001_latest.pth.tar')
img_size = [400, 400, 1]
latent_channels = 512
model_name = 'CAE_2d_classify'
CellName = ['Translocated', 'Un-Translocated']
threshold_cam = 0.15
threshold_fl = 0.3

for ii in range(len(total_datasets)):
# for ii in range(2):
    to_eval = model_name + "()"
    model = eval(to_eval)
    model.load_state_dict(scratch_model['best_model_wts'])
    target_layers = [model.encoder2d.conv2d4_1]
    
    [ImgTensor,FLTensor, ImgLabel] = total_datasets.__getitem__(ii, load_fl=True)

    rgb_img = np.squeeze(np.float32(ImgTensor))
    input_tensor = ImgTensor.unsqueeze(0)


    # We have to specify the target we want to generate
    # the Class Activation Maps for.
    # If targets is None, the highest scoring category (for every member in the batch) will be used.
    # You can target specific categories by
    # targets = [e.g ClassifierOutputTarget(281)]
    targets = None

    # Using the with statement ensures the context is freed, and you can
    # recreate different CAM objects in a loop.
    cam_algorithm = methods[args.method]
    with cam_algorithm(model=model,
                       target_layers=target_layers,
                       use_cuda=args.use_cuda) as cam:

        # AblationCAM and ScoreCAM have batched implementations.
        # You can override the internal batch size for faster computation.
        cam.batch_size = 32
        grayscale_cam = cam(input_tensor=input_tensor,
                            targets=targets,
                            aug_smooth=args.aug_smooth,
                            eigen_smooth=args.eigen_smooth)

        # Here grayscale_cam has only one image in the batch
        grayscale_cam = grayscale_cam[0, :]

        cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=False)

    gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
    gb = gb_model(input_tensor, target_category=ImgLabel)

    cam_mask = grayscale_cam
    cam_gb = cam_mask * np.squeeze(gb)
    cam_mask_thresholded = cam_mask.copy()
    cam_mask_thresholded[cam_mask_thresholded<threshold_cam*cam_mask.max()]=0
    cam_mask_thresholded[cam_mask_thresholded>=threshold_cam*cam_mask.max()]=1
    
    fl_img = FLTensor.numpy()
    fl_img_thresholded = fl_img
    fl_img_thresholded[fl_img_thresholded<threshold_fl*fl_img.max()]=0
    fl_img_thresholded[fl_img_thresholded>=threshold_fl*fl_img.max()]=1

    
    input_img = np.transpose(np.expand_dims(np.squeeze(ImgTensor.numpy()), axis=0), (1,2,0))
    cam_mask_thresholded = np.transpose(np.expand_dims(np.squeeze(cam_mask_thresholded), axis=0), (1,2,0))
    fl_mask_thresholded = np.transpose(np.expand_dims(np.squeeze(fl_img_thresholded), axis=0), (1,2,0))
    cam_img = np.transpose(np.expand_dims(np.squeeze(cam_mask), axis=0), (1,2,0))
    
    Display_cam(input_img, cam_img, cam_mask_thresholded, fl_mask_thresholded, save_name=f'Images_Resize/{CellName[ImgLabel]}/{ii}_{args.method}.jpg')
    
    name = tmpath[ii]
    base_dir = name[:-23]
    base_name = name[-19:-10]
    maskname = base_dir+'Mask/'+base_name+'_2DMask.mat'

    maskdata = {'data':cam_mask_thresholded}
    savemat(maskname, maskdata)
    print('\nCluster label:', CellName[cluster_label[ii]])
    print('Ground Truth label:', CellName[truth_label[ii]])
    if ii == 0:
        name_ssc = [sscpath[ii]]
        name_tm = [tmpath[ii]]
        name_mask = [maskname]
        label_cluster = [cluster_label[ii]]
        label_truth = [truth_label[ii]]
    else:
        name_ssc = np.append(name_ssc, [sscpath[ii]], axis=0)
        name_tm = np.append(name_tm, [tmpath[ii]], axis=0)
        name_mask = np.append(name_mask, [maskname], axis=0)
        label_cluster = np.append(label_cluster, [cluster_label[ii]], axis=0)
        label_truth = np.append(label_truth, [truth_label[ii]], axis=0)

In [None]:
import csv

my_dict = {'name': name_ssc, 'tmname': name_tm, 'maskname': name_mask, 'label': label_truth, 'cluster': label_cluster}
headers = my_dict.keys()
with open('./reports/PT_mask.csv', 'w', newline='') as f:  # You will need 'wb' mode in Python 2.x
    writer = csv.writer(f)
    writer.writerow(headers)
    writer.writerows(zip(*my_dict.values()))