In [None]:

import os

GPU_ID = "0"
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID

import sys
sys.path.append('/home/rustam/hyperspecter_segmentation/makitorch')
sys.path.append('/home/rustam/hyperspecter_segmentation/')

PREFIX_INFO_PATH = '/home/rustam/hyperspecter_segmentation/danil_cave/kfolds_data/kfold0'
PATH_DATA = '/raid/rustam/hyperspectral_dataset/new_cropped_hsi_data'


from multiprocessing.dummy import Pool
from multiprocessing import shared_memory
import copy
from makitorch import *
import math
import numpy as np
import numba as nb
import comet_ml
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms as T
import torchvision.transforms.functional as TF
from torchvision import utils
import cv2
from Losses import FocalLoss
import matplotlib.pyplot as plt

import seaborn as sns
import json
from tqdm import tqdm

from sklearn.decomposition import PCA
from makitorch.architectures.U2Net import U2Net

from hsi_dataset_api import HsiDataset

from makitorch.dataloaders.HsiDataloader import HsiDataloader
from makitorch.architectures.Unet import Unet, UnetWithFeatureSelection
from makitorch.loss import muti_bce_loss_fusion
from sklearn.metrics import jaccard_score
np.set_printoptions(suppress=True)


from makitorch.data_tools.augmentation import DataAugmentator
from makitorch.data_tools.augmentation import BaseDataAugmentor
from makitorch.data_tools.preprocessing import BaseDataPreprocessor
from makitorch.data_tools.preprocessing import DataPreprocessor

from typing import Callable, Optional, Union

import torch
from sklearn.utils import shuffle
from hsi_dataset_api import HsiDataset


@nb.njit
def cut_into_parts(
        image: np.ndarray, mask: np.ndarray, h_parts: int, 
        w_parts: int, h_win: int, w_win: int):
    image_parts_list = []
    mask_parts_list = []

    for h_i in range(h_parts):
        for w_i in range(w_parts):
            img_part = image[:, 
                h_i * h_win: (h_i+1) * h_win, 
                w_i * w_win: (w_i+1) * w_win
            ]
            mask_part = mask[
                h_i * h_win: (h_i+1) * h_win, 
                w_i * w_win: (w_i+1) * w_win
            ]

            image_parts_list.append(img_part)
            mask_parts_list.append(mask_part)
    return image_parts_list, mask_parts_list


class ShmData:

    def __init__(self, shm_name, shape, dtype):
        self.shm_name = shm_name
        self.shape = shape
        self.dtype = dtype


class DatasetCreator:

    def __init__(
            self, 
            data_path: str,
            preprocessing: Optional[Union[DataPreprocessor, Callable]] = BaseDataPreprocessor(),
            indices = None,
            cut_window=(8, 8),
            map_mask_to_class=False,
            create_shared_memory=False,
            shuffle_then_prepared=False):
        self.dataset = HsiDataset(data_path)
        self.preprocessing = preprocessing
        self.cut_window = cut_window
        self.map_mask_to_class = map_mask_to_class
        self.create_shared_memory = create_shared_memory
        
        self.images = []
        self.masks = []
        self._shm_imgs = None
        self._shm_masks = None

        for idx, data_point in tqdm(enumerate(self.dataset.data_iterator(opened=True, shuffle=False))):
            if indices is not None and idx not in indices:
                continue
            image, mask = data_point.hsi, data_point.mask
            if cut_window is not None:
                image_parts, mask_parts = self._cut_with_window(image, mask, cut_window)
                self.images += image_parts
                self.masks += mask_parts
            else:
                self.images.append(image)
                self.masks.append(mask)
        print("Preprocess data...")
        if self.preprocessing is not None:
            self.images, self.masks = self.preprocessing(
                self.images, self.masks, map_mask_to_class=map_mask_to_class
            )

        if shuffle_then_prepared:
            self.images, self.masks = shuffle(self.images, self.masks)

        # Create shared memory
        if create_shared_memory:
            print('Create shared memory...')
            # First - map images and masks into np
            self.images = np.asarray(self.images, dtype=np.float32)
            self.masks = np.asarray(self.masks, dtype=np.int64)
            # Imgs
            shm_imgs = shared_memory.SharedMemory(create=True, size=self.images.nbytes)
            shm_imgs_arr = np.ndarray(self.images.shape, dtype=self.images.dtype, buffer=shm_imgs.buf)
            shm_imgs_arr[:] = self.images[:]
            self.images = shm_imgs_arr # Do not keep dublicate 
            self.data_shm_imgs = ShmData(
                shm_name=shm_imgs.name, shape=self.images.shape, 
                dtype=self.images.dtype
            )
            self._shm_imgs = shm_imgs
            # Masks
            shm_masks = shared_memory.SharedMemory(create=True, size=self.masks.nbytes)
            shm_masks_arr = np.ndarray(self.masks.shape, dtype=self.masks.dtype, buffer=shm_masks.buf)
            shm_masks_arr[:] = self.masks[:]
            self.masks = shm_masks_arr # Do not keep dublicate 
            self.data_shm_masks = ShmData(
                shm_name=shm_masks.name, shape=self.masks.shape,
                dtype=self.masks.dtype
            )
            self._shm_masks = shm_masks
            print("Shared memory are created for imgs and masks!")

    def close_shm(self):
        if self.create_shared_memory:
            # Make sure there is no reference data to images/masks
            del self.images
            del self.masks
            self.images = []
            self.masks = []
            # Close and unlink
            if self._shm_masks is not None:
                self._shm_masks.close()
                self._shm_masks.unlink()
                self._shm_masks = None

            if self._shm_imgs is not None:
                self._shm_imgs.close()
                self._shm_imgs.unlink()
                self._shm_imgs = None
            print("Shared memory for masks and images are success cleared!")
                    
    
    def _cut_with_window(self, image, mask, cut_window):
        assert len(cut_window) == 2
        h_win, w_win = cut_window
        _, h, w = image.shape
        h_parts = h // h_win
        w_parts = w // w_win
        if h % h_win != 0:
            print(f"{h % h_win} pixels will be dropped by h axis. Input shape={image.shape}")

        if w % w_win != 0:
            print(f"{w % w_win} pixels will be dropped by w axis. Input shape={image.shape}")
        return cut_into_parts(
            image=image, mask=mask, h_parts=h_parts, w_parts=w_parts,
            h_win=h_win, w_win=w_win
        )


