In [15]:
import warnings
import time

from datetime import datetime
import torch
import torchvision.transforms as transforms
from laplace import Laplace
from torch.utils.data import DataLoader

from src.data.dataloader import MURADataset
from src.models.models import CNN, CNN_3
from netcal.metrics import ECE
from src.models.utils import pred, save_laplace, load_laplace

batch_size = 32
shuffle = True
num_workers = 1

warnings.filterwarnings("ignore")

la_path = 'models/laplace_diag_17-04-2022_13.pkl'
N = 10
method = 'union'

device = "cuda" if torch.cuda.is_available() else "cpu"
la = load_laplace(la_path)

la._device = torch.device(device)

la_sample = la.sample(N)

if method == "average":
    samples = la_sample.mean(axis=0)
if method == "intersect":
    samples = la_sample.min(axis=0).values
if method == "union":
    samples = la_sample.max(axis=0).values

model = CNN(input_channels=3, input_height=256, input_width=256, num_classes=7).to(
    device
)

model.eval()

model.load_state_dict(
    torch.load(
        "models/STATEtrained_model_epocs100_16_04_22_trans_1_layers_5.pt",
        map_location=torch.device(device),
    )
)

model.l_out.weight.data = torch.reshape(torch.reshape(samples, (65, 7))[:-1], (7,64))
model.l_out.bias.data = torch.reshape(samples, (65, 7))[-1]

torch.save(
    model.state_dict(),
    f"./models/tester_{method}.pt",
)

In [329]:
samples.shape

torch.Size([455])

In [327]:
torch.reshape(torch.reshape(parameters_to_vector(model.l_out.parameters()), (65, 7))[:-1], (7,64))

