In [1]:
import math
import os
import random
from typing import Any, Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from ipywidgets import IntSlider, interact
from sklearn.metrics import classification_report, confusion_matrix
from torch.nn import Parameter
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import wandb

wandb.login()

from pytorch_lightning.loggers import WandbLogger

  warn(


cuda


[34m[1mwandb[0m: Currently logged in as: [33mnbennewiz[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
root = "../../Data"
data_paths = [os.path.join(root, path) for path in os.listdir(root)]


def get_data_loaders(
    batch_size=4, sequence_length=4, num_workers=0, pin_memory=False, drop_last=False
):
    # Load all folder paths
    # Split into train/val/test
    n_train = int(len(data_paths) * train_split)
    n_val = int(len(data_paths) * val_split)
    train_paths = data_paths[:n_train]
    val_paths = data_paths[n_train : n_train + n_val]
    test_paths = data_paths[n_train + n_val :]

    # Create datasets
    train_dataset = Dataset2D(train_paths, sequence_length)
    val_dataset = Dataset2D(val_paths, sequence_length)
    test_dataset = Dataset2D(test_paths, sequence_length)

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last
    )
    val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=drop_last)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=drop_last)

    return train_loader, val_loader, test_loader

10


In [8]:
train_loader, val_loader, test_loader = get_data_loaders(
    batch_size=batch_size, sequence_length=sequence_length
)

In [9]:
for i, (input, target) in enumerate(train_loader):
    print(i, input.shape, target.shape)  # [N, T, C, H, W]
    break

0 torch.Size([4, 8, 1, 256, 256]) torch.Size([4, 1, 1, 256, 256])


#### 2D SimpVP

In [10]:
class BasicConv2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        transpose=False,
        act_norm=False,
    ):
        super(BasicConv2d, self).__init__()
        self.act_norm = act_norm
        if not transpose:
            self.conv = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
            )
        else:
            self.conv = nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                output_padding=stride // 2,
            )
        self.norm = nn.GroupNorm(2, out_channels)
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        y = self.conv(x)
        if self.act_norm:
            y = self.act(self.norm(y))
        return y


class ConvSC(nn.Module):
    def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
        super(ConvSC, self).__init__()
        if stride == 1:
            transpose = False
        self.conv = BasicConv2d(
            C_in,
            C_out,
            kernel_size=3,
            stride=stride,
            padding=1,
            transpose=transpose,
            act_norm=act_norm,
        )

    def forward(self, x):
        y = self.conv(x)
        return y


class GroupConv2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        groups,
        act_norm=False,
    ):
        super(GroupConv2d, self).__init__()
        self.act_norm = act_norm
        if in_channels % groups != 0:
            groups = 1
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
        )
        self.norm = nn.GroupNorm(groups, out_channels)
        self.activate = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        y = self.conv(x)
        if self.act_norm:
            y = self.activate(self.norm(y))
        return y


class Inception(nn.Module):
    def __init__(self, C_in, C_hid, C_out, incep_ker=[3, 5, 7, 11], groups=8):
        super(Inception, self).__init__()
        self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
        layers = []
        for ker in incep_ker:
            layers.append(
                GroupConv2d(
                    C_hid,
                    C_out,
                    kernel_size=ker,
                    stride=1,
                    padding=ker // 2,
                    groups=groups,
                    act_norm=True,
                )
            )
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        y = 0
        for layer in self.layers:
            y += layer(x)
        return y

In [11]:
def stride_generator(N, reverse=False):
    strides = [1, 2] * 10
    if reverse:
        return list(reversed(strides[:N]))
    else:
        return strides[:N]