class HsiDataloaderCutter(torch.utils.data.IterableDataset):
    def __init__(
            self, 
            images, masks,
            augmentation: Optional[Union[DataAugmentator, Callable]] = BaseDataAugmentor(),
            shuffle_data=False,
            cut_window=(8, 8),
            map_mask_to_class=False,
            shared_memory_imgs_data: ShmData = None,
            shared_memory_masks_data: ShmData = None
        ):
        super().__init__()
        self.shuffle_data = shuffle_data
        self.augmentation = augmentation
        if cut_window is not None and cut_window[0] == 1 and cut_window[1] == 1:
            self.ignore_image_augs = True
        else:
            self.ignore_image_augs = False
        self.cut_window = cut_window
        self.map_mask_to_class = map_mask_to_class
        self.shared_memory_imgs_data = shared_memory_imgs_data
        self.shared_memory_masks_data = shared_memory_masks_data
        
        self.shm_imgs: shared_memory.SharedMemory = None
        self.shm_masks: shared_memory.SharedMemory = None

        self.images = images
        self.masks = masks

    def __iter__(self):
        assert self.images is not None and self.masks is not None
        if self.shuffle_data:
            self.images, self.masks = shuffle(self.images, self.masks)
        
        for image, mask in zip(self.images, self.masks):
            yield self.augmentation(
                image, mask, 
                map_mask_to_class=self.map_mask_to_class,
                ignore_image_augs=self.ignore_image_augs
            )


def pca_transformation(x):
    if len(x.shape) == 3:
        x_t = x.reshape((x.shape[0], -1)) # (C, H, W) -> (C, H * W)
        x_t = np.transpose(x_t, (1, 0)) # (C, H * W) -> (H * W, C)
        x_t = x_t - pca_mean
        x_t = np.dot(x_t, pca_components.T) / np.sqrt(pca_explained_variance)
        return x_t.reshape((x.shape[1], x.shape[2], pca_components.shape[0])).astype(np.float32, copy=False) # (H, W, N)
    elif len(x.shape) == 4:
        # x - (N, C, H, W)
        x_t = np.transpose(x, (0, 2, 3, 1)) # (N, C, H, W) -> (N, H, W, C)
        x_t = x_t - pca_mean
        x_t = np.dot(x_t, pca_components.T) / np.sqrt(pca_explained_variance)
        x_t = np.transpose(x_t, (0, -1, 1, 2)) # (N, H, W, C) -> (N, C, H, W)
        return x_t.astype(np.float32, copy=False)
    else:
        raise ValueError(f"Unknown shape={x.shape}, must be of len 3 or 4.")

