In [1]:
print("strawberry")

strawberry


In [8]:
import torch
import numpy as np

# dataset
from twaidata.torchdatasets.in_ram_ds import MRISegmentation2DDataset, MRISegmentation3DDataset
from torch.utils.data import DataLoader, random_split, ConcatDataset

# model
from trustworthai.models.uq_models.drop_UNet import UNet
#from trustworthai.models.uq_models.HyperMapp3r import HyperMapp3r
from trustworthai.models.base_models.torchUNet import UNet as StandardUNet

# augmentation and pretrain processing
from trustworthai.utils.augmentation.standard_transforms import RandomFlip, GaussianBlur, GaussianNoise, \
                                                            RandomResizeCrop, RandomAffine, \
                                                            NormalizeImg, PairedCompose, LabelSelect, \
                                                            PairedCentreCrop, CropZDim
# loss function
from trustworthai.utils.losses_and_metrics.tversky_loss import TverskyLoss
from trustworthai.utils.losses_and_metrics.misc_metrics import IOU
from trustworthai.utils.losses_and_metrics.dice import dice, DiceMetric
from trustworthai.utils.losses_and_metrics.dice_losses import DiceLoss, GeneralizedDiceLoss
from trustworthai.utils.losses_and_metrics.power_jaccard_loss import PowerJaccardLoss
from torch.nn import BCELoss, MSELoss, BCEWithLogitsLoss

# fitter
from trustworthai.utils.fitting_and_inference.fitters.basic_lightning_fitter import StandardLitModelWrapper
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import pytorch_lightning as pl

# misc
import os
import torch
import matplotlib.pyplot as plt
import torch
from torchinfo import summary

### Set the seed

In [9]:
seed = 3407
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

### define datasets and dataloaders

In [10]:
is3D = False

In [11]:
root_dir = "/disk/scratch/s2208943/ipdis/preprep/out_data/collated/"
wmh_dir = root_dir + "WMH_challenge_dataset/"
ed_dir = root_dir + "EdData/"

In [12]:
# domains = [
#             wmh_dir + d for d in ["Singapore", "Utrecht", "GE3T"]
#           ]

# domains = [
#             wmh_dir + d for d in ["Singapore", "Utrecht", "GE3T"]
#           ] + [
#             ed_dir + d for d in ["domainA", "domainB", "domainC", "domainD"]
#           ]


domains = [
            ed_dir + d for d in ["domainA", "domainB", "domainC", "domainD"]
          ]

In [13]:
# augmentation definintion
def get_transforms(is_3D):
    transforms = [
        LabelSelect(label_id=1),
        RandomFlip(p=0.5, orientation="horizontal"),
        # GaussianBlur(p=0.5, kernel_size=7, sigma=(.1, 1.5)),
        # GaussianNoise(p=0.2, mean=0, sigma=0.2),
        # RandomAffine(p=0.2, shear=(.1,3.)),
        # RandomAffine(p=0.2, degrees=5),
        # RandomResizeCrop(p=1., scale=(0.6, 1.), ratio=(3./4., 4./3.))
    ]
    if not is_3D:
        return PairedCompose(transforms)
    else:
        transforms.append(CropZDim(size=32, minimum=0, maximum=-1))
        return PairedCompose(transforms)

In [14]:
# function to do train validate test split
test_proportion = 0.1
validation_proportion = 0.2

def train_val_test_split(dataset, val_prop, test_prop, seed):
    # I think the sklearn version might be prefereable for determinism and things
    # but that involves fiddling with the dataset implementation I think....
    size = len(dataset)
    test_size = int(test_prop*size) 
    val_size = int(val_prop*size)
    train_size = size - val_size - test_size
    train, val, test = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(seed))
    return train, val, test

In [15]:
# load datasets
# this step is quite slow, all the data is being loaded into memory
if is3D:
    datasets_domains = [MRISegmentation3DDataset(root_dir, domain, transforms=get_transforms(is_3D=True)) for domain in domains]
else:
    datasets_domains = [MRISegmentation2DDataset(root_dir, domain, transforms=get_transforms(is_3D=False)) for domain in domains]

# split into train, val test datasets
datasets = [train_val_test_split(dataset, validation_proportion, test_proportion, seed) for dataset in datasets_domains]

