In [4]:
import deepdish
import glob
# import h5py11
import lightning as L
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import segmentation_models_pytorch as smp
import torchvision.transforms as T
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.optim import AdamW
from torchmetrics.functional import dice
from torch.utils.data import Dataset, DataLoader
import warnings
import yaml
# from pprint import pprint

warnings.filterwarnings("ignore")

In [5]:
%pwd

'/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands'

In [6]:
torch.set_float32_matmul_precision('medium')

In [7]:
data_path_mf_train = "/home/albert/ml/Contrails/data/full_dataset/train/"
data_path_mf_valid = "/home/albert/ml/Contrails/data/full_dataset/validation/"

### Create and save config

In [25]:
!mkdir "./resnest101e_ash_attention_sample_loss_kfold"

mkdir: cannot create directory ‘./resnest101e_ash_attention_sample_loss_kfold’: File exists


In [29]:
%%writefile resnest101e_ash_attention_sample_loss_kfold.yaml

output_dir: "./resnest101e_ash_attention_sample_loss_kfold"

folds:
    n_splits: 4
    random_state: 42
train_folds: [3]

seed: 42

train_bs: 42
valid_bs: 60
workers: 6

progress_bar_refresh_rate: 1

early_stop:
    monitor: "val_loss"
    mode: "min"
    patience: 5
    verbose: 1

trainer:
    max_epochs: 24
    min_epochs: 20
    enable_progress_bar: True
    precision: "16-mixed"
    devices: 1

model:
    alpha: -1.0
    beta: -1.0
    alpha_union: -1.0
    beta_union: -1.0
    weight_union: -1.0
    weight_ground_truth: -1.0
    seg_model: "Unet"
    encoder_name: "timm-resnest101e"
    encoder_depth: 5
    loss_smooth: 1.0
    decoder_attention_type: "scse"
    image_size: 384
    optimizer_params:
        lr: 0.0005
        weight_decay: 0.0
        eps: 1.0e-6
    scheduler:
        name: "CosineAnnealingLR"
        params:
            CosineAnnealingLR:
                T_max: 2
                eta_min: 1.0e-6
                last_epoch: -1
            ReduceLROnPlateau:
                mode: "min"
                factor: 0.31622776601
                patience: 4
                verbose: True

Overwriting resnest101e_ash_attention_sample_loss_kfold.yaml


In [30]:
with open("resnest101e_ash_attention_sample_loss_kfold.yaml", "r") as file_obj:
    resnest101e_ash_attention_sample_loss_kfold = yaml.safe_load(file_obj)

### 101 512

In [63]:
!mkdir "./resnest101e_ash_attention_sample_loss_kfold_512"

In [64]:
%%writefile resnest101e_ash_attention_sample_loss_kfold_512.yaml

output_dir: "./resnest101e_ash_attention_sample_loss_kfold_512"

folds:
    n_splits: 4
    random_state: 42
train_folds: [0, 1, 2, 3]

seed: 42

train_bs: 20
valid_bs: 30
workers: 6

progress_bar_refresh_rate: 1

early_stop:
    monitor: "val_loss"
    mode: "min"
    patience: 5
    verbose: 1

trainer:
    max_epochs: 30
    min_epochs: 24
    enable_progress_bar: True
    precision: "16-mixed"
    devices: 1

model:
    alpha: -1.0
    beta: -1.0
    alpha_union: -1.0
    beta_union: -1.0
    weight_union: -1.0
    weight_ground_truth: -1.0
    seg_model: "Unet"
    encoder_name: "timm-resnest101e"
    encoder_depth: 5
    loss_smooth: 1.0
    decoder_attention_type: "scse"
    image_size: 384
    optimizer_params:
        lr: 0.0005
        weight_decay: 0.0
        eps: 1.0e-6
    scheduler:
        name: "CosineAnnealingLR"
        params:
            CosineAnnealingLR:
                T_max: 2
                eta_min: 1.0e-6
                last_epoch: -1
            ReduceLROnPlateau:
                mode: "min"
                factor: 0.31622776601
                patience: 4
                verbose: True

Writing resnest101e_ash_attention_sample_loss_kfold_512.yaml


In [65]:
with open("resnest101e_ash_attention_sample_loss_kfold_512.yaml", "r") as file_obj:
    resnest101e_ash_attention_sample_loss_kfold_512 = yaml.safe_load(file_obj)

### Dataset for full dataset