def standartization(img, mean, std):
    img -= mean
    img /= std
    return img

def standartization_pool(mean, std):
    # X shape - (N, C, H, W)
    # from shape (comp,) -> (1, comp, 1, 1)
    mean = np.expand_dims(np.expand_dims(np.array(mean, dtype=np.float32), axis=-1), axis=-1)
    std = np.expand_dims(np.expand_dims(np.array(std, dtype=np.float32), axis=-1), axis=-1)
    
    return lambda x: standartization(x, mean=mean, std=std)


def mask2class(mask):
    # Calculate which class have more pixel count
    max_value = -1
    pixel_count = -1
    for class_indx in np.unique(mask):
        pix_count_s = np.sum(mask == class_indx)
        if pix_count_s > pixel_count:
            max_value = class_indx
            pixel_count = pix_count_s
    assert max_value != -1
    return np.array([max_value], dtype=np.int64) 


def preprocessing(imgs, masks, map_mask_to_class=False, split_size=256):
    with open(f'{PREFIX_INFO_PATH}/data_standartization_params_kfold0.json', 'r') as f:
        data_standartization_params = json.load(f)
    mean = data_standartization_params.get('means')
    std = data_standartization_params.get('stds')
    assert mean is not None and std is not None
    print('Create np array of imgs and masks...')
    imgs_np = np.asarray(imgs, dtype=np.float32) # (N, 237, 1, 1)
    masks_np = np.asarray(masks, dtype=np.int64) # (N, 1, 1, 3)
    print("Split imgs dataset...")
    imgs_split_np = np.array_split(imgs_np, split_size) # (split_size, Ns, 237, 1, 1)
    print('Start preprocess images...')
    # Wo PCA
    # _images = [np.transpose(image, (1, 2, 0)) for image in imgs]
    # W Pca
    with Pool(18) as p:
        _images = list(tqdm(p.imap(
                pca_transformation, 
                imgs_split_np,
                #chunksize=1
            ), total=len(imgs_split_np))
        )
        _images = list(tqdm(p.imap(
            standartization_pool(mean=mean, std=std), 
            _images,
            #chunksize=1
            ), total=len(imgs_split_np))
        )
    _images = list(np.concatenate(_images, axis=0)) # (split_size, Ns, 237, 1, 1) -> (split_size * Ns, 237, 1, 1)
    print("Preprocess masks...")
    _masks = list(np.transpose(masks_np[..., 0:1], (0, -1, 1, 2)))
    print("Finish preprocess!")
    if map_mask_to_class:
        _masks = [mask2class(mask) for mask in _masks]
    return _images, _masks


def test_augmentation(image, mask, **kwargs):
    image = torch.from_numpy(image)
    #image = (image - image.min()) / (image.max() - image.min())
    
    mask = torch.from_numpy(mask)
    mask = torch.squeeze(mask, 0)
    return image, mask


def augmentation(image, mask, map_mask_to_class=False, ignore_image_augs=False):
    image = torch.from_numpy(image)
    mask = torch.from_numpy(mask)
    if not ignore_image_augs:
        # Rotate
        angle = T.RandomRotation.get_params((-30, 30))
        image = TF.rotate(image, angle, interpolation=T.InterpolationMode.BILINEAR)
        if not map_mask_to_class:
            mask = TF.rotate(mask, angle, interpolation=T.InterpolationMode.NEAREST)
        # Flip horizontal
        if torch.rand(1) > 0.5:
            image = TF.hflip(image)
            if not map_mask_to_class:
                mask = TF.hflip(mask)
        # Flip vertical
        if torch.rand(1) > 0.5:
            image = TF.vflip(image)
            if not map_mask_to_class:
                mask = TF.vflip(mask)
    
    #image = (image - image.min()) / (image.max() - image.min())
    mask = torch.squeeze(mask, 0)
    return image, mask
    

In [None]:
PREFIX_INFO_PATH = '/home/rustam/hyperspecter_segmentation/danil_cave/kfolds_data/kfold0'
PATH_DATA = '/raid/rustam/hyperspectral_dataset/new_cropped_hsi_data'


pca_explained_variance = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaExplainedVariance_.npy')
pca_mean = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaMean.npy')
pca_components = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaComponents.npy')


test_indices = np.load(f'{PREFIX_INFO_PATH}/kfold0_indx_test.npy')
train_indices = np.load(f'{PREFIX_INFO_PATH}/kfold0_indx_train.npy')

cut_window = (512, 512)
batch_size = 2