# concat the train val test datsets
train_dataset = ConcatDataset([ds[0] for ds in datasets])
val_dataset = ConcatDataset([ds[1] for ds in datasets])
test_dataset = ConcatDataset([ds[2] for ds in datasets])

In [16]:
len(train_dataset), len(val_dataset), len(test_dataset)

(8743, 2497, 1248)

In [17]:
# define dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size = 16, shuffle=False, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

### setup model

In [18]:
import torch
import torch.nn as nn
from trustworthai.models.uq_models.uq_model import UQModel
from trustworthai.models.uq_models.uq_layers.uq_generic_layer import UQLayerWrapper
 
# various dropout and dropconnect layers
from trustworthai.models.uq_models.uq_layers.dropoutconnect import (
    UQDropout,
    UQDropout2d,
    UQDropout3d,
    UQGaussianDropout,
    UQGaussianDropout2d,
    UQGaussianDropout3d,
    UQDropConnect,
    UQDropConnect2d,
    UQDropConnect3d,
    UQGaussianConnect,
    UQGaussianConnect2d,
    UQGaussianConnect3d,
)

def normalization_layer(planes, norm='gn', gn_groups=None, dims=2, as_uq_layer=False):
    if as_uq_layer:
        wrapper = lambda l : UQLayerWrapper(l)
    else:
        wrapper = lambda x : x
    if dims == 2:
        if norm == "bn":
            return lambda : wrapper(nn.BatchNorm2d(planes))
        elif norm == "gn":
            return lambda : wrapper(nn.GroupNorm(gn_groups, planes)) # it does 2d auomatically?
        elif norm == "in":
            return lambda : wrapper(nn.InstanceNorm2d(planes))
        else:
            raise ValueError(f"norm type {norm} not supported, only 'bn', 'in', or 'gn' supported")
    elif dims == 3:
        if norm == "bn":
            return lambda : wrapper(nn.BatchNorm3d(planes))
        elif norm == "gn":
            return lambda : wrapper(nn.GroupNorm(gn_groups, planes)) # it does 3d automatically?
        elif norm == "in":
            return lambda : wrapper(nn.InstanceNorm3d(planes))
        else:
            raise ValueError(f"norm type {norm} not supported, only 'bn', 'in', or 'gn' supported")

