In [None]:
import numpy as np
import glob
import cv2
import os
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import jaccard_score

import torch
import torch.nn as nn

from multiprocessing import Pool

from tqdm import tqdm

from typing import *

In [None]:
ATT_EXP = 'AttentionV11'
WO_ATT_EXP = 'AttentionV11-WO'
SW_ATT_EXP = 'SW-AttentionV11'
SW_WO_ATT_EXP = 'SW-AttentionV11-WO'

OTHER_CLASSES_INFO = {
    'kfold0': {
        'origin': {
            'val': [7, 8, 9],
            'test': [4, 11, 12, 13, 14, 16]
        },
        'swapped': {
            'val': [4, 11, 12, 13, 14, 16],
            'test': [7, 8, 9],
        }
    },
    'kfold1': {
        'origin': {
            'val': [4, 5, 6, 16],
            'test': [1, 3, 8, 9, 10, 11, 15]
        },
        'swapped': {
            'val': [1, 3, 8, 9, 10, 11, 15],
            'test': [4, 5, 6, 16]
        }
    },
    'kfold2': {
        'origin': {
            'val': [8, 10, 12, 14],
            'test': [5, 6, 9, 11, 15, 16]
        },
        'swapped': {
            'val': [5, 6, 9, 11, 15, 16],
            'test': [8, 10, 12, 14],
        }
    },
}

In [None]:
def calculate_iou(pred_list, target_list, num_classes):
    res_list = []
    
    for preds, target in zip(pred_list, target_list): 
        preds_one_hoted = torch.nn.functional.one_hot(preds, num_classes).view(-1, num_classes).cpu()
        target_one_hoted = torch.nn.functional.one_hot(target, num_classes).view(-1, num_classes).cpu()
        res = jaccard_score(target_one_hoted, preds_one_hoted, average=None, zero_division=1)
        res_list.append(
            res
        )
    res_np = np.stack(res_list)
    return res_np


def dice_loss(preds, ground_truth, eps=1e-5, dim=None, use_softmax=False, softmax_dim=1):
    """
    Computes Dice loss according to the formula from:
    V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation
    Link to the paper: http://campar.in.tum.de/pub/milletari2016Vnet/milletari2016Vnet.pdf
    Parameters
    ----------
    preds : tf.Tensor
        Predicted probabilities.
    ground_truth : tf.Tensor
        Ground truth labels.
    eps : float
        Used to prevent division by zero in the Dice denominator.
    axes : list
        Defines which axes the dice value will be computed on. The computed dice values will be averaged
        along the remaining axes. If None, Dice is computed on an entire batch.
    Returns
    -------
    tf.Tensor
        Scalar dice loss tensor.
    """
    ground_truth = ground_truth.float().to(device=preds.device)
    
    if use_softmax:
        preds = nn.functional.softmax(preds, dim=softmax_dim)
    
    numerator = preds * ground_truth
    numerator = torch.sum(numerator, dim=dim)

    p_squared = torch.square(preds)
    p_squared = torch.sum(p_squared, dim=dim)
    # ground_truth is not squared to avoid unnecessary computation.
    # 0^2 = 0
    # 1^2 = 1
    g_squared = torch.sum(torch.square(ground_truth), dim=dim)
    denominator = p_squared + g_squared + eps

    dice = 2 * numerator / denominator
    return 1 - dice

#def pred2origin_subclasses(origin: np.ndarray, pred: np.ndarray):
    

In [None]:
kfolds_list = ['kfold0', 'kfold1', 'kfold2']
cur_exp_list = [ATT_EXP, SW_ATT_EXP, SW_WO_ATT_EXP, WO_ATT_EXP]
run_list = ['R1', 'R2', 'R3', 'R4']
type_indices_list = ['test', 'val']

class ParamsData:
    kfold: str
    cur_exp: str
    run: str
    type_indices: str

params_list = []
for kfold in kfolds_list:
    for cur_exp in cur_exp_list:
        for run in run_list:
            for type_indices in type_indices_list:
                params_data = ParamsData()
                params_data.kfold = kfold
                params_data.cur_exp = cur_exp
                params_data.run = run
                params_data.type_indices = type_indices
                params_list.append(params_data)