In [11]:
class ContrailsDatasetDf(Dataset):
    def __init__(self, df, split="train", mode="single", delta_t = 0):
        self.delta_t = delta_t
        self.split = split
        self.mode = mode
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.df=df
    
    def read_record(self, directory):
        record_data = {}
        for x in [
            "band_11", 
            "band_14", 
            "band_15"
        ]:
            record_data[x] = np.load(os.path.join(directory, x + ".npy"))

        return record_data

    def normalize_range(self, data, bounds):
        """Maps data to the range [0, 1]."""
        return (data - bounds[0]) / (bounds[1] - bounds[0])
    
    def get_false_color(self, record_data):
        _T11_BOUNDS = (243, 303)
        _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
        _TDIFF_BOUNDS = (-4, 2)
        r = self.normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
        g = self.normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
        b = self.normalize_range(record_data["band_14"], _T11_BOUNDS)
        false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
        if self.mode == "single":
            t_null = 4
            return false_color[..., t_null+self.delta_t]
        else:
            return false_color
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        path = row.path
        data = self.read_record(path)    
        img = self.get_false_color(data)
        if self.split == "validation":
            label = np.load(os.path.join(path, "human_pixel_masks.npy")).squeeze()
            label = torch.Tensor(label).to(torch.int64)
        if self.split == "train":
            label = np.load(os.path.join(path, "human_pixel_masks.npy")).squeeze()
            label = torch.Tensor(label).to(torch.int64)
            label_indiv = np.load(os.path.join(path, "human_individual_masks.npy")).squeeze()
            labelers = label_indiv[1,1,:].shape[0]
            label_union = torch.zeros([256, 256])
            for i in range(labelers):
                label_i = label_indiv[:, :,i]
                label_union[label_i[:, :] == 1] = 1   

        if self.mode == "single":
            img = torch.tensor(np.reshape(img, (256, 256, 3, -1))).to(torch.float32).permute(3, 2, 0, 1).squeeze()
        else:
            img = torch.tensor(np.reshape(img, (256, 256, 3, -1))).to(torch.float32).permute(3, 2, 0, 1)

        img = self.normalize_image(img)

        if self.split in ["train"]:
            return img.float(), [label.float(),label_union.float()]
        if self.split in ["validation"]:
            return img.float(), label.float()

        return img.float()
    
    def __len__(self):
        return len(self.df)

In [12]:
class ContrailsDatasetMixed(Dataset):
    def __init__(self, split="train", mode="single", delta_t = 0):
        self.delta_t = delta_t
        self.split = split
        self.mode = mode
        self.path = (lambda x: data_path_mf_train if x == "train" else data_path_mf_valid)(self.split)
        self.examples = os.listdir(self.path)
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    
    def read_record(self, directory):
        record_data = {}
        for x in [
            "band_11", 
            "band_14", 
            "band_15"
        ]:
            record_data[x] = np.load(os.path.join(directory, x + ".npy"))

        return record_data

    def normalize_range(self, data, bounds):
        """Maps data to the range [0, 1]."""
        return (data - bounds[0]) / (bounds[1] - bounds[0])
    
    def get_false_color(self, record_data):
        _T11_BOUNDS = (243, 303)
        _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
        _TDIFF_BOUNDS = (-4, 2)
        r = self.normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
        g = self.normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
        b = self.normalize_range(record_data["band_14"], _T11_BOUNDS)
        false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
        if self.mode == "single":
            t_null = 4
            return false_color[..., t_null+self.delta_t]
        else:
            return false_color
    
    def __getitem__(self, index):
        path = f"{self.path}{self.examples[index]}"
        data = self.read_record(path)    
        img = self.get_false_color(data)
        if self.split == "validation":
            label = np.load(os.path.join(path, "human_pixel_masks.npy")).squeeze()
            label = torch.Tensor(label).to(torch.int64)
        if self.split == "train":
            label = np.load(os.path.join(path, "human_pixel_masks.npy")).squeeze()
            label = torch.Tensor(label).to(torch.int64)
            label_indiv = np.load(os.path.join(path, "human_individual_masks.npy")).squeeze()
            labelers = label_indiv[1,1,:].shape[0]
            label_union = torch.zeros([256, 256])
            for i in range(labelers):
                label_i = label_indiv[:, :,i]
                label_union[label_i[:, :] == 1] = 1   

        if self.mode == "single":
            img = torch.tensor(np.reshape(img, (256, 256, 3, -1))).to(torch.float32).permute(3, 2, 0, 1).squeeze()
        else:
            img = torch.tensor(np.reshape(img, (256, 256, 3, -1))).to(torch.float32).permute(3, 2, 0, 1)

        img = self.normalize_image(img)

        if self.split in ["train"]:
            return img.float(), [label.float(),label_union.float()]
        if self.split in ["validation"]:
            return img.float(), label.float()

        return img.float()
    
    def __len__(self):
        return len(self.examples)

