# Test the apply flag works correctly
(and that we get random sampling (nondeterministic) behaviour when we want to!)

In [1]:
import torch
from torchinfo import summary

In [2]:
from trustworthai.models.uq_models.drop_UNet import UNet
from trustworthai.models.uq_models.uq_model import UQModel

In [83]:
p=0.4
mdim=True
args = {
    "dropout_type":None,
    "dropconnect_type":"bernoulli",
    "p":p,
    "mdim":False,
    "norm_type":"bn",
    "uqnorm":True,
    "gn_groups":4
}

model = UNet(in_channels=3,
             out_channels=2,
             kernel_size=3,
             init_features=32,
             softmax=False,
             dropout_type=args["dropout_type"],
             dropout_p=args["p"],
             gaussout_mean=1, 
             dropconnect_type=args["dropconnect_type"],
             dropconnect_p=args["p"],
             gaussconnect_mean=1,
             norm_type=args["norm_type"],
             use_uq_norm_layer=args['uqnorm'],
             use_multidim_dropout=args["mdim"],  
             use_multidim_dropconnect=args["mdim"], 
             groups=None,
             gn_groups=args["gn_groups"], 
            )

In [153]:
#model = model.cuda()

In [57]:
summary(model, (1, 3, 160, 224))

Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     --                        --
├─Block: 1-1                             [1, 32, 160, 224]         --
│    └─UQDropConnect2d: 2-1              [1, 32, 160, 224]         --
│    └─UQLayerWrapper: 2-2               [1, 32, 160, 224]         --
│    │    └─BatchNorm2d: 3-1             [1, 32, 160, 224]         64
│    └─ReLU: 2-3                         [1, 32, 160, 224]         --
│    └─UQDropConnect2d: 2-4              [1, 32, 160, 224]         --
│    └─UQLayerWrapper: 2-5               [1, 32, 160, 224]         --
│    │    └─BatchNorm2d: 3-2             [1, 32, 160, 224]         64
│    └─ReLU: 2-6                         [1, 32, 160, 224]         --
├─MaxPool2d: 1-2                         [1, 32, 80, 112]          --
├─Block: 1-3                             [1, 64, 80, 112]          --
│    └─UQDropConnect2d: 2-7              [1, 64, 80, 112]          --
│    └─UQLayerW

In [84]:
model=model.cuda()

In [95]:
xs = torch.rand(1, 3, 160, 224)

In [96]:
model.set_applyfunc(True)

### check for determinism and non determinism as we expect

In [116]:
model.set_applyfunc(False)

In [117]:
y1 = model(xs.cuda())
y2 = model(xs.cuda())

In [118]:
torch.sum(y1[0] != y2[0]) # should equal zero for determinism

tensor(0, device='cuda:0')

In [119]:
model.set_applyfunc(True)

In [126]:
z1 = model(xs.cuda())
z2 = model(xs.cuda())

In [127]:
torch.sum(z1[0] != z2[0]), torch.numel(zs[0])

(tensor(71680, device='cuda:0'), 71680)

### checkinf for non determinism fo the btach norm layers is more complex

In [47]:
a = torch.rand(3, 160, 224); ass = torch.stack([a,a,a,a,a,a,a,a])
b = torch.rand(3, 160, 224); bss = torch.stack([b,b,b,b,b,b,b,b])

**non-determinism**

In [64]:
model.set_applyfunc(True)
_ = model(ass.cuda())
model.set_applyfunc(False)
z1 = model(xs.cuda())

model.set_applyfunc(True)
_ = model(bss.cuda())
model.set_applyfunc(False)
z2 = model(xs.cuda())

In [65]:
torch.sum(z1[0] != z2[0]), torch.numel(zs[0])

(tensor(0, device='cuda:0'), 71680)

**determinism**

In [62]:
model.set_applyfunc(False)
_ = model(ass.cuda())
z1 = model(xs.cuda())

_ = model(bss.cuda())
z2 = model(xs.cuda())

In [63]:
torch.sum(z1[0] != z2[0]), torch.numel(zs[0])

(tensor(0, device='cuda:0'), 71680)

### hmmm setting to false works but back to true again doesn't...

In [124]:
model.set_applyfunc(False)

In [125]:
for m in model.modules():
    if isinstance(m, UQModel):
        print("yes")
        print(m.get_applyfunc())
        #print(m)
    # else:
    #     print("no")
    

yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
yes
False