In [None]:
def collect_data(single_params):
    if single_params.kfold == 'kfold1':
        num_classes = 7
        other_indx = 6
    else: 
        num_classes = 8
        other_indx = 7


    other_classes = OTHER_CLASSES_INFO[single_params.kfold][
        'swapped' if single_params.cur_exp in [SW_WO_ATT_EXP, SW_ATT_EXP] else 'origin'
    ][single_params.type_indices]

    result_dict = {
        'dice': dict([(str(cl), []) for cl in other_classes]),
        'iou': dict([(str(cl), []) for cl in other_classes])
    }

    path_to_results = '/raid/rustam/hyperspectral_dataset/' +\
                      'diff_exp_with_other__attention_with_other/result_masks_0/' +\
                     f'{single_params.kfold}/{single_params.cur_exp}/{single_params.run}/{single_params.type_indices}'

    path_to_origin = f'{path_to_results}/origin' 
    path_to_pred = f'{path_to_results}/pred' 

    for single_origin_path in glob.glob(f'{path_to_origin}/*'):
        file_name = single_origin_path.split('/')[-1]
        single_pred_path = f'{path_to_pred}/{file_name}'
        origin = np.squeeze(np.load(single_origin_path))
        pred = np.squeeze(np.load(single_pred_path))
        local_mask = np.zeros_like(origin, dtype=np.float32)
        for other_subclass in other_classes:
            if other_subclass not in origin:
                continue

            origin_mask = (origin == other_subclass).astype(np.int64)
            pred_mask = (pred == other_indx).astype(np.int64)

            pred_mask_t = torch.from_numpy(pred_mask)
            origin_mask_t = torch.from_numpy(origin_mask)
            iou = calculate_iou([pred_mask_t], [origin_mask_t], 2)
            result_dict['iou'][str(other_subclass)] += [float(iou[-1, -1])]

            origin_one_hotted_tensor = torch.nn.functional.one_hot(
                origin_mask_t.unsqueeze(0), 2 # Num classes
            )
            # (N, H, W, C) --> (N, C, H, W)
            origin_one_hotted_tensor = origin_one_hotted_tensor.permute(0, -1, 1, 2)
            dice = dice_loss(
                pred_mask_t.unsqueeze(0).unsqueeze(0), origin_one_hotted_tensor,
                dim=[0, 2, 3], use_softmax=False, softmax_dim=1,
            )
            result_dict['dice'][str(other_subclass)] += [float(dice[-1])]
    return result_dict, single_params

In [None]:
with Pool(16) as p:
    result_dict_list: List[Union[ParamsData, dict]] = list(tqdm(p.imap(collect_data, params_list), total=len(params_list)))

all_result_dict = {}
for result_dict, params in result_dict_list:
    if all_result_dict.get(params.kfold) is None:
        all_result_dict[params.kfold] = {}
    
    if all_result_dict[params.kfold].get(params.cur_exp) is None:
        all_result_dict[params.kfold][params.cur_exp] = {}
    
    if all_result_dict[params.kfold][params.cur_exp].get(params.run) is None:
        all_result_dict[params.kfold][params.cur_exp][params.run] = {}
    
    if all_result_dict[params.kfold][params.cur_exp][params.run].get(params.type_indices) is None:
        all_result_dict[params.kfold][params.cur_exp][params.run][params.type_indices] = {}
    
    all_result_dict[params.kfold][params.cur_exp][params.run][params.type_indices] = result_dict

