In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "2"

import sys
sys.path.append('/home/rustam/hyperspecter_segmentation/makitorch')
sys.path.append('/home/rustam/hyperspecter_segmentation/')

PREFIX_INFO_PATH = '/home/rustam/hyperspecter_segmentation/danil_cave/kfolds_data/kfold0'
PATH_DATA = '/raid/rustam/hyperspectral_dataset/new_cropped_hsi_data'


from multiprocessing.dummy import Pool
from multiprocessing import shared_memory

from makitorch import *
import math
import numpy as np
import numba as nb
import comet_ml
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms as T
import torchvision.transforms.functional as TF
from torchvision import utils
import cv2
from Losses import FocalLoss
import matplotlib.pyplot as plt

import seaborn as sns
import json
from tqdm import tqdm

from sklearn.decomposition import PCA
from makitorch.architectures.U2Net import U2Net

from hsi_dataset_api import HsiDataset

from makitorch.dataloaders.HsiDataloader import HsiDataloader
from makitorch.architectures.Unet import Unet, UnetWithFeatureSelection
from makitorch.loss import muti_bce_loss_fusion
from sklearn.metrics import jaccard_score
np.set_printoptions(suppress=True)


from makitorch.data_tools.augmentation import DataAugmentator
from makitorch.data_tools.augmentation import BaseDataAugmentor
from makitorch.data_tools.preprocessing import BaseDataPreprocessor
from makitorch.data_tools.preprocessing import DataPreprocessor

from typing import Callable, Optional, Union

import torch
from sklearn.utils import shuffle
from hsi_dataset_api import HsiDataset

import time

In [None]:
device = 'cuda:0'

In [None]:
def cut_into_parts_model_input(
        image: np.ndarray, h_parts: int, 
        w_parts: int, h_win: int, w_win: int):
    image_parts_list = []

    for h_i in range(h_parts):
        for w_i in range(w_parts):
            img_part = image[:, :,  
                h_i * h_win: (h_i+1) * h_win, 
                w_i * w_win: (w_i+1) * w_win
            ]
            image_parts_list.append(img_part)
    return image_parts_list


def merge_parts_into_single_mask(
        preds, shape, h_parts: int, 
        w_parts: int, h_win: int, w_win: int):
    pred_mask = torch.zeros(
        shape,
        dtype=preds.dtype, device=preds.device
    )
    counter = 0

    for h_i in range(h_parts):
        for w_i in range(w_parts):
            pred_mask[:, :,  
                h_i * h_win: (h_i+1) * h_win, 
                w_i * w_win: (w_i+1) * w_win
            ] = preds[counter]
            counter += 1
    return pred_mask


def collect_prediction_and_target(eval_loader, model, cut_window=(8, 8), image_shape=(512, 512), num_classes=17):
    target_list = []
    pred_list = []
    
    for in_data_x, val_data in tqdm(iter(eval_loader)):
        batch_size = in_data_x.shape[0]
        # We will cut image into peases and stack it into single BIG batch
        h_win, w_win = cut_window
        h_parts, w_parts = image_shape[1] // w_win, image_shape[0] // h_win
        in_data_x_parts_list = cut_into_parts_model_input(
            in_data_x, h_parts=h_parts, 
            w_parts=w_parts, h_win=h_win, w_win=w_win
        )
        in_data_x_batch = torch.cat(in_data_x_parts_list, dim=0) # (N, 17, 1, 1)
        # Make predictions
        preds = model(in_data_x_batch) # (N, num_classes, 8, 8)
        # Create full image again from peases
        pred_mask = merge_parts_into_single_mask(
            preds=preds, shape=(batch_size, num_classes, image_shape[0], image_shape[1]), 
            h_parts=h_parts, w_parts=w_parts, h_win=h_win, w_win=w_win
        )
        target_list.append(val_data)
        pred_list.append(pred_mask)
    return (torch.cat(pred_list, dim=0), 
            torch.cat(target_list, dim=0)
    )


def matrix2onehot(matrix, num_classes=17):
    matrix = matrix.copy().reshape(-1)
    one_hoted = np.zeros((matrix.size, num_classes))
    one_hoted[np.arange(matrix.size),matrix] = 1
    return one_hoted

