In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "-1"

import sys
sys.path.append('/home/rustam/hyperspecter_segmentation/makitorch')
sys.path.append('/home/rustam/hyperspecter_segmentation/')
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import gaussian_filter
import cv2
from hsi_dataset_api import HsiDataset
from makitorch.dataloaders.HsiDataloader import HsiDataloader

import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms as T
import torchvision.transforms.functional as TF
from torchvision import utils
import cv2
from Losses import FocalLoss
import matplotlib.pyplot as plt

import seaborn as sns
import json

In [None]:
class FocalLossCustom(nn.Module):
    """
    copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
        Focal_Loss= -1*alpha*(1-pt)*log(pt)
    :param num_class:
    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                    focus on hard misclassified example
    :param smooth: (float,double) smooth value when cross entropy
    :param balance_index: (int) balance class index, should be specific when alpha is float
    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
    """

    def __init__(self, alpha=None, gamma=5.5, balance_index=2, smooth=1e-5, size_average=False):
        super(FocalLossCustom, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.balance_index = balance_index
        self.smooth = smooth
        self.size_average = size_average
        self.cel = nn.CrossEntropyLoss(reduction='none')
        self.softmax = nn.Softmax(dim=-1)

        if self.smooth is not None:
            if self.smooth < 0 or self.smooth > 1.0:
                raise ValueError('smooth value should be in [0,1]')

    def forward(self, logit, target):
        num_class = logit.shape[1]

        if logit.dim() > 2:
            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
            logit = logit.view(logit.size(0), logit.size(1), -1)
            # N,C,m -> N,m,C
            logit = logit.permute(0, 2, 1).contiguous()
            # N,m,C -> N,m*C
            logit = logit.view(-1, logit.size(-1))
        target = torch.squeeze(target, 1)
        target = target.view(-1, 1)
        
        ce_loss = self.cel(logit, target.view(-1))
        train_conf = self.softmax(logit)
        
        idx = target.cpu().long()
        one_hot_labels  = torch.FloatTensor(target.size(0), num_class).zero_()
        one_hot_labels  = one_hot_labels.scatter_(1, idx, 1)
        if one_hot_labels.device != logit.device:
            one_hot_labels = one_hot_labels.to(logit.device)
        
        filtered_conf = train_conf * one_hot_labels
        sparce_conf, _ = torch.max(filtered_conf, dim=-1)
        loss = torch.pow((torch.ones_like(sparce_conf) - sparce_conf), self.gamma) * ce_loss
        if self.size_average:
            loss = loss.mean()
        if not self.size_average:
            # Norm by positive
            num_positive = torch.sum(target != self.balance_index)
            loss = loss.sum() / (num_positive + 1e-10)
        else:
            loss = loss.sum()
        return loss
    
    def __str__(self):
        return 'focal_loss'

In [None]:
ce = FocalLossCustom()

In [None]:
str(ce)

In [None]:
PREFIX_INFO_PATH = '/home/rustam/hyperspecter_segmentation/danil_cave/kfolds_data/kfold0'


pca_explained_variance = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaExplainedVariance_.npy')
pca_mean = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaMean.npy')
pca_components = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaComponents.npy')
test_indices = np.load(f'{PREFIX_INFO_PATH}/kfold0_indx_test.npy')
train_indices = np.load(f'{PREFIX_INFO_PATH}/kfold0_indx_train.npy')

path = '/raid/rustam/hyperspectral_dataset/new_cropped_hsi_data'

In [None]:
def pca_transformation(x):
    x_t = np.reshape(x, (x.shape[0], -1)) # (C, H, W) -> (C, H * W)
    x_t = np.swapaxes(x_t, 0, 1) # (C, H * W) -> (H * W, C)
    x_t = x_t - pca_mean
    x_t = np.dot(x_t, pca_components.T) / np.sqrt(pca_explained_variance)
    return np.reshape(x_t, (x.shape[1], x.shape[2], pca_components.shape[0])).astype(np.float32) # (H, W, N)

In [None]:
def test_augmentation(image, mask, *args):
    image = TF.to_tensor(image)
    #image = (image - image.min()) / (image.max() - image.min())
    
    mask = torch.from_numpy(mask)
    
    mask = torch.squeeze(mask, 0)
    return image, mask


def mask2class(mask):
    # Calculate which class have more pixel count
    max_value = -1
    pixel_count = -1
    for class_indx in np.unique(mask):
        pix_count_s = np.sum(mask == class_indx)
        if pix_count_s > pixel_count:
            max_value = class_indx
            pixel_count = pix_count_s
    assert max_value != -1
    return np.array([max_value], dtype=np.int64) 


def preprocessing(imgs, masks, map_mask_to_class=False):
    with open(f'{PREFIX_INFO_PATH}/data_standartization_params_kfold0.json', 'r') as f:
        data_standartization_params = json.load(f)
    mean = data_standartization_params.get('means')
    std = data_standartization_params.get('stds')
    assert mean is not None and std is not None
    def standartization(img):
        return np.array((img - mean) / std, dtype=np.float32)
    _images = [np.transpose(image, (1, 2, 0)) for image in imgs] #[pca_transformation(image) for image in imgs]
    #_images = [standartization(image) for image in _images]
    _masks = [
        np.expand_dims(
            cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY).astype(np.uint8)
            ,0
        ).astype(np.int64)
        for mask in masks
    ]
    if map_mask_to_class:
        _masks = [mask2class(mask) for mask in _masks]
    return _images, _masks

