In [22]:
import sys

#%conda install --yes jupyter
#%conda install --yes pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
%conda list


Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import h5py
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from  torch.utils.data import DataLoader

import albumentations as A
import albumentations.pytorch.transforms as Atorch

import pytorch_lightning as pl

import segmentation_models_pytorch as smp

from pathlib import Path
from tqdm.auto import tqdm

BASEDIR = Path("")
fn = BASEDIR / "train_eval.hdf5"

# WICHTIG: Nur so viele CPUs benutzen wie unserem Job zugeteilt sind.
#   Sonst wird alles *sehr* langsam!
N_CPUS = int(os.getenv("SLURM_CPUS_PER_TASK", 1))
torch.set_num_threads(N_CPUS)

  warn(


In [3]:
torch.cuda.is_available()

False

In [4]:
objectWithMask = 0
objectWithPost_fire = 0
objectWithPre_fire = 0

objectsSummary = []
currentObject = {}

def listObject(obj, indent=0):
    global objectWithMask
    global objectWithPost_fire
    global objectWithPre_fire
    global objectsSummary
    global currentObject

    
    
    for name, thing in obj.items():
        if indent == 0:
            currentObject = { "name" : "", "mask": False, "pre_fire" : False, "post_fire": False}

        if isinstance(thing, h5py.Group):
            if indent == 0:
                currentObject["name"] = name
            listObject(thing, indent+1)
            
        else:
            if name == "mask":
                currentObject["mask"] = True
                objectWithMask += 1
            elif name == "post_fire":
                currentObject["post_fire"] = True
                objectWithPost_fire += 1
            elif name == "pre_fire":
                currentObject["pre_fire"] = True
                objectWithPre_fire += 1
            else:
                print(name)
    
        if indent == 0:
            objectsSummary.append(currentObject)

In [5]:
objectWithMask = 0
objectWithPost_fire = 0
objectWithPre_fire = 0

objectsSummary = []
currentObject = {}

with h5py.File(fn, "r") as fd:
    listObject(fd)

In [6]:
class FiresDataset(torch.utils.data.Dataset):
    COORDS = {"x": range(512), "y": range(512), "band": ["1", "2", "3", "4", "5", "6", "7", "8", "8a", "9", "11", "12"]}
    
    def __init__(self, filename, folds=(0, 1, 2, 3, 4), exclude_pre=False, transform=None):
        self._filename = filename
        self._fd = h5py.File(filename, "r")
        self._transform = transform
        self._names = []        
        for name in self._fd:
            if self._fd[name].attrs["fold"] not in folds:
                continue
            self._names.append((name, "post_fire"))
            if "pre_fire" in self._fd[name] and not exclude_pre:
                self._names.append((name, "pre_fire"))
        
    def __getitem__(self, idx):
        name, state = self._names[idx]
        data = self._fd[name][state][...].astype("float32") / 10000.0
        if state == "pre_fire":
            mask = np.zeros((512, 512), dtype="float32")
        else:
            mask = self._fd[name]["mask"][..., 0].astype("float32")
            
        b2 = data[..., 1]
        b3 = data[..., 2]
        b4 = data[..., 3]
        b6 = data[..., 5]
        b7 = data[..., 6]
        b8a = data[..., 8]
        b11 = data[..., 10]
        b12 = data[..., 11]

        # Indices taken from https://www.mdpi.com/2072-4.astype("float32") / 10000.092/14/7/1727
        #
        # NBR
        nbr = (b12 - b8a) / (b12 + b8a + 1.0e-8)
        
        # MIRBI
        mirbi = 10 * b12 - 9.8 * b11 + 2.0

        # BAIS2:
        c1 = 1 - np.sqrt((b6 * b7 * b8a) / (b4 + 1.0e-8))
        c2 = (b12 - b8a) / np.sqrt(b12 + b8a + 1.0e-8) + 1
        bais2 = (c1 * c2)        
        
        # NBR+:        
        nbr_plus = ((b12 - b8a - b3 - b2) / (b12 + b8a + b3 + b2 + 1.0e-8))

        # Stack indices into a new image in CHW format.
        image =  np.stack((nbr, mirbi, bais2, nbr_plus))

        if self._transform:
            # Transpose image so we get HWC instead of CHW format.
            # Transform is responsible for transposing back as required by PyTorch.
            image = image.transpose((1, 2, 0))
            xfrm = self._transform(image=image, mask=mask)
            image, mask = xfrm["image"], xfrm["mask"]
        return {"image": image, "mask": mask[None, :]}

    def __len__(self):
        return len(self._names)

Wir augmentieren die Trainingsdaten in dem wir jedes Bild (mit Wahrscheinlichkeit 0.5):

* Horizontal spiegeln
* Vertical spiegeln
* Transponieren
* Um 90° rotieren

In [7]:
class FireModel(pl.LightningModule):
    def __init__(self, in_channels=4, batch_size=16, lr=0.000025, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.batch_size = batch_size        
        self.model = smp.UnetPlusPlus(in_channels=in_channels, 
                                      classes=1,
                                      encoder_depth=3,
                                      encoder_weights="imagenet",
                                      decoder_channels=[256, 128, 64],
                                      **kwargs)
        # for image segmentation dice loss is a good first choice
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

    def forward(self, image):
        return self.model(image)

    def shared_step(self, batch, stage):      
        image, mask = batch["image"], batch["mask"]

        logits = self.forward(image)
        
        loss = self.loss_fn(logits, mask)
        
        metrics = {}
        metrics[f"{stage}_loss"] = loss
        
        prob_mask = logits.sigmoid()
        pred_mask = (prob_mask > 0.5).long()
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")        
        metrics[f"{stage}_iou"] = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")

        self.log_dict(metrics, prog_bar=True)                  
        return {"loss": loss}

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")        

    def train_dataloader(self):
        train_xfrm = A.Compose([
            A.VerticalFlip(p=0.5),       
            A.HorizontalFlip(p=0.5),
            A.Transpose(p=0.5),
            A.RandomRotate90(p=0.5),
            Atorch.ToTensorV2()
        ])

        result = DataLoader(
            FiresDataset(fn, folds=[1, 2, 3, 4], transform=train_xfrm, exclude_pre=True),
            batch_size=self.batch_size,
            num_workers=N_CPUS,
            shuffle=True,
            pin_memory=True,
            drop_last=True,
        )
        return result
    
    def val_dataloader(self):
        result = DataLoader(
            FiresDataset(fn, folds=[0], exclude_pre=True),
            batch_size=self.batch_size,
            num_workers=N_CPUS,
            shuffle=False,
            pin_memory=True,
            drop_last=False,
        )
        return result
    
    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")
                
    def configure_optimizers(self):
        # TODO: Can we do better? We should probably implement a learning rate schedule?
        return torch.optim.Adam(self.parameters(), lr=self.lr)

## Modell erstellen

Entweder wir laden einen vorherigen Checkpoint und setzen das Training fort, oder wir erstellen ein neues (leeres) Modell.

In [8]:
if False:
    model = FireModel.load_from_checkpoint("lightning_logs/version_19/checkpoints/epoch=19-step=520.ckpt")
else:
    model = FireModel(lr=0.001)

## Trainieren

Anstelle einer Trainingsschleife nutzen wir den PyTorch Lightning `Trainer` um das Trainieren zu koordinieren.

In [9]:
trainer = pl.Trainer(log_every_n_steps=5, max_epochs=20)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Training starten bis die maximale Anzahl an Epochen erreicht ist oder das Training stagniert.

In [10]:
trainer.validate(model=model, verbose=True)

TypeError: An invalid dataloader was returned from `FireModel.val_dataloader()`. Found <torch.utils.data.dataloader.DataLoader object at 0x00000217093D6B00>.

In [None]:


trainer.fit(model)


  | Name    | Type         | Params
-----------------------------------------
0 | model   | UnetPlusPlus | 23.1 M
1 | loss_fn | DiceLoss     | 0     
-----------------------------------------
23.1 M    Trainable params
0         Non-trainable params
23.1 M    Total params
92.535    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

TypeError: An invalid dataloader was returned from `FireModel.val_dataloader()`. Found <torch.utils.data.dataloader.DataLoader object at 0x0000021A8A77F1F0>.