# custom block for selecting dropout/drop connect methods and normalization methods
class Block(UQModel):
    def __init__(self, 
                 in_channels,
                 out_channels,
                 name,
                 dims=2, # 2 =2D, 3=3D,
                 kernel_size=3,
                 dropout_type="bernoulli",
                 dropout_p=0.1,
                 gaussout_mean=1, # NOTE THE PREDICT STEP CURRENTLY ONLY SUPPORTS MEAN = 1
                 dropconnect_type="bernoulli",
                 dropconnect_p=0.1,
                 gaussconnect_mean=1,
                 norm_type="bn", # batch norm, or instance 'in' or group 'gn'
                 use_uq_norm_layer=False,
                 use_multidim_dropout=True, # use 2d or 3d dropout instead of 1d dropout. applies to gaussian dropout too
                 use_multidim_dropconnect = True, # use 2d or 3d dropconnect instead of 1d dropconnect, applies to gaussian dropconnect too
                 groups=1,
                 gn_groups=4, # number of groups for group norm normalization.
                ):
        super().__init__()
        
        # determine convolution func
        if dims == 2:
            conv_f = nn.Conv2d
        elif dims == 3:
            conv_f = nn.Conv3d
        else:
            raise ValueError(f"values of dims of 2 or 3 (2D or 2D conv) are supported only, not {dims}")
            
        # determine dropout func
        if dropout_type:
            # standard dropout
            if dropout_type == "bernoulli":
                if use_multidim_dropout:
                    if dims == 2:
                        dropout_f = UQDropout2d
                    else:
                        dropout_f = UQDropout3d
                else:
                    dropout_f = UQDropout
                    
            # gaussian dropout    
            elif dropout_type == "gaussian":
                if use_multidim_dropout:
                    if dims == 2:
                        dropout_f = UQGaussianDropout2d
                    elif dims == 3:
                        dropout_f = UQGaussianDropout3d
                else:
                    dropout_f = UQGaussianDropout
            else:
                raise ValueError(f"dropout type {dropout_type} not supported, "
                                 "only 'bernoulli' or 'gaussian' are supported")
        # no dropout
        else:
            dropout_f = None
        
        # determine dropconnect function
        if dropconnect_type:
            # standard dropconnect
            if dropconnect_type == "bernoulli":
                if use_multidim_dropout:
                    if dims == 2:
                        dropconnect_f = UQDropConnect2d
                    else:
                        dropconnect_f = UQDropConnect3d
                else:
                    dropconnect_f = UQDropConnect
                    
            # gaussian dropout    
            elif dropconnect_type == "gaussian":
                if use_multidim_dropconnect:
                    if dims == 2:
                        dropconnect_f = UQGaussianConnect2d
                    elif dims == 3:
                        dropconnect_f = UQGaussianConnect3d
                else:
                    dropconnect_f = UQGaussianConnect
            else:
                raise ValueError(f"dropconnect type {dropconnect_type} not supported, "
                                 "only 'bernoulli' or 'gaussian' are supported")
        else:
            dropconnect_f = None
    
        # determine normalization type
        norm_layer = normalization_layer(out_channels, norm=norm_type, gn_groups=gn_groups, dims=dims, as_uq_layer=use_uq_norm_layer)

        # layers needed for the forward pass
        self.conv1 = conv_f(in_channels, out_channels, kernel_size, padding=1, bias=False)
        if dropconnect_f:
            if dropconnect_type == "bernoulli":
                self.convout1 = dropconnect_f(self.conv1, None, dropconnect_p)
            else:
                self.convout1 = dropconnect_f(self.conv1, None, gaussconnect_mean, dropconnect_p)
        else:
            self.convout1 = self.conv1

        if dropout_f:
            if dropout_type == "bernoulli":
                self.dropout1 = dropout_f(dropout_p)
            else:
                self.dropout1 = dropout_f(gaussout_mean, dropout_p)
        else:
            self.dropout1 = None

        self.norm1 = norm_layer()

        self.conv2 = conv_f(out_channels, out_channels, kernel_size, padding=1, bias=False)
        if dropconnect_f:
            if dropconnect_type == "bernoulli":
                self.convout2 = dropconnect_f(self.conv2, None, dropconnect_p)
            else:
                self.convout2 = dropconnect_f(self.conv2, None, gaussconnect_mean, dropconnect_p)
        else:
            self.convout2 = self.conv2

        if dropout_f:
            if dropout_type == "bernoulli":
                self.dropout2 = dropout_f(dropout_p)
            else:
                self.dropout2 = dropout_f(gaussout_mean, dropout_p)
        else:
            self.dropout2 = None

        self.norm2 = norm_layer()


        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.convout1(x)
        x = self.norm1(x)
        if self.dropout1:
            x = self.dropout1(x)
        x = self.relu(x)
        
        x = self.convout2(x)
        x = self.norm2(x)
        if self.dropout2:
            x = self.dropout2(x)
        x = self.relu(x)
        
        return x

