In [68]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "2"

import sys
sys.path.append('/home/rustam/hyperspecter_segmentation/makitorch')
sys.path.append('/home/rustam/hyperspecter_segmentation/')

PREFIX_INFO_PATH = '/home/rustam/hyperspecter_segmentation/danil_cave/kfolds_data/kfold0'
PATH_DATA = '/raid/rustam/hyperspectral_dataset/new_cropped_hsi_data'


from multiprocessing.dummy import Pool
from multiprocessing import shared_memory

from makitorch import *
import math
import numpy as np
import numba as nb
import comet_ml
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms as T
import torchvision.transforms.functional as TF
from torchvision import utils
import cv2
from Losses import FocalLoss
import matplotlib.pyplot as plt

import seaborn as sns
import json
from tqdm import tqdm

from sklearn.decomposition import PCA
from makitorch.architectures.U2Net import U2Net

from hsi_dataset_api import HsiDataset

from makitorch.dataloaders.HsiDataloader import HsiDataloader
from makitorch.architectures.Unet import Unet, UnetWithFeatureSelection
from makitorch.loss import muti_bce_loss_fusion
from sklearn.metrics import jaccard_score
np.set_printoptions(suppress=True)


from makitorch.data_tools.augmentation import DataAugmentator
from makitorch.data_tools.augmentation import BaseDataAugmentor
from makitorch.data_tools.preprocessing import BaseDataPreprocessor
from makitorch.data_tools.preprocessing import DataPreprocessor

from typing import Callable, Optional, Union

import torch
from sklearn.utils import shuffle
from hsi_dataset_api import HsiDataset


@nb.njit
def cut_into_parts(
        image: np.ndarray, mask: np.ndarray, h_parts: int, 
        w_parts: int, h_win: int, w_win: int):
    image_parts_list = []
    mask_parts_list = []

    for h_i in range(h_parts):
        for w_i in range(w_parts):
            img_part = image[:, 
                h_i * h_win: (h_i+1) * h_win, 
                w_i * w_win: (w_i+1) * w_win
            ]
            mask_part = mask[
                h_i * h_win: (h_i+1) * h_win, 
                w_i * w_win: (w_i+1) * w_win
            ]

            image_parts_list.append(img_part)
            mask_parts_list.append(mask_part)
    return image_parts_list, mask_parts_list


class ShmData:

    def __init__(self, shm_name, shape, dtype):
        self.shm_name = shm_name
        self.shape = shape
        self.dtype = dtype


