### Single Fold

In [1]:
import deepdish
import glob
# import h5py11
import lightning as L
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import segmentation_models_pytorch as smp
import torchvision.transforms as T
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.optim import AdamW
from torchmetrics.functional import dice
from torch.utils.data import Dataset, DataLoader
import warnings
import yaml
# from pprint import pprint

warnings.filterwarnings("ignore")

In [30]:
class LightningModule(L.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = smp.Unet(encoder_name = "timm-resnest101e",
                              encoder_weights=None,
                              in_channels=3,
                              classes=1,
                              activation=None,
                              decoder_attention_type = "scse",
                              )
        self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=1)
        self.val_step_outputs = []
        self.val_step_labels = []

    def forward(self, batch):
        return self.model(batch)
    
    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
        preds = torch.sigmoid(preds).cpu().detach().numpy()
        predicted_mask_with_threshold = np.zeros((30,256, 256))
        predicted_mask_with_threshold[preds[0, :, :] < 0.45] = 0
        predicted_mask_with_threshold[preds[0, :, :] > 0.45] = 1
        loss = self.loss_module(preds, labels)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_step_outputs.append(preds)
        self.val_step_labels.append(labels)

    def on_validation_epoch_end(self):
        all_preds = torch.cat(self.val_step_outputs)
        all_labels = torch.cat(self.val_step_labels)
        all_preds = torch.sigmoid(all_preds)
        self.val_step_outputs.clear()
        self.val_step_labels.clear()
        val_dice = dice(all_preds, all_labels.long())
        self.log("val_dice", val_dice, on_step=False, on_epoch=True, prog_bar=True)
        if self.trainer.global_rank == 0:
            print(f"\nEpoch: {self.current_epoch}", flush=True)



In [31]:
data_path_mf_train = "/home/albert/ml/Contrails/data/full_dataset/train/"
data_path_mf_valid = "/home/albert/ml/Contrails/data/full_dataset/validation/"

In [32]:
class ContrailsDatasetMixed(Dataset):
    def __init__(self, split="train", mode="single", delta_t = 0):
        self.delta_t = delta_t
        self.split = split
        self.mode = mode
        self.path = (lambda x: data_path_mf_train if x == "train" else data_path_mf_valid)(self.split)
        self.examples = os.listdir(self.path)
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    
    def read_record(self, directory):
        record_data = {}
        for x in [
            "band_11", 
            "band_14", 
            "band_15"
        ]:
            record_data[x] = np.load(os.path.join(directory, x + ".npy"))

        return record_data

    def normalize_range(self, data, bounds):
        """Maps data to the range [0, 1]."""
        return (data - bounds[0]) / (bounds[1] - bounds[0])
    
    def get_false_color(self, record_data):
        _T11_BOUNDS = (243, 303)
        _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
        _TDIFF_BOUNDS = (-4, 2)
        r = self.normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
        g = self.normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
        b = self.normalize_range(record_data["band_14"], _T11_BOUNDS)
        false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
        if self.mode == "single":
            t_null = 4
            return false_color[..., t_null+self.delta_t]
        else:
            return false_color
    
    def __getitem__(self, index):
        path = f"{self.path}{self.examples[index]}"
        data = self.read_record(path)    
        img = self.get_false_color(data)
        if self.split == "validation":
            label = np.load(os.path.join(path, "human_pixel_masks.npy")).squeeze()
            label = torch.Tensor(label).to(torch.int64)
        if self.split == "train":
            label = np.load(os.path.join(path, "human_pixel_masks.npy")).squeeze()
            label = torch.Tensor(label).to(torch.int64)
            label_indiv = np.load(os.path.join(path, "human_individual_masks.npy")).squeeze()
            labelers = label_indiv[1,1,:].shape[0]
            label_union = torch.zeros([256, 256])
            for i in range(labelers):
                label_i = label_indiv[:, :,i]
                label_union[label_i[:, :] == 1] = 1   

        if self.mode == "single":
            img = torch.tensor(np.reshape(img, (256, 256, 3, -1))).to(torch.float32).permute(3, 2, 0, 1).squeeze()
        else:
            img = torch.tensor(np.reshape(img, (256, 256, 3, -1))).to(torch.float32).permute(3, 2, 0, 1)

        img = self.normalize_image(img)

        if self.split in ["train"]:
            return img.float(), [label.float(),label_union.float()]
        if self.split in ["validation"]:
            return img.float(), label.float()

        return img.float()
    
    def __len__(self):
        return len(self.examples)

In [33]:
path = '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest101e_ash_attention_sample_loss_kfold/model-f3-val_dice=0.6537.ckpt'

dataset_validation = ContrailsDatasetMixed("validation","single", 0)
data_loader_validation = DataLoader(
    dataset_validation,
    batch_size=30,
    shuffle=False,
    num_workers=1,
)

progress_bar_callback = TQDMProgressBar(
refresh_rate=1
)

trainer = L.Trainer()
model = LightningModule().load_from_checkpoint(path)
trainer.validate(model, data_loader_validation)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

IndexError: boolean index did not match indexed array along dimension 0; dimension is 30 but corresponding boolean dimension is 1