In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "2"

import sys
sys.path.append('/home/rustam/hyperspecter_segmentation/makitorch')
sys.path.append('/home/rustam/hyperspecter_segmentation/')

In [None]:
from makitorch import *

In [None]:
import numpy as np
import comet_ml
import pytorch_lightning as pl
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.loggers import TensorBoardLogger
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms as T
import torchvision.transforms.functional as TF
from torchvision import utils
import cv2
from Losses import FocalLoss
import matplotlib.pyplot as plt

import seaborn as sns
import json

In [None]:
from sklearn.decomposition import PCA
from makitorch.architectures.U2Net import U2Net

In [None]:
from hsi_dataset_api import HsiDataset

In [None]:
from makitorch.dataloaders.HsiDataloader import HsiDataloader
from makitorch.architectures.Unet import Unet, UnetWithFeatureSelection
from makitorch.loss import muti_bce_loss_fusion

In [None]:
from sklearn.metrics import jaccard_score
np.set_printoptions(suppress=True)
def matrix2onehot(matrix, num_classes=17):
    matrix = matrix.copy().reshape(-1)
    one_hoted = np.zeros((matrix.size, num_classes))
    one_hoted[np.arange(matrix.size),matrix] = 1
    return one_hoted


def calculate_iou(eval_loader, model, num_classes=17, loss=None):
    res_list = []
    target_list = []
    pred_list = []
    loss_list = []
    
    for in_data_x, val_data in iter(eval_loader):
        
        preds = model(in_data_x)
        if loss is not None:
            loss_list.append(
                loss(preds, val_data).cpu().detach().numpy()
            )
        else:
            loss_list.append(None)
        
        preds = nn.functional.softmax(preds, dim=1).cpu().detach().numpy()
        preds = np.squeeze(np.argmax(preds, axis=1))
        target = np.squeeze(val_data.cpu().detach().numpy())
        
        target_list.append(target)
        pred_list.append(preds)
        
        preds_one_hoted = matrix2onehot(preds, num_classes=num_classes)
        target_one_hoted = matrix2onehot(target, num_classes=num_classes)
        res = jaccard_score(target_one_hoted, preds_one_hoted, average=None, zero_division=1)
        res_list.append(
            res
        )
    
    res_np = np.stack(res_list)
    #res_np = res_np.mean(axis=0)
    return res_np, target_list, pred_list, loss_list

In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        if getattr(m, 'bias') is not None:
            m.bias.data.fill_(0.01)

In [None]:
class WeightConstraint(object):
    def __init__(self):
        pass
    
    def __call__(self,module):
        if hasattr(module,'weight'):
            w=module.weight.data
            w=w.clamp(0, 1)
            module.weight.data=w

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        
    def forward(self, fs_weight, preds, mask):
        return self.ce(preds, mask) + torch.sum(1 - (torch.abs(fs_weight) / 0.99 - 1) ** 2)

In [None]:
class NnModel(pl.LightningModule):
    def __init__(self, model, loss, enable_image_logging=False):
        super().__init__()
        self.model = model
        self.loss = loss
        self.enable_image_logging = enable_image_logging
        #self.weight_contraint_function = WeightConstraint()

    def _custom_histogram_adder(self):
        for name,params in self.named_parameters():
            self.logger.experiment.add_histogram(name,params,self.current_epoch)
            
    def forward(self, x):
        out = self.model(x)
        return out
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
    
    def training_step(self, train_batch, batch_idx):
        img, mask = train_batch
        preds = self.model(img)
        loss = self.loss(preds, mask)
        self.log('train_loss', loss)
        return loss        

    def validation_step(self, batch, batch_idx):
        return batch
    
    def validation_epoch_end(self, outputs):
        print('Size epoch end input: ', len(outputs))
        metric, target_list, pred_list, loss_list = calculate_iou(outputs, self.model, loss=self.loss)
        for batch_idx, (loss_s, metric_s, target_s, pred_s) in enumerate(zip(loss_list, metric, target_list, pred_list)):
            if self.enable_image_logging:
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
                sns.heatmap(pred_s, ax=ax1, vmin=0, vmax=17)
                sns.heatmap(target_s, ax=ax2, vmin=0, vmax=17)
                fig.savefig('temp_fig.png')
                plt.close(fig)

    #             trainer.logger.experiment.log_histogram_3d(
    #                 self.model.features_selection.weight.detach().cpu().numpy(),
    #                 name='band-selection layer',
    #                 step=self.global_step
    #             )
                if hasattr(trainer.logger.experiment, 'log_image'):
                    # For Comet logger
                    trainer.logger.experiment.log_image(
                        'temp_fig.png', name=f'{batch_idx}', 
                        overwrite=False, step=self.global_step
                    )
                else:
                    # For tensorboard logger
                    img = cv2.imread('temp_fig.png')
                    trainer.logger.experiment.add_image(f'{batch_idx}', img, dataformats='HWC')

            d = {f'iou_{i}': iou for i, iou in enumerate(metric_s)}
            self.log_dict(d, on_step=False, on_epoch=True, prog_bar=True)
            d = {f'loss_image_{batch_idx}': torch.tensor(loss_s, dtype=torch.float) }
            self.log_dict(d, on_step=False, on_epoch=True, prog_bar=True)

        self.log_dict(
            {
                f"mean_iou_class_{i}": torch.tensor(iou, dtype=torch.float)
                for i, iou in enumerate(metric.mean(axis=0))
            },
            on_step=False, on_epoch=True, prog_bar=True
        )

        self.log_dict(
            {
                "mean_iou": torch.tensor(np.asarray(metric).mean(), dtype=torch.float),
            },
            on_step=False, on_epoch=True, prog_bar=True
        )

        self.log_dict(
            {
                "mean_loss": torch.tensor(np.asarray(loss_list).mean(), dtype=torch.float),
            },
            on_step=False, on_epoch=True, prog_bar=True
        )