class DatasetCreator:

    def __init__(
            self, 
            data_path: str,
            preprocessing: Optional[Union[DataPreprocessor, Callable]] = BaseDataPreprocessor(),
            indices = None,
            cut_window=(8, 8),
            map_mask_to_class=False,
            create_shared_memory=False):
        self.dataset = HsiDataset(data_path)
        self.preprocessing = preprocessing
        self.cut_window = cut_window
        self.map_mask_to_class = map_mask_to_class
        self.create_shared_memory = create_shared_memory
        
        self.images = []
        self.masks = []
        
        for idx, data_point in tqdm(enumerate(self.dataset.data_iterator(opened=True, shuffle=False))):
            if indices is not None and idx not in indices:
                continue
            image, mask = data_point.hsi, data_point.mask
            if cut_window is not None:
                image_parts, mask_parts = self._cut_with_window(image, mask, cut_window)
                self.images += image_parts
                self.masks += mask_parts
            else:
                self.images.append(image)
                self.masks.append(mask)
        print("Preprocess data...")
        self.images = self.images[:100]
        self.masks = self.masks[:100]
        if self.preprocessing is not None:
            self.images, self.masks = self.preprocessing(
                self.images, self.masks, map_mask_to_class=map_mask_to_class
            )

        # Create shared memory
        if create_shared_memory:
            print('Create shared memory...')
            # First - map images and masks into np
            self.images = np.asarray(self.images, dtype=np.float32)
            self.masks = np.asarray(self.masks, dtype=np.int64)
            # Imgs
            shm_imgs = shared_memory.SharedMemory(create=True, size=self.images.nbytes)
            shm_imgs_arr = np.ndarray(self.images.shape, dtype=self.images.dtype, buffer=shm_imgs.buf)
            shm_imgs_arr[:] = self.images[:]
            # Masks
            shm_masks = shared_memory.SharedMemory(create=True, size=self.masks.nbytes)
            shm_masks_arr = np.ndarray(self.masks.shape, dtype=self.masks.dtype, buffer=shm_masks.buf)
            shm_masks_arr[:] = self.masks[:]
            print("Shared memory are created for imgs and masks!")
            self.data_shm_imgs = ShmData(
                shm_name=shm_imgs.name, shape=self.images.shape, 
                dtype=self.images.dtype
            )
            self.data_shm_masks = ShmData(
                shm_name=shm_masks.name, shape=self.masks.shape,
                dtype=self.masks.dtype
            )
                    
    
    def _cut_with_window(self, image, mask, cut_window):
        assert len(cut_window) == 2
        h_win, w_win = cut_window
        _, h, w = image.shape
        h_parts = h // h_win
        w_parts = w // w_win
        if h % h_win != 0:
            print(f"{h % h_win} pixels will be dropped by h axis. Input shape={image.shape}")

        if w % w_win != 0:
            print(f"{w % w_win} pixels will be dropped by w axis. Input shape={image.shape}")
        return cut_into_parts(
            image=image, mask=mask, h_parts=h_parts, w_parts=w_parts,
            h_win=h_win, w_win=w_win
        )


# Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset
    shared_memory_imgs_data: ShmData = dataset.shared_memory_imgs_data
    shared_memory_masks_data: ShmData = dataset.shared_memory_masks_data
    if shared_memory_imgs_data is not None and shared_memory_masks_data is not None:
        # Take array from memory
        existing_shm_imgs = shared_memory.SharedMemory(name=shared_memory_imgs_data.shm_name)
        dataset_imgs_np = np.ndarray(
            shared_memory_imgs_data.shape, 
            dtype=shared_memory_imgs_data.dtype, buffer=existing_shm_imgs.buf
        )
        dataset.shm_imgs = existing_shm_imgs
        existing_shm_masks = shared_memory.SharedMemory(name=shared_memory_masks_data.shm_name)
        dataset_masks_np = np.ndarray(
            shared_memory_masks_data.shape, 
            dtype=shared_memory_masks_data.dtype, buffer=existing_shm_masks.buf
        )
        dataset.shm_masks = existing_shm_masks
    else:
        assert dataset.images is not None and dataset.masks is not None
        dataset_imgs_np = dataset.images
        dataset_masks_np = dataset.masks
    overall_start = 0
    overall_end = len(dataset_imgs_np)
    # configure the dataset to only process the split workload
    per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
    worker_id = worker_info.id
    print(f"id={worker_id}, num_workers={worker_info.num_workers}")
    start = overall_start + worker_id * per_worker
    end = min(start + per_worker, overall_end)
    dataset.images = list(dataset_imgs_np[start:end])
    dataset.masks = list(dataset_masks_np[start:end])


class HsiDataloaderCutter(torch.utils.data.IterableDataset):
    def __init__(
            self, 
            images, masks,
            preprocessing: Optional[Union[DataPreprocessor, Callable]] = BaseDataPreprocessor(),
            augmentation: Optional[Union[DataAugmentator, Callable]] = BaseDataAugmentor(),
            indices = None,
            shuffle_data=False,
            cut_window=(8, 8),
            map_mask_to_class=False,
            data_start=None, data_end=None,
            shared_memory_imgs_data: ShmData = None,
            shared_memory_masks_data: ShmData = None
        ):
        super().__init__()
        self.shuffle_data = shuffle_data
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.cut_window = cut_window
        self.map_mask_to_class = map_mask_to_class
        self.shared_memory_imgs_data = shared_memory_imgs_data
        self.shared_memory_masks_data = shared_memory_masks_data
        
        self.shm_imgs = None
        self.shm_masks = None

        self.images = images
        self.masks = masks

    def __iter__(self):
        assert self.images is not None and self.masks is not None
        if self.shuffle_data:
            self.images, self.masks = shuffle(self.images, self.masks)
        
        for image, mask in zip(self.images, self.masks):
            yield self.augmentation(
                image, mask, 
                map_mask_to_class=self.map_mask_to_class
            )


