In [1]:
print("strawberry")

strawberry


In [2]:
import torch
import numpy as np

# dataset
from twaidata.torchdatasets.in_ram_ds import MRISegmentation2DDataset, MRISegmentation3DDataset
from torch.utils.data import DataLoader, random_split, ConcatDataset

# model
from trustworthai.models.uq_models.initial_variants.HyperMapp3r_deterministic import HyperMapp3r
from trustworthai.models.uq_models.initial_variants.HyperMapp3r_DDU import HyperMapp3rDDU
from trustworthai.models.uq_models.initial_variants.HyperMapp3r_SSN import HyperMapp3rSSN


# augmentation and pretrain processing
from trustworthai.utils.augmentation.standard_transforms import RandomFlip, GaussianBlur, GaussianNoise, \
                                                            RandomResizeCrop, RandomAffine, \
                                                            NormalizeImg, PairedCompose, LabelSelect, \
                                                            PairedCentreCrop, CropZDim
# loss function
from trustworthai.utils.losses_and_metrics.per_individual_losses import (
    dice_loss,
    log_cosh_dice_loss,
    TverskyLoss,
    FocalTverskyLoss,
    DiceLossMetric
)
from torch.nn import BCELoss, MSELoss, BCEWithLogitsLoss

# fitter
from trustworthai.utils.fitting_and_inference.fitters.basic_lightning_fitter import StandardLitModelWrapper
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import pytorch_lightning as pl

# misc
import os
import torch
import matplotlib.pyplot as plt
import torch
from torchinfo import summary
import torch.distributions as td

In [3]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f5c41ea99d0>

### Set the seed

In [4]:
seed = 3407
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

### define datasets and dataloaders

In [5]:
is3D = False

In [6]:
root_dir = "/disk/scratch/s2208943/ipdis/preprep/out_data/collated/"
#root_dir = "/media/benp/NVMEspare/datasets/preprocessing_attempts/local_results/collated/"
wmh_dir = root_dir + "WMH_challenge_dataset/"
ed_dir = root_dir + "EdData/"

In [7]:
# domains = [
#             wmh_dir + d for d in ["Singapore", "Utrecht", "GE3T"]
#           ]

# domains = [
#             wmh_dir + d for d in ["Singapore", "Utrecht", "GE3T"]
#           ] + [
#             ed_dir + d for d in ["domainA", "domainB", "domainC", "domainD"]
#           ]


domains = [
            ed_dir + d for d in ["domainA", "domainB", "domainC", "domainD"]
          ]

In [8]:
# augmentation definintion
def get_transforms(is_3D):
    transforms = [
        LabelSelect(label_id=1),
        RandomFlip(p=0.5, orientation="horizontal"),
        # GaussianBlur(p=0.5, kernel_size=7, sigma=(.1, 1.5)),
        # GaussianNoise(p=0.2, mean=0, sigma=0.2),
        # RandomAffine(p=0.2, shear=(.1,3.)),
        # RandomAffine(p=0.2, degrees=5),
        #RandomResizeCrop(p=1., scale=(0.6, 1.), ratio=(3./4., 4./3.))
        RandomResizeCrop(p=1., scale=(0.3, 0.5), ratio=(3./4., 4./3.)) # ssn
    ]
    if not is_3D:
        transforms.append(lambda x, y: (x, y.squeeze().type(torch.long)))
        return PairedCompose(transforms)
    else:
        transforms.append(CropZDim(size=32, minimum=0, maximum=-1))
        transforms.append(lambda x, y: (x, y.squeeze().type(torch.long)))
        return PairedCompose(transforms)

In [9]:
# function to do train validate test split
test_proportion = 0.1
validation_proportion = 0.2

def train_val_test_split(dataset, val_prop, test_prop, seed):
    # I think the sklearn version might be prefereable for determinism and things
    # but that involves fiddling with the dataset implementation I think....
    size = len(dataset)
    test_size = int(test_prop*size) 
    val_size = int(val_prop*size)
    train_size = size - val_size - test_size
    train, val, test = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(seed))
    return train, val, test

In [10]:
# load datasets
# this step is quite slow, all the data is being loaded into memory
if is3D:
    datasets_domains = [MRISegmentation3DDataset(root_dir, domain, transforms=get_transforms(is_3D=True)) for domain in domains]
else:
    datasets_domains = [MRISegmentation2DDataset(root_dir, domain, transforms=get_transforms(is_3D=False)) for domain in domains]

# split into train, val test datasets
datasets = [train_val_test_split(dataset, validation_proportion, test_proportion, seed) for dataset in datasets_domains]

# concat the train val test datsets
train_dataset = ConcatDataset([ds[0] for ds in datasets])
val_dataset = ConcatDataset([ds[1] for ds in datasets])
test_dataset = ConcatDataset([ds[2] for ds in datasets])