In [None]:
device = 'cuda:0'

In [None]:
pca_explained_variance = np.load('PcaExplainedVariance_.npy')
pca_mean = np.load('PcaMean.npy')
pca_components = np.load('PcaComponents.npy')

In [None]:
def pca_transformation(x):
    x_t = np.reshape(x, (x.shape[0], -1)) # (C, H, W) -> (C, H * W)
    x_t = np.swapaxes(x_t, 0, 1) # (C, H * W) -> (H * W, C)
    x_t = x_t - pca_mean
    x_t = np.dot(x_t, pca_components.T) / np.sqrt(pca_explained_variance)
    return np.reshape(x_t, (x.shape[1], x.shape[2], pca_components.shape[0])).astype(np.float32) # (H, W, N)

In [None]:
def preprocessing_old(imgs, masks):
    target_size = (256, 256)
    _images = [image.resize(target_size,Image.BILINEAR)
                   for image in imgs]
    _masks = [mask.resize(target_size, Image.BILINEAR) for mask in masks]
    return _images, _masks

In [None]:
def preprocess_mask(mask):
    kernel = np.ones((2,2),np.uint8)

    erosion = cv2.erode(mask, kernel, iterations = 2)
    dilation = cv2.dilate(erosion, kernel,iterations = 4)
    mask_filtered = cv2.erode(dilation, kernel, iterations = 2)
    return mask_filtered

def preprocessing(imgs, masks):
    with open('data_standartization_params.json', 'r') as f:
        data_standartization_params = json.load(f)
    mean = data_standartization_params.get('means')
    std = data_standartization_params.get('stds')
    def standartization(img):
        return np.array((img - mean) / std, dtype=np.float32)
    _images = [pca_transformation(image) for image in imgs]
    _images = [standartization(image) for image in _images]
    _masks = [
        np.expand_dims(
            preprocess_mask(cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY).astype(np.uint8))
            ,0
        ).astype(np.int64)
        for mask in masks
    ]
    return _images, _masks

In [None]:
def test_augmentation(image, mask):
    image = TF.to_tensor(image)
    #image = (image - image.min()) / (image.max() - image.min())
    
    mask = torch.from_numpy(mask)
    
    mask = torch.squeeze(mask, 0)
    return image, mask

In [None]:
def augmentation(image, mask):
    image = TF.to_tensor(image)
    mask = torch.from_numpy(mask)
    angle = T.RandomRotation.get_params((-30, 30))
    image = TF.rotate(image, angle, interpolation=T.InterpolationMode.BILINEAR)
    mask = TF.rotate(mask, angle, interpolation=T.InterpolationMode.NEAREST)
    
    if np.random.random() > 0.5:
        image = TF.hflip(image)
        mask = TF.hflip(mask)

    if np.random.random() > 0.5:
        image = TF.vflip(image)
        mask = TF.vflip(mask)
    
    #image = (image - image.min()) / (image.max() - image.min())
    mask = torch.squeeze(mask, 0)
    return image, mask

In [None]:
# random = np.random.permutation(np.arange(384))
# test_indices = random[310:]
# train_indices = random[:310]

test_indices = np.load('test_indices.npy')
train_indices = np.load('train_indices.npy')
path = '/raid/rustam/hyperspectral_dataset/cropped_hsi_data'