def list_target_to_onehot(target_tensor, num_classes=17):
    one_hoted_list = []
    for target in target_tensor:
        # target - (H, W)
        target =  target.cpu().detach().numpy()
        h,w = target.shape
        target = matrix2onehot(target, num_classes=num_classes)
        target = target.reshape(h, w, -1)
        target = np.transpose(target, [2, 0, 1])
        one_hoted_list.append(target)
    return torch.from_numpy(np.stack(one_hoted_list, axis=0))
        

def calculate_iou(pred_list, target_list, num_classes=17):
    res_list = []
    pred_as_mask_list = []
    
    for preds, target in zip(pred_list, target_list):
        # preds - (num_classes, H, W)
        preds = preds.detach()
        # target - (H, W)
        target = target.detach()

        preds = nn.functional.softmax(preds, dim=0)
        preds = torch.argmax(preds, dim=0)
        pred_as_mask_list.append(preds)
        
        preds_one_hoted = torch.nn.functional.one_hot(preds, num_classes).view(-1, num_classes)
        target_one_hoted = torch.nn.functional.one_hot(target, num_classes).view(-1, num_classes)
        res = jaccard_score(target_one_hoted, preds_one_hoted, average=None, zero_division=1)
        res_list.append(
            res
        )
    
    res_np = np.stack(res_list)
    #res_np = res_np.mean(axis=0)
    return res_np, pred_as_mask_list


def dice_loss(preds, ground_truth, eps=1e-5, dim=None, use_softmax=False, softmax_dim=1):
    """
    Computes Dice loss according to the formula from:
    V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation
    Link to the paper: http://campar.in.tum.de/pub/milletari2016Vnet/milletari2016Vnet.pdf
    Parameters
    ----------
    preds : tf.Tensor
        Predicted probabilities.
    ground_truth : tf.Tensor
        Ground truth labels.
    eps : float
        Used to prevent division by zero in the Dice denominator.
    axes : list
        Defines which axes the dice value will be computed on. The computed dice values will be averaged
        along the remaining axes. If None, Dice is computed on an entire batch.
    Returns
    -------
    tf.Tensor
        Scalar dice loss tensor.
    """
    ground_truth = ground_truth.float().to(device=preds.device)
    
    if use_softmax:
        preds = nn.functional.softmax(preds, dim=softmax_dim)
    
    numerator = preds * ground_truth
    numerator = torch.sum(numerator, dim=dim)

    p_squared = torch.square(preds)
    p_squared = torch.sum(p_squared, dim=dim)
    # ground_truth is not squared to avoid unnecessary computation.
    # 0^2 = 0
    # 1^2 = 1
    g_squared = torch.sum(torch.square(ground_truth), dim=dim)
    denominator = p_squared + g_squared + eps

    dice = 2 * numerator / denominator
    return 1 - dice

def clear_metric_calculation(final_metric, target_t, pred_t, num_classes=17):
    """
    
    Parameters
    ----------
    final_metric: torch.Tensor
        Tensor with shape (N, C)
    target_t: torch.Tensor or list
        Tensor with shape (N, 1, H, W)
    pred_t: torch.Tensor or list
        Tensor with shape (N, 1, H, W)
    
    """
    # For each image
    final_metric_dict = dict([
        (str(i), []) for i in range(num_classes)
    ])
    for metric_s, target_t_s, pred_t_s in zip(final_metric, target_t, pred_t):
        unique_indx_target = torch.unique(target_t_s.long())
        if isinstance(pred_t_s, np.ndarray):
            pred_t_s = torch.from_numpy(pred_t_s)
        unique_indx_pred = torch.unique(pred_t_s.long())
        for i in range(num_classes):
            if i in unique_indx_target or i in unique_indx_pred:
                final_metric_dict[str(i)].append(metric_s[i])
    
    mean_per_class_metric = [
        sum(final_metric_dict[str(i)]) / len(final_metric_dict[str(i)])
        if len(final_metric_dict[str(i)]) != 0
        else 0.0
        for i in range(num_classes)
    ] 
    mean_metric = sum(mean_per_class_metric) / len(mean_per_class_metric)
    return mean_per_class_metric, mean_metric



In [None]:


