In [1]:
!git clone https://github.com/aqntks/brain-seg

Cloning into 'brain-seg'...
remote: Enumerating objects: 77, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (54/54), done.[K
remote: Total 77 (delta 32), reused 61 (delta 18), pack-reused 0[K
Unpacking objects: 100% (77/77), done.


In [11]:
%cd brain-seg
!mkdir data

mkdir: cannot create directory ‘data’: File exists


In [3]:
!pip install -r requirements.txt

Collecting monai
  Downloading monai-0.7.0-202109240007-py3-none-any.whl (650 kB)
[K     |████████████████████████████████| 650 kB 5.4 MB/s 
Installing collected packages: monai
Successfully installed monai-0.7.0


In [13]:
import os
import time
import torch
import torch.nn as nn

from collections import OrderedDict

from monai.data import decollate_batch
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    EnsureType,
)

import data_loader
from utils import pre_visualize, train_graph

In [None]:
train_ds, val_ds, train_loader, val_loader = data_loader.brats_brain('data/')

In [None]:
pre_visualize(val_ds)

In [None]:
device = torch.device("cuda:0")

max_epochs = 300
val_interval = 1
VAL_AMP = True

In [6]:
class UNet3d(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet3d, self).__init__()

        features = init_features
        self.encoder1 = UNet3d._block(in_channels, features, name="enc1")  # 1/32/32
        self.pool1 = nn.MaxPool3d((2, 2, 2))
        self.encoder2 = UNet3d._block(features, features * 2, name="enc2")  # 32/64/64
        self.pool2 = nn.MaxPool3d((2, 2, 2))
        self.encoder3 = UNet3d._block(features * 2, features * 4, name="enc3")  # 64/128/128
        self.pool3 = nn.MaxPool3d((2, 2, 2))
        self.encoder4 = UNet3d._block(features * 4, features * 8, name="enc4")  # 128/256/256
        self.pool4 = nn.MaxPool3d((2, 2, 2))

        self.bottleneck = UNet3d._block(features * 8, features * 16, name="bottleneck")  # 256/512/512

        self.upconv4 = nn.ConvTranspose3d(features * 16, features * 8, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=True)
        self.decoder4 = UNet3d._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose3d(features * 8, features * 4, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=True)
        self.decoder3 = UNet3d._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose3d(features * 4, features * 2, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=True)
        self.decoder2 = UNet3d._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose3d(features * 2, features, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=True)
        self.decoder1 = UNet3d._block(features * 2, features, name="dec1")

        self.conv = nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=(1, 1, 1), padding=(0, 0, 0), bias=True)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv3d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=(3, 3, 3),
                            stride=(1, 1, 1),
                            padding=(1, 1, 1),
                            bias=True,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm3d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv3d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=(3, 3, 3),
                            stride=(1, 1, 1),
                            padding=(1, 1, 1),
                            bias=True,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm3d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [None]:
model = UNet3d(in_channels=4, out_channels=3).to(device)

loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose(
    [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]
)

In [7]:
def train(train_ds, val_ds, train_loader, val_loader):
    def inference(input):

        def _compute(input):
            return sliding_window_inference(
                inputs=input,
                roi_size=(240, 240, 160),
                sw_batch_size=1,
                predictor=model,
                overlap=0.5,
            )

        if VAL_AMP:
            with torch.cuda.amp.autocast():
                return _compute(input)
        else:
            return _compute(input)

    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True

    # train
    best_metric = -1
    best_metric_epoch = -1
    best_metrics_epochs_and_time = [[], [], []]
    epoch_loss_values = []
    metric_values = []
    metric_values_tc = []
    metric_values_wt = []
    metric_values_et = []

    total_start = time.time()
    for epoch in range(max_epochs):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step_start = time.time()
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            epoch_loss += loss.item()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}"
                f", train_loss: {loss.item():.4f}"
                f", step time: {(time.time() - step_start):.4f}"
            )
        lr_scheduler.step()
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():

                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    val_outputs = inference(val_inputs)
                    val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                    dice_metric(y_pred=val_outputs, y=val_labels)
                    dice_metric_batch(y_pred=val_outputs, y=val_labels)

                metric = dice_metric.aggregate().item()
                metric_values.append(metric)
                metric_batch = dice_metric_batch.aggregate()
                metric_tc = metric_batch[0].item()
                metric_values_tc.append(metric_tc)
                metric_wt = metric_batch[1].item()
                metric_values_wt.append(metric_wt)
                metric_et = metric_batch[2].item()
                metric_values_et.append(metric_et)
                dice_metric.reset()
                dice_metric_batch.reset()

                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    best_metrics_epochs_and_time[0].append(best_metric)
                    best_metrics_epochs_and_time[1].append(best_metric_epoch)
                    best_metrics_epochs_and_time[2].append(time.time() - total_start)
                    torch.save(
                        model.state_dict(),
                        os.path.join("weights/", "best_metric_model.pth"),
                    )
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                    f"\nbest mean dice: {best_metric:.4f}"
                    f" at epoch: {best_metric_epoch}"
                )
        print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
    total_time = time.time() - total_start

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")

    train_graph(epoch_loss_values, val_interval, metric_values, metric_values_tc, metric_values_wt, metric_values_et)

In [None]:
train(train_ds, val_ds, train_loader, val_loader)