class Encoder(nn.Module):
    def __init__(self, C_in, C_hid, N_S):
        super(Encoder, self).__init__()
        strides = stride_generator(N_S)
        self.enc = nn.Sequential(
            ConvSC(C_in, C_hid, stride=strides[0]),
            *[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
        )

    def forward(self, x):  # B*4, 3, 128, 128
        enc1 = self.enc[0](x)
        latent = enc1
        for i in range(1, len(self.enc)):
            latent = self.enc[i](latent)
        return latent, enc1


class Decoder(nn.Module):
    def __init__(self, C_hid, C_out, N_S):
        super(Decoder, self).__init__()
        strides = stride_generator(N_S, reverse=True)
        self.dec = nn.Sequential(
            *[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
            ConvSC(2 * C_hid, C_hid, stride=strides[-1], transpose=True)
        )
        self.readout = nn.Conv2d(C_hid, C_out, 1)

    def forward(self, hid, enc1=None):
        for i in range(0, len(self.dec) - 1):
            hid = self.dec[i](hid)
        Y = self.dec[-1](torch.cat([hid, enc1], dim=1))
        Y = self.readout(Y)
        return Y


class Mid_Xnet(nn.Module):
    def __init__(self, channel_in, channel_hid, N_T, incep_ker=[3, 5, 7, 11], groups=8):
        super(Mid_Xnet, self).__init__()

        self.N_T = N_T
        enc_layers = [
            Inception(
                channel_in,
                channel_hid // 2,
                channel_hid,
                incep_ker=incep_ker,
                groups=groups,
            )
        ]
        for i in range(1, N_T - 1):
            enc_layers.append(
                Inception(
                    channel_hid,
                    channel_hid // 2,
                    channel_hid,
                    incep_ker=incep_ker,
                    groups=groups,
                )
            )
        enc_layers.append(
            Inception(
                channel_hid,
                channel_hid // 2,
                channel_hid,
                incep_ker=incep_ker,
                groups=groups,
            )
        )

        dec_layers = [
            Inception(
                channel_hid,
                channel_hid // 2,
                channel_hid,
                incep_ker=incep_ker,
                groups=groups,
            )
        ]
        for i in range(1, N_T - 1):
            dec_layers.append(
                Inception(
                    2 * channel_hid,
                    channel_hid // 2,
                    channel_hid,
                    incep_ker=incep_ker,
                    groups=groups,
                )
            )
        dec_layers.append(
            Inception(
                2 * channel_hid,
                channel_hid // 2,
                channel_in,
                incep_ker=incep_ker,
                groups=groups,
            )
        )

        self.enc = nn.Sequential(*enc_layers)
        self.dec = nn.Sequential(*dec_layers)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.reshape(B, T * C, H, W)

        # encoder
        skips = []
        z = x
        for i in range(self.N_T):
            z = self.enc[i](z)
            if i < self.N_T - 1:
                skips.append(z)

        # decoder
        z = self.dec[0](z)
        for i in range(1, self.N_T):
            z = self.dec[i](torch.cat([z, skips[-i]], dim=1))

        y = z.reshape(B, T, C, H, W)
        return y


class SimVP(nn.Module):
    def __init__(
        self,
        shape_in,
        hid_S=16,
        hid_T=256,
        N_S=4,
        N_T=8,
        incep_ker=[3, 5, 7, 11],
        groups=8,
    ):
        super(SimVP, self).__init__()
        T, C, H, W = shape_in
        self.enc = Encoder(C, hid_S, N_S)
        self.hid = Mid_Xnet(T * hid_S, hid_T, N_T, incep_ker, groups)
        self.dec = Decoder(hid_S, C, N_S)

    def forward(self, x_raw):
        B, T, C, H, W = x_raw.shape
        x = x_raw.view(B * T, C, H, W)

        embed, skip = self.enc(x)
        _, C_, H_, W_ = embed.shape

        z = embed.view(B, T, C_, H_, W_)
        hid = self.hid(z)
        hid = hid.reshape(B * T, C_, H_, W_)

        Y = self.dec(hid, skip)
        Y = Y.reshape(B, T, C, H, W)
        return Y

In [12]:
# testing the model
model = SimVP([8, 1, 256, 256]).to(device)
x = torch.randn(4, 8, 1, 256, 256).to(device)
out = model(x)
print(out.shape)
print(out[:, :1, :, :].shape)

torch.Size([4, 8, 1, 256, 256])
torch.Size([4, 1, 1, 256, 256])


In [13]:
class Pl_Model(pl.LightningModule):
    def __init__(
        self,
        passed_model: nn.Module,
        config: Dict[str, Any],
    ):
        super(Pl_Model, self).__init__()
        self.passed_model = passed_model
        self.config = config

        # speicher alle parameter ab
        self.save_hyperparameters()

        # Setup training components
        self.mse_criterion = nn.MSELoss()
        self.huber_criterion = nn.HuberLoss(delta=1.0)

    def forward(self, x):
        x = self.passed_model(x)
        return x

    def configure_optimizers(self):
        """Sets the Optimizer for the Model"""
        optimizer = optim.Adam(
            self.parameters(),
            lr=config["learning_rate"],
        )
        return [optimizer]

    def _calculate_loss(self, batch, mode="train"):
        """Calculates the loss for a batch in different modes (training, validation, testing)"""
        inputs, targets = batch
        # to device
        # inputs = inputs.to(device)
        # targets = targets.to(device)

        # forward pass
        outputs = self.forward(inputs)
        # get only the first predicted frame
        outputs = outputs[:, :1, :, :]
        # calcualte losses
        mse_loss = self.mse_criterion(outputs, targets)
        huber_loss = self.huber_criterion(outputs, targets)
        total_loss = mse_loss + 0.5 * huber_loss

        # logging
        self.log(f"{mode}_mse_loss", mse_loss)
        self.log(f"{mode}_huber_loss", huber_loss)
        self.log(f"{mode}_total_loss", total_loss)

        return total_loss, mse_loss, huber_loss

    def training_step(self, batch, batch_idx):
        loss, _, _ = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="test")

    def check_losses(self, loader, mode, use_wandb=False):
        mse_loss = 0.0
        huber_loss = 0.0
        total_loss = 0.0
        for inputs, targets in loader:
            mse_loss_ = self.mse_criterion(inputs[:, -1, :, :, :].unsqueeze(1), targets)
            huber_loss_ = self.huber_criterion(
                inputs[:, -1, :, :, :].unsqueeze(1), targets
            )
            total_loss_ = mse_loss_ + 0.5 * huber_loss_

            mse_loss += mse_loss_.item()
            huber_loss += huber_loss_.item()
            total_loss += total_loss_.item()
        mse_loss = mse_loss / len(loader)
        huber_loss = huber_loss / len(loader)
        total_loss = total_loss / len(loader)

        if use_wandb:
            wandb.log({f"Checked_{mode}_mse_loss": mse_loss})
            wandb.log({f"Checked_{mode}_mse_loss": huber_loss})
            wandb.log({f"Checked_{mode}_mse_loss": total_loss})

        return mse_loss, huber_loss, total_loss

    def log_predictions(self):
        """Log example predictions to wandb"""
        # needs to be added to other method
        if epoch % self.config["viz_interval"] == 0:
            self.log_predictions()
        # but this whole method needs to be rewritten
        self.model.eval()
        with torch.no_grad():
            # Get a batch of validation data
            data, target = next(iter(self.val_loader))
            data = data.to(self.device)
            target = target.to(self.device)

            # Generate predictions
            output = self.model(data)

            # Log images
            wandb.log(
                {
                    "predictions": wandb.Image(output[0, 0].cpu()),
                    "targets": wandb.Image(target[0, 0].cpu()),
                    "input_sequence": [
                        wandb.Image(data[0, i].cpu()) for i in range(data.shape[1])
                    ],
                }
            )

#average
input_t-1  -> target_t
model(input_t-1) = output_t

loss = criterion(output_t, target_t)
loss = criterion(input_t-1, target_t)

criterion(output_t, target_t) < criterion(input_t-1, target_t)

In [14]:
config = {
    # for the dataloaders
    "batch_size": 20,
    "learning_rate": 0.0005,
    "num_workers": 10,  # 0, wenn die gpu nicht benutzt wird
    "pin_memory": True,  # False, wenn die gpu nicht benutzt wird
    "drop_last": False,
    "epochs": 40,
    #'log_interval': 20,
    #'viz_interval': 1,
    "run_name": "2D-SimpVP_v1",
    "input_frames": 8,
    "base_filters": 32,
}

# Initialize model
model = SimVP(shape_in=[8, 1, 256, 256])

# Get data loaders
train_loader, val_loader, test_loader = get_data_loaders(
    batch_size=config["batch_size"],
    sequence_length=config["input_frames"],
    num_workers=config["num_workers"],
    pin_memory=config["pin_memory"],
    drop_last=config["drop_last"],
)

wandb_logger = WandbLogger(project="perfusion-ct-prediction", name=config["run_name"])

# Initialize pl_model
pl_model = Pl_Model(
    passed_model=model,
    config=config,
)

# Initialize trainer
trainer = pl.Trainer(
    logger=wandb_logger,
    accelerator="gpu",
    devices=[0] if torch.cuda.is_available() else None,
    max_epochs=config["epochs"],
)

wandb_logger.watch(pl_model)

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'passed_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['passed_model'])`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [15]:
trainer.fit(
    pl_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

# check the losses "to beat"
pl_model.check_losses(train_loader, mode="train", use_wandb=True)
pl_model.check_losses(val_loader, mode="val", use_wandb=True)
pl_model.check_losses(test_loader, mode="test", use_wandb=True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name            | Type      | Params | Mode 
------------------------------------------------------
0 | passed_model    | SimVP     | 13.8 M | train
1 | mse_criterion   | MSELoss   | 0      | train
2 | huber_criterion | HuberLoss | 0      | train
------------------------------------------------------
13.8 M    Trainable params
0         Non-trainable params
13.8 M    Total params
55.030    Total estimated model params size (MB)
355       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=40` reached.


MisconfigurationException: You are trying to `self.log()` but the loop's result collection is not registered yet. This is most likely because you are trying to log in a `predict` hook, but it doesn't support logging

In [16]:
val_results = trainer.validate(dataloaders=val_loader)
test_results = trainer.test(dataloaders=test_loader)

Restoring states from the checkpoint path at ./perfusion-ct-prediction/aw5tpimc/checkpoints/epoch=39-step=2560.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at ./perfusion-ct-prediction/aw5tpimc/checkpoints/epoch=39-step=2560.ckpt


Validation: |                                                                                                 …

Restoring states from the checkpoint path at ./perfusion-ct-prediction/aw5tpimc/checkpoints/epoch=39-step=2560.ckpt


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     val_huber_loss        0.005670313723385334
      val_mse_loss         0.011347890831530094
     val_total_loss        0.014183047227561474
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at ./perfusion-ct-prediction/aw5tpimc/checkpoints/epoch=39-step=2560.ckpt
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Testing: |                                                                                                    …

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_huber_loss       0.006982726044952869
      test_mse_loss        0.013984769582748413
     test_total_loss        0.01747613213956356
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [18]:
save_load_path = f"../ModelWeights/{config['run_name']}.ckpt"
trainer.save_checkpoint(save_load_path)