In [11]:
len(train_dataset), len(val_dataset), len(test_dataset)

(8743, 2497, 1248)

In [12]:
# define dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size = 16, shuffle=False, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

### setup model

In [13]:
in_channels = 3
out_channels = 1

if is3D:
    pass
else:
    # encoder_features=[16, 32, 64, 128, 256] # orig: [16, 32, 64, 128, 256]
    # decoder_features=encoder_features[::-1][1:]
    model_raw = HyperMapp3r(dims=2,
                 in_channels=3,
                 out_channels=2,
                 encoder_features=[16, 32, 64, 128, 256],
                 decoder_features=[128, 64, 32, 16],
                 softmax=False,
                 up_res_blocks=False,
                 block_params={
                     "dropout_p":0.1,
                     "norm_type":"in", 
                     "dropout_both_layers":False,
                 }
                   )
    
    
    optimizer_params={"lr":2e-3}
    lr_scheduler_params={"step_size":20, "gamma":0.1}
    # optimizer_params={"lr":2e-3, "momentum":0.6}
    # optimizer = torch.optim.RMSprop
    # lr_scheduler_params={"milestones":[10,100,200], "gamma":0.5}
    # lr_scheduler_constructor = torch.optim.lr_scheduler.MultiStepLR


In [14]:
### functions for getting code from the evidential distribution. Nice.
scale = 8
def relu_evidence(logits):
    return torch.nn.functional.relu(logits)

def exp_evidence(logits):
    return logits.clamp(-10, 10).exp()

def softplus_evidence(logits):
    return torch.nn.functional.softplus(logits)


def get_S(evidence):
    # evidence is shape [b, c, <dims>], we want an S per pixel, so reduce on dim 1
    S = (evidence + 1.).sum(dim = 1).unsqueeze(1)
    return S

def get_bk(evidence, S):
    return evidence / S

def get_uncert(K, S):
    return K / S

def get_alpha(evidence):
    return evidence + 1.

def get_one_hot_target(K, target):
    one_hot = torch.zeros((target.shape[0], K, *target.shape[1:])).to(target.device)
    one_hot[:,0] = 1 - target
    one_hot[:,1] = target
    
    return one_hot

def get_mean_p_hat(alpha, S):
    return alpha / S

In [15]:
z = torch.rand(1, 1, 224, 160)

In [16]:
map = torch.nn.functional.max_pool2d(z, kernel_size=8, stride=8)

In [17]:
map.shape

torch.Size([1, 1, 28, 20])

In [18]:
def digamma(values):
    return torch.digamma(values).clamp(-100,100)

def get_alpha_modified(alpha, one_hot_target):
    return one_hot_target + ((1 - one_hot_target) * alpha)

def xent_bayes_risk(alpha, S, one_hot_target):
    digamma_S = torch.digamma(S).expand(alpha.shape)
    digamma_alpha = torch.digamma(alpha)
    
    p_ij = one_hot_target * (digamma_S - digamma_alpha)
    per_pixel_loss =  torch.sum(p_ij, dim=1)
    
    return torch.sum(per_pixel_loss, dim=(-2,-1)).mean() # reduction = mean


def mse_bayes_risk(mean_p_hat, S, one_hot_target):
    l_err = torch.nn.functional.mse_loss(mean_p_hat, one_hot_target, reduction='none')
    
    l_var = mean_p_hat * (1.- mean_p_hat) / (S + 1.)
    
    return (l_err + l_var).sum(dim=(-2,-1)).mean()


def KL(alpha_modified):
    K = alpha_modified.shape[1]
    beta = torch.ones((1, *alpha_modified.shape[1:])).to(alpha_modified.device)
    sum_alpha = alpha_modified.sum(dim=1)
    sum_beta = beta.sum(dim=1)
    
    lnB = torch.lgamma(sum_alpha) - torch.lgamma(alpha_modified).sum(dim=1)
    lnB_uni = torch.lgamma(beta).sum(dim=1) - torch.lgamma(sum_beta)
    
    dg0 = torch.digamma(sum_alpha).unsqueeze(1)
    dg1 = torch.digamma(alpha_modified)
    
    diff = (alpha_modified - beta)
    v = (dg1 - dg0)
    
    # print(sum_alpha.shape)
    # print(sum_beta.shape)
    # print(diff.shape)
    # print(v.shape)
    
    rhs = torch.sum(diff * v, dim=1)
    
    kl = lnB + lnB_uni + rhs
    
    return torch.sum(kl, dim=(-2,-1)).mean()
    
    