In [39]:
from trustworthai.models.uq_models.uq_model import UQModel
class PyramidUNet(UQModel):

    def __init__(self, in_channels=3, out_channels=1, init_features=32, softmax=True,
                 kernel_size=3,
                 dropout_type="bernoulli",
                 dropout_p=0.1,
                 gaussout_mean=1, # NOTE THE PREDICT STEP CURRENTLY ONLY SUPPORTS MEAN = 1
                 dropconnect_type="bernoulli",
                 dropconnect_p=0.1,
                 gaussconnect_mean=1,
                 norm_type="bn", # batch norm, or instance 'in' or group 'gn'
                 use_uq_norm_layer=False,
                 use_multidim_dropout = True, # use 2d or 3d dropout instead of 1d dropout. applies to gaussian dropout too
                 use_multidim_dropconnect = True, # use 2d or 3d dropconnect instead of 1d dropconnect, applies to gaussian dropconnect too
                 groups=1,
                 gn_groups=4, # number of groups for group norm normalization.
                ):
        super().__init__()
                 
        block_params = {"dims":2, "kernel_size":kernel_size,"dropout_type":dropout_type,
                        "dropout_p":dropout_p,"gaussout_mean":gaussout_mean,
                        "dropconnect_p":dropconnect_p,"dropconnect_type":dropconnect_type,"gaussconnect_mean":gaussconnect_mean,
                        "norm_type":norm_type,"use_uq_norm_layer":use_uq_norm_layer,"use_multidim_dropout":use_multidim_dropout,
                        "use_multidim_dropconnect":use_multidim_dropconnect,"groups":groups,
                        "gn_groups":gn_groups,
                       }

        features = init_features
        self.encoder1 = Block(in_channels, features, name="enc1", **block_params)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = Block(features, features * 2, name="enc2",**block_params)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = Block(features * 2, features * 4, name="enc3", **block_params)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = Block(features * 4, features * 8, name="enc4", **block_params)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = Block(features * 8, features * 16, name="bottleneck", **block_params)

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = Block((features * 8) * 2, features * 8, name="dec4", **block_params)
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = Block((features * 4) * 2, features * 4, name="dec3", **block_params)
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = Block((features * 2) * 2, features * 2, name="dec2", **block_params)
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = Block(features * 2, features, name="dec1", **block_params)

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )
        self.do_softmax = softmax
        
        self.final_bottle = nn.Conv2d(in_channels=features*16, out_channels=out_channels, kernel_size=1)
        self.final4 = nn.Conv2d(in_channels=features*8, out_channels=out_channels, kernel_size=1)
        self.final3 = nn.Conv2d(in_channels=features*4, out_channels=out_channels, kernel_size=1)
        self.final2 = nn.Conv2d(in_channels=features*2, out_channels=out_channels, kernel_size=1)
        self.final1 = nn.Conv2d(in_channels=features, out_channels=out_channels,kernel_size=1)

    def forward(self, x, return_layer=None):
        #print(return_layer)
        if return_layer == None:
            return_layer = -1
        curr_layer = 0
        
        enc1 = self.encoder1(x)
        curr_layer += 1
        if curr_layer == return_layer:
            return enc1
        
        enc2 = self.encoder2(self.pool1(enc1))
        curr_layer += 1
        if curr_layer == return_layer:
            return enc2
        
        enc3 = self.encoder3(self.pool2(enc2))
        if curr_layer == return_layer:
            return enc3
        
        enc4 = self.encoder4(self.pool3(enc3))
        if curr_layer == return_layer:
            return enc4
        

        bottleneck = self.bottleneck(self.pool4(enc4))
        if curr_layer == return_layer:
            return bottleneck
        
        out_bln = self.final_bottle(bottleneck)

        dec4 = self.upconv4(bottleneck)
        if curr_layer == return_layer:
            return dec4
        
        out4 = self.final4(dec4)
        
        dec4 = torch.cat((dec4, enc4), dim=1)
        
        dec4 = self.decoder4(dec4)
        if curr_layer == return_layer:
            return dec4
        
        dec3 = self.upconv3(dec4)
        if curr_layer == return_layer:
            return dec3
        
        out3 = self.final3(dec3)
        
        dec3 = torch.cat((dec3, enc3), dim=1)
        
        dec3 = self.decoder3(dec3)
        if curr_layer == return_layer:
            return dec3
        
        dec2 = self.upconv2(dec3)
        if curr_layer == return_layer:
            return dec2
        
        out2 = self.final2(dec2)
        
        dec2 = torch.cat((dec2, enc2), dim=1)
        
        dec2 = self.decoder2(dec2)
        if curr_layer == return_layer:
            return dec2
        
        dec1 = self.upconv1(dec2)
        if curr_layer == return_layer:
            return dec1
        
        out1 = self.final1(dec1)
        
        dec1 = torch.cat((dec1, enc1), dim=1)
        
        dec1 = self.decoder1(dec1)
        if curr_layer == return_layer:
            return dec1
        
        out = self.conv(dec1)
        
        out1 = torch.nn.functional.interpolate(out1, out.shape[-2:], mode='bilinear')
        out2 = torch.nn.functional.interpolate(out2, out.shape[-2:], mode='bilinear')
        out3 = torch.nn.functional.interpolate(out3, out.shape[-2:], mode='bilinear')
        out4 = torch.nn.functional.interpolate(out4, out.shape[-2:], mode='bilinear')
        out_bln = torch.nn.functional.interpolate(out_bln, out.shape[-2:], mode='bilinear')
        
        return [out, out1, out2, out3, out4, out_bln]

In [40]:
in_channels = 3
out_channels = 1

