In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "3"

import sys
sys.path.append('/home/rustam/hyperspecter_segmentation/makitorch')
sys.path.append('/home/rustam/hyperspecter_segmentation/')
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import gaussian_filter
import cv2
from hsi_dataset_api import HsiDataset
from makitorch.dataloaders.HsiDataloader import HsiDataloader

import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms as T
import torchvision.transforms.functional as TF
from torchvision import utils
import cv2
from Losses import FocalLoss
import matplotlib.pyplot as plt

import seaborn as sns
import json

In [None]:
PREFIX_INFO_PATH = '/home/rustam/hyperspecter_segmentation/danil_cave/kfolds_data/kfold0'
PATH_DATA = '/raid/rustam/hyperspectral_dataset/new_cropped_hsi_data'

test_indices = np.load(f'{PREFIX_INFO_PATH}/kfold0_indx_test.npy')
train_indices = np.load(f'{PREFIX_INFO_PATH}/kfold0_indx_train.npy')

pca_explained_variance = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaExplainedVariance_.npy')
pca_mean = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaMean.npy')
pca_components = np.load(f'{PREFIX_INFO_PATH}/kfold0_PcaComponents.npy')

device = 'cuda:0'


def preprocessing(imgs, masks, cut_window=None):
    with open(f'{PREFIX_INFO_PATH}/data_standartization_params_kfold0.json', 'r') as f:
        data_standartization_params = json.load(f)
    mean = data_standartization_params.get('means')
    std = data_standartization_params.get('stds')
    assert mean is not None and std is not None
    def standartization(img):
        return np.array((img - mean) / std, dtype=np.float32)
    _images = [np.transpose(image, (1, 2, 0)) for image in imgs] #[pca_transformation(image) for image in imgs]
    #_images = [standartization(image) for image in _images]
    _masks = [
        np.expand_dims(
            cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY).astype(np.uint8)
            ,0
        ).astype(np.int64)
        for mask in masks
    ]
    if cut_window is not None:
        _masks = [mask2class(mask) for mask in _masks]
    return _images, _masks


def test_augmentation(image, mask, *args):
    image = TF.to_tensor(image)
    #image = (image - image.min()) / (image.max() - image.min())
    
    mask = torch.from_numpy(mask)
    
    mask = torch.squeeze(mask, 0)
    return image, mask


from makitorch.data_tools.augmentation import DataAugmentator
from makitorch.data_tools.augmentation import BaseDataAugmentor
from makitorch.data_tools.preprocessing import BaseDataPreprocessor
from makitorch.data_tools.preprocessing import DataPreprocessor

from typing import Callable, Optional, Union

import torch
from sklearn.utils import shuffle
from hsi_dataset_api import HsiDataset


class HsiDataloaderCutter(torch.utils.data.IterableDataset):
    def __init__(
            self, 
            data_path: str,
            preprocessing: Optional[Union[DataPreprocessor, Callable]] = BaseDataPreprocessor(),
            augmentation: Optional[Union[DataAugmentator, Callable]] = BaseDataAugmentor(),
            indices = None,
            shuffle_data=False,
            cut_window=(8, 8)
        ):
        super().__init__()
        self.shuffle_data = shuffle_data
        self.dataset = HsiDataset(data_path)
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.cut_window = cut_window
        
        self.images = []
        self.masks = []
        
        for idx, data_point in enumerate(self.dataset.data_iterator(opened=True, shuffle=False)):
            if indices is not None and idx not in indices:
                continue
            image, mask = data_point.hsi, data_point.mask
            if cut_window is not None:
                image_parts, mask_parts = self._cut_with_window(image, mask, cut_window)
                self.images += image_parts
                self.masks += mask_parts
            else:
                self.images.append(image)
                self.masks.append(mask)
        
        if self.preprocessing is not None:
            self.images, self.masks = self.preprocessing(self.images, self.masks, cut_window=cut_window)
    
    def _cut_with_window(self, image, mask, cut_window):
        assert len(cut_window) == 2
        h_win, w_win = cut_window
        _, h, w = image.shape
        h_parts = h // h_win
        w_parts = w // w_win
        if h % h_win != 0:
            print(f"{h % h_win} pixels will be dropped by h axis. Input shape={image.shape}")

        if w % w_win != 0:
            print(f"{w % w_win} pixels will be dropped by w axis. Input shape={image.shape}")

        image_parts_list = []
        mask_parts_list = []

        for h_i in range(h_parts):
            for w_i in range(w_parts):
                img_part = image[:, 
                    h_i * h_win: (h_i+1) * h_win, 
                    w_i * w_win: (w_i+1) * w_win
                ]
                mask_part = mask[
                    h_i * h_win: (h_i+1) * h_win, 
                    w_i * w_win: (w_i+1) * w_win
                ]

                image_parts_list.append(img_part)
                mask_parts_list.append(mask_part)
        return image_parts_list, mask_parts_list

    def __iter__(self):
        if self.shuffle_data:
            self.images, self.masks = shuffle(self.images, self.masks)
        
        for image, mask in zip(self.images, self.masks):
            yield self.augmentation(image, mask, self.cut_window)