### Testing new training labels

### Lightning Module

#### Tversky loss

In [13]:
class LightningModuleTrLoss(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # print(config["aux_params"])
        self.model = model = smp.Unet(
            encoder_name=config["encoder_name"],
            encoder_depth=config["encoder_depth"],
            decoder_channels = (256, 128, 64, 32,16)[:config["encoder_depth"]],
            # aux_params = config["aux_params"],
            decoder_attention_type =config["decoder_attention_type"],
            encoder_weights="imagenet",
            in_channels=3,
            classes=1,
            activation=None,
        )
        self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=config["loss_smooth"])
        self.val_step_outputs = []
        self.val_step_labels = []
        self.alpha = config["alpha"]
        self.beta = config["beta"]
        self.alpha_union = config["alpha_union"]
        self.beta_union = config["beta_union"]

    def forward(self, batch):
        imgs = batch
        preds = self.model(imgs)
        return preds

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), **self.config["optimizer_params"])

        if self.config["scheduler"]["name"] == "CosineAnnealingLR":
            scheduler = CosineAnnealingLR(
                optimizer,
                **self.config["scheduler"]["params"]["CosineAnnealingLR"],
            )
            lr_scheduler_dict = {"scheduler": scheduler, "interval": "step"}
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}
        elif self.config["scheduler"]["name"] == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(
                optimizer,
                **self.config["scheduler"]["params"]["ReduceLROnPlateau"],
            )
            lr_scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def training_step(self, batch, batch_idx):
        imgs, [label_gt,label_union] = batch
        preds = self.model(imgs)
        if self.config["image_size"] != 256:
            preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
        loss_ground_truth = smp.losses.TverskyLoss("binary", classes=None, log_loss=False, from_logits=True, smooth=0.0, ignore_index=None, eps=1e-06, alpha=self.alpha, beta=self.beta, gamma=1.0)(preds, label_gt)
        loss_union = smp.losses.TverskyLoss("binary", classes=None, log_loss=False, from_logits=True, smooth=0.0, ignore_index=None, eps=1e-06, alpha=self.alpha_union, beta=self.beta_union, gamma=1.0)(preds, label_union)
        loss = torch.mul(loss_ground_truth,self.config["weight_ground_truth"])+torch.mul(loss_union, self.config["weight_union"])
        self.log('loss_union', float(loss_union) , on_step=False, on_epoch=True, prog_bar=True) #, batch_size=16
        self.log('loss_ground_truth', float(loss_ground_truth) , on_step=False, on_epoch=True, prog_bar=True)

        for param_group in self.trainer.optimizers[0].param_groups:
            lr = param_group["lr"]
        self.log("lr", lr, on_step=True, on_epoch=False, prog_bar=True)

        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        if self.config["image_size"] != 256:
            preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
        loss = self.loss_module(preds, labels)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_step_outputs.append(preds)
        self.val_step_labels.append(labels)

    def on_validation_epoch_end(self):
        all_preds = torch.cat(self.val_step_outputs)
        all_labels = torch.cat(self.val_step_labels)
        all_preds = torch.sigmoid(all_preds)
        self.val_step_outputs.clear()
        self.val_step_labels.clear()
        val_dice = dice(all_preds, all_labels.long())
        self.log("val_dice", val_dice, on_step=False, on_epoch=True, prog_bar=True)
        if self.trainer.global_rank == 0:
            print(f"\nEpoch: {self.current_epoch}", flush=True)