In [None]:
dataset_creator_train = DatasetCreator(
    PATH_DATA, preprocessing=preprocessing, 
    indices=train_indices, cut_window=cut_window,
    create_shared_memory=False, shuffle_then_prepared=True
)
dataset_train = HsiDataloaderCutter(
    images=dataset_creator_train.images, masks=dataset_creator_train.masks,
    augmentation=augmentation,
    shuffle_data=True, cut_window=cut_window,
)
train_loader = torch.utils.data.DataLoader(
    dataset_train, batch_size=batch_size, drop_last=True
)
# Test
dataset_creator_test = DatasetCreator(
    PATH_DATA, preprocessing=preprocessing, 
    indices=test_indices, cut_window=None,
    create_shared_memory=False
)
dataset_test = HsiDataloaderCutter(
    images=dataset_creator_test.images, masks=dataset_creator_test.masks,
    augmentation=test_augmentation,
    shuffle_data=False, cut_window=None,
)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=1)

In [None]:
from tqdm import tqdm

In [None]:
val_loader = train_loader

specter_list = []
target_list = []

for img_s, mask_s in tqdm(val_loader):
    specter_list.append(img_s.numpy())
    target_list.append(mask_s.numpy())
len(specter_list)

In [None]:
indx_b = 0
indx = 122
indx_sp = 0

In [None]:
sns.heatmap(specter_list[indx][indx_b][indx_sp])

In [None]:
sns.heatmap(target_list[indx][indx_b])

In [None]:
cube = specter_list[indx][indx_b][:, :32, :32].copy().reshape(-1)
sns.distplot(cube)

In [None]:
cube = specter_list[indx][indx_b][indx_sp].copy().reshape(-1)
sns.distplot(cube)

In [None]:
plt.imshow((specter_list[indx][indx_b][indx_sp] * 255).astype(np.uint8))

In [None]:
import cv2

cube = specter_list[indx][indx_b][indx_sp].copy().reshape(-1)
cube = (cube * 255).astype(np.uint8)
cube = cv2.equalizeHist(cube)
cube = (cube / 255).astype(np.float32)

In [None]:
sns.distplot(cube)

In [None]:
plt.imshow((cube.reshape(512, 512) * 255).astype(np.uint8))

In [None]:
import cv2

cube = specter_list[indx][indx_b][indx_sp].copy().reshape(-1)
cube = (cube * 255).astype(np.uint8)
# create a CLAHE object (Arguments are optional).
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
cube = clahe.apply(cube)
cube = (cube / 255).astype(np.float32)

In [None]:
sns.distplot(cube)

In [None]:
plt.imshow((cube.reshape(512, 512) * 255).astype(np.uint8))

In [None]:
import seaborn as sns
import pandas as pd

In [None]:
num_class2count_test = dict([(str(i), 0) for i in range(17)])
num_class2count_train = dict([(str(i), 0) for i in range(17)])

num_class2count_pixels_test = dict([(str(i), 0) for i in range(17)])
num_class2count_pixels_train = dict([(str(i), 0) for i in range(17)])

# Test count
for img_s, mask_s in tqdm(val_loader_test):
    for n_c in np.unique(mask_s):
        num_class2count_test[str(n_c)] += 1

# Train count
for img_s, mask_s in tqdm(val_loader_train):
    for n_c in np.unique(mask_s):
        num_class2count_train[str(n_c)] += 1

# Test pixels
for img_s, mask_s in tqdm(val_loader_test):
    for n_c in np.unique(mask_s):
        num_class2count_pixels_test[str(n_c)] += torch.sum(mask_s == n_c).numpy()

# Train pixels
for img_s, mask_s in tqdm(val_loader_train):
    for n_c in np.unique(mask_s):
        num_class2count_pixels_train[str(n_c)] += torch.sum(mask_s == n_c).numpy()

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_train, index=[0]) / 324)

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_test, index=[0]) / 37)

In [None]:
num_class2count_diff = dict()

for k,v in num_class2count_test.items():
    num_class2count_diff[str(k)] = v - num_class2count_train[k]
    print(f'class={str(k).zfill(2)} num_test={str(v).zfill(2)} num_train={str(num_class2count_train[k]).zfill(2)}')

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_diff, index=[0]))

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_pixels_train, index=[0]))

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_pixels_test, index=[0]))

In [None]:
num_class2count_pixels_diff = dict()

for k,v in num_class2count_pixels_test.items():
    num_class2count_pixels_diff[str(k)] = v - num_class2count_pixels_train[k]

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_pixels_diff, index=[0]))