dataset_train = HsiDataloader(
    path, preprocessing=preprocessing, 
    augmentation=augmentation, indices=train_indices,
    shuffle_data=True
)
dataset_test = HsiDataloader(path, preprocessing=preprocessing, augmentation=test_augmentation, indices=test_indices)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=28)

In [None]:
val_loader = torch.utils.data.DataLoader(dataset_test, batch_size=1)

In [None]:
iter_train = iter(train_loader)

In [None]:
data = next(iter_train)

In [None]:
data_i = np.transpose(data[0], (0, 1, 2, 3))
data_i.shape

In [None]:
data[0].cpu().detach().numpy().shape, data[1].cpu().detach().numpy().shape

In [None]:
data[1][0].shape

In [None]:
sns.heatmap(cv2.medianBlur(data[1][0].cpu().detach().numpy().astype(np.uint8), 11))

In [None]:
sns.heatmap(data_i[0][9].cpu().detach().numpy())

In [None]:
plt.imshow(data_i[0][7].cpu().detach().numpy())

In [None]:
sns.heatmap(np.transpose(data[0][0].cpu().detach().numpy(), (1, 2, 0))[..., 0])

In [None]:
data = next(iter(val_loader))

In [None]:
sns.heatmap(np.transpose(data[0][0].cpu().detach().numpy(), (1, 2, 0))[..., 0])

In [None]:
data[1].shape