In [14]:
class LightningModuleTrLoss512(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # print(config["aux_params"])
        self.model = model = smp.Unet(
            encoder_name=config["encoder_name"],
            encoder_depth=config["encoder_depth"],
            decoder_channels = (256, 128, 64, 32,16)[:config["encoder_depth"]],
            # aux_params = config["aux_params"],
            decoder_attention_type =config["decoder_attention_type"],
            encoder_weights="imagenet",
            in_channels=3,
            classes=1,
            activation=None,
        )
        self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=config["loss_smooth"])
        self.val_step_outputs = []
        self.val_step_labels = []
        self.alpha = config["alpha"]
        self.beta = config["beta"]
        self.alpha_union = config["alpha_union"]
        self.beta_union = config["beta_union"]

    def forward(self, batch):
        imgs = batch
        preds = self.model(imgs)
        return preds

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), **self.config["optimizer_params"])

        if self.config["scheduler"]["name"] == "CosineAnnealingLR":
            scheduler = CosineAnnealingLR(
                optimizer,
                **self.config["scheduler"]["params"]["CosineAnnealingLR"],
            )
            lr_scheduler_dict = {"scheduler": scheduler, "interval": "step"}
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}
        elif self.config["scheduler"]["name"] == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(
                optimizer,
                **self.config["scheduler"]["params"]["ReduceLROnPlateau"],
            )
            lr_scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def training_step(self, batch, batch_idx):
        imgs, [label_gt,label_union] = batch
        imgs = torch.nn.functional.interpolate(imgs, size=512, mode='bilinear')
        preds = self.model(imgs)
        if self.config["image_size"] != 256:
            preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
        loss_ground_truth = smp.losses.TverskyLoss("binary", classes=None, log_loss=False, from_logits=True, smooth=0.0, ignore_index=None, eps=1e-06, alpha=self.alpha, beta=self.beta, gamma=1.0)(preds, label_gt)
        loss_union = smp.losses.TverskyLoss("binary", classes=None, log_loss=False, from_logits=True, smooth=0.0, ignore_index=None, eps=1e-06, alpha=self.alpha_union, beta=self.beta_union, gamma=1.0)(preds, label_union)
        loss = torch.mul(loss_ground_truth,self.config["weight_ground_truth"])+torch.mul(loss_union, self.config["weight_union"])
        self.log('loss_union', float(loss_union) , on_step=False, on_epoch=True, prog_bar=True) #, batch_size=16
        self.log('loss_ground_truth', float(loss_ground_truth) , on_step=False, on_epoch=True, prog_bar=True)

        for param_group in self.trainer.optimizers[0].param_groups:
            lr = param_group["lr"]
        self.log("lr", lr, on_step=True, on_epoch=False, prog_bar=True)

        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        imgs = torch.nn.functional.interpolate(imgs, size=512, mode='bilinear')
        preds = self.model(imgs)
        if self.config["image_size"] != 256:
            preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
        loss = self.loss_module(preds, labels)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_step_outputs.append(preds)
        self.val_step_labels.append(labels)

    def on_validation_epoch_end(self):
        all_preds = torch.cat(self.val_step_outputs)
        all_labels = torch.cat(self.val_step_labels)
        all_preds = torch.sigmoid(all_preds)
        self.val_step_outputs.clear()
        self.val_step_labels.clear()
        val_dice = dice(all_preds, all_labels.long())
        self.log("val_dice", val_dice, on_step=False, on_epoch=True, prog_bar=True)
        if self.trainer.global_rank == 0:
            print(f"\nEpoch: {self.current_epoch}", flush=True)

### Training functions

In [15]:
def sample_union_loss(delta, cfg, module,alpha, alpha_union, weight_ground_truth):
    cfg['model']["alpha"] = alpha
    cfg['model']["beta"] = 1-alpha
    cfg['model']["alpha_union"] = alpha_union
    cfg['model']["beta_union"] = 1-alpha_union
    cfg['model']["weight_ground_truth"] = weight_ground_truth
    cfg['model']["weight_union"] = 1-weight_ground_truth
    identifier = "resnest101e_alpha="+str(alpha)+"_alpha_union="+str(alpha_union)+"_weight_ground_truth="+str(weight_ground_truth)
    dataset_train = ContrailsDatasetMixed("train", "single",delta)
    dataset_validation = ContrailsDatasetMixed("validation","single", delta)
    
    data_loader_train = DataLoader(
    dataset_train,
    batch_size=cfg["train_bs"],
    shuffle=True,
    num_workers=cfg["workers"],
    )
    
    data_loader_validation = DataLoader(
        dataset_validation,
        batch_size=cfg["valid_bs"],
        shuffle=False,
        num_workers=cfg["workers"],
    )
    
    checkpoint_callback = ModelCheckpoint(
    save_weights_only=True,
    monitor="val_dice",
    dirpath=cfg["output_dir"],
    mode="max",
    filename= identifier,
    save_top_k=1,
    verbose=1,
    )
    progress_bar_callback = TQDMProgressBar(
    refresh_rate=cfg["progress_bar_refresh_rate"]
    )

    early_stop_callback = EarlyStopping(**cfg["early_stop"])
    csv_logger = CSVLogger(cfg["output_dir"], name=identifier)
    trainer = L.Trainer(
        callbacks=[checkpoint_callback, early_stop_callback, progress_bar_callback], logger = csv_logger,
        **cfg["trainer"],
    )
    
    cfg["model"]["scheduler"]["params"]["CosineAnnealingLR"]["T_max"] *= len(data_loader_train)/cfg["trainer"]["devices"]
    model = module(cfg["model"])
    
    trainer.fit(model, data_loader_train, data_loader_validation)
    
    

