In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from trustworthai.models.building_blocks.hypermapp3r_blocks import *
import torchvision.transforms.functional as TF

class HyperMapp3r(nn.Module):
    def __init__(self,dropout_p = 0., encoder_sizes=[16,32,64,128,256], inchannels=3, outchannels=2, p_unet_hook=False):
        super().__init__()
        self.out_channels = outchannels
        self.dropout_p = dropout_p
        self.p_unet_hook = p_unet_hook
        
        # input layer
        self.conv_first = nn.Conv2d(inchannels, encoder_sizes[0], kernel_size=5, stride=1, dilation=1, padding='same')
        self.activ = nn.ReLU()
        
        # encoder section
        l = len(encoder_sizes) - 1
        self.down_blocks = nn.ModuleList([
            DownBlock(encoder_sizes[i], encoder_sizes[i+1]) for i in range(0, l)
        ])
        
        self.res_blocks = nn.ModuleList([
            HmResBlock(c, dropout_p) for c in encoder_sizes
        ])
        
        # decoder section
        self.upsample_blocks = nn.ModuleList([
            HmUpsampBlock(c) for c in encoder_sizes[:-1][::-1]
        ])
        
        self.feature_blocks = nn.ModuleList([
            HmFeatureBlock(encoder_sizes[l - i]) for i in range(l-1)
        ])
        
        
        # multi-scale feature section
        self.ms_feature_layers = nn.ModuleList([
            nn.Conv2d(encoder_sizes[2], encoder_sizes[1], 3, padding='same'),
            nn.Conv2d(encoder_sizes[1], encoder_sizes[1], 3, padding='same'),
            nn.Conv2d(encoder_sizes[1], encoder_sizes[1], 3, padding='same')
        ])
        
        
        # output layer
        self.last_1 = nn.Conv2d(encoder_sizes[1], encoder_sizes[1], 3, padding='same')
        self.last_2 = nn.Conv2d(encoder_sizes[1], encoder_sizes[1], 1)
        self.last_3 = nn.Conv2d(encoder_sizes[1], outchannels, 1)
        self.last_norm = nn.InstanceNorm2d(encoder_sizes[1])
        self.num_out_features = encoder_sizes[1]
        
    def forward(self, x):
        # input layer
        out = self.activ(self.conv_first(x))
        # print(out.shape)
        
        skips = []
        
        # encoder section
        out = self.res_blocks[0](out)
        # print(out.shape)
        skips.append(out)
        for i in range(len(self.res_blocks) - 1):
            out = self.down_blocks[i](out)
            out = self.res_blocks[i+1](out)
            # print("loop: ", out.shape)
            skips.append(out)
        
        # decoder section
        ml_features = []
        out = skips.pop()
        for i in range(len(self.upsample_blocks)):
            # print("dec")
            if i > 0:
                sk = skips.pop()
                sk = TF.center_crop(sk, out.shape[-2:])
                out = torch.cat([out, sk], dim=1)
                out = self.feature_blocks[i-1](out)
            
            if i > 1:
                ml_features.append(self.ms_feature_layers[i-2](out))
                
            out = self.upsample_blocks[i](out)
        
        # final layers
        sk = skips.pop()
        sk = TF.center_crop(sk, out.shape[-2:])
        out = torch.cat([out, sk], dim=1)
        out = self.last_norm(self.activ(self.last_1(out)))
        
        # multiscale feature section
        ml_features = [out] + ml_features
        
        for mlf in ml_features:
            print("mlf: ", mlf.shape)
        
        ml_features = [F.interpolate(mf, size=x.shape[-2:], mode='bilinear') for mf in ml_features]
        #combined_features = torch.cat(ml_features, dim=1)
        combined_features = ml_features[0]
        for mlf in ml_features[1:]:
            combined_features += mlf
        print(combined_features.shape)
        
        out = self.activ(self.last_2(combined_features))
        
        if self.p_unet_hook:
            return out
        
        out = self.last_3(out)
        
        return out

    
    @property
    def output_channels(self):
        return self.out_channels

In [35]:
encoder_sizes=[16,32,64,128,256]
base_model = HyperMapp3r(
    dropout_p=0.,
    encoder_sizes=encoder_sizes,
    inchannels=3,
    outchannels=2,
)

In [38]:
import sys

In [41]:
a = torch.randn(250, 20, 10, 2, 224, 160)

In [42]:
sys.getsizeof(a.storage())

14336000048

In [36]:
inp = torch.randn(12, 3, 224, 160)
with torch.no_grad():
    out = base_model.cuda()(inp.cuda())

mlf:  torch.Size([12, 32, 224, 160])
mlf:  torch.Size([12, 32, 56, 40])
mlf:  torch.Size([12, 32, 112, 80])
torch.Size([12, 32, 224, 160])