In [69]:

device = 'cuda:0'
pca_explained_variance = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaExplainedVariance_.npy')
pca_mean = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaMean.npy')
pca_components = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaComponents.npy')


def pca_transformation(x):
    if len(x.shape) == 3:
        x_t = x.reshape((x.shape[0], -1)) # (C, H, W) -> (C, H * W)
        x_t = np.transpose(x_t, (1, 0)) # (C, H * W) -> (H * W, C)
        x_t = x_t - pca_mean
        x_t = np.dot(x_t, pca_components.T) / np.sqrt(pca_explained_variance)
        return x_t.reshape((x.shape[1], x.shape[2], pca_components.shape[0])).astype(np.float32, copy=False) # (H, W, N)
    elif len(x.shape) == 4:
        # x - (N, C, H, W)
        x_t = np.transpose(x, (0, 2, 3, 1)) # (N, C, H, W) -> (N, H, W, C)
        x_t = x_t - pca_mean
        x_t = np.dot(x_t, pca_components.T) / np.sqrt(pca_explained_variance)
        x_t = np.transpose(x_t, (0, -1, 1, 2)) # (N, H, W, C) -> (N, C, H, W)
        return x_t.astype(np.float32, copy=False)
    else:
        raise ValueError(f"Unknown shape={x.shape}, must be of len 3 or 4.")

def standartization(img, mean, std):
    img -= mean
    img /= std
    return img

def standartization_pool(mean, std):
    # X shape - (N, C, H, W)
    # from shape (comp,) -> (1, comp, 1, 1)
    mean = np.expand_dims(np.expand_dims(np.array(mean, dtype=np.float32), axis=-1), axis=-1)
    std = np.expand_dims(np.expand_dims(np.array(std, dtype=np.float32), axis=-1), axis=-1)
    
    return lambda x: standartization(x, mean=mean, std=std)


def mask2class(mask):
    # Calculate which class have more pixel count
    max_value = -1
    pixel_count = -1
    for class_indx in np.unique(mask):
        pix_count_s = np.sum(mask == class_indx)
        if pix_count_s > pixel_count:
            max_value = class_indx
            pixel_count = pix_count_s
    assert max_value != -1
    return np.array([max_value], dtype=np.int64) 


def preprocessing(imgs, masks, map_mask_to_class=False, split_size=256):
    with open(f'{PREFIX_INFO_PATH}/data_standartization_params_kfold0.json', 'r') as f:
        data_standartization_params = json.load(f)
    mean = data_standartization_params.get('means')
    std = data_standartization_params.get('stds')
    assert mean is not None and std is not None
    print('Create np array of imgs and masks...')
    imgs_np = np.asarray(imgs, dtype=np.float32) # (N, 237, 1, 1)
    masks_np = np.asarray(masks, dtype=np.int64) # (N, 1, 1, 3)
    print("Split imgs dataset...")
    imgs_split_np = np.array_split(imgs_np, split_size) # (split_size, Ns, 237, 1, 1)
    print('Start preprocess images...')
    # Wo PCA
    # _images = [np.transpose(image, (1, 2, 0)) for image in imgs]
    # W Pca
    with Pool(18) as p:
        _images = list(tqdm(p.imap(
                pca_transformation, 
                imgs_split_np,
                #chunksize=1
            ), total=len(imgs_split_np))
        )
        _images = list(tqdm(p.imap(
            standartization_pool(mean=mean, std=std), 
            _images,
            #chunksize=1
            ), total=len(imgs_split_np))
        )
    _images = list(np.concatenate(_images, axis=0)) # (split_size, Ns, 237, 1, 1) -> (split_size * Ns, 237, 1, 1)
    print("Preprocess masks...")
    _masks = list(np.transpose(masks_np[..., 0:1], (0, -1, 1, 2)))
    print("Finish preprocess!")
    if map_mask_to_class:
        _masks = [mask2class(mask) for mask in _masks]
    return _images, _masks