if is3D:
    model = UNet3D(in_channels,
                 out_channels,
                 init_features=32,
                 kernel_size=3,
                 softmax=False,
                 dropout_type=None,
                 dropout_p=None,
                 gaussout_mean=None, 
                 dropconnect_type=None,
                 dropconnect_p=None,
                 gaussconnect_mean=None,
                 norm_type="bn", 
                 use_multidim_dropout = None,  
                 use_multidim_dropconnect = None, 
                 groups=None,
                 gn_groups=None, 
                )
    optimizer_params={"lr":2e-3}
    lr_scheduler_params={"step_size":100, "gamma":0.5}
else:
    model = PyramidUNet(in_channels,
                 out_channels,
                 kernel_size=3,
                 init_features=32,
                 softmax=False,
                 dropout_type="bernoulli",
                 dropout_p=0.1,
                 gaussout_mean=None, 
                 dropconnect_type="gaussian",
                 dropconnect_p=0.1,
                 gaussconnect_mean=1,
                 norm_type="bn", 
                 use_multidim_dropout = True,  
                 use_multidim_dropconnect = True, 
                 groups=None,
                 gn_groups=None, 
                )
    optimizer_params={"lr":1e-3}
    lr_scheduler_params={"step_size":30, "gamma":0.1}
    # encoder_features=[16, 32, 64, 128, 256] # orig: [16, 32, 64, 128, 256]
    # decoder_features=encoder_features[::-1][1:]
    # model = HyperMapp3r(dims=2, in_channels=in_channels, out_channels=1, softmax=False,
    #                    encoder_features=encoder_features,
    #              decoder_features=decoder_features)
    optimizer_params={"lr":1e-3}
    lr_scheduler_params={"step_size":5, "gamma":0.1}


In [41]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

import torch
import torch.nn as nn
import torch.nn.functional as F

class PyramidLitModelWrapper(pl.LightningModule):
    def __init__(self, model, loss=F.cross_entropy, logging_metrics=None, optimizer_params={"lr":1e-3}, lr_scheduler_params={"step_size":30, "gamma":0.1}, is_uq_model=False):
        super().__init__()
        """
        logging metrics are (name, metric function)
        """
        self.model = model
        self.loss = loss
        self.logging_metrics = nn.ModuleList(logging_metrics)
        self.optim_params = optimizer_params
        self.lr_scheduler_params = lr_scheduler_params
        self.is_uq_model = False

        
    def forward(self, x, **kwargs):
        return self.model(x, **kwargs)
    
    def configure_optimizers(self):
        # optimizer and schedulers go in the configure optimizers hook
        optimizer = torch.optim.Adam(self.parameters(), **self.optim_params)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **self.lr_scheduler_params)
        return [optimizer], [lr_scheduler]
    
    def training_step(self, batch, batch_idx):
        """
        lightning automates the training loop, 
        does epoch, back_tracking, optimizers and schedulers,
        and metric reduction.
        we just define how we want to process a single batch. 
        we can optionally pass optimizer_idx if we want to define multiple optimizers within the configure_optimizers
        hook, and I presume we can add our own parameters also to functions?
        """
        
        if self.is_uq_model:
            self.model.set_applyfunc(True)
        
        X, y = batch
        y_hats = self(X)
        loss = 0
        for y_hat in y_hats:
            loss += self.loss(y_hat, y)
        self.log("train_loss", loss)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        """
        note: call trainer.validate() automatically loads the best checkpoint if checkpointing was enabled during fitting
        well yes I want to enable checkpointing but will deal with that later.
        also it does stuff like model.eval() and torch.no_grad() automatically which is nice.
        I will need a custom eval thing to do my dropout estimation but can solve that later too.
        """
        if self.is_uq_model:
            self.model.set_applyfunc(False)
        
        X, y = batch
        y_hats = self(X)
        val_loss = 0
        for y_hat in y_hats:
            val_loss += self.loss(y_hat, y)
        self.log("val_loss", val_loss)
        
    def test_step(self, batch, batch_idx):
        """
        we would need to directly call this function using the trainer
        """
        
        if self.is_uq_model:
            self.model.set_applyfunc(False)
        
        X, y = batch
        y_hats = self(X)
        test_loss = 0
        for y_hat in y_hats:
            test_loss += self.loss(y_hat, y)
        self.log("test_loss", test_loss)
        
    def predict_step(self, batch, batch_idx):
        """
        just for making predictions as opposed to collecting metrics etc
        note to use this, we just call .predict(dataloader) and it then automates the look
        these functions are for a single batch. Nice.
        """
        X, y = batch
        pred = self(X)
        return pred