### Training the timesteps

In [60]:
len(train_df)

20529

In [17]:
import os

import pandas as pd
from sklearn.model_selection import KFold

folders_valid = os.listdir(data_path_mf_valid)



valid_df = pd.DataFrame(folders_valid, columns=['record_id'])
valid_df['train'] = 'valid'
valid_df.to_csv(os.path.join("/home/albert/ml/Contrails/data/full_dataset", 'valid_df.csv'), index=False)
valid_df["path"] = data_path_mf_valid + valid_df["record_id"].astype(str)

folders_train = os.listdir(data_path_mf_train)
train_df = pd.DataFrame(folders_train, columns=['record_id'])
train_df['train'] = 'train'
train_df.to_csv(os.path.join("/home/albert/ml/Contrails/data/full_dataset", 'train_df.csv'), index=False)
train_df["path"] = data_path_mf_train + train_df["record_id"].astype(str)

# dataset_train = ContrailsDatasetDf(valid_df, "validation", "single",0 )
# a, _ = dataset_train.__getitem__(1231)
# a.shape
# data_loader_train = DataLoader(
#         dataset_train,
#         batch_size=2,
#         shuffle=True,
#         num_workers=2,
#     )



In [None]:
import gc
gc.enable()
alpha = 0.5
alpha_union = 0.85
weight_ground_truth = 0.45

cfg = resnest101e_ash_attention_sample_loss_kfold
cfg['model']["alpha"] = alpha
cfg['model']["beta"] = 1-alpha
cfg['model']["alpha_union"] = alpha_union
cfg['model']["beta_union"] = 1-alpha_union
cfg['model']["weight_ground_truth"] = weight_ground_truth
cfg['model']["weight_union"] = 1-weight_ground_truth
config = cfg
identifier = "resnest26d_alpha="+str(alpha)+"_alpha_union="+str(alpha_union)+"_weight_ground_truth="+str(weight_ground_truth)

# folders_train = os.listdir(data_path_mf_train)
# train_df = pd.DataFrame(folders_train, columns=['record_id'])
# train_df['train'] = 'train'
# train_df.to_csv(os.path.join("/home/albert/ml/Contrails/data/full_dataset", 'train_df.csv'), index=False)
# train_df["path"] = data_path_mf_train + train_df["record_id"].astype(str)

# Fold = KFold(shuffle=True, **config["folds"])
# for n, (trn_index, val_index) in enumerate(Fold.split(train_df)):
#     train_df.loc[val_index, "kfold"] = int(n)
# train_df["kfold"] = train_df["kfold"].astype(int)