def test_augmentation(image, mask, **kwargs):
    image = torch.from_numpy(image)
    #image = (image - image.min()) / (image.max() - image.min())
    
    mask = torch.from_numpy(mask)
    
    mask = torch.squeeze(mask, 0)
    return image, mask


def augmentation(image, mask, map_mask_to_class=False):
    image = torch.from_numpy(image)
    mask = torch.from_numpy(mask)
    angle = T.RandomRotation.get_params((-30, 30))
    image = TF.rotate(image, angle, interpolation=T.InterpolationMode.BILINEAR)
    if not map_mask_to_class:
        mask = TF.rotate(mask, angle, interpolation=T.InterpolationMode.NEAREST)
    
    if np.random.random() > 0.5:
        image = TF.hflip(image)
        if not map_mask_to_class:
            mask = TF.hflip(mask)

    if np.random.random() > 0.5:
        image = TF.vflip(image)
        if not map_mask_to_class:
            mask = TF.vflip(mask)
    
    #image = (image - image.min()) / (image.max() - image.min())
    mask = torch.squeeze(mask, 0)
    return image, mask

In [77]:
cut_window=(1,1)
train_indices = np.load(f'{PREFIX_INFO_PATH}/kfold0_indx_train.npy')[:1]
dataset_creator_train = DatasetCreator(
    PATH_DATA, preprocessing=preprocessing, 
    indices=train_indices, cut_window=cut_window,
    create_shared_memory=True
)

361it [00:13, 26.99it/s]
100%|██████████| 256/256 [00:00<00:00, 23433.91it/s]
100%|██████████| 256/256 [00:00<00:00, 37875.83it/s]


Preprocess data...
Create np array of imgs and masks...
Split imgs dataset...
Start preprocess images...
Preprocess masks...
Finish preprocess!
Create shared memory...
Shared memory are created for imgs and masks!


In [89]:
dataset_train = HsiDataloaderCutter(
    images=None, masks=None,
    preprocessing=preprocessing, 
    augmentation=augmentation, indices=train_indices,
    shuffle_data=True, cut_window=cut_window,
    shared_memory_imgs_data=dataset_creator_train.data_shm_imgs,
    shared_memory_masks_data=dataset_creator_train.data_shm_masks,
)

In [90]:
train_loader = torch.utils.data.DataLoader(
    dataset_train, batch_size=5, 
    num_workers=4, pin_memory=False, prefetch_factor=2,
    worker_init_fn=worker_init_fn
)

In [106]:
counter = 0
dict_used = dict([(str(i), False) for i in range(100)])
for in_x, target in iter(train_loader):
    for target_s in target:
        dict_used[str(int(target_s[0][0]))] = True
    if counter % 1_000 == 0:
        print(counter)
    counter += 1
counter

id=0, num_workers=4
id=1, num_workers=4id=2, num_workers=4

id=3, num_workers=4
0


20

In [107]:
counter

20

In [108]:
in_x.shape

torch.Size([5, 17, 1, 1])

In [109]:
target.shape, target[0][0][0]

(torch.Size([5, 1, 1]), tensor(78))

In [110]:
dict_used