In [None]:
class MySuperNet3DLittleInput(nn.Module):
    
    def __init__(self, in_f=17, out_f=17, *args):
        super().__init__()
        #self.bn_start = nn.BatchNorm3d(in_f)
        
        self.conv1 = nn.Conv3d(1, 16, kernel_size=(3, 3, 11), stride=(1, 1, 3), padding=(1, 1, 6))
        # (N, 16, 8, 8, 80)
        self.bn1 = nn.BatchNorm3d(16)
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv3d(16, 16, kernel_size=(3, 3, 5), stride=1, padding=(1, 1, 2))
        # (N, 16, 8, 8, 80)
        self.bn2 = nn.BatchNorm3d(16)
        self.act2 = nn.ReLU()
        
        self.conv3 = nn.Conv3d(16, 16, kernel_size=(3, 3, 5), stride=1, padding=(1, 1, 2))
        # (N, 16, 8, 8, 80)
        self.bn3 = nn.BatchNorm3d(16)
        self.act3 = nn.ReLU()
        
        self.conv4 = nn.Conv3d(16, 16, kernel_size=(2, 2, 5), stride=1, padding=(0, 0, 2))
        # (N, 16, 7, 7, 80)
        self.bn4 = nn.BatchNorm3d(16)
        self.act4 = nn.ReLU()

        self.pooling = nn.AvgPool3d((2, 2, 3), stride=(2, 2, 3), padding=(0, 0, 1))
        # (N, 16, 3, 3, 27)
        self.dropout = nn.Dropout(0.5)
        self.linear = nn.Linear(16 * 3 * 3 * 27, 17, bias=False)
    
    def __call__(self, x):
        # (N, 237, 8, 8) -> (N, 1, 8, 8, 237)
        x = x.permute(0, 2, 3, 1)
        x = x.unsqueeze(dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act3(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.act4(x)

        x = self.pooling(x)
        x = x.view(x.shape[0], -1) # (N, ...)
        x = self.dropout(x)
        x = self.linear(x)
        return x

In [None]:
net = MySuperNet3DLittleInput(17, 17)
_ = net(torch.randn(1, 237, 8, 8)).shape
net

In [None]:

class NnModel(pl.LightningModule):
    def __init__(
            self, model, loss,
            T_0=10, T_mult=2, experiment=None, enable_image_logging=True):
        super().__init__()
        self.model = model
        self.loss = loss
        self.experiment = experiment
        self.enable_image_logging = enable_image_logging
        #self.weight_contraint_function = WeightConstraint()

        self.T_0 = T_0
        self.T_mult = T_mult

    def _custom_histogram_adder(self):
        for name,params in self.named_parameters():
            self.logger.experiment.add_histogram(name,params,self.current_epoch)
            
    def forward(self, x):
        out = self.model(x)
        return out
    
    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(), lr=1e-3
        )
        lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, 
            T_0=self.T_0, T_mult=self.T_mult, eta_min=0
        )
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
    
    def training_step(self, train_batch, batch_idx):
        img, mask = train_batch
        preds = self.model(img) # (N, C)
        loss = self.loss(preds, mask) # (N,)
        self.log('train_loss', loss)
        if self.experiment is not None:
            self.experiment.log_metric("train_loss", loss, epoch=self.current_epoch, step=self.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        return batch
    
    def validation_epoch_end(self, outputs):
        print('Size epoch end input: ', len(outputs))
        
        pred_tensor, target_tensor = collect_prediction_and_target(outputs, self.model)
        target_one_hotted_tensor = list_target_to_onehot(target_tensor)
        dice_loss_val = dice_loss(pred_tensor, target_one_hotted_tensor, dim=[0, 2, 3], use_softmax=True, softmax_dim=1)
        metric, pred_as_mask_list = calculate_iou(pred_tensor, target_tensor)
        
        for batch_idx, (metric_s, target_s, pred_s) in enumerate(zip(metric, target_tensor, pred_as_mask_list)):
            if self.enable_image_logging:
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
                sns.heatmap(pred_s, ax=ax1, vmin=0, vmax=17)
                sns.heatmap(target_s.cpu().detach().numpy(), ax=ax2, vmin=0, vmax=17)
                fig.savefig('temp_fig.png')
                plt.close(fig)

    #             trainer.logger.experiment.log_histogram_3d(
    #                 self.model.features_selection.weight.detach().cpu().numpy(),
    #                 name='band-selection layer',
    #                 step=self.global_step
    #             )
                if self.experiment is not None:
                    # For Comet logger
                    self.experiment.log_image(
                        'temp_fig.png', name=f'{batch_idx}', 
                        overwrite=False, step=self.global_step
                    )
            
            d = {f'iou_{i}': iou for i, iou in enumerate(metric_s)}
            
            if self.experiment is not None:
                self.experiment.log_metrics(d, epoch=self.current_epoch)
            else:
                print(d)
        if self.experiment is not None:
            # Add confuse matrix
            self.experiment.log_confusion_matrix(
                target_tensor.cpu().detach().numpy().reshape(-1), 
                np.asarray(pred_as_mask_list).reshape(-1)
            )
            
        mean_per_class_metric, mean_metric = clear_metric_calculation(metric, target_tensor, pred_tensor)
        mean_dice_loss_per_class_dict = {
            f"mean_dice_loss_per_class_{i}": torch.tensor(d_l, dtype=torch.float)
            for i, d_l in enumerate(dice_loss_val)
        }
        mean_dice_loss_dict = {
            f"mean_dice_loss": torch.tensor(dice_loss_val.mean(), dtype=torch.float)
        }
        mean_iou_class_dict = {
            f"mean_iou_class_{i}": torch.tensor(iou, dtype=torch.float)
            for i, iou in enumerate(mean_per_class_metric)
        }
        mean_iou_dict = {
            "mean_iou": torch.tensor(mean_metric, dtype=torch.float),
        }
        
        # Log this metric in order to save checkpoint of experements
        self.log_dict(mean_iou_dict)
        
        if self.experiment is not None:
        
            self.experiment.log_metrics(
                mean_dice_loss_per_class_dict,
                epoch=self.current_epoch
            )

            self.experiment.log_metrics(
                mean_dice_loss_dict,
                epoch=self.current_epoch
            )

            self.experiment.log_metrics(
                mean_iou_class_dict,
                epoch=self.current_epoch
            )

            self.experiment.log_metrics(
                mean_iou_dict,
                epoch=self.current_epoch
            )
        else:
            print(mean_dice_loss_per_class_dict)
            print(mean_dice_loss_dict)
            print(mean_iou_class_dict)
            print(mean_iou_dict)
            print('---------------------------------')

In [None]:
import glob
w_sorted = sorted(glob.glob('pytorch_li_logs/(run=1)MySuperNet3DLittleInput | _LrCosine W weight decay lower_arch_50ep_Wo full PCA._RustamPreprocess(k=1)_CEcosine(t_0=2,t_mul=1) | arch_type=MySuperNet3DLittleInput/*'), 
       key=lambda x: -float(x.split('/')[-1].split('-')[-1][9:13])
)

pick_best_one = w_sorted[0]
w_sorted

In [None]:
net.conv1.bias

In [None]:
net.conv1.bias

In [None]:
model = NnModel.load_from_checkpoint(
    pick_best_one,
    loss=None, model=net
)

In [None]:
net = model.model
net.to(device=device)
net.eval()

In [None]:
dataset_test = HsiDataloaderCutter(
    PATH_DATA, preprocessing=preprocessing, 
    augmentation=test_augmentation, indices=test_indices,
    cut_window=None
)

In [None]:
val_loader = torch.utils.data.DataLoader(dataset_test, batch_size=1)

In [None]:
from tqdm import tqdm

In [None]:
def collect_prediction_and_target(eval_loader, model, cut_window=(8, 8), image_shape=(512, 512), num_classes=17):
    target_list = []
    pred_list = []
    
    for in_data_x, val_data in tqdm(eval_loader):
        batch_size = in_data_x.shape[0]
        pred_mask = torch.zeros(
            (batch_size, num_classes, image_shape[0], image_shape[1]),
            dtype=in_data_x.dtype, device=in_data_x.device
        )
        # Take prediction from each window
        for w_i in range(image_shape[1] // cut_window[1]):
            for h_i in range(image_shape[0] // cut_window[0]):
                img_part = in_data_x[:, :, 
                    h_i * cut_window[0]: (h_i+1) * cut_window[0],
                    w_i * cut_window[1]: (w_i+1) * cut_window[1]
                ]
                img_part = img_part.to(device=device)
                with torch.no_grad():
                    pred = model(img_part.to(device=device)).cpu() # (N, num_classes) -> (N, num_classes, 1, 1)
                _ = img_part.cpu()
                pred = pred.unsqueeze(dim=-1).unsqueeze(dim=-1)
                pred_mask[:, :,
                    h_i * cut_window[0]: (h_i+1) * cut_window[0],
                    w_i * cut_window[1]: (w_i+1) * cut_window[1]
                ] = pred
        
        target_list.append(val_data)
        pred_list.append(pred_mask)
    return (torch.cat(pred_list, dim=0), 
            torch.cat(target_list, dim=0)
    )

In [None]:
pred_tensor, target_tensor = collect_prediction_and_target(val_loader, net)

In [None]:
print(f'pred_tensor.shape={pred_tensor.shape}')
print(f'target_tensor.shape={target_tensor.shape}')

In [None]:
pred_softmax = nn.functional.softmax(pred_tensor, dim=1)
pred_np = pred_softmax.cpu().detach().numpy()
pred_np = np.transpose(pred_np, [0, 2, 3, 1])
pred_np = np.argmax(pred_np, axis=-1)
target_np = torch.squeeze(target_tensor).cpu().detach().numpy()

In [None]:
pred_np.shape, target_np.shape

In [None]:
indx=22

In [None]:
sns.heatmap(pred_np[indx], vmax=17)

In [None]:
sns.heatmap(target_np[indx], vmax=17)

In [None]:
from sklearn.metrics import f1_score

In [None]:
f1_score(
    np.asarray(pred_np).reshape(-1),
    np.asarray(target_np).reshape(-1), 
    average='macro'
)

In [None]:
f1_score(
    np.asarray(pred_np).reshape(-1),
    np.asarray(target_np).reshape(-1), 
    average='weighted'
)

In [None]:
np.mean(np.asarray(pred_np).reshape(-1) == np.asarray(target_np).reshape(-1))