for fold in config["train_folds"]:
    print(f"\n###### Fold {fold}")
    trn_df = train_df[train_df.kfold != fold].reset_index(drop=True)
    vld_df = train_df[train_df.kfold == fold].reset_index(drop=True)

    dataset_train = ContrailsDatasetDf(trn_df, "train", "single",0 )
    dataset_validation = ContrailsDatasetDf(vld_df,"validation", "single",0)

    data_loader_train = DataLoader(
        dataset_train,
        batch_size=config["train_bs"],
        shuffle=True,
        num_workers=config["workers"],
    )
    data_loader_validation = DataLoader(
        dataset_validation,
        batch_size=config["valid_bs"],
        shuffle=False,
        num_workers=config["workers"],
    )

    checkpoint_callback = ModelCheckpoint(
        save_weights_only=True,
        monitor="val_dice",
        dirpath=config["output_dir"],
        mode="max",
        filename=f"model-f{fold}-{{val_dice:.4f}}",
        save_top_k=1,
        verbose=1,
    )

    progress_bar_callback = TQDMProgressBar(
        refresh_rate=config["progress_bar_refresh_rate"]
    )

    early_stop_callback = EarlyStopping(**config["early_stop"])

    csv_logger = CSVLogger(cfg["output_dir"], name=identifier+f'fold_{fold}/')
    trainer = L.Trainer(
        callbacks=[checkpoint_callback, early_stop_callback, progress_bar_callback],
        logger=csv_logger,
        **config["trainer"],
    )

    model = LightningModuleTrLoss(config["model"])

    trainer.fit(model, data_loader_train, data_loader_validation)

    del (
        dataset_train,
        dataset_validation,
        data_loader_train,
        data_loader_validation,
        model,
        trainer,
        checkpoint_callback,
        progress_bar_callback,
        early_stop_callback,
    )
    torch.cuda.empty_cache()
    gc.collect() 

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



###### Fold 3


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type     | Params
-----------------------------------------
0 | model       | Unet     | 56.6 M
1 | loss_module | DiceLoss | 0     
-----------------------------------------
56.6 M    Trainable params
0         Non-trainable params
56.6 M    Total params
226.221   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]


Epoch: 0


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


Epoch: 0


Metric val_loss improved. New best score: 0.451
Epoch 0, global step 367: 'val_dice' reached 0.56138 (best 0.56138), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest101e_ash_attention_sample_loss_kfold/model-f3-val_dice=0.5614.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 1


Metric val_loss improved by 0.048 >= min_delta = 0.0. New best score: 0.402
Epoch 1, global step 734: 'val_dice' reached 0.60450 (best 0.60450), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest101e_ash_attention_sample_loss_kfold/model-f3-val_dice=0.6045.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 2


Metric val_loss improved by 0.027 >= min_delta = 0.0. New best score: 0.376
Epoch 2, global step 1101: 'val_dice' reached 0.63120 (best 0.63120), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest101e_ash_attention_sample_loss_kfold/model-f3-val_dice=0.6312.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 3


Epoch 3, global step 1468: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 4


Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 0.358
Epoch 4, global step 1835: 'val_dice' reached 0.64662 (best 0.64662), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest101e_ash_attention_sample_loss_kfold/model-f3-val_dice=0.6466.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 5


Epoch 5, global step 2202: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 6


Epoch 6, global step 2569: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 7


Epoch 7, global step 2936: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 8


Epoch 8, global step 3303: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 9


Monitored metric val_loss did not improve in the last 5 records. Best score: 0.358. Signaling Trainer to stop.
Epoch 9, global step 3670: 'val_dice' was not in top 1
Trainer was signaled to stop but the required `min_epochs=20` or `min_steps=None` has not been met. Training will continue...


Validation: 0it [00:00, ?it/s]


Epoch: 10


Monitored metric val_loss did not improve in the last 6 records. Best score: 0.358. Signaling Trainer to stop.
Epoch 10, global step 4037: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 11


Monitored metric val_loss did not improve in the last 7 records. Best score: 0.358. Signaling Trainer to stop.
Epoch 11, global step 4404: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 12


Monitored metric val_loss did not improve in the last 8 records. Best score: 0.358. Signaling Trainer to stop.
Epoch 12, global step 4771: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 13


Monitored metric val_loss did not improve in the last 9 records. Best score: 0.358. Signaling Trainer to stop.
Epoch 13, global step 5138: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 14


Monitored metric val_loss did not improve in the last 10 records. Best score: 0.358. Signaling Trainer to stop.
Epoch 14, global step 5505: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 15


Monitored metric val_loss did not improve in the last 11 records. Best score: 0.358. Signaling Trainer to stop.
Epoch 15, global step 5872: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 16


Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.350
Epoch 16, global step 6239: 'val_dice' reached 0.65373 (best 0.65373), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest101e_ash_attention_sample_loss_kfold/model-f3-val_dice=0.6537.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 17


Epoch 17, global step 6606: 'val_dice' was not in top 1