In [27]:
out.shape

torch.Size([12, 2, 224, 160])

In [28]:
from torchinfo import summary

In [14]:
summary(base_model, (12, 3, 224, 160))

Layer (type:depth-idx)                   Output Shape              Param #
HyperMapp3r                              [12, 2, 224, 160]         9,248
├─Conv2d: 1-1                            [12, 16, 224, 160]        1,216
├─ReLU: 1-2                              [12, 16, 224, 160]        --
├─ModuleList: 1-11                       --                        (recursive)
│    └─HmResBlock: 2-1                   [12, 16, 224, 160]        --
│    │    └─Conv2d: 3-1                  [12, 16, 224, 160]        12,560
│    │    └─InstanceNorm2d: 3-2          [12, 16, 224, 160]        --
│    │    └─ReLU: 3-3                    [12, 16, 224, 160]        --
│    │    └─Dropout2d: 3-4               [12, 16, 224, 160]        --
│    │    └─Conv2d: 3-5                  [12, 16, 224, 160]        2,320
│    │    └─InstanceNorm2d: 3-6          [12, 16, 224, 160]        --
│    │    └─ReLU: 3-7                    [12, 16, 224, 160]        --
├─ModuleList: 1-10                       --                    

In [4]:
import torch
import torch.nn as nn
from trustworthai.models.stochastic_wrappers.ssn.LowRankMVCustom import LowRankMultivariateNormalCustom
from trustworthai.models.stochastic_wrappers.ssn.ReshapedDistribution import ReshapedDistribution
from trustworthai.models.uq_model import UncertaintyQuantificationModel
from tqdm import tqdm
import torch.distributions as td

class DeepSSN(UncertaintyQuantificationModel):
    def __init__(self, base_model, rank, diagonal, epsilon, intermediate_channels, out_channels, dims):
        super().__init__()
        self.base_model = base_model
        self.ssn_rank = rank
        self.ssn_diagonal = diagonal
        self.ssn_epsilon = epsilon
        self.ssn_num_classes = out_channels
        
        self.lrelu = nn.LeakyReLU(0.01)
        
        self.mean_l = nn.Conv2d(intermediate_channels, out_channels, kernel_size = (1,) *  dims, padding='same')
        self.log_cov_diag_l = nn.Conv2d(intermediate_channels, out_channels, kernel_size = (1,) * dims, padding='same')
        self.cov_factor_l = nn.Conv2d(intermediate_channels, out_channels * self.ssn_rank, kernel_size = (1,) * dims, padding='same')
        
    def forward(self, x):
        logits = self.lrelu(self.base_model(x))

        batch_size = logits.shape[0]
        event_shape = (self.ssn_num_classes,) + logits.shape[2:]
        
        mean = self.mean_l(logits)
        mean = mean.view((batch_size, -1))
        
        cov_diag = self.log_cov_diag_l(logits).exp() + self.ssn_epsilon
        cov_diag = cov_diag.view((batch_size, -1))
        
        cov_factor = self.cov_factor_l(logits)
        cov_factor = cov_factor.view((batch_size, self.ssn_rank, self.ssn_num_classes, -1))
        cov_factor = cov_factor.flatten(2,3)
        cov_factor = cov_factor.transpose(1,2)
        
        # covariance tends to blow up to infinity, hence set to 0 outside the ROI
        mask = x[:,1]
        mask = mask.unsqueeze(1).expand((batch_size, self.ssn_num_classes) + mask.shape[1:]).reshape(batch_size, -1)
        cov_factor = cov_factor * mask.unsqueeze(-1)
        cov_diag = cov_diag * mask + self.ssn_epsilon
    
        
        if self.ssn_diagonal:
            base_distribution = td.Independent(td.Normal(loc=mean, scale=torch.sqrt(cov_diag)), 1)
        else:
            try:
                base_distribution = LowRankMultivariateNormalCustom(loc=mean, cov_factor=cov_factor, cov_diag=cov_diag)
            except Exception as e:
                print("was thrown: ", e)
                print('hmm: Covariance became non invertible using independent normals for this batch!')
                print("cov diag okay: ", torch.sum(cov_diag <=0))
                print("sqrt cov diag okay: ", torch.sum(torch.sqrt(cov_diag) <=0))
                
                try:
                    base_distribution = td.Independent(td.Normal(loc=mean, scale=torch.sqrt(cov_diag)),1)
                except Exception as e:
                    print("second fail: ", e)
                    print(torch.min(torch.sqrt(cov_diag), torch.max(torch.sqrt(cov_diag))))
        
        distribution = ReshapedDistribution(base_distribution, event_shape)
        
        shape = (batch_size,) + event_shape
        logit_mean_view = mean.view(shape).detach()
        cov_diag_view = cov_diag.view(shape).detach()
        cov_factor_view = cov_factor.transpose(2,1).view((batch_size, self.ssn_num_classes * self.ssn_rank) + event_shape[1:]).detach()
        
        output_dict = {
            'logit_mean':logit_mean_view,
            'cov_diag':cov_diag_view,
            'cov_factor':cov_factor_view,
            'distribution':distribution,
        }
        
        return output_dict
    
    def mean(self, x, temperature=1):
        return self(x)['logit_mean'] / temperature
    
    def _samples_from_dist(self, dist, num_samples, rsample=True, symmetric=True):
        if symmetric:
            assert num_samples % 2 == 0
            num_samples = num_samples // 2
            
        if rsample:
            samples = dist.rsample((num_samples,))
        else:
            samples = dist.sample((num_samples,))
        
        if symmetric:
            mean = dist.mean
            samples = samples - mean
            return torch.cat([samples, -samples]) + mean
        else:
            return samples
    
    def mean_and_sample(self, x, num_samples, rsample=True, temperature=1):
        # NOTE: this does temperature scaling!!
        t = temperature
        out = self(x)
        mean = out['logit_mean']
        dist = out['distribution']
        samples = self._samples_from_dist(dist, num_samples, rsample)
        return mean/t, samples/t
        
            