In [None]:
class MySuperNet(nn.Module):
    
    def __init__(self, in_f=17, out_f=17):
        super().__init__()
        self.bn_start = nn.BatchNorm2d(in_f)
        
        self.conv1 = nn.Conv2d(in_f, in_f * 4, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(in_f * 4)
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(in_f * 4, in_f * 8, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm2d(in_f * 8)
        self.act2 = nn.ReLU()
        
        self.conv3 = nn.Conv2d(in_f * 8, in_f * 4, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm2d(in_f * 4)
        self.act3 = nn.ReLU()
        
        self.conv4 = nn.Conv2d(in_f * 4, in_f, kernel_size=5, stride=1, padding=2)
        self.bn4 = nn.BatchNorm2d(in_f)
        self.act4 = nn.ReLU()
        
        self.final_conv = nn.Conv2d(in_f, out_f, kernel_size=1, stride=1, padding=0)
    
    def __call__(self, x):
        x = self.bn_start(x)
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act3(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.act4(x)
        
        x = self.final_conv(x)
        
        return x

In [None]:
net = MySuperNet(17, 17)

In [None]:
net

In [None]:
net(torch.randn(1, 17, 512, 512)).shape

In [None]:
logger = CometLogger(
    api_key="your-key",
    workspace="your-workspace",  # Optional
    project_name="your-project-name",  # Optional
    experiment_name="new IOU//lower arch//50ep.W PCA.// RustamPreprocess(k=2) /makiloss/gamma=4/balance=2"
)

#logger = TensorBoardLogger(
#    'logs/'
#)


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLossCustom(nn.Module):
    """
    copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
        Focal_Loss= -1*alpha*(1-pt)*log(pt)
    :param num_class:
    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                    focus on hard misclassified example
    :param smooth: (float,double) smooth value when cross entropy
    :param balance_index: (int) balance class index, should be specific when alpha is float
    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
    """

    def __init__(self, alpha=None, gamma=4, balance_index=2, smooth=1e-5, size_average=False):
        super(FocalLossCustom, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.balance_index = balance_index
        self.smooth = smooth
        self.size_average = size_average
        self.cel = nn.CrossEntropyLoss(reduction='none')
        self.softmax = nn.Softmax(dim=-1)

        if self.smooth is not None:
            if self.smooth < 0 or self.smooth > 1.0:
                raise ValueError('smooth value should be in [0,1]')

    def forward(self, logit, target):
        num_class = logit.shape[1]

        if logit.dim() > 2:
            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
            logit = logit.view(logit.size(0), logit.size(1), -1)
            logit = logit.permute(0, 2, 1).contiguous()
            logit = logit.view(-1, logit.size(-1))
        target = torch.squeeze(target, 1)
        target = target.view(-1, 1)
        
        ce_loss = self.cel(logit, target.view(-1))
        train_conf = self.softmax(logit)
        
        idx = target.cpu().long()
        one_hot_labels  = torch.FloatTensor(target.size(0), num_class).zero_()
        one_hot_labels  = one_hot_labels.scatter_(1, idx, 1)
        if one_hot_labels.device != logit.device:
            one_hot_labels = one_hot_labels.to(logit.device)
        
        filtered_conf = train_conf * one_hot_labels
        sparce_conf, _ = torch.max(filtered_conf, dim=-1)
        loss = torch.pow((torch.ones_like(sparce_conf) - sparce_conf), self.gamma) * ce_loss
        """
        
        # print(logit.shape, target.shape)
        # 
        alpha = self.alpha

        if alpha is None:
            alpha = torch.ones(num_class, 1)
        elif isinstance(alpha, (list, np.ndarray)):
            assert len(alpha) == num_class
            alpha = torch.FloatTensor(alpha).view(num_class, 1)
            alpha = alpha / alpha.sum()
        elif isinstance(alpha, float):
            alpha = torch.ones(num_class, 1)
            alpha = alpha * (1 - self.alpha)
            alpha[self.balance_index] = self.alpha

        else:
            raise TypeError('Not support alpha type')
        
        if alpha.device != logit.device:
            alpha = alpha.to(logit.device)

        idx = target.cpu().long()

        one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
        one_hot_key = one_hot_key.scatter_(1, idx, 1)
        if one_hot_key.device != logit.device:
            one_hot_key = one_hot_key.to(logit.device)

        if self.smooth:
            one_hot_key = torch.clamp(
                one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
        pt = (one_hot_key * logit).sum(1) + self.smooth
        logpt = pt.log()
        
        gamma = self.gamma

        alpha = alpha[idx]
        alpha = torch.squeeze(alpha)
        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
        """
        if self.size_average:
            loss = loss.mean()
        if not self.size_average:
            # Norm by positive
            num_positive = torch.sum(target != self.balance_index)
            loss = loss.sum() / (num_positive + 1e-10)
        else:
            loss = loss.sum()
        return loss

In [None]:
# model = NnModel(net, muti_bce_loss_fusion, enable_image_logging=True)
model = NnModel(net, FocalLossCustom(), enable_image_logging=True)

trainer = pl.Trainer(
    gpus=1, 
    max_epochs=50,
    check_val_every_n_epoch=2,
    logger=logger
)
# trainer = pl.Trainer(
#     gpus=1, 
#     max_epochs=2000,
#     check_val_every_n_epoch=2000)
trainer.fit(model, train_loader, val_loader)

In [None]:
msg = """

"""

In [None]:
logger.experiment.log_html(msg)

In [None]:
logger.experiment.end()

In [None]:
from sklearn.metrics import jaccard_score
np.set_printoptions(suppress=True)
def matrix2onehot(matrix, num_classes=17):
    matrix = matrix.copy().reshape(-1)
    one_hoted = np.zeros((matrix.size, num_classes))
    one_hoted[np.arange(matrix.size),matrix] = 1
    return one_hoted


def calculate_iou(eval_loader, model, num_classes=17):
    res_list = []
    target_list = []
    pred_list = []
    
    for in_data_x, val_data in iter(eval_loader):
        #preds = nn.functional.softmax(model(in_data_x), dim=1).cpu().detach().numpy()
        preds = model(in_data_x.cpu()).cpu().detach().numpy()
        preds = np.squeeze(np.argmax(preds, axis=1))
        target = np.squeeze(val_data.cpu().detach().numpy())
        
        target_list.append(target)
        pred_list.append(preds)
        
        preds_one_hoted = matrix2onehot(preds, num_classes=num_classes)
        target_one_hoted = matrix2onehot(target, num_classes=num_classes)
        res = jaccard_score(target_one_hoted, preds_one_hoted, average=None, zero_division=0)
        res_list.append(
            res
        )
    
    res_np = np.stack(res_list)
    res_np = res_np.mean(axis=0)
    return res_np, target_list, pred_list

In [None]:
res, target_list, pred_list,_ = calculate_iou(val_loader, net, num_classes=17)
res, res.mean()

In [None]:
iiiii = 12

In [None]:
sns.heatmap(target_list[iiiii])

In [None]:
sns.heatmap(pred_list[iiiii])

In [None]:
iter_val = iter(val_loader)

In [None]:
in_data_x, val_data = next(iter_val)

In [None]:
preds = net(in_data_x)

In [None]:
preds.shape, val_data.shape

In [None]:
preds_np, val_data_np = preds.detach().numpy(), val_data.detach().numpy()
preds_np.shape, val_data_np.shape

In [None]:
preds_np = np.argmax(preds_np, axis=1)
preds_np.shape

In [None]:
iou_numpy(np.expand_dims(preds_np, axis=0), val_data_np)

In [None]:
preds_np

In [None]:
val_data_np

In [None]:
sns.heatmap(preds_np[0])

In [None]:
sns.heatmap(val_data_np[0])

In [None]:
plt.imshow(in_data_x[0, 0])