class MySuperNetLittleInput(nn.Module):
    
    def __init__(self, in_f=237, out_f=17, *args):
        super().__init__()
        #self.bn_start = nn.BatchNorm3d(in_f)
        
        self.conv1 = nn.Conv2d(in_f, 128, kernel_size=3, stride=1, padding=1)
        # (N, 128, 8, 8)
        self.bn1 = nn.BatchNorm2d(128)
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        # (N, 128, 8, 8)
        self.bn2 = nn.BatchNorm2d(128)
        self.act2 = nn.ReLU()
        
        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        # (N, 64, 8, 8)
        self.bn3 = nn.BatchNorm2d(64)
        self.act3 = nn.ReLU()
        
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        # (N, 64, 8, 8)
        self.bn4 = nn.BatchNorm2d(64)
        self.act4 = nn.ReLU()

        self.conv5 = nn.Conv2d(64, out_f, kernel_size=3, stride=1, padding=1)
        # (N, 17, 8, 8)
        self.bn5 = nn.BatchNorm2d(out_f)
        self.act5 = nn.ReLU()

        self.final_conv = nn.Conv2d(out_f, out_f, kernel_size=1, stride=1, padding=0)
    
    def __call__(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act3(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.act4(x)
        
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.act5(x)

        x = self.final_conv(x)
        return x


In [None]:
net = MySuperNetLittleInput(in_f=17, out_f=17)
net.eval()
with torch.no_grad():
    _ = net(torch.randn(1, 17, 128, 128))
#net.to(device=device)

In [None]:
fake_input = torch.rand(39, 17, 512, 512, dtype=torch.float)
fake_masks = torch.randint(0, 17, size=(39, 512, 512)).long()

In [None]:
def data_generator(fake_input, fake_masks):
    for f_i, f_m in zip(fake_input, fake_masks):
        yield torch.unsqueeze(f_i, dim=0), torch.unsqueeze(f_m, dim=0)

In [None]:
data_g = data_generator(fake_input, fake_masks)

In [None]:
pred_tensor, target_tensor = collect_prediction_and_target(data_g, net, cut_window=(128, 128))

In [None]:
start = time.time()
target_one_hotted_tensor = list_target_to_onehot(target_tensor)
time.time() - start

In [None]:
target_one_hotted_tensor.shape

In [None]:
start = time.time()
target_one_hotted_tensor_n = torch.nn.functional.one_hot(
    target_tensor, 17 # Num classes
)
target_one_hotted_tensor_n = target_one_hotted_tensor_n.permute(0, -1, 1, 2)
time.time() - start

In [None]:
torch.permute

In [None]:
target_one_hotted_tensor_n.shape

In [None]:
torch.mean((target_one_hotted_tensor_n == target_one_hotted_tensor).float())

In [None]:
start = time.time()
dice_loss_val = dice_loss(pred_tensor, target_one_hotted_tensor, dim=[0, 2, 3], use_softmax=True, softmax_dim=1)
time.time() - start

In [None]:
start = time.time()
metric, pred_as_mask_list = calculate_iou(pred_tensor, target_tensor)
time.time() - start

In [None]:
start = time.time()
metric, pred_as_mask_list = calculate_iou(pred_tensor, target_tensor)
time.time() - start

In [None]:
pred_as_mask_list[0].dtype

In [None]:
metric.mean(axis=0)

In [None]:
isinstance(pred_as_mask_list[0], np.ndarray)

In [None]:
start = time.time()
clear_metric_calculation(metric, target_tensor, pred_as_mask_list)
time.time() - start

In [None]:
start = time.time()
clear_metric_calculation(metric, target_tensor, pred_as_mask_list)
time.time() - start

In [None]:
d_np = pred_tensor[0].detach().numpy().copy()
d_np.shape

In [None]:
start = time.time()
_ = np.unique(d_np)
time.time() - start

In [None]:
start = time.time()
_ = torch.unique(pred_tensor[0])
time.time() - start

In [None]:
target = target_tensor[0].cpu().detach().numpy()
target_one_hoted = matrix2onehot(target, num_classes=17)

In [None]:
target_tensor.shape

In [None]:
target_one_hoted_t = torch.nn.functional.one_hot(target_tensor[0], 17)

In [None]:
target_one_hoted[1000:1010].astype(np.int32)

In [None]:
target_one_hoted_t.view(-1, 17)[1000:1010]

In [None]:
np.mean(target_one_hoted.astype(np.float32) == target_one_hoted_t.view(-1, 17).numpy().astype(np.float32))

In [None]:
isinstance(target_tensor[0], np.ndarray)

In [None]:
import seaborn as sns

In [None]:
start = time.time()
sns.heatmap(target_tensor[0])
time.time() - start

In [None]:
target_tensor.shape