In [5]:
from trustworthai.utils.data_preprep.dataset_pipelines import load_data
from trustworthai.utils.fitting_and_inference.fitters.basic_lightning_fitter import StandardLitModelWrapper
from trustworthai.utils.fitting_and_inference.get_trainer import get_trainer
from trustworthai.models.core_models.Hypermapp3r import HyperMapp3r
from trustworthai.models.stochastic_wrappers.ssn.ssn import SSN
from torchinfo import summary
import torch

In [2]:
# get the 2d axial slice dataloaders
train_dl, val_dl, test_dl = load_data(
    dataset="ed", 
    test_proportion=0.15, 
    validation_proportion=0.15,
    seed=3407,
    empty_proportion_retained=0.1,
    batch_size=32,
    dataloader2d_only=True,
    cross_validate=True,
    cv_split=0
)

NameError: name 'load_data' is not defined

In [None]:
train_dl.dataset[0][0].sum(), test_dl.dataset[0][0].sum(), val_dl.dataset[0][0].sum()

In [57]:
encoder_sizes

[16, 32, 64, 128, 256]

In [58]:
#summary(base_model, (12, 3, 224, 160))

In [59]:
model_raw = SSN(
    base_model=base_model,
    rank=15,
    diagonal=False,
    epsilon=1e-5,
    intermediate_channels=base_model.output_channels,
    out_channels=2,
    dims=2
    ).cuda()

In [60]:
optimizer_params={"lr":2e-4, "weight_decay":0.0001}
optimizer = torch.optim.Adam
lr_scheduler_params={"milestones":[1000], "gamma":0.5}
lr_scheduler_constructor = torch.optim.lr_scheduler.MultiStepLR

In [61]:
loss = SSNCombinedDiceXentLoss(
    empty_slice_weight=0.5,
    mc_samples=10,
    dice_factor=5,
    xent_factor=0.01,
    sample_dice_coeff=0.05,
)

In [62]:
model = StandardLitModelWrapper(model_raw, loss, 
                                    logging_metric=lambda : None,
                                    optimizer_params=optimizer_params,
                                    lr_scheduler_params=lr_scheduler_params,
                                    optimizer_constructor=optimizer,
                                    lr_scheduler_constructor=lr_scheduler_constructor
                                   )

In [63]:
trainer = get_trainer(max_epochs=100, results_dir="/home/s2208943/ipdis/results/test/ssn_test/", early_stop_patience=15)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [64]:
trainer.fit(model, train_dl, val_dl)

Missing logger folder: /home/s2208943/ipdis/results/test/ssn_test/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | SSN                     | 6.3 M 
1 | loss  | SSNCombinedDiceXentLoss | 0     
--------------------------------------------------
6.3 M     Trainable params
0         Non-trainable params
6.3 M     Total params
25.240    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 11.096


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 3.011 >= min_delta = 0.01. New best score: 8.085


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 1.136 >= min_delta = 0.01. New best score: 6.949


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.822 >= min_delta = 0.01. New best score: 6.127


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.233 >= min_delta = 0.01. New best score: 5.895


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.445 >= min_delta = 0.01. New best score: 5.449


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.074 >= min_delta = 0.01. New best score: 5.376


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.057 >= min_delta = 0.01. New best score: 5.319