In [None]:
def draw_single_figure(net_result_dict, folder_save):
    os.makedirs(folder_save, exist_ok=True)
    result_dice_dict = {'run': [], 'type_indx': [], 'class': [], 'metric_result': []}
    result_iou_dict = {'run': [], 'type_indx': [], 'class': [], 'metric_result': []}

    get_dice_mean = lambda x: 1.0 if len(x) == 0 else np.asarray(x).mean()
    get_iou_mean = lambda x: 0.0 if len(x) == 0 else np.asarray(x).mean()
    
    for run_k in net_result_dict.keys():
        res_per_run_dict = net_result_dict[run_k]
        for type_indx in res_per_run_dict.keys():
            res_per_type_indx_dict = res_per_run_dict[type_indx]

            for class_indx in res_per_type_indx_dict['dice'].keys():
                result_dice_dict['run'] += [run_k]
                result_dice_dict['type_indx'] += [type_indx]
                result_dice_dict['class'] += [str(class_indx)]
                result_dice_dict['metric_result'] += [
                    get_dice_mean(res_per_type_indx_dict['dice'][str(class_indx)])
                ]

            for class_indx in res_per_type_indx_dict['iou'].keys():
                result_iou_dict['run'] += [run_k]
                result_iou_dict['type_indx'] += [type_indx]
                result_iou_dict['class'] += [str(class_indx)]
                result_iou_dict['metric_result'] += [
                    get_iou_mean(res_per_type_indx_dict['iou'][str(class_indx)])
                ]
    # Mean dice
    for type_indx in ['val', 'test']:
        for class_indx in net_result_dict['R1'][type_indx]['dice'].keys():
            result_dice_dict['run'] += ['mean']
            result_dice_dict['type_indx'] += [type_indx]
            result_dice_dict['class'] += [str(class_indx)]
            result_dice_dict['metric_result'] += [
                get_dice_mean([
                    get_dice_mean(net_result_dict[run_k][type_indx]['dice'][str(class_indx)])
                    for run_k in net_result_dict.keys()
                ])
            ]
    # Mean iou
    for type_indx in ['val', 'test']:
        for class_indx in net_result_dict['R1'][type_indx]['iou'].keys():
            result_iou_dict['run'] += ['mean']
            result_iou_dict['type_indx'] += [type_indx]
            result_iou_dict['class'] += [str(class_indx)]
            result_iou_dict['metric_result'] += [
                get_iou_mean([
                    get_iou_mean(net_result_dict[run_k][type_indx]['iou'][str(class_indx)])
                    for run_k in net_result_dict.keys()
                ])
            ]

    df_dice = pd.DataFrame.from_dict(result_dice_dict)
    df_iou = pd.DataFrame.from_dict(result_iou_dict)
    
    plt.rcParams.update({'font.size': 22})
    g = sns.catplot(
        x='run', y='metric_result', hue='class', 
        col='type_indx', data=df_dice, kind='bar',
        height=12, aspect=.9
    )
    g.savefig(f'{folder_save}/dice.png')
    plt.close(g.figure)
    
    plt.rcParams.update({'font.size': 22})
    g = sns.catplot(
        x='run', y='metric_result', hue='class', 
        col='type_indx', data=df_iou, kind='bar',
        height=12, aspect=.9
    )
    g.savefig(f'{folder_save}/iou.png')
    plt.close(g.figure)

In [None]:
for kfold in all_result_dict.keys():
    kfold_dict = all_result_dict[kfold]
    for cur_exp in kfold_dict.keys():
        net_result_dict = kfold_dict[cur_exp]
        folder_save = f'res_{kfold}/{cur_exp}'
        draw_single_figure(net_result_dict, folder_save)

In [None]:
def draw_single_figure_mean(net_result_dict, folder_save):
    os.makedirs(folder_save, exist_ok=True)
    result_dice_dict = {'run': [], 'type_indx': [], 'metric_result': []}
    result_iou_dict = {'run': [], 'type_indx': [], 'metric_result': []}

    get_dice_mean = lambda x: 1.0 if len(x) == 0 else np.asarray(x).mean()
    get_iou_mean = lambda x: 0.0 if len(x) == 0 else np.asarray(x).mean()


    for run_k in net_result_dict.keys():
        res_per_run_dict = net_result_dict[run_k]
        for type_indx in res_per_run_dict.keys():
            res_per_type_indx_dict = res_per_run_dict[type_indx]
            # Dice
            result_dice_dict['run'] += [run_k]
            result_dice_dict['type_indx'] += [type_indx]
            result_dice_dict['metric_result'] += [
                get_dice_mean([
                    get_dice_mean(res_per_type_indx_dict['dice'][str(class_indx)])
                    for class_indx in res_per_type_indx_dict['dice'].keys()
                ])
            ]
            # Iou
            result_iou_dict['run'] += [run_k]
            result_iou_dict['type_indx'] += [type_indx]
            result_iou_dict['metric_result'] += [
                get_iou_mean([
                    get_iou_mean(res_per_type_indx_dict['iou'][str(class_indx)])
                    for class_indx in res_per_type_indx_dict['iou'].keys()
                ])
            ]
    # Mean dice
    for type_indx in ['val', 'test']:
        result_dice_dict['run'] += ['mean']
        result_dice_dict['type_indx'] += [type_indx]
        result_dice_dict['metric_result'] += [
           get_dice_mean([
                get_dice_mean([
                    get_dice_mean(net_result_dict[run_k][type_indx]['dice'][str(class_indx)])
                    for run_k in net_result_dict.keys()
                ])
                for class_indx in net_result_dict['R1'][type_indx]['dice'].keys()
            ])
        ]
    # Mean iou
    for type_indx in ['val', 'test']:
        result_iou_dict['run'] += ['mean']
        result_iou_dict['type_indx'] += [type_indx]
        result_iou_dict['metric_result'] += [
            get_iou_mean([
                get_iou_mean([
                    get_iou_mean(net_result_dict[run_k][type_indx]['iou'][str(class_indx)])
                    for run_k in net_result_dict.keys()
                ])
                for class_indx in net_result_dict['R1'][type_indx]['iou'].keys()
            ])
        ]

    df_dice = pd.DataFrame.from_dict(result_dice_dict)
    df_iou = pd.DataFrame.from_dict(result_iou_dict)

    plt.rcParams.update({'font.size': 22})
    g = sns.catplot(
        x='run', y='metric_result', 
        col='type_indx', data=df_dice, kind='bar',
        height=12, aspect=.9
    )
    g.savefig(f'{folder_save}/dice_mean.png')
    plt.close(g.figure)

    plt.rcParams.update({'font.size': 22})
    g = sns.catplot(
        x='run', y='metric_result', 
        col='type_indx', data=df_iou, kind='bar',
        height=12, aspect=.9
    )
    g.savefig(f'{folder_save}/iou_mean.png')
    plt.close(g.figure)