def combined_loss(logits, target):
    # compute the scaled down map
    logits = torch.nn.functional.max_pool2d(logits, kernel_size=scale, stride=scale)
    target = torch.nn.functional.max_pool2d(target.type(torch.float32), kernel_size=scale, stride=scale)
    target = target.type(torch.long)
    
    
    # get relevent terms required for loss func
    evidence = softplus_evidence(logits)
    S = get_S(evidence)
    alpha = get_alpha(evidence)
    K = alpha.shape[1]
    one_hot = get_one_hot_target(K, target)
    mean_p_hat = get_mean_p_hat(alpha, S)
    alpha_modified = get_alpha_modified(alpha, one_hot)
    
    
    #mse = mse_bayes_risk(mean_p_hat, S, one_hot)
    xent = xent_bayes_risk(alpha, S, one_hot)
    kl = KL(alpha_modified)
    
    #dice = dice_loss(mean_p_hat, target)
    
    # print(kl)
    # print(xent)
    # print(dice)
    # print("--")
    
    return xent + kl * 0.2 #+ dice

In [19]:
#summary(model_raw, (1, 3, 128, 128))

In [20]:
loss = combined_loss
# loss = soft_dice

In [21]:
optimizer_params={"lr":1e-3}
optimizer = torch.optim.Adam
lr_scheduler_params={"milestones":[100,200], "gamma":0.5}
lr_scheduler_constructor = torch.optim.lr_scheduler.MultiStepLR

In [22]:
from torchmetrics import Metric
class MeanAccuracyMetric(Metric):
    is_differentiable = False
    higher_is_better = True
    full_state_update = False
    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        
        preds = torch.nn.functional.max_pool2d(preds, kernel_size=scale, stride=scale)
        target = torch.nn.functional.max_pool2d(target.type(torch.float32), kernel_size=scale, stride=scale)
        
        # non zero inds
        nzs = target.sum(dim=(-2, -1)) > 0
        ps = preds[nzs]
        ts = target[nzs]
        
        evidence = softplus_evidence(ps)
        cs = evidence.argmax(dim=1)
        
        correct = torch.sum(cs == ts)
        
        
        #update = torch.sum(preds==1)
        self.correct += correct
        self.total += ts.shape[0] * ts.shape[-1] * ts.shape[-2]

    def compute(self):
        return self.correct.float() / self.total
    
    
class MeanTruePositiveMetric(Metric):
    is_differentiable = False
    higher_is_better = True
    full_state_update = False
    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        
        preds = torch.nn.functional.max_pool2d(preds, kernel_size=scale, stride=scale)
        target = torch.nn.functional.max_pool2d(target.type(torch.float32), kernel_size=scale, stride=scale)
        
        # non zero inds
        nzs = target.sum(dim=(-2, -1)) > 0
        ps = preds[nzs]
        ts = target[nzs]
        
        # true positives
        # print(ts.shape)
        tps = ts.view(-1) == 1
        
        evidence = softplus_evidence(ps)
        cs = evidence.argmax(dim=1)
        # print(cs.shape)
        cs = cs.view(-1)[tps]
        
        # print(cs.shape)
        correct = cs.shape[0]
        
        
        #update = torch.sum(preds==1)
        self.correct += correct
        self.total += tps.sum()

    def compute(self):
        return self.correct.float() / self.total
    

In [23]:
# model = StandardLitModelWrapper(model, loss, 
#                                 logging_metric=DiceLossMetric,
#                                 optimizer_params=optimizer_params,
#                                 lr_scheduler_params=lr_scheduler_params,
#                                 is_uq_model=False,
#                                 optimizer_constructor=optimizer,
#                                 lr_scheduler_constructor=lr_scheduler_constructor
#                                )
checkpoint="../epoch=38-step=1560.ckpt"
model = StandardLitModelWrapper.load_from_checkpoint(checkpoint, model=model_raw, loss=loss, logging_metric=MeanAccuracyMetric)

In [25]:
# x1, y1 = next(iter(train_dataloader))
# with torch.no_grad():
#     pred = model(x1.to(model.device))

In [26]:
checkpoint_dir = "./lightning_logs"
strategy = None
# strategy = "deepspeed_stage_2"
# strategy = "dp"
#strategy = "deepspeed_stage_2_offload"

accelerator="gpu"
devices=1
max_epochs=500
precision = 32

#checkpoint_callback = ModelCheckpoint(checkpoint_dir, save_top_k=2, monitor="val_loss")
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=500, verbose="False", mode="min", check_finite=True)
trainer = pl.Trainer(
    callbacks=[early_stop_callback], # callbacks=[checkpoint_callback, early_stop_callback]
    accelerator=accelerator,
    devices=devices,
    max_epochs=max_epochs,
    strategy=strategy,
    precision=precision,
    default_root_dir=checkpoint_dir
)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### train