In [None]:
from makitorch.data_tools.augmentation import DataAugmentator
from makitorch.data_tools.augmentation import BaseDataAugmentor
from makitorch.data_tools.preprocessing import BaseDataPreprocessor
from makitorch.data_tools.preprocessing import DataPreprocessor

from typing import Callable, Optional, Union

import torch
from sklearn.utils import shuffle
from hsi_dataset_api import HsiDataset


class HsiDataloaderCutter(torch.utils.data.IterableDataset):
    def __init__(
            self, 
            data_path: str,
            preprocessing: Optional[Union[DataPreprocessor, Callable]] = BaseDataPreprocessor(),
            augmentation: Optional[Union[DataAugmentator, Callable]] = BaseDataAugmentor(),
            indices = None,
            shuffle_data=False,
            cut_window=(8, 8),
            map_mask_to_class=False,
        ):
        super().__init__()
        self.shuffle_data = shuffle_data
        self.dataset = HsiDataset(data_path)
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.cut_window = cut_window
        
        self.images = []
        self.masks = []
        
        for idx, data_point in enumerate(self.dataset.data_iterator(opened=True, shuffle=False)):
            if indices is not None and idx not in indices:
                continue
            image, mask = data_point.hsi, data_point.mask
            if cut_window is not None:
                image_parts, mask_parts = self._cut_with_window(image, mask, cut_window)
                self.images += image_parts
                self.masks += mask_parts
            else:
                self.images.append(image)
                self.masks.append(mask)
        
        if self.preprocessing is not None:
            self.images, self.masks = self.preprocessing(
                self.images, self.masks, map_mask_to_class=map_mask_to_class
            )
    
    def _cut_with_window(self, image, mask, cut_window):
        assert len(cut_window) == 2
        h_win, w_win = cut_window
        _, h, w = image.shape
        h_parts = h // h_win
        w_parts = w // w_win
        if h % h_win != 0:
            print(f"{h % h_win} pixels will be dropped by h axis. Input shape={image.shape}")

        if w % w_win != 0:
            print(f"{w % w_win} pixels will be dropped by w axis. Input shape={image.shape}")

        image_parts_list = []
        mask_parts_list = []

        for h_i in range(h_parts):
            for w_i in range(w_parts):
                img_part = image[:, 
                    h_i * h_win: (h_i+1) * h_win, 
                    w_i * w_win: (w_i+1) * w_win
                ]
                mask_part = mask[
                    h_i * h_win: (h_i+1) * h_win, 
                    w_i * w_win: (w_i+1) * w_win
                ]

                image_parts_list.append(img_part)
                mask_parts_list.append(mask_part)
        return image_parts_list, mask_parts_list

    def __iter__(self):
        if self.shuffle_data:
            self.images, self.masks = shuffle(self.images, self.masks)
        
        for image, mask in zip(self.images, self.masks):
            yield self.augmentation(image, mask, self.cut_window)