In [None]:
for kfold in all_result_dict.keys():
    kfold_dict = all_result_dict[kfold]
    for cur_exp in kfold_dict.keys():
        net_result_dict = kfold_dict[cur_exp]
        folder_save = f'res_{kfold}/{cur_exp}'
        draw_single_figure_mean(net_result_dict, folder_save)

In [None]:
all_result_dict.keys()

In [None]:
kfold_dict = all_result_dict['kfold1']

In [None]:
kfold_dict.keys()

In [None]:
cur_exp = 'AttentionV11'

In [None]:
net_result_dict = kfold_dict[cur_exp]
folder_save = f'test_{kfold}/{cur_exp}'

In [None]:
[(k, len(v)) for (k, v) in net_result_dict['R1']['val']['dice'].items()]

In [None]:
net_result_dict['R1']['val']['dice']['16']

In [None]:
os.makedirs(folder_save, exist_ok=True)
result_dice_dict = {'run': [], 'type_indx': [], 'metric_result': []}
result_iou_dict = {'run': [], 'type_indx': [], 'metric_result': []}

get_dice_mean = lambda x: 1.0 if len(x) == 0 else np.asarray(x).mean()
get_iou_mean = lambda x: 0.0 if len(x) == 0 else np.asarray(x).mean()


for run_k in net_result_dict.keys():
    res_per_run_dict = net_result_dict[run_k]
    for type_indx in res_per_run_dict.keys():
        res_per_type_indx_dict = res_per_run_dict[type_indx]
        # Dice
        result_dice_dict['run'] += [run_k]
        result_dice_dict['type_indx'] += [type_indx]
        result_dice_dict['metric_result'] += [
            get_dice_mean([
                get_dice_mean(res_per_type_indx_dict['dice'][str(class_indx)])
                for class_indx in res_per_type_indx_dict['dice'].keys()
            ])
        ]
        # Iou
        result_iou_dict['run'] += [run_k]
        result_iou_dict['type_indx'] += [type_indx]
        result_iou_dict['metric_result'] += [
            get_iou_mean([
                get_iou_mean(res_per_type_indx_dict['iou'][str(class_indx)])
                for class_indx in res_per_type_indx_dict['iou'].keys()
            ])
        ]
# Mean dice
for type_indx in ['val', 'test']:
    result_dice_dict['run'] += ['mean']
    result_dice_dict['type_indx'] += [type_indx]
    result_dice_dict['metric_result'] += [
       get_dice_mean([
            get_dice_mean([
                get_dice_mean(net_result_dict[run_k][type_indx]['dice'][str(class_indx)])
                for run_k in net_result_dict.keys()
            ])
            for class_indx in net_result_dict['R1'][type_indx]['dice'].keys()
        ])
    ]
# Mean iou
for type_indx in ['val', 'test']:
    result_iou_dict['run'] += ['mean']
    result_iou_dict['type_indx'] += [type_indx]
    result_iou_dict['metric_result'] += [
        get_iou_mean([
            get_iou_mean([
                get_iou_mean(net_result_dict[run_k][type_indx]['iou'][str(class_indx)])
                for run_k in net_result_dict.keys()
            ])
            for class_indx in net_result_dict['R1'][type_indx]['iou'].keys()
        ])
    ]