In [None]:
path = 
class LightningModuleValid(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = smp.Unet(encoder_name = "timm-resnest101e",
                              encoder_weights=None,
                              in_channels=3,
                              classes=1,
                              activation=None,
                              decoder_attention_type = "scse",
                              )

    def forward(self, batch):
        return self.model(batch)
    
model = LightningModule().load_from_checkpoint(path)

## 512 101e

In [None]:
trainer.

In [None]:
import gc
gc.enable()
alpha = 0.5
alpha_union = 0.85
weight_ground_truth = 0.45

cfg = resnest101e_ash_attention_sample_loss_kfold_512
cfg['model']["alpha"] = alpha
cfg['model']["beta"] = 1-alpha
cfg['model']["alpha_union"] = alpha_union
cfg['model']["beta_union"] = 1-alpha_union
cfg['model']["weight_ground_truth"] = weight_ground_truth
cfg['model']["weight_union"] = 1-weight_ground_truth
config = cfg
identifier = "resnest26d_alpha="+str(alpha)+"_alpha_union="+str(alpha_union)+"_weight_ground_truth="+str(weight_ground_truth)

folders_train = os.listdir(data_path_mf_train)
train_df = pd.DataFrame(folders_train, columns=['record_id'])
train_df['train'] = 'train'
train_df.to_csv(os.path.join("/home/albert/ml/Contrails/data/full_dataset", 'train_df.csv'), index=False)
train_df["path"] = data_path_mf_train + train_df["record_id"].astype(str)

Fold = KFold(shuffle=True, **config["folds"])
for n, (trn_index, val_index) in enumerate(Fold.split(train_df)):
    train_df.loc[val_index, "kfold"] = int(n)
train_df["kfold"] = train_df["kfold"].astype(int)

for fold in config["train_folds"]:
    print(f"\n###### Fold {fold}")
    trn_df = train_df[train_df.kfold != fold].reset_index(drop=True)
    vld_df = train_df[train_df.kfold == fold].reset_index(drop=True)

    dataset_train = ContrailsDatasetDf(trn_df, "train", "single",0 )
    dataset_validation = ContrailsDatasetDf(vld_df,"validation", "single",0)

    data_loader_train = DataLoader(
        dataset_train,
        batch_size=config["train_bs"],
        shuffle=True,
        num_workers=config["workers"],
    )
    data_loader_validation = DataLoader(
        dataset_validation,
        batch_size=config["valid_bs"],
        shuffle=False,
        num_workers=config["workers"],
    )

    checkpoint_callback = ModelCheckpoint(
        save_weights_only=True,
        monitor="val_dice",
        dirpath=config["output_dir"],
        mode="max",
        filename=f"model-f{fold}-{{val_dice:.4f}}",
        save_top_k=1,
        verbose=1,
    )

    progress_bar_callback = TQDMProgressBar(
        refresh_rate=config["progress_bar_refresh_rate"]
    )

    early_stop_callback = EarlyStopping(**config["early_stop"])

    csv_logger = CSVLogger(cfg["output_dir"], name=identifier+f'fold_{fold}/')
    trainer = L.Trainer(
        callbacks=[checkpoint_callback, early_stop_callback, progress_bar_callback],
        logger=csv_logger,
        **config["trainer"],
    )

    model = LightningModuleTrLoss512(config["model"])

    trainer.fit(model, data_loader_train, data_loader_validation)

    del (
        dataset_train,
        dataset_validation,
        data_loader_train,
        data_loader_validation,
        model,
        trainer,
        checkpoint_callback,
        progress_bar_callback,
        early_stop_callback,
    )
    torch.cuda.empty_cache()
    gc.collect() 

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



###### Fold 0


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type     | Params
-----------------------------------------
0 | model       | Unet     | 56.6 M
1 | loss_module | DiceLoss | 0     
-----------------------------------------
56.6 M    Trainable params
0         Non-trainable params
56.6 M    Total params
226.221   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]


Epoch: 0


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


Epoch: 0


Metric val_loss improved. New best score: 0.448
Epoch 0, global step 770: 'val_dice' reached 0.56261 (best 0.56261), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest101e_ash_attention_sample_loss_kfold_512/model-f0-val_dice=0.5626.ckpt' as top 1


In [None]:
import seaborn as sn
import matplotlib.pyplot as plt

for fold in config["train_folds"]:
    metrics = pd.read_csv(f"/kaggle/working/logs_f{fold}/lightning_logs/version_0/metrics.csv")
    del metrics["step"]
    del metrics["lr"]
    del metrics["train_loss_step"]
    metrics.set_index("epoch", inplace=True)
    g = sn.relplot(data=metrics, kind="line")
    plt.title(f"Fold {fold}")
    plt.gcf().set_size_inches(15, 5)
    plt.grid()
    plt.show()