In [42]:
print("WARNING: FOR SOME REASON PARAMETERS INSIDE DROPCONNECT BLOCKS ARE NOT REGISTERING THEIR NUMBER OF PARAMS CORRECTLY")



In [43]:
summary(model, (1, 3, 128, 128))

Layer (type:depth-idx)                   Output Shape              Param #
PyramidUNet                              --                        --
├─Block: 1-1                             [1, 32, 128, 128]         --
│    └─UQGaussianConnect2d: 2-1          [1, 32, 128, 128]         --
│    └─BatchNorm2d: 2-2                  [1, 32, 128, 128]         64
│    └─UQDropout2d: 2-3                  [1, 32, 128, 128]         --
│    └─ReLU: 2-4                         [1, 32, 128, 128]         --
│    └─UQGaussianConnect2d: 2-5          [1, 32, 128, 128]         --
│    └─BatchNorm2d: 2-6                  [1, 32, 128, 128]         64
│    └─UQDropout2d: 2-7                  [1, 32, 128, 128]         --
│    └─ReLU: 2-8                         [1, 32, 128, 128]         --
├─MaxPool2d: 1-2                         [1, 32, 64, 64]           --
├─Block: 1-3                             [1, 64, 64, 64]           --
│    └─UQGaussianConnect2d: 2-9          [1, 64, 64, 64]           --
│    └─BatchNor

In [44]:
loss = GeneralizedDiceLoss(normalization='sigmoid')

In [45]:
model = PyramidLitModelWrapper(model, loss, 
                                optimizer_params=optimizer_params,
                                lr_scheduler_params=lr_scheduler_params,
                                is_uq_model=True
                               )

In [46]:
checkpoint_dir = "/disk/scratch/s2208943/results/"
strategy = None
# strategy = "deepspeed_stage_2"
# strategy = "dp"
#strategy = "deepspeed_stage_2_offload"

accelerator="gpu"
devices=1
max_epochs=1000
precision = 16

checkpoint_callback = ModelCheckpoint(checkpoint_dir, save_top_k=2, monitor="val_loss")
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=100, verbose="False", mode="min", check_finite=True)
trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stop_callback],
    accelerator=accelerator,
    devices=devices,
    max_epochs=max_epochs,
    strategy=strategy,
    precision=precision,
    default_root_dir=checkpoint_dir
)


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### train

In [47]:
trainer.fit(model, train_dataloader, val_dataloader)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                | Params
--------------------------------------------------------
0 | model           | PyramidUNet         | 14.8 M
1 | loss            | GeneralizedDiceLoss | 0     
2 | logging_metrics | ModuleList          | 0     
--------------------------------------------------------
14.8 M    Trainable params
0         Non-trainable params
14.8 M    Total params
29.649    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 4.830


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.961 >= min_delta = 0.0. New best score: 3.869


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.133 >= min_delta = 0.0. New best score: 3.736


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.178 >= min_delta = 0.0. New best score: 3.558


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.124 >= min_delta = 0.0. New best score: 3.435


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 3.414


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.049 >= min_delta = 0.0. New best score: 3.365


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.043 >= min_delta = 0.0. New best score: 3.322


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [79]:
trainer.validate(model, val_dataloader, ckpt_path="best")

Restoring states from the checkpoint path at /disk/scratch/s2208943/results/epoch=9-step=1370.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /disk/scratch/s2208943/results/epoch=9-step=1370.ckpt


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val_loss            3.3639798164367676
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 3.3639798164367676}]

In [80]:
trainer.test(model, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss            3.596513509750366
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 3.596513509750366}]

### validate each layer separately

In [64]:
class PyramidLitModelWrapperLayerSel(StandardLitModelWrapper):
    def __init__(self, layer, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.layer = layer
        
    def forward(self, x, **kwargs):
        outs = self.model(x, **kwargs)
        return outs[self.layer]

In [59]:
x, y = next(iter(val_dataloader))

In [62]:
v = model2(x)

In [63]:
v.shape

torch.Size([32, 1, 224, 160])

In [85]:
ckpt = "pyramid_sigmoid.ckpt"
model2 = PyramidLitModelWrapperLayerSel.load_from_checkpoint(
    ckpt,
    model=model.model, 
    layer=1,
    loss=loss, 
    is_uq_model=True,
)

In [86]:
trainer.validate(model2, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val_loss            0.4824848473072052
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.4824848473072052}]