df_dice = pd.DataFrame.from_dict(result_dice_dict)
df_iou = pd.DataFrame.from_dict(result_iou_dict)

plt.rcParams.update({'font.size': 22})
g = sns.catplot(
    x='run', y='metric_result', 
    col='type_indx', data=df_dice, kind='bar',
    height=12, aspect=.9
)
g.savefig(f'{folder_save}/dice_mean.png')
plt.close(g.figure)

plt.rcParams.update({'font.size': 22})
g = sns.catplot(
    x='run', y='metric_result', 
    col='type_indx', data=df_iou, kind='bar',
    height=12, aspect=.9
)
g.savefig(f'{folder_save}/iou_mean.png')
plt.close(g.figure)

In [None]:
!rm -r test_kfold2/

In [None]:
single_params = params_list[0]
single_params.kfold, single_params.type_indices, single_params.cur_exp

In [None]:
if single_params.kfold == 'kfold1':
    num_classes = 7
    other_indx = 6
else: 
    num_classes = 8
    other_indx = 7


other_classes = OTHER_CLASSES_INFO[single_params.kfold][
    'swapped' if single_params.cur_exp in [SW_WO_ATT_EXP, SW_ATT_EXP] else 'origin'
][single_params.type_indices]

result_dict = {
    'dice': dict([(str(cl), []) for cl in other_classes]),
    'iou': dict([(str(cl), []) for cl in other_classes])
}

path_to_results = '/raid/rustam/hyperspectral_dataset/' +\
                  'diff_exp_with_other__attention_with_other/result_masks_0/' +\
                 f'{single_params.kfold}/{single_params.cur_exp}/{single_params.run}/{single_params.type_indices}'

path_to_origin = f'{path_to_results}/origin' 
path_to_pred = f'{path_to_results}/pred' 

for single_origin_path in glob.glob(f'{path_to_origin}/*'):
    file_name = single_origin_path.split('/')[-1]
    single_pred_path = f'{path_to_pred}/{file_name}'
    origin = np.squeeze(np.load(single_origin_path))
    pred = np.squeeze(np.load(single_pred_path))
    local_mask = np.zeros_like(origin, dtype=np.float32)
    for other_subclass in other_classes:
        if other_subclass not in origin:
            continue
        
        origin_mask = (origin == other_subclass).astype(np.int64)
        pred_mask = (pred == other_indx).astype(np.int64)
        
        pred_mask_t = torch.from_numpy(pred_mask)
        origin_mask_t = torch.from_numpy(origin_mask)
        iou = calculate_iou([pred_mask_t], [origin_mask_t], 2)
        result_dict['iou'][str(other_subclass)] += [float(iou[-1, -1])]

        origin_one_hotted_tensor = torch.nn.functional.one_hot(
            origin_mask_t.unsqueeze(0), 2 # Num classes
        )
        # (N, H, W, C) --> (N, C, H, W)
        origin_one_hotted_tensor = origin_one_hotted_tensor.permute(0, -1, 1, 2)
        dice = dice_loss(
            pred_mask_t.unsqueeze(0).unsqueeze(0), origin_one_hotted_tensor,
            dim=[0, 2, 3], use_softmax=False, softmax_dim=1,
        )
        result_dict['dice'][str(other_subclass)] += [float(dice[-1])]

In [None]:
[len(result_dict['iou'][k]) for k in result_dict['iou'].keys()]

In [None]:
result_dice_dict = {'run': [], 'type_indx': [], 'class': [], 'metric_result': []}
result_iou_dict = {'run': [], 'type_indx': [], 'class': [], 'metric_result': []}

net_result_dict = all_result_dict['kfold0'][ATT_EXP]

for run_k in net_result_dict.keys():
    res_per_run_dict = net_result_dict[run_k]
    for type_indx in res_per_run_dict.keys():
        res_per_type_indx_dict = res_per_run_dict[type_indx]
        
        for class_indx in res_per_type_indx_dict['dice'].keys():
            result_dice_dict['run'] += [run_k]
            result_dice_dict['type_indx'] += [type_indx]
            result_dice_dict['class'] += [str(class_indx)]
            result_dice_dict['metric_result'] += [
                np.asarray(res_per_type_indx_dict['dice'][str(class_indx)]).mean()
            ]
        
        for class_indx in res_per_type_indx_dict['iou'].keys():
            result_iou_dict['run'] += [run_k]
            result_iou_dict['type_indx'] += [type_indx]
            result_iou_dict['class'] += [str(class_indx)]
            result_iou_dict['metric_result'] += [
                np.asarray(res_per_type_indx_dict['iou'][str(class_indx)]).mean()
            ]