In [27]:
trainer.fit(model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type               | Params
------------------------------------------------------------
0 | model                | HyperMapp3r        | 2.8 M 
1 | logging_metric_train | MeanAccuracyMetric | 0     
2 | logging_metric_val   | MeanAccuracyMetric | 0     
------------------------------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.098    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  digamma_S = torch.digamma(S).expand(alpha.shape)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 33.651


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 3.940 >= min_delta = 0.0. New best score: 29.710


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 1.064 >= min_delta = 0.0. New best score: 28.646


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.923 >= min_delta = 0.0. New best score: 27.723


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 1.181 >= min_delta = 0.0. New best score: 26.542


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.839 >= min_delta = 0.0. New best score: 25.704


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.867 >= min_delta = 0.0. New best score: 24.836


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.352 >= min_delta = 0.0. New best score: 24.485


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.274 >= min_delta = 0.0. New best score: 24.211


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.973 >= min_delta = 0.0. New best score: 23.238


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.551 >= min_delta = 0.0. New best score: 22.687


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.369 >= min_delta = 0.0. New best score: 22.317


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.087 >= min_delta = 0.0. New best score: 22.230


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.335 >= min_delta = 0.0. New best score: 21.896


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.146 >= min_delta = 0.0. New best score: 21.749


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.195 >= min_delta = 0.0. New best score: 21.554


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.160 >= min_delta = 0.0. New best score: 21.394


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.030 >= min_delta = 0.0. New best score: 21.365


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 21.343


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.050 >= min_delta = 0.0. New best score: 21.293


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.408 >= min_delta = 0.0. New best score: 20.885


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.157 >= min_delta = 0.0. New best score: 20.728


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.054 >= min_delta = 0.0. New best score: 20.674


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.499 >= min_delta = 0.0. New best score: 20.176


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [28]:
trainer.validate(model, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val_loss             20.80113983154297
    val_metric_epoch        0.9771565198898315
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_metric_epoch': 0.9771565198898315, 'val_loss': 20.80113983154297}]

In [31]:
a = torch.rand(1, 1, 10, 10)
a.requires_grad_(True)

tensor([[[[0.4235, 0.0912, 0.2689, 0.3777, 0.5751, 0.4552, 0.5021, 0.5417,
           0.6960, 0.3727],
          [0.1864, 0.6071, 0.9696, 0.9349, 0.7431, 0.9379, 0.6117, 0.6556,
           0.4828, 0.7148],
          [0.1091, 0.6460, 0.4286, 0.8095, 0.9196, 0.7096, 0.3917, 0.7920,
           0.7960, 0.7430],
          [0.1061, 0.9554, 0.1220, 0.5295, 0.9287, 0.7299, 0.7553, 0.8270,
           0.3334, 0.5610],
          [0.1751, 0.9672, 0.7131, 0.9283, 0.5504, 0.5355, 0.2801, 0.5935,
           0.6761, 0.9183],
          [0.0337, 0.2889, 0.1691, 0.6720, 0.5568, 0.4092, 0.5225, 0.1885,
           0.0274, 0.2909],
          [0.8703, 0.3404, 0.3863, 0.7916, 0.0881, 0.2999, 0.6033, 0.4069,
           0.2805, 0.2457],
          [0.0215, 0.9805, 0.0366, 0.5619, 0.6067, 0.4569, 0.4668, 0.8777,
           0.8244, 0.8979],
          [0.0088, 0.3680, 0.8272, 0.0814, 0.1913, 0.5256, 0.9087, 0.7746,
           0.3294, 0.2913],
          [0.6453, 0.6923, 0.9573, 0.9221, 0.1295, 0.5587, 0.5846, 0.9056

In [33]:
b = torch.nn.functional.avg_pool2d(a, 2, stride=2)

In [36]:
torch.sigmoid(b)

tensor([[[[0.5810, 0.6542, 0.6633, 0.6406, 0.6380],
          [0.6116, 0.6160, 0.6947, 0.6663, 0.6476],
          [0.5905, 0.6504, 0.6255, 0.5978, 0.6173],
          [0.6349, 0.6092, 0.5897, 0.6431, 0.6369],
          [0.6055, 0.6675, 0.5869, 0.6886, 0.6004]]]],
       grad_fn=<SigmoidBackward0>)

In [155]:
def epistemic_uncert(K, S):
    return K / S

def aleotoric_uncert(alpha, S):
    lhs = torch.digamma(alpha + 1)
    rhs = torch.digamma(S + 1)
    
    frac = alpha / S
    
    return -torch.sum(frac * (lhs - rhs), dim=1)
    
    
def distributional_uncert(alpha, S):
    frac = alpha / S
    rhs = torch.log(frac) - torch.digamma(alpha + 1) + torch.digamma(S + 1)
    return -torch.sum(frac * rhs, dim=1)

In [51]:
trainer.validate(model, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val_loss            27.073522567749023
    val_metric_epoch        0.9772764444351196
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_metric_epoch': 0.9772764444351196, 'val_loss': 27.073522567749023}]