### Results

## Baseline

In [135]:
sample_union_loss(0, resnest26d_ash_attention_sample_loss, LightningModuleTrLoss, alpha = 0.5,alpha_union = 0.5, weight_ground_truth = 1)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type     | Params
-----------------------------------------
0 | model       | Unet     | 25.3 M
1 | loss_module | DiceLoss | 0     
-----------------------------------------
25.3 M    Trainable params
0         Non-trainable params
25.3 M    Total params
101.314   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]


Epoch: 0


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


Epoch: 0


Metric val_loss improved. New best score: 0.498
Epoch 0, global step 302: 'val_dice' reached 0.52061 (best 0.52061), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest26d_ash_attention_sample_loss/resnest101e_alpha=0.5_alpha_union=0.5_weight_ground_truth=1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 1


Metric val_loss improved by 0.040 >= min_delta = 0.0. New best score: 0.457
Epoch 1, global step 604: 'val_dice' reached 0.55547 (best 0.55547), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest26d_ash_attention_sample_loss/resnest101e_alpha=0.5_alpha_union=0.5_weight_ground_truth=1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 2


Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.455
Epoch 2, global step 906: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 3


Metric val_loss improved by 0.029 >= min_delta = 0.0. New best score: 0.426
Epoch 3, global step 1208: 'val_dice' reached 0.58091 (best 0.58091), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest26d_ash_attention_sample_loss/resnest101e_alpha=0.5_alpha_union=0.5_weight_ground_truth=1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 4


Epoch 4, global step 1510: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 5


Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 0.411
Epoch 5, global step 1812: 'val_dice' reached 0.59485 (best 0.59485), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest26d_ash_attention_sample_loss/resnest101e_alpha=0.5_alpha_union=0.5_weight_ground_truth=1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 6


Epoch 6, global step 2114: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 7


Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.406
Epoch 7, global step 2416: 'val_dice' reached 0.60106 (best 0.60106), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest26d_ash_attention_sample_loss/resnest101e_alpha=0.5_alpha_union=0.5_weight_ground_truth=1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 8


Epoch 8, global step 2718: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 9


Epoch 9, global step 3020: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 10


Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.404
Epoch 10, global step 3322: 'val_dice' reached 0.60125 (best 0.60125), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest26d_ash_attention_sample_loss/resnest101e_alpha=0.5_alpha_union=0.5_weight_ground_truth=1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 11


Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.394
Epoch 11, global step 3624: 'val_dice' reached 0.61328 (best 0.61328), saving model to '/home/albert/ml/Contrails/notebooks/full_dataset/test_use_all_bands/resnest26d_ash_attention_sample_loss/resnest101e_alpha=0.5_alpha_union=0.5_weight_ground_truth=1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]


Epoch: 12


Epoch 12, global step 3926: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 13


Epoch 13, global step 4228: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 14


Epoch 14, global step 4530: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 15


Epoch 15, global step 4832: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 16


Monitored metric val_loss did not improve in the last 5 records. Best score: 0.394. Signaling Trainer to stop.
Epoch 16, global step 5134: 'val_dice' was not in top 1
Trainer was signaled to stop but the required `min_epochs=24` or `min_steps=None` has not been met. Training will continue...


Validation: 0it [00:00, ?it/s]


Epoch: 17


Monitored metric val_loss did not improve in the last 6 records. Best score: 0.394. Signaling Trainer to stop.
Epoch 17, global step 5436: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 18


Monitored metric val_loss did not improve in the last 7 records. Best score: 0.394. Signaling Trainer to stop.
Epoch 18, global step 5738: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 19


Monitored metric val_loss did not improve in the last 8 records. Best score: 0.394. Signaling Trainer to stop.
Epoch 19, global step 6040: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 20


Monitored metric val_loss did not improve in the last 9 records. Best score: 0.394. Signaling Trainer to stop.
Epoch 20, global step 6342: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]


Epoch: 21


Monitored metric val_loss did not improve in the last 10 records. Best score: 0.394. Signaling Trainer to stop.
Epoch 21, global step 6644: 'val_dice' was not in top 1


1. Bei nicht benutzung von k folds und den unterschiedlichen verteilungen der klassen in val/train ist tversky_loss mit a>b marginal besser als dice loss in val_dice
2. Image size größer scheint val_dice zu verbessern bei verlängerter lr