{'0': False,
 '1': True,
 '2': True,
 '3': True,
 '4': True,
 '5': True,
 '6': True,
 '7': True,
 '8': True,
 '9': True,
 '10': True,
 '11': True,
 '12': True,
 '13': True,
 '14': True,
 '15': True,
 '16': True,
 '17': True,
 '18': True,
 '19': True,
 '20': True,
 '21': True,
 '22': True,
 '23': True,
 '24': True,
 '25': True,
 '26': True,
 '27': True,
 '28': True,
 '29': True,
 '30': True,
 '31': True,
 '32': True,
 '33': True,
 '34': True,
 '35': True,
 '36': True,
 '37': True,
 '38': True,
 '39': True,
 '40': True,
 '41': True,
 '42': True,
 '43': True,
 '44': True,
 '45': True,
 '46': True,
 '47': True,
 '48': True,
 '49': True,
 '50': True,
 '51': True,
 '52': True,
 '53': True,
 '54': True,
 '55': True,
 '56': True,
 '57': True,
 '58': True,
 '59': True,
 '60': True,
 '61': True,
 '62': True,
 '63': True,
 '64': True,
 '65': True,
 '66': True,
 '67': True,
 '68': True,
 '69': True,
 '70': True,
 '71': True,
 '72': True,
 '73': True,
 '74': True,
 '75': True,
 '76': True,
 '77': T

In [105]:
size = len(dataset_creator_train.masks)
masks_t = []
for i in range(size):
    masks_t.append(
        np.ones((1, 1, 1)).astype(np.int64) * (i+1)
    )
masks_t = np.array(masks_t)
from multiprocessing import shared_memory
existing_shm_masks = shared_memory.SharedMemory(name=dataset_creator_train.data_shm_masks.shm_name)
dataset_imgs_np = np.ndarray(
    dataset_creator_train.data_shm_masks.shape, 
    dtype=dataset_creator_train.data_shm_masks.dtype, buffer=existing_shm_masks.buf
)
dataset_imgs_np[:] = masks_t[:]

In [32]:
size = len(dataset_creator_test.masks)
dataset_creator_test.masks = []
for i in range(size):
    dataset_creator_test.masks.append(
        np.ones((1, 1, 1)).astype(np.int64) * (i+1)
    )

In [26]:
dataset_creator_test = DatasetCreator(
    PATH_DATA, preprocessing=preprocessing, 
    indices=train_indices, cut_window=(1, 1),
    create_shared_memory=False
)

361it [00:13, 26.36it/s]


Preprocess data...
Create np array of imgs and masks...


 14%|█▍        | 36/256 [00:00<00:00, 312.15it/s]

Split imgs dataset...
Start preprocess images...


100%|██████████| 256/256 [00:00<00:00, 884.74it/s]
100%|██████████| 256/256 [00:00<00:00, 13700.40it/s]


Preprocess masks...
Finish preprocess!


In [38]:
dataset_test = HsiDataloaderCutter(
    images=dataset_creator_test.images, masks=dataset_creator_test.masks,
    preprocessing=preprocessing, 
    augmentation=test_augmentation, indices=train_indices,
    shuffle_data=False, cut_window=(1, 1),
)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=1, 
    num_workers=14, pin_memory=False, prefetch_factor=2,
    worker_init_fn=worker_init_fn
)

In [39]:
counter = 0
for in_x, target in iter(test_loader):
    if counter % 1_000 == 0:
        print(counter)
    counter += 1
counter

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000


KeyboardInterrupt: 

In [118]:
in_x.shape

torch.Size([5, 17, 1, 1])

In [119]:
target.shape

torch.Size([5, 1, 1])

In [42]:
target

tensor([[[301908]]])

In [121]:
import torchvision.transforms.functional as TF
print(in_x[0, :2])
TF.rotate(in_x, -150, interpolation=T.InterpolationMode.NEAREST)[0, :2]

tensor([[[ 3.2171]],

        [[-0.2690]]])


tensor([[[ 3.2171]],

        [[-0.2690]]])