In [None]:
dataset_test = HsiDataloaderCutter(
    path, preprocessing=preprocessing, 
    augmentation=test_augmentation, indices=test_indices,
    cut_window=(8, 8), map_mask_to_class=False
)

#dataset_train = HsiDataloaderCutter(
#    path, preprocessing=preprocessing, 
#    augmentation=test_augmentation, indices=train_indices
#)

In [None]:
val_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1)

#val_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=1)

In [None]:
from tqdm import tqdm

In [None]:
val_loader = val_loader_test
specter_list = []
target_list = []

for i, (img_s, mask_s) in enumerate(tqdm(val_loader)):
    specter_list.append(img_s[0].numpy())
    target_list.append(mask_s[0].numpy())
    if len(np.unique(target_list[-1])) > 1:
        print(i)
    if i == 1000:
        break
len(specter_list)

In [None]:
mask_s.shape, img_s.shape

In [None]:
indx = 955
indx_sp = 55

In [None]:
sns.heatmap(specter_list[indx][indx_sp])

In [None]:
sns.heatmap(target_list[indx])

In [None]:
indx += 1

In [None]:
import seaborn as sns
import pandas as pd

In [None]:
num_class2count_test = dict([(str(i), 0) for i in range(17)])
num_class2count_train = dict([(str(i), 0) for i in range(17)])

num_class2count_pixels_test = dict([(str(i), 0) for i in range(17)])
num_class2count_pixels_train = dict([(str(i), 0) for i in range(17)])

# Test count
for img_s, mask_s in tqdm(val_loader_test):
    for n_c in np.unique(mask_s):
        num_class2count_test[str(n_c)] += 1

# Train count
for img_s, mask_s in tqdm(val_loader_train):
    for n_c in np.unique(mask_s):
        num_class2count_train[str(n_c)] += 1

# Test pixels
for img_s, mask_s in tqdm(val_loader_test):
    for n_c in np.unique(mask_s):
        num_class2count_pixels_test[str(n_c)] += torch.sum(mask_s == n_c).numpy()

# Train pixels
for img_s, mask_s in tqdm(val_loader_train):
    for n_c in np.unique(mask_s):
        num_class2count_pixels_train[str(n_c)] += torch.sum(mask_s == n_c).numpy()

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_train, index=[0]) / 324)

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_test, index=[0]) / 37)

In [None]:
num_class2count_diff = dict()

for k,v in num_class2count_test.items():
    num_class2count_diff[str(k)] = v - num_class2count_train[k]
    print(f'class={str(k).zfill(2)} num_test={str(v).zfill(2)} num_train={str(num_class2count_train[k]).zfill(2)}')

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_diff, index=[0]))

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_pixels_train, index=[0]))

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_pixels_test, index=[0]))

In [None]:
num_class2count_pixels_diff = dict()

for k,v in num_class2count_pixels_test.items():
    num_class2count_pixels_diff[str(k)] = v - num_class2count_pixels_train[k]

In [None]:
sns.barplot(data=pd.DataFrame(num_class2count_pixels_diff, index=[0]))

In [None]:
cel = nn.CrossEntropyLoss(reduction='none')

In [None]:
input = torch.randn(3, 5 * 17, requires_grad=True)
target = torch.empty(3 * 5, dtype=torch.long).random_(17)
input, target

In [None]:
input = input.view(3 * 5, 17)

In [None]:
output = cel(input, target.view(-1))
output, output.mean()

In [None]:
output = cel(input, target)
output, output.mean()

In [None]:
!nvidia-smi