In [1]:
import torchvision
import torch

import numpy as np
import matplotlib.pyplot as plt
import cv2
import monai

from sklearn import decomposition
from sklearn import manifold
from sklearn.cluster import KMeans

## 2D Image input 확인

In [2]:
def check_input(iterator, batch_size):
    for data in iterator:
        input = data['img']
        label = data['label']
        
        plt.imshow(torchvision.utils.make_grid(input, normalize=True).permute(1,2,0))
        print(''.join(f'{label[i]} 'for i in range(batch_size)))
        plt.show()

## Hook

In [None]:
class Hook():
    def __init__(self, module, backward = False):
        if backward == False:
            self.hook = module.register_forward_hook(self.hook_fn)
        else:
            self.hook = module.register_backward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.input = input
        self.features = output
    def close(self):
        self.hook.remove()

## Feature extraction

In [3]:
def get_representations(model, iterator, device, hook = None):
    '''
    If Hook is not None.
        ex)
        hook_model = Hook(model.layer4)
        def get_representations(model, test_loader, device, hook = hook_model):
    '''
    model.eval()

    outputs = []
    labels = []

    with torch.no_grad():
        for data in iterator:
            
            input = data['img'].to(device)
            label = data['label']
            
            output = model(input)
            
            if hook is not None:
                output = hook.features
                
            outputs.append(output.cpu())
            labels.append(label)
        
    outputs = torch.cat(outputs, dim = 0)
    labels = torch.cat(labels, dim = 0)
    return outputs, labels

## PCA

In [4]:
def get_pca(data, n_components = 2):
    pca = decomposition.PCA()
    pca.n_components = n_components
    pca_data = pca.fit_transform(data)
    return pca_data

# TSNE

In [6]:
def get_tsne(data, n_components = 2):
    tsne = manifold.TSNE(n_components = n_components, random_state = 0)
    tsne_data = tsne.fit_transform(data)
    return tsne_data

## Plot Representation(PCA, TSNE)

In [7]:
def plot_representations(data, labels, n_images = None):
    '''
    ex)
    outputs, labels = get_representation(model, test_loader)
    
    output_pca = get_pca(outputs)
    plot_representations(output_pca, labels)
    
    output_tsne = get_tsne(outputs)
    plot_representations(output_tsne, labels)
    '''            
    if n_images is not None:
        data = data[:n_images]
        labels = labels[:n_images]
                
    fig = plt.figure(figsize = (15, 15))
    ax = fig.add_subplot(111)
    scatter = ax.scatter(data[:, 0], data[:, 1], c = labels, cmap = 'hsv')



## K-Means Clustering

In [None]:
def plot_kmeans(data, n_cluster, n_images = None):
    '''
    ex)
    outputs, labels = get_representation(model, test_loader)
    plot_kmeans(outputs, n_cluster = 5)
    '''
    if n_images is not None:
        data = data[:n_images]
    
    kmeans = KMeans(n_clusters = n_cluster, random_state = 0)
    k_labels = kmeans.fit_predict(data)

    unique_labels = np.unique(k_labels)
    centroids = kmeans.cluster_centers_
    
    plt.figure(figsize = (15, 15))
    for i in unique_labels:
        plt.scatter(data[k_labels == i, 0], data[k_labels == i, 1], label = i, s = 80, cmap = 'hsv')
    plt.scatter(centroids[:, 0], centroids[:, 1], s = 90, color = 'k', marker='x')

## Easy & Difficult 

In [None]:
def test_easy_difficult(iterator, model, target_layer, device):
    '''
    ex)
    test_easy_difficult(test_loader, model, target_layer = 'layer4', device)
    '''

    cam = monai.visualize.GradCAM(nn_module=model, target_layers= target_layer)

    for data in iterator:
        input = data['img'].to(device)
        label = data['label']

        output = model(input)
        _, pred = torch.max(output, dim=1) # max, max_index
        
        cam_result = cam(input.float().to(device)).permute(0,2,3,1)

        for i in range(cam_result.shape[0]): 
            cam_show = cv2.applyColorMap(np.uint8(cam_result[i].detach().cpu().numpy() * 255), cv2.COLORMAP_JET)/255
            
            fig = plt.figure()
            ax1 = fig.add_subplot(121)
            ax1.imshow(input[i].detach().cpu().numpy().transpose(1,2,0), 'gray')
            ax1.axis('off')

            ax2 = fig.add_subplot(122)
            ax2.imshow(input[i].detach().cpu().numpy().transpose(1,2,0), 'gray')
            ax2.imshow(cam_show, alpha = 0.3)
            ax2.axis("off")

            fig.suptitle(f'GT : {int(label[i])}  Pred : {int(pred[i])}')
            fig.tight_layout()

    plt.show()