tensor([[-0.1714, -0.1416, -0.0820, -0.0553, -0.0986,  0.4477,  0.5367,  0.8040,
         -0.1054, -0.1196, -0.2342,  0.7754,  0.2543, -0.1372,  0.6522, -0.2257,
         -0.3246,  0.8251,  0.5270, -0.1109,  0.8341, -0.1458,  0.1337, -0.3559,
         -0.1020, -0.1072,  0.6151, -0.1169, -0.1092,  0.7664, -0.1830, -0.1022,
         -0.1335, -0.3054, -0.2112, -0.3189, -0.1337,  0.2937,  0.3473, -0.2284,
          0.8369, -0.0646, -0.0743, -0.2321, -0.1896, -0.0971,  0.3692, -0.0868,
         -0.3395, -0.3295,  0.3550, -0.0985, -0.1551, -0.1572, -0.3203, -0.1426,
         -0.2239,  0.5537,  0.8222,  0.8099, -0.3459, -0.1104, -0.0927, -0.2411],
        [-0.1144, -0.1047, -0.0719, -0.0651, -0.0724, -0.0994, -0.0926, -0.0571,
         -0.0671, -0.0941, -0.1207, -0.0594, -0.1096, -0.0698, -0.0590, -0.1016,
         -0.1144, -0.0554, -0.1046, -0.0752, -0.0617, -0.1101, -0.0734, -0.1175,
         -0.0732, -0.0490, -0.1082, -0.1066, -0.0564, -0.0507, -0.1140, -0.0699,
         -0.1153, -0.0955, 

In [326]:
model.l_out.weight.data

tensor([[-0.1714, -0.1416, -0.0820, -0.0553, -0.0986,  0.4477,  0.5367,  0.8040,
         -0.1054, -0.1196, -0.2342,  0.7754,  0.2543, -0.1372,  0.6522, -0.2257,
         -0.3246,  0.8251,  0.5270, -0.1109,  0.8341, -0.1458,  0.1337, -0.3559,
         -0.1020, -0.1072,  0.6151, -0.1169, -0.1092,  0.7664, -0.1830, -0.1022,
         -0.1335, -0.3054, -0.2112, -0.3189, -0.1337,  0.2937,  0.3473, -0.2284,
          0.8369, -0.0646, -0.0743, -0.2321, -0.1896, -0.0971,  0.3692, -0.0868,
         -0.3395, -0.3295,  0.3550, -0.0985, -0.1551, -0.1572, -0.3203, -0.1426,
         -0.2239,  0.5537,  0.8222,  0.8099, -0.3459, -0.1104, -0.0927, -0.2411],
        [-0.1144, -0.1047, -0.0719, -0.0651, -0.0724, -0.0994, -0.0926, -0.0571,
         -0.0671, -0.0941, -0.1207, -0.0594, -0.1096, -0.0698, -0.0590, -0.1016,
         -0.1144, -0.0554, -0.1046, -0.0752, -0.0617, -0.1101, -0.0734, -0.1175,
         -0.0732, -0.0490, -0.1082, -0.1066, -0.0564, -0.0507, -0.1140, -0.0699,
         -0.1153, -0.0955, 

In [224]:
torch.reshape(parameters_to_vector(model.l_out.parameters()), (65, 7))[-1]

tensor([ 0.1458, -0.4801,  0.1616,  0.2786,  0.1849, -0.4138,  0.1300],
       grad_fn=<SelectBackward0>)

In [246]:
model.l_out.bias

Parameter containing:
tensor([ 0.1458, -0.4801,  0.1616,  0.2786,  0.1849, -0.4138,  0.1300],
       requires_grad=True)

In [239]:
model.l_out.parameters

<bound method Module.parameters of Linear(in_features=64, out_features=7, bias=True)>

In [304]:
torch.reshape(parameters_to_vector(model.l_out.parameters()), (65, 7))

tensor([[-0.1714, -0.1416, -0.0820, -0.0553, -0.0986,  0.4477,  0.5367],
        [ 0.8040, -0.1054, -0.1196, -0.2342,  0.7754,  0.2543, -0.1372],
        [ 0.6522, -0.2257, -0.3246,  0.8251,  0.5270, -0.1109,  0.8341],
        [-0.1458,  0.1337, -0.3559, -0.1020, -0.1072,  0.6151, -0.1169],
        [-0.1092,  0.7664, -0.1830, -0.1022, -0.1335, -0.3054, -0.2112],
        [-0.3189, -0.1337,  0.2937,  0.3473, -0.2284,  0.8369, -0.0646],
        [-0.0743, -0.2321, -0.1896, -0.0971,  0.3692, -0.0868, -0.3395],
        [-0.3295,  0.3550, -0.0985, -0.1551, -0.1572, -0.3203, -0.1426],
        [-0.2239,  0.5537,  0.8222,  0.8099, -0.3459, -0.1104, -0.0927],
        [-0.2411, -0.1144, -0.1047, -0.0719, -0.0651, -0.0724, -0.0994],
        [-0.0926, -0.0571, -0.0671, -0.0941, -0.1207, -0.0594, -0.1096],
        [-0.0698, -0.0590, -0.1016, -0.1144, -0.0554, -0.1046, -0.0752],
        [-0.0617, -0.1101, -0.0734, -0.1175, -0.0732, -0.0490, -0.1082],
        [-0.1066, -0.0564, -0.0507, -0.1140, -0.069

In [299]:
torch.reshape(parameters_to_vector(model.l_out.parameters()), (7, 65))

tensor([[-0.1714, -0.1416, -0.0820, -0.0553, -0.0986,  0.4477,  0.5367,  0.8040,
         -0.1054, -0.1196, -0.2342,  0.7754,  0.2543, -0.1372,  0.6522, -0.2257,
         -0.3246,  0.8251,  0.5270, -0.1109,  0.8341, -0.1458,  0.1337, -0.3559,
         -0.1020, -0.1072,  0.6151, -0.1169, -0.1092,  0.7664, -0.1830, -0.1022,
         -0.1335, -0.3054, -0.2112, -0.3189, -0.1337,  0.2937,  0.3473, -0.2284,
          0.8369, -0.0646, -0.0743, -0.2321, -0.1896, -0.0971,  0.3692, -0.0868,
         -0.3395, -0.3295,  0.3550, -0.0985, -0.1551, -0.1572, -0.3203, -0.1426,
         -0.2239,  0.5537,  0.8222,  0.8099, -0.3459, -0.1104, -0.0927, -0.2411,
         -0.1144],
        [-0.1047, -0.0719, -0.0651, -0.0724, -0.0994, -0.0926, -0.0571, -0.0671,
         -0.0941, -0.1207, -0.0594, -0.1096, -0.0698, -0.0590, -0.1016, -0.1144,
         -0.0554, -0.1046, -0.0752, -0.0617, -0.1101, -0.0734, -0.1175, -0.0732,
         -0.0490, -0.1082, -0.1066, -0.0564, -0.0507, -0.1140, -0.0699, -0.1153,
         