# Mean dice
for type_indx in ['val', 'test']:
    for class_indx in net_result_dict['R1'][type_indx]['dice'].keys():
        result_dice_dict['run'] += ['mean']
        result_dice_dict['type_indx'] += [type_indx]
        result_dice_dict['class'] += [str(class_indx)]
        result_dice_dict['metric_result'] += [
            np.asarray(
                [
                    np.asarray(net_result_dict[run_k][type_indx]['dice'][str(class_indx)]).mean()
                    for run_k in net_result_dict.keys()
                ]
            ).mean()
        ]
# Mean iou
for type_indx in ['val', 'test']:
    for class_indx in net_result_dict['R1'][type_indx]['iou'].keys():
        result_iou_dict['run'] += ['mean']
        result_iou_dict['type_indx'] += [type_indx]
        result_iou_dict['class'] += [str(class_indx)]
        result_iou_dict['metric_result'] += [
            np.asarray(
                [
                    np.asarray(net_result_dict[run_k][type_indx]['iou'][str(class_indx)]).mean()
                    for run_k in net_result_dict.keys()
                ]
            ).mean()
        ]

In [None]:
df_dice = pd.DataFrame.from_dict(result_dice_dict)
df_dice.head()

In [None]:
plt.rcParams.update({'font.size': 22})
g = sns.catplot(
    x='run', y='metric_result', hue='class', 
    col='type_indx', data=df_dice, kind='bar',
    height=12, aspect=.9
)
g.axes.set_title('dice')
g.savefig('d.png')

In [None]:
kfold = 'kfold2'
cur_exp = SW_WO_ATT_EXP
run = 'R2'
type_indices = 'val' #['val', 'test']

if kfold == 'kfold1':
    num_classes = 7
    other_indx = 6
else: 
    num_classes = 8
    other_indx = 7


other_classes = OTHER_CLASSES_INFO[kfold]['swapped' if cur_exp in [SW_WO_ATT_EXP, SW_ATT_EXP] else 'origin'][type_indices]
other_classes

In [None]:
path_to_results = '/raid/rustam/hyperspectral_dataset/' +\
                 f'diff_exp_with_other__attention_with_other/result_masks_0/{kfold}/{cur_exp}/{run}/{type_indices}'

path_to_origin = f'{path_to_results}/origin' 
path_to_pred = f'{path_to_results}/pred' 

In [None]:
single_origin_path = glob.glob(f'{path_to_origin}/*')[np.random.randint(0, len(glob.glob(f'{path_to_origin}/*')))]
file_name = single_origin_path.split('/')[-1]
single_pred_path = f'{path_to_pred}/{file_name}'

In [None]:
origin = np.squeeze(np.load(single_origin_path))
pred = np.squeeze(np.load(single_pred_path))

In [None]:
sns.heatmap(pred)

In [None]:
sns.heatmap(origin), np.unique(origin)

In [None]:
local_mask = np.zeros_like(origin, dtype=np.float32)

In [None]:
other_subclass = other_classes[2]
other_subclass

In [None]:
origin_mask = (origin == other_subclass).astype(np.int64)
pred_mask = (pred == other_indx).astype(np.int64)
correct_predicted_area = pred_mask * origin_mask
sns.heatmap(correct_predicted_area)

In [None]:
sns.heatmap(origin_mask)

In [None]:
sns.heatmap(pred_mask)

In [None]:
pred_mask_t = torch.from_numpy(pred_mask)
origin_mask_t = torch.from_numpy(origin_mask)
calculate_iou([pred_mask_t], [origin_mask_t], 2)

In [None]:
origin_one_hotted_tensor = torch.nn.functional.one_hot(
    origin_mask_t.unsqueeze(0), 2 # Num classes
)
# (N, H, W, C) --> (N, C, H, W)
origin_one_hotted_tensor = origin_one_hotted_tensor.permute(0, -1, 1, 2)
dice_loss(
    pred_mask_t.unsqueeze(0).unsqueeze(0), origin_one_hotted_tensor,
    dim=[0, 2, 3], use_softmax=False, softmax_dim=1,
)