In [2]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.version.cuda)



True
1
12.1


# U-Net Model

In [3]:
class DoubleConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.step = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1),
            torch.nn.ReLU()
        )


    def forward(self, X):
        return self.step(X)

In [4]:
class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Encoding Layers
        self.layer1 = DoubleConv(in_channels = 1, out_channels = 64)
        self.layer2 = DoubleConv(in_channels = 64, out_channels = 128)
        self.layer3 = DoubleConv(in_channels = 128, out_channels = 256)
        self.layer4 = DoubleConv(in_channels = 256, out_channels = 512)

        # Decoding Layers
        self.layer5 = DoubleConv(in_channels = 512+256, out_channels = 256) # verbunden mit layer3
        self.layer6 = DoubleConv(in_channels = 256+128, out_channels = 128) # verbunden mit layer2
        self.layer7 = DoubleConv(in_channels = 128+64, out_channels = 64) # verbunden mit layer1
        self.layer8 = torch.nn.Conv2d(in_channels = 64, out_channels = 1, kernel_size = 1)

        # Pooling
        self.maxpool = torch.nn.MaxPool2d(2)

    def forward(self, x):
        # Encoding Layers
        x1 = self.layer1(x)
        x1m = self.maxpool(x1)

        x2 = self.layer2(x1m)
        x2m = self.maxpool(x2)

        x3 = self.layer3(x2m)
        x3m = self.maxpool(x3)

        x4 = self.layer4(x3m)

        # Decoding Layers
        x5 = torch.nn.Upsample(scale_factor=2, mode="bilinear")(x4) # Upsample x4
        x5 = torch.cat([x5, x3], dim=1) # konkatinierte features: verbindung zwischen layer5 und layer3. Channel dim = 1
        x5 = self.layer5(x5)

        x6 = torch.nn.Upsample(scale_factor=2, mode="bilinear")(x5) 
        x6 = torch.cat([x6, x2], dim=1) 
        x6 = self.layer6(x6)

        x7 = torch.nn.Upsample(scale_factor=2, mode="bilinear")(x6) 
        x7 = torch.cat([x7, x1], dim=1) 
        x7 = self.layer7(x7)

        ret = self.layer8(x7)

        return ret

In [5]:
model = UNet()

random_input = torch.randn(1,1,256,256) # (batch size, channel dimension, pixel width, pixel height)
output = model(random_input)
assert output.shape == torch.Size([1,1,256,256])


# Training

In [6]:
from pathlib import Path

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import imgaug.augmenters as iaa
import numpy as np
import matplotlib.pyplot as plt

from dataset import CardiacDataset



### Augmentation Pipeline
Zufällige Transformationen um robuster zu lernen

In [7]:
seq = iaa.Sequential([
    iaa.Affine(scale=(0.85, 1.15), # Zoom in or out
               rotate=(-45, 45)),  # Rotate up to 45 degrees
    iaa.ElasticTransformation()  # Random Elastic Deformations
])

### Dataset

In [8]:
train_path = Path("Preprocessed/train/")
val_path = Path("Preprocessed/val/")

train_dataset = CardiacDataset(train_path, seq) # nur das train dataset wird augmented
val_dataset = CardiacDataset(val_path, None)

### Dataloader

In [9]:
batch_size = 8
num_workers = 4

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers,shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

### Custom Lossfunction (Dice Loss)

In [10]:
class DiceLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, mask): # predicted segmentation, real segmentation
        pred = torch.flatten(pred)
        mask = torch.flatten(mask)

        counter = (pred * mask).sum()
        denum = pred.sum() + mask.sum() + 1e-8 #  1e-8, so that we never divide by 0
        dice = (2*counter) / denum
        return 1 - dice

In [18]:
class BrainTumorSegmentation(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.model = UNet()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr= 1e-4)
        self.loss_fn = DiceLoss()

    def forward(self, data):
        return torch.sigmoid(self.model(data))
    
    def training_step(self, batch, batch_idx):
        mri, mask = batch              
        mask = mask.float() # real segmentation
        pred = self(mri) # predicted segmentation

        loss = self.loss_fn(pred, mask)

        self.log("Train Dice", loss)

        if batch_idx % 50 == 0:
            self.log_images(mri.cpu(), pred.cpu(), mask.cpu(), "Train")

        return loss
    
    def validation_step(self, batch, batch_idx):
        mri, mask = batch              
        mask = mask.float() # real segmentation
        pred = self(mri) # predicted segmentation

        loss = self.loss_fn(pred, mask)

        self.log("Val Dice", loss)

        if batch_idx % 2 == 0:
            self.log_images(mri.cpu(), pred.cpu(), mask.cpu(), "Val")

        return loss
    
    def log_images(self, mri, pred, mask, name):

        pred = pred > 0.5 # Only pixels with pred > 0.5 are is segmantation class (Threshholding)
        
        fig, axis = plt.subplots(1, 2)
        axis[0].imshow(mri[0][0], cmap="bone")
        mask_ = np.ma.masked_where(mask[0][0] == 0, mask[0][0])
        axis[0].imshow(mask_, alpha=0.6)
        
        axis[1].imshow(mri[0][0], cmap="bone")
        mask_ = np.ma.masked_where(pred[0][0] == 0, pred[0][0])
        axis[1].imshow(mask_, alpha=0.6)
        
        self.logger.experiment.add_figure(name, fig, self.global_step)


    def configure_optimizers(self):
        return [self.optimizer]

In [19]:
torch.manual_seed(0)
model = BrainTumorSegmentation()

In [20]:
checkpoint_callback = ModelCheckpoint(monitor="Val Dice", save_top_k=10, mode="min")

In [21]:
trainer = pl.Trainer(devices=1, accelerator="gpu", logger=TensorBoardLogger(save_dir="logs"), log_every_n_steps=1, callbacks = checkpoint_callback, max_epochs=75)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [22]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type     | Params
-------------------------------------
0 | model   | UNet     | 7.8 M 
1 | loss_fn | DiceLoss | 0     
-------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.127    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  dv = np.float64(self.norm.vmax) - np.float64(self.norm.vmin)
  vmid = np.float64(self.norm.vmin) + dv / 2
c:\Users\Flori\anaconda3\envs\udemy\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

c:\Users\Flori\anaconda3\envs\udemy\lib\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
