In [1]:
import abc
import collections
import enum
import math
import pathlib
import typing
import warnings

import numpy as np
import torch
import torch.optim
import torch.utils.data
import tqdm
from matplotlib import pyplot as plt

from util import draw_reliability_diagram, cost_function, setup_seeds, calc_calibration_curve

EXTENDED_EVALUATION = False
"""
Set `EXTENDED_EVALUATION` to `True` in order to generate additional plots on validation data.
"""

USE_PRETRAINED_INIT = True
"""
If `USE_PRETRAINED_INIT` is `True`, then MAP inference uses provided pretrained weights.
You should not modify MAP training or the CNN architecture before passing the hard baseline.
If you set the constant to `False` (to further experiment),
this solution always performs MAP inference before running your SWAG implementation.
Note that MAP inference can take a long time.
"""

'\nIf `USE_PRETRAINED_INIT` is `True`, then MAP inference uses provided pretrained weights.\nYou should not modify MAP training or the CNN architecture before passing the hard baseline.\nIf you set the constant to `False` (to further experiment),\nthis solution always performs MAP inference before running your SWAG implementation.\nNote that MAP inference can take a long time.\n'

In [2]:
torch.rand(4).shape

torch.Size([4])

In [3]:
torch.rand(4).unsqueeze(0).shape

torch.Size([1, 4])

In [4]:
L = [torch.rand(4).unsqueeze(0) for i in range(5)]

In [5]:
torch.cat(L,0).shape

torch.Size([5, 4])

In [6]:
torch.mean(torch.cat(L,0),0)

tensor([0.3612, 0.6067, 0.5483, 0.5626])

In [7]:
# def main():
#     raise RuntimeError(
#         "This main() method is for illustrative purposes only"
#         " and will NEVER be called when running your solution to generate your submission file!\n"
#         "The checker always directly interacts with your SWAGInference class and evaluate method.\n"
#         "You can remove this exception for local testing, but be aware that any changes to the main() method"
#         " are ignored when generating your submission file."
#     )



In [8]:
pathlib

<module 'pathlib' from 'C:\\APPS\\Anaconda3\\envs\\pai\\lib\\pathlib.py'>

In [9]:
data_dir = pathlib.Path.cwd()
model_dir = pathlib.Path.cwd()
output_dir = pathlib.Path.cwd()

In [10]:
# Load training data
train_xs = torch.from_numpy(np.load(data_dir / "train_xs.npz")["train_xs"])
raw_train_meta = np.load(data_dir / "train_ys.npz")
train_ys = torch.from_numpy(raw_train_meta["train_ys"])
train_is_snow = torch.from_numpy(raw_train_meta["train_is_snow"])
train_is_cloud = torch.from_numpy(raw_train_meta["train_is_cloud"])
dataset_train = torch.utils.data.TensorDataset(train_xs, train_is_snow, train_is_cloud, train_ys)

FileNotFoundError: [Errno 2] No such file or directory: 'Y:\\private\\desktop-dinfk-xp\\2023-f\\pai\\pai_proj\\task2_handout_e14a_works_for_new_mac\\train_xs.npz'

In [None]:
train_xs.shape

In [None]:
plt.imshow(train_xs[100].permute(1,2,0))

In [None]:
counts, bins = np.histogram(train_ys, bins=[-1,0,1,2,3,4,5])
plt.hist(bins[:-1], bins, weights=counts)

In [None]:
counts, bins = np.histogram(train_is_snow, bins=[-1,0,1,2,3,4,5])
plt.hist(bins[:-1], bins, weights=counts)

In [None]:
counts, bins = np.histogram(train_is_cloud, bins=[-1,0,1,2,3,4,5])
plt.hist(bins[:-1], bins, weights=counts)

In [None]:
# Load validation data
val_xs = torch.from_numpy(np.load(data_dir / "val_xs.npz")["val_xs"])
raw_val_meta = np.load(data_dir / "val_ys.npz")
val_ys = torch.from_numpy(raw_val_meta["val_ys"])
val_is_snow = torch.from_numpy(raw_val_meta["val_is_snow"])
val_is_cloud = torch.from_numpy(raw_val_meta["val_is_cloud"])
dataset_val = torch.utils.data.TensorDataset(val_xs, val_is_snow, val_is_cloud, val_ys)

In [None]:
val_xs.shape

In [11]:
val_ys

NameError: name 'val_ys' is not defined

In [12]:
val_is_snow

NameError: name 'val_is_snow' is not defined

In [13]:
counts, bins = np.histogram(val_ys, bins=[-1,0,1,2,3,4,5])
plt.hist(bins[:-1], bins, weights=counts)

NameError: name 'val_ys' is not defined

In [14]:
counts, bins = np.histogram(val_ys, bins=[-1,0,1,2,3,4,5])
plt.hist(bins[:-1], bins, weights=counts)

NameError: name 'val_ys' is not defined

In [15]:
counts, bins = np.histogram(val_is_snow, bins=[-1,0,1,2,3,4,5])
plt.hist(bins[:-1], bins, weights=counts)

NameError: name 'val_is_snow' is not defined

In [16]:
counts, bins = np.histogram(val_is_cloud, bins=[-1,0,1,2,3,4,5])
plt.hist(bins[:-1], bins, weights=counts)

NameError: name 'val_is_cloud' is not defined

In [17]:
bins

NameError: name 'bins' is not defined

In [18]:
# Fix all randomness
setup_seeds()

In [19]:
# Build and run the actual solution
train_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=16,
    shuffle=True,
    num_workers=0,
)

NameError: name 'dataset_train' is not defined

In [20]:
swag = SWAGInference(
    train_xs=dataset_train.tensors[0],
    model_dir=model_dir,
)
swag.fit(train_loader)
swag.calibrate(dataset_val)

NameError: name 'SWAGInference' is not defined

In [21]:











# fork_rng ensures that the evaluation does not change the rng state.
# That way, you should get exactly the same results even if you remove evaluation
# to save computational time when developing the task
# (as long as you ONLY use torch randomness, and not e.g. random or numpy.random).
with torch.random.fork_rng():
    evaluate(swag, dataset_val, EXTENDED_EVALUATION, output_dir)


NameError: name 'evaluate' is not defined

In [22]:
class CNN(torch.nn.Module):
    """
    Small convolutional neural network used in this task.
    You should not modify this class before passing the hard baseline.

    Note that if you change the architecture of this network,
    you need to re-run MAP inference and cannot use the provided pretrained weights anymore.
    Hence, you need to set `USE_PRETRAINED_INIT = False` at the top of this file.
    """
    def __init__(
        self,
        in_channels: int,
        out_classes: int,
    ):
        super().__init__()

        self.layer0 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, 32, kernel_size=5),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
        )
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
        )
        self.pool1 = torch.nn.MaxPool2d((2, 2), stride=(2, 2))

        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=3),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 64, kernel_size=3),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
        )
        self.pool2 = torch.nn.MaxPool2d((2, 2), stride=(2, 2))

        self.layer5 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 64, kernel_size=3),
        )

        self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1))

        self.linear = torch.nn.Linear(64, out_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.pool1(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.pool2(x)
        x = self.layer5(x)

        # Average features over both spatial dimensions, and remove the now superfluous dimensions
        x = self.global_pool(x).squeeze(-1).squeeze(-1)

        # Note: this network does NOT output the per-class probabilities y =[y_1, ..., y_C],
        # but a feature vector z such that y = softmax(z).
        # This avoids numerical instabilities during optimization.
        # The PyTorch loss automatically handles this.
        log_softmax = self.linear(x)

        return log_softmax

In [23]:
network = CNN(in_channels=3, out_classes=6)

In [24]:
network.named_parameters()

<generator object Module.named_parameters at 0x000001EFB8BCDDD0>

In [25]:
current_params = {name: param.detach() for name, param in network.named_parameters()}

In [26]:
current_params

{'layer0.0.weight': tensor([[[[ 0.0595, -0.0510, -0.0224,  0.0542, -0.1087],
           [ 0.0692, -0.0238,  0.0587,  0.0161, -0.0141],
           [ 0.0320,  0.0057,  0.0422, -0.0450, -0.0084],
           [-0.0104,  0.0167, -0.0005,  0.1009,  0.0359],
           [-0.0430, -0.0697, -0.0194, -0.0498, -0.0370]],
 
          [[ 0.0055,  0.0688,  0.0628, -0.1129,  0.0716],
           [ 0.0323,  0.1095,  0.0762, -0.1052, -0.1098],
           [-0.0557,  0.1014, -0.0192,  0.0494, -0.0537],
           [ 0.1133, -0.0489,  0.0866,  0.0014, -0.0608],
           [ 0.0594, -0.0613,  0.0340, -0.0333, -0.0127]],
 
          [[-0.1110, -0.0551,  0.0627, -0.0281,  0.1150],
           [ 0.0926, -0.0054, -0.0771,  0.0703,  0.0358],
           [-0.0746,  0.0750,  0.0701,  0.1024, -0.0647],
           [-0.0190, -0.0022,  0.0169, -0.0876, -0.0819],
           [ 0.0628, -0.0271,  0.0564,  0.0066,  0.0379]]],
 
 
         [[[ 0.0254,  0.0420,  0.0572, -0.1069,  0.0581],
           [-0.0812, -0.0871,  0.0070, -0

In [27]:
for name, param in current_params.items():
    break

In [28]:
name

'layer0.0.weight'

In [29]:
current_params[name]*100

tensor([[[[  5.9497,  -5.0966,  -2.2385,   5.4198, -10.8708],
          [  6.9249,  -2.3755,   5.8745,   1.6052,  -1.4139],
          [  3.2027,   0.5696,   4.2173,  -4.4999,  -0.8419],
          [ -1.0395,   1.6737,  -0.0461,  10.0940,   3.5933],
          [ -4.3002,  -6.9739,  -1.9355,  -4.9807,  -3.7002]],

         [[  0.5529,   6.8835,   6.2762, -11.2878,   7.1582],
          [  3.2258,  10.9530,   7.6217, -10.5209, -10.9790],
          [ -5.5695,  10.1396,  -1.9234,   4.9416,  -5.3660],
          [ 11.3301,  -4.8855,   8.6594,   0.1367,  -6.0832],
          [  5.9353,  -6.1294,   3.3960,  -3.3343,  -1.2659]],

         [[-11.1012,  -5.5054,   6.2658,  -2.8068,  11.5014],
          [  9.2559,  -0.5407,  -7.7074,   7.0317,   3.5837],
          [ -7.4642,   7.5003,   7.0102,  10.2410,  -6.4728],
          [ -1.9007,  -0.2237,   1.6865,  -8.7632,  -8.1940],
          [  6.2816,  -2.7077,   5.6404,   0.6583,   3.7914]]],


        [[[  2.5391,   4.1985,   5.7240, -10.6938,   5.8122],


In [30]:
current_params[name]*10000*current_params[name]

tensor([[[[3.5399e+01, 2.5975e+01, 5.0110e+00, 2.9374e+01, 1.1817e+02],
          [4.7955e+01, 5.6431e+00, 3.4509e+01, 2.5768e+00, 1.9990e+00],
          [1.0257e+01, 3.2448e-01, 1.7786e+01, 2.0249e+01, 7.0876e-01],
          [1.0807e+00, 2.8012e+00, 2.1278e-03, 1.0189e+02, 1.2912e+01],
          [1.8492e+01, 4.8636e+01, 3.7460e+00, 2.4807e+01, 1.3692e+01]],

         [[3.0569e-01, 4.7382e+01, 3.9391e+01, 1.2741e+02, 5.1240e+01],
          [1.0406e+01, 1.1997e+02, 5.8091e+01, 1.1069e+02, 1.2054e+02],
          [3.1019e+01, 1.0281e+02, 3.6994e+00, 2.4420e+01, 2.8794e+01],
          [1.2837e+02, 2.3868e+01, 7.4985e+01, 1.8697e-02, 3.7005e+01],
          [3.5228e+01, 3.7569e+01, 1.1533e+01, 1.1117e+01, 1.6026e+00]],

         [[1.2324e+02, 3.0309e+01, 3.9260e+01, 7.8782e+00, 1.3228e+02],
          [8.5673e+01, 2.9233e-01, 5.9404e+01, 4.9445e+01, 1.2843e+01],
          [5.5714e+01, 5.6254e+01, 4.9143e+01, 1.0488e+02, 4.1897e+01],
          [3.6127e+00, 5.0036e-02, 2.8444e+00, 7.6794e+01, 6

In [31]:
param

tensor([[[[ 0.0595, -0.0510, -0.0224,  0.0542, -0.1087],
          [ 0.0692, -0.0238,  0.0587,  0.0161, -0.0141],
          [ 0.0320,  0.0057,  0.0422, -0.0450, -0.0084],
          [-0.0104,  0.0167, -0.0005,  0.1009,  0.0359],
          [-0.0430, -0.0697, -0.0194, -0.0498, -0.0370]],

         [[ 0.0055,  0.0688,  0.0628, -0.1129,  0.0716],
          [ 0.0323,  0.1095,  0.0762, -0.1052, -0.1098],
          [-0.0557,  0.1014, -0.0192,  0.0494, -0.0537],
          [ 0.1133, -0.0489,  0.0866,  0.0014, -0.0608],
          [ 0.0594, -0.0613,  0.0340, -0.0333, -0.0127]],

         [[-0.1110, -0.0551,  0.0627, -0.0281,  0.1150],
          [ 0.0926, -0.0054, -0.0771,  0.0703,  0.0358],
          [-0.0746,  0.0750,  0.0701,  0.1024, -0.0647],
          [-0.0190, -0.0022,  0.0169, -0.0876, -0.0819],
          [ 0.0628, -0.0271,  0.0564,  0.0066,  0.0379]]],


        [[[ 0.0254,  0.0420,  0.0572, -0.1069,  0.0581],
          [-0.0812, -0.0871,  0.0070, -0.0197,  0.0678],
          [-0.0669, -0.

In [32]:
current_params.keys()

dict_keys(['layer0.0.weight', 'layer0.0.bias', 'layer0.1.weight', 'layer0.1.bias', 'layer1.0.weight', 'layer1.0.bias', 'layer1.1.weight', 'layer1.1.bias', 'layer2.0.weight', 'layer2.0.bias', 'layer2.1.weight', 'layer2.1.bias', 'layer3.0.weight', 'layer3.0.bias', 'layer3.1.weight', 'layer3.1.bias', 'layer4.0.weight', 'layer4.0.bias', 'layer4.1.weight', 'layer4.1.bias', 'layer5.0.weight', 'layer5.0.bias', 'linear.weight', 'linear.bias'])

In [33]:
def _create_weight_copy(net):
    """Create an all-zero copy of the network weights as a dictionary that maps name -> weight"""
    return {
        name: torch.zeros_like(param, requires_grad=False)
        for name, param in net.named_parameters()
    }

In [34]:
_create_weight_copy(network)

{'layer0.0.weight': tensor([[[[0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
 
          [[0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
 
          [[0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]]],
 
 
         [[[0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
 
          [[0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
 
          [[0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
      

In [35]:
network

CNN(
  (layer0): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer1): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (layer3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer4): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1,

In [36]:
network.load_state_dict(torch.load(PRETRAINED_WEIGHTS_FILE))

NameError: name 'PRETRAINED_WEIGHTS_FILE' is not defined

In [37]:
torch.load(PRETRAINED_WEIGHTS_FILE).keys()

NameError: name 'PRETRAINED_WEIGHTS_FILE' is not defined

In [None]:
current_params.keys()

In [38]:
PRETRAINED_WEIGHTS_FILE = model_dir / "map_weights.pt"

In [39]:
current_params

{'layer0.0.weight': tensor([[[[ 0.0595, -0.0510, -0.0224,  0.0542, -0.1087],
           [ 0.0692, -0.0238,  0.0587,  0.0161, -0.0141],
           [ 0.0320,  0.0057,  0.0422, -0.0450, -0.0084],
           [-0.0104,  0.0167, -0.0005,  0.1009,  0.0359],
           [-0.0430, -0.0697, -0.0194, -0.0498, -0.0370]],
 
          [[ 0.0055,  0.0688,  0.0628, -0.1129,  0.0716],
           [ 0.0323,  0.1095,  0.0762, -0.1052, -0.1098],
           [-0.0557,  0.1014, -0.0192,  0.0494, -0.0537],
           [ 0.1133, -0.0489,  0.0866,  0.0014, -0.0608],
           [ 0.0594, -0.0613,  0.0340, -0.0333, -0.0127]],
 
          [[-0.1110, -0.0551,  0.0627, -0.0281,  0.1150],
           [ 0.0926, -0.0054, -0.0771,  0.0703,  0.0358],
           [-0.0746,  0.0750,  0.0701,  0.1024, -0.0647],
           [-0.0190, -0.0022,  0.0169, -0.0876, -0.0819],
           [ 0.0628, -0.0271,  0.0564,  0.0066,  0.0379]]],
 
 
         [[[ 0.0254,  0.0420,  0.0572, -0.1069,  0.0581],
           [-0.0812, -0.0871,  0.0070, -0

In [40]:
network.load_state_dict(current_params)

RuntimeError: Error(s) in loading state_dict for CNN:
	Missing key(s) in state_dict: "layer0.1.running_mean", "layer0.1.running_var", "layer1.1.running_mean", "layer1.1.running_var", "layer2.1.running_mean", "layer2.1.running_var", "layer3.1.running_mean", "layer3.1.running_var", "layer4.1.running_mean", "layer4.1.running_var". 

In [None]:
network

In [46]:
current_params

{'layer0.0.weight': tensor([[[[ 0.0595, -0.0510, -0.0224,  0.0542, -0.1087],
           [ 0.0692, -0.0238,  0.0587,  0.0161, -0.0141],
           [ 0.0320,  0.0057,  0.0422, -0.0450, -0.0084],
           [-0.0104,  0.0167, -0.0005,  0.1009,  0.0359],
           [-0.0430, -0.0697, -0.0194, -0.0498, -0.0370]],
 
          [[ 0.0055,  0.0688,  0.0628, -0.1129,  0.0716],
           [ 0.0323,  0.1095,  0.0762, -0.1052, -0.1098],
           [-0.0557,  0.1014, -0.0192,  0.0494, -0.0537],
           [ 0.1133, -0.0489,  0.0866,  0.0014, -0.0608],
           [ 0.0594, -0.0613,  0.0340, -0.0333, -0.0127]],
 
          [[-0.1110, -0.0551,  0.0627, -0.0281,  0.1150],
           [ 0.0926, -0.0054, -0.0771,  0.0703,  0.0358],
           [-0.0746,  0.0750,  0.0701,  0.1024, -0.0647],
           [-0.0190, -0.0022,  0.0169, -0.0876, -0.0819],
           [ 0.0628, -0.0271,  0.0564,  0.0066,  0.0379]]],
 
 
         [[[ 0.0254,  0.0420,  0.0572, -0.1069,  0.0581],
           [-0.0812, -0.0871,  0.0070, -0

In [47]:
network.named_parameters

<bound method Module.named_parameters of CNN(
  (layer0): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer1): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (layer3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer4): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (1):

In [48]:
dd ={
            name: torch.zeros_like(param, requires_grad=False)
            for name, param in network.named_parameters()
}

In [49]:
dd.keys()

dict_keys(['layer0.0.weight', 'layer0.0.bias', 'layer0.1.weight', 'layer0.1.bias', 'layer1.0.weight', 'layer1.0.bias', 'layer1.1.weight', 'layer1.1.bias', 'layer2.0.weight', 'layer2.0.bias', 'layer2.1.weight', 'layer2.1.bias', 'layer3.0.weight', 'layer3.0.bias', 'layer3.1.weight', 'layer3.1.bias', 'layer4.0.weight', 'layer4.0.bias', 'layer4.1.weight', 'layer4.1.bias', 'layer5.0.weight', 'layer5.0.bias', 'linear.weight', 'linear.bias'])

In [50]:
network.named_parameters

<bound method Module.named_parameters of CNN(
  (layer0): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer1): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (layer3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer4): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (1):

In [51]:
d = {1: 23, 3:4}

In [52]:
d2 = {1:32}

In [53]:
d.update(d2)

In [54]:
d

{1: 32, 3: 4}

In [55]:
network.state_dict().update(current_params)

In [42]:
net = network.state_dict()
net.update(sampled_param)
self.network.load_state_dict(net)
self._update_batchnorm()

NameError: name 'sampled_param' is not defined

In [44]:
param.shape

torch.Size([32, 3, 5, 5])

In [49]:
D = collections.deque()

In [50]:
K = 10

In [52]:
for i in range(K):
    D.append(_create_weight_copy(network))

In [58]:
len(D)

10

In [61]:
dv = []

In [63]:
for D_i in D:
    dv.append(D_i[name])

In [66]:
dv[0].shape

torch.Size([32, 3, 5, 5])

In [68]:
torch.tensor(dv)

ValueError: only one element tensors can be converted to Python scalars

In [41]:
class InferenceMode(enum.Enum):
    """
    Inference mode switch for your implementation.
    `MAP` simply predicts the most likely class using pretrained MAP weights.
    `SWAG_DIAGONAL` and `SWAG_FULL` correspond to SWAG-diagonal and the full SWAG method, respectively.
    """
    MAP = 0
    SWAG_DIAGONAL = 1
    SWAG_FULL = 2

In [80]:
class SWAGInference(object):
    """
    Your implementation of SWA-Gaussian.
    This class is used to run and evaluate your solution.
    You must preserve all methods and signatures of this class.
    However, you can add new methods if you want.

    We provide basic functionality and some helper methods.
    You can pass all baselines by only modifying methods marked with TODO.
    However, we encourage you to skim other methods in order to gain a better understanding of SWAG.
    """

    def __init__(
        self,
        train_xs: torch.Tensor,
        model_dir: pathlib.Path,
        # TODO(1): change inference_mode to InferenceMode.SWAG_DIAGONAL
        # TODO(2): change inference_mode to InferenceMode.SWAG_FULL
        inference_mode: InferenceMode = InferenceMode.SWAG_DIAGONAL, # InferenceMode.MAP,
        # TODO(2): optionally add/tweak hyperparameters
        swag_epochs: int = 30,
        swag_learning_rate: float = 0.045,
        swag_update_freq: int = 1,
        deviation_matrix_max_rank: int = 15,
        bma_samples: int = 30,
    ):
        """
        :param train_xs: Training images (for storage only)
        :param model_dir: Path to directory containing pretrained MAP weights
        :param inference_mode: Control which inference mode (MAP, SWAG-diagonal, full SWAG) to use
        :param swag_epochs: Total number of gradient descent epochs for SWAG
        :param swag_learning_rate: Learning rate for SWAG gradient descent
        :param swag_update_freq: Frequency (in epochs) for updating SWAG statistics during gradient descent
        :param deviation_matrix_max_rank: Rank of deviation matrix for full SWAG
        :param bma_samples: Number of networks to sample for Bayesian model averaging during prediction
        """

        self.model_dir = model_dir
        self.inference_mode = inference_mode
        self.swag_epochs = swag_epochs
        self.swag_learning_rate = swag_learning_rate
        self.swag_update_freq = swag_update_freq
        self.deviation_matrix_max_rank = deviation_matrix_max_rank
        self.bma_samples = bma_samples

        # Network used to perform SWAG.
        # Note that all operations in this class modify this network IN-PLACE!
        self.network = CNN(in_channels=3, out_classes=6)

        # Store training dataset to recalculate batch normalization statistics during SWAG inference
        self.train_dataset = torch.utils.data.TensorDataset(train_xs)

        # SWAG-diagonal
        # TODO(1): create attributes for SWAG-diagonal
        #  Hint: self._create_weight_copy() creates an all-zero copy of the weights
        #  as a dictionary that maps from weight name to values.
        #  Hint: you never need to consider the full vector of weights,
        #  but can always act on per-layer weights (in the format that _create_weight_copy() returns)
        self.w_swa = self._create_weight_copy()
        self.w2_swa = self._create_weight_copy()
        self.n = 0

        # Full SWAG
        # TODO(2): create attributes for SWAG-diagonal
        #  Hint: check collections.deque

        # Calibration, prediction, and other attributes
        # TODO(2): create additional attributes, e.g., for calibration
        self._prediction_threshold = None  # this is an example, feel free to be creative

    def update_swag(self) -> None:
        """
        Update SWAG statistics with the current weights of self.network.
        """

        # Create a copy of the current network weights
        current_params = {name: param.detach() for name, param in self.network.named_parameters()}

        # SWAG-diagonal
        for name, param in current_params.items():
            # TODO(1): update SWAG-diagonal attributes for weight `name` using `current_params` and `param`
            self.w_swa[name] = (self.n * self.w_swa[name] + param)/(self.n + 1)
            self.w2_swa[name] = (self.n * self.w2_swa[name] + param*param)/(self.n + 1)
            # raise NotImplementedError("Update SWAG-diagonal statistics")

        # Full SWAG
        if self.inference_mode == InferenceMode.SWAG_FULL:
            # TODO(2): update full SWAG attributes for weight `name` using `current_params` and `param`
            raise NotImplementedError("Update full SWAG statistics")

    def fit_swag(self, loader: torch.utils.data.DataLoader) -> None:
        """
        Fit SWAG on top of the pretrained network self.network.
        This method should perform gradient descent with occasional SWAG updates
        by calling self.update_swag().
        """

        # We use SGD with momentum and weight decay to perform SWA.
        # See the paper on how weight decay corresponds to a type of prior.
        # Feel free to play around with optimization hyperparameters.
        optimizer = torch.optim.SGD(
            self.network.parameters(),
            lr=self.swag_learning_rate,
            momentum=0.9,
            nesterov=False,
            weight_decay=1e-4,
        )
        loss = torch.nn.CrossEntropyLoss(
            reduction="mean",
        )
        # TODO(2): Update SWAGScheduler instantiation if you decided to implement a custom schedule.
        #  By default, this scheduler just keeps the initial learning rate given to `optimizer`.
        lr_scheduler = SWAGScheduler(
            optimizer,
            epochs=self.swag_epochs,
            steps_per_epoch=len(loader),
        )

        # TODO(1): Perform initialization for SWAG fitting
        self.update_swag()
        # raise NotImplementedError("Initialize SWAG fitting")

        self.network.train()
        with tqdm.trange(self.swag_epochs, desc="Running gradient descent for SWA") as pbar:
            pbar_dict = {}
            for epoch in pbar:
                average_loss = 0.0
                average_accuracy = 0.0
                num_samples_processed = 0
                for batch_xs, batch_is_snow, batch_is_cloud, batch_ys in loader:
                    optimizer.zero_grad()
                    pred_ys = self.network(batch_xs)
                    batch_loss = loss(input=pred_ys, target=batch_ys)
                    batch_loss.backward()
                    optimizer.step()
                    pbar_dict["lr"] = lr_scheduler.get_last_lr()[0]
                    lr_scheduler.step()

                    # Calculate cumulative average training loss and accuracy
                    average_loss = (batch_xs.size(0) * batch_loss.item() + num_samples_processed * average_loss) / (
                        num_samples_processed + batch_xs.size(0)
                    )
                    average_accuracy = (
                        torch.sum(pred_ys.argmax(dim=-1) == batch_ys).item()
                        + num_samples_processed * average_accuracy
                    ) / (num_samples_processed + batch_xs.size(0))
                    num_samples_processed += batch_xs.size(0)
                    pbar_dict["avg. epoch loss"] = average_loss
                    pbar_dict["avg. epoch accuracy"] = average_accuracy
                    pbar.set_postfix(pbar_dict)

                # TODO(1): Implement periodic SWAG updates using the attributes defined in __init__
                if epoch % self.swag_update_freq == 0:
                    self.n = epoch/self.swag_update_freq
                    self.update_swag()
                # raise NotImplementedError("Periodically update SWAG statistics")

    def calibrate(self, validation_data: torch.utils.data.Dataset) -> None:
        """
        Calibrate your predictions using a small validation set.
        validation_data contains well-defined and ambiguous samples,
        where you can identify the latter by having label -1.
        """
        if self.inference_mode == InferenceMode.MAP:
            # In MAP mode, simply predict argmax and do nothing else
            self._prediction_threshold = 0.0
            return

        # TODO(1): pick a prediction threshold, either constant or adaptive.
        #  The provided value should suffice to pass the easy baseline.
        self._prediction_threshold = 2.0 / 3.0

        # TODO(2): perform additional calibration if desired.
        #  Feel free to remove or change the prediction threshold.
        val_xs, val_is_snow, val_is_cloud, val_ys = validation_data.tensors
        assert val_xs.size() == (140, 3, 60, 60)  # N x C x H x W
        assert val_ys.size() == (140,)
        assert val_is_snow.size() == (140,)
        assert val_is_cloud.size() == (140,)

    def predict_probabilities_swag(self, loader: torch.utils.data.DataLoader) -> torch.Tensor:
        """
        Perform Bayesian model averaging using your SWAG statistics and predict
        probabilities for all samples in the loader.
        Outputs should be a Nx6 tensor, where N is the number of samples in loader,
        and all rows of the output should sum to 1.
        That is, output row i column j should be your predicted p(y=j | x_i).
        """

        self.network.eval()

        # Perform Bayesian model averaging:
        # Instead of sampling self.bma_samples networks (using self.sample_parameters())
        # for each datapoint, you can save time by sampling self.bma_samples networks,
        # and perform inference with each network on all samples in loader.
        per_model_sample_predictions = []
        for _ in tqdm.trange(self.bma_samples, desc="Performing Bayesian model averaging"):
            # TODO(1): Sample new parameters for self.network from the SWAG approximate posterior
            self.sample_parameters()
            # raise NotImplementedError("Sample network parameters")

            # TODO(1): Perform inference for all samples in `loader` using current model sample,
            #  and add the predictions to per_model_sample_predictions
            predictions = []
            for (batch_xs,) in loader:
                predictions.append(self.network(batch_xs))
            
            predictions = torch.cat(predictions)
            
            per_model_sample_predictions.append(predictions.unsqueeze(0))
            # raise NotImplementedError("Perform inference using current model")

        assert len(per_model_sample_predictions) == self.bma_samples
        assert all(
            isinstance(model_sample_predictions, torch.Tensor)
            and model_sample_predictions.dim() == 2  # N x C
            and model_sample_predictions.size(1) == 6
            for model_sample_predictions in per_model_sample_predictions
        )

        # TODO(1): Average predictions from different model samples into bma_probabilities
        # raise NotImplementedError("Aggregate predictions from model samples")
        bma_logits = torch.cat(per_model_sample_predictions, 0)
        bma_logits = torch.mean(bma_logits, 0)
        bma_probabilities = torch.softmax(bma_logits, dim=-1)

        assert bma_probabilities.dim() == 2 and bma_probabilities.size(1) == 6  # N x C
        return bma_probabilities

    def sample_parameters(self) -> None:
        """
        Sample a new network from the approximate SWAG posterior.
        For simplicity, this method directly modifies self.network in-place.
        Hence, after calling this method, self.network corresponds to a new posterior sample.
        """

        # Instead of acting on a full vector of parameters, all operations can be done on per-layer parameters.
        net = self.network.state_dict()

        for name, param in self.network.named_parameters():
            # SWAG-diagonal part
            z_1 = torch.randn(param.size())
            # TODO(1): Sample parameter values for SWAG-diagonal
            # raise NotImplementedError("Sample parameter for SWAG-diagonal")
            
            current_mean = self.w_swa[name]
            current_std = torch.sqrt(self.w2_swa[name] - self.w_swa[name]**2)
            
            assert current_mean.size() == param.size() and current_std.size() == param.size()

            # Diagonal part
            sampled_param = current_mean + current_std * z_1

            # Full SWAG part
            if self.inference_mode == InferenceMode.SWAG_FULL:
                # TODO(2): Sample parameter values for full SWAG
                raise NotImplementedError("Sample parameter for full SWAG")
                sampled_param += ...

            # Modify weight value in-place; directly changing self.network
            param.data = sampled_param

            net[name] = sampled_param

        # TODO(1): Don't forget to update batch normalization statistics using self._update_batchnorm()
        #  in the appropriate place!
        self.network.load_state_dict(net)
        self._update_batchnorm()

        # raise NotImplementedError("Update batch normalization statistics for newly sampled network")

    def predict_labels(self, predicted_probabilities: torch.Tensor) -> torch.Tensor:
        """
        Predict labels in {0, 1, 2, 3, 4, 5} or "don't know" as -1
        based on your model's predicted probabilities.
        The parameter predicted_probabilities is an Nx6 tensor containing predicted probabilities
        as returned by predict_probabilities(...).
        The output should be a N-dimensional long tensor, containing values in {-1, 0, 1, 2, 3, 4, 5}.
        """

        # label_probabilities contains the per-row maximum values in predicted_probabilities,
        # max_likelihood_labels the corresponding column index (equivalent to class).
        label_probabilities, max_likelihood_labels = torch.max(predicted_probabilities, dim=-1)
        num_samples, num_classes = predicted_probabilities.size()
        assert label_probabilities.size() == (num_samples,) and max_likelihood_labels.size() == (num_samples,)

        # A model without uncertainty awareness might simply predict the most likely label per sample:
        # return max_likelihood_labels

        # A bit better: use a threshold to decide whether to return a label or "don't know" (label -1)
        # TODO(2): implement a different decision rule if desired
        return torch.where(
            label_probabilities >= self._prediction_threshold,
            max_likelihood_labels,
            torch.ones_like(max_likelihood_labels) * -1,
        )

    def _create_weight_copy(self) -> typing.Dict[str, torch.Tensor]:
        """Create an all-zero copy of the network weights as a dictionary that maps name -> weight"""
        return {
            name: torch.zeros_like(param, requires_grad=False)
            for name, param in self.network.named_parameters()
        }

    def fit(
        self,
        loader: torch.utils.data.DataLoader,
    ) -> None:
        """
        Perform full SWAG fitting procedure.
        If `PRETRAINED_WEIGHTS_FILE` is `True`, this method skips the MAP inference part,
        and uses pretrained weights instead.

        Note that MAP inference can take a very long time.
        You should hence only perform MAP inference yourself after passing the hard baseline
        using the given CNN architecture and pretrained weights.
        """

        # MAP inference to obtain initial weights
        PRETRAINED_WEIGHTS_FILE = self.model_dir / "map_weights.pt"
        if USE_PRETRAINED_INIT:
            self.network.load_state_dict(torch.load(PRETRAINED_WEIGHTS_FILE))
            print("Loaded pretrained MAP weights from", PRETRAINED_WEIGHTS_FILE)
        else:
            self.fit_map(loader)

        # SWAG
        if self.inference_mode in (InferenceMode.SWAG_DIAGONAL, InferenceMode.SWAG_FULL):
            self.fit_swag(loader)

    def fit_map(self, loader: torch.utils.data.DataLoader) -> None:
        """
        MAP inference procedure to obtain initial weights of self.network.
        This is the exact procedure that was used to obtain the pretrained weights we provide.
        """
        map_epochs = 140
        initial_lr = 0.01
        decayed_lr = 0.0001
        decay_start_epoch = 50
        decay_factor = decayed_lr / initial_lr

        # Create optimizer, loss, and a learning rate scheduler that aids convergence
        optimizer = torch.optim.SGD(
            self.network.parameters(),
            lr=initial_lr,
            momentum=0.9,
            nesterov=False,
            weight_decay=1e-4,
        )
        loss = torch.nn.CrossEntropyLoss(
            reduction="mean",
        )
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            [
                torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0),
                torch.optim.lr_scheduler.LinearLR(
                    optimizer,
                    start_factor=1.0,
                    end_factor=decay_factor,
                    total_iters=(map_epochs - decay_start_epoch) * len(loader),
                ),
            ],
            milestones=[decay_start_epoch * len(loader)],
        )

        # Put network into training mode
        # Batch normalization layers are only updated if the network is in training mode,
        # and are replaced by a moving average if the network is in evaluation mode.
        self.network.train()
        with tqdm.trange(map_epochs, desc="Fitting initial MAP weights") as pbar:
            pbar_dict = {}
            # Perform the specified number of MAP epochs
            for epoch in pbar:
                average_loss = 0.0
                average_accuracy = 0.0
                num_samples_processed = 0
                # Iterate over batches of randomly shuffled training data
                for batch_xs, _, _, batch_ys in loader:
                    # Training step
                    optimizer.zero_grad()
                    pred_ys = self.network(batch_xs)
                    batch_loss = loss(input=pred_ys, target=batch_ys)
                    batch_loss.backward()
                    optimizer.step()

                    # Save learning rate that was used for step, and calculate new one
                    pbar_dict["lr"] = lr_scheduler.get_last_lr()[0]
                    with warnings.catch_warnings():
                        # Suppress annoying warning (that we cannot control) inside PyTorch
                        warnings.simplefilter("ignore")
                        lr_scheduler.step()

                    # Calculate cumulative average training loss and accuracy
                    average_loss = (batch_xs.size(0) * batch_loss.item() + num_samples_processed * average_loss) / (
                        num_samples_processed + batch_xs.size(0)
                    )
                    average_accuracy = (
                        torch.sum(pred_ys.argmax(dim=-1) == batch_ys).item()
                        + num_samples_processed * average_accuracy
                    ) / (num_samples_processed + batch_xs.size(0))
                    num_samples_processed += batch_xs.size(0)

                    pbar_dict["avg. epoch loss"] = average_loss
                    pbar_dict["avg. epoch accuracy"] = average_accuracy
                    pbar.set_postfix(pbar_dict)

    def predict_probabilities(self, xs: torch.Tensor) -> torch.Tensor:
        """
        Predict class probabilities for the given images xs.
        This method returns an NxC float tensor,
        where row i column j corresponds to the probability that y_i is class j.

        This method uses different strategies depending on self.inference_mode.
        """
        self.network = self.network.eval()

        # Create a loader that we can deterministically iterate many times if necessary
        loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(xs),
            batch_size=32,
            shuffle=False,
            num_workers=0,
            drop_last=False,
        )

        with torch.no_grad():  # save memory by not tracking gradients
            if self.inference_mode == InferenceMode.MAP:
                return self.predict_probabilities_map(loader)
            else:
                return self.predict_probabilities_swag(loader)

    def predict_probabilities_map(self, loader: torch.utils.data.DataLoader) -> torch.Tensor:
        """
        Predict probabilities assuming that self.network is a MAP estimate.
        This simply performs a forward pass for every batch in `loader`,
        concatenates all results, and applies a row-wise softmax.
        """
        predictions = []
        for (batch_xs,) in loader:
            predictions.append(self.network(batch_xs))

        predictions = torch.cat(predictions)
        return torch.softmax(predictions, dim=-1)

    def _update_batchnorm(self) -> None:
        """
        Reset and fit batch normalization statistics using the training dataset self.train_dataset.
        We provide this method for you for convenience.
        See the SWAG paper for why this is required.

        Batch normalization usually uses an exponential moving average, controlled by the `momentum` parameter.
        However, we are not training but want the statistics for the full training dataset.
        Hence, setting `momentum` to `None` tracks a cumulative average instead.
        The following code stores original `momentum` values, sets all to `None`,
        and restores the previous hyperparameters after updating batchnorm statistics.
        """

        old_momentum_parameters = dict()
        for module in self.network.modules():
            # Only need to handle batchnorm modules
            if not isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                continue

            # Store old momentum value before removing it
            old_momentum_parameters[module] = module.momentum
            module.momentum = None

            # Reset batch normalization statistics
            module.reset_running_stats()

        loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=0,
            drop_last=False,
        )

        self.network.train()
        for (batch_xs,) in loader:
            self.network(batch_xs)
        self.network.eval()

        # Restore old `momentum` hyperparameter values
        for module, momentum in old_momentum_parameters.items():
            module.momentum = momentum

In [82]:
class SWAGScheduler(torch.optim.lr_scheduler.LRScheduler):
    """
    Custom learning rate scheduler that calculates a different learning rate each gradient descent step.
    The default implementation keeps the original learning rate constant, i.e., does nothing.
    You can implement a custom schedule inside calculate_lr,
    and add+store additional attributes in __init__.
    You should not change any other parts of this class.
    """

    def calculate_lr(self, current_epoch: float, old_lr: float) -> float:
        """
        Calculate the learning rate for the epoch given by current_epoch.
        current_epoch is the fractional epoch of SWA fitting, starting at 0.
        That is, an integer value x indicates the start of epoch (x+1),
        and non-integer values x.y correspond to steps in between epochs (x+1) and (x+2).
        old_lr is the previous learning rate.

        This method should return a single float: the new learning rate.
        """
        # TODO(2): Implement a custom schedule if desired
        return old_lr

    # TODO(2): Add and store additional arguments if you decide to implement a custom scheduler
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        epochs: int,
        steps_per_epoch: int,
    ):
        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        super().__init__(optimizer, last_epoch=-1, verbose=False)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
            )
        return [
            self.calculate_lr(self.last_epoch / self.steps_per_epoch, group["lr"])
            for group in self.optimizer.param_groups
        ]


def evaluate(
    swag: SWAGInference,
    eval_dataset: torch.utils.data.Dataset,
    extended_evaluation: bool,
    output_dir: pathlib.Path,
) -> None:
    """
    Evaluate your model.
    Feel free to change or extend this code.
    :param swag: Trained model to evaluate
    :param eval_dataset: Validation dataset
    :param: extended_evaluation: If True, generates additional plots
    :param output_dir: Directory into which extended evaluation plots are saved
    """

    print("Evaluating model on validation data")

    # We ignore is_snow and is_cloud here, but feel free to use them as well
    xs, is_snow, is_cloud, ys = eval_dataset.tensors

    # Predict class probabilities on test data,
    # most likely classes (according to the max predicted probability),
    # and classes as predicted by your SWAG implementation.
    pred_prob_all = swag.predict_probabilities(xs)
    pred_prob_max, pred_ys_argmax = torch.max(pred_prob_all, dim=-1)
    pred_ys = swag.predict_labels(pred_prob_all)

    # Create a mask that ignores ambiguous samples (those with class -1)
    nonambiguous_mask = ys != -1

    # Calculate three kinds of accuracy:
    # 1. Overall accuracy, counting "don't know" (-1) as its own class
    # 2. Accuracy on all samples that have a known label. Predicting -1 on those counts as wrong here.
    # 3. Accuracy on all samples that have a known label w.r.t. the class with the highest predicted probability.
    accuracy = torch.mean((pred_ys == ys).float()).item()
    accuracy_nonambiguous = torch.mean((pred_ys[nonambiguous_mask] == ys[nonambiguous_mask]).float()).item()
    accuracy_nonambiguous_argmax = torch.mean(
        (pred_ys_argmax[nonambiguous_mask] == ys[nonambiguous_mask]).float()
    ).item()
    print(f"Accuracy (raw): {accuracy:.4f}")
    print(f"Accuracy (non-ambiguous only, your predictions): {accuracy_nonambiguous:.4f}")
    print(f"Accuracy (non-ambiguous only, predicting most-likely class): {accuracy_nonambiguous_argmax:.4f}")

    # Determine which threshold would yield the smallest cost on the validation data
    # Note that this threshold does not necessarily generalize to the test set!
    # However, it can help you judge your method's calibration.
    thresholds = [0.0] + list(torch.unique(pred_prob_max, sorted=True))
    costs = []
    for threshold in thresholds:
        thresholded_ys = torch.where(pred_prob_max <= threshold, -1 * torch.ones_like(pred_ys), pred_ys)
        costs.append(cost_function(thresholded_ys, ys).item())
    best_idx = np.argmin(costs)
    print(f"Best cost {costs[best_idx]} at threshold {thresholds[best_idx]}")
    print("Note that this threshold does not necessarily generalize to the test set!")

    # Calculate ECE and plot the calibration curve
    calibration_data = calc_calibration_curve(pred_prob_all.numpy(), ys.numpy(), num_bins=20)
    print("Validation ECE:", calibration_data["ece"])

    if extended_evaluation:
        print("Plotting reliability diagram")
        fig = draw_reliability_diagram(calibration_data)
        fig.savefig(output_dir / "reliability_diagram.pdf")

        sorted_confidence_indices = torch.argsort(pred_prob_max)

        # Plot samples your model is most confident about
        print("Plotting most confident validation set predictions")
        most_confident_indices = sorted_confidence_indices[-10:]
        fig, ax = plt.subplots(4, 5, figsize=(13, 11))
        for row in range(0, 4, 2):
            for col in range(5):
                sample_idx = most_confident_indices[5 * row // 2 + col]
                ax[row, col].imshow(xs[sample_idx].permute(1, 2, 0).numpy())
                ax[row, col].set_axis_off()
                ax[row + 1, col].set_title(f"pred. {pred_ys[sample_idx]}, true {ys[sample_idx]}")
                bar_colors = ["C0"] * 6
                if ys[sample_idx] >= 0:
                    bar_colors[ys[sample_idx]] = "C1"
                ax[row + 1, col].bar(
                    np.arange(6), pred_prob_all[sample_idx].numpy(), tick_label=np.arange(6), color=bar_colors
                )
        fig.suptitle("Most confident predictions", size=20)
        fig.savefig(output_dir / "examples_most_confident.pdf")

        # Plot samples your model is least confident about
        print("Plotting least confident validation set predictions")
        least_confident_indices = sorted_confidence_indices[:10]
        fig, ax = plt.subplots(4, 5, figsize=(13, 11))
        for row in range(0, 4, 2):
            for col in range(5):
                sample_idx = least_confident_indices[5 * row // 2 + col]
                ax[row, col].imshow(xs[sample_idx].permute(1, 2, 0).numpy())
                ax[row, col].set_axis_off()
                ax[row + 1, col].set_title(f"pred. {pred_ys[sample_idx]}, true {ys[sample_idx]}")
                bar_colors = ["C0"] * 6
                if ys[sample_idx] >= 0:
                    bar_colors[ys[sample_idx]] = "C1"
                ax[row + 1, col].bar(
                    np.arange(6), pred_prob_all[sample_idx].numpy(), tick_label=np.arange(6), color=bar_colors
                )
        fig.suptitle("Least confident predictions", size=20)
        fig.savefig(output_dir / "examples_least_confident.pdf")


In [83]:
# Fix all randomness
setup_seeds()

# Build and run the actual solution
train_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=16,
    shuffle=True,
    num_workers=0,
)

swag = SWAGInference(
    train_xs=dataset_train.tensors[0],
    model_dir=model_dir,
)
# swag.fit(train_loader)
swag.calibrate(dataset_val)

<torch.utils.data.dataset.TensorDataset at 0x7f9fc9c133d0>

In [87]:
with torch.random.fork_rng():
    evaluate(swag, dataset_val, EXTENDED_EVALUATION, output_dir)

Evaluating model on validation data


Performing Bayesian model averaging: 100%|█████| 30/30 [03:18<00:00,  6.62s/it]


AssertionError: 

In [93]:
swag.train_dataset

<torch.utils.data.dataset.TensorDataset at 0x7f9fba5b2400>

In [94]:
loader = torch.utils.data.DataLoader(
    swag.train_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    drop_last=False,
)

In [96]:
swag.bma_samples = 2

In [120]:
swag.network.eval()

# Perform Bayesian model averaging:
# Instead of sampling self.bma_samples networks (using self.sample_parameters())
# for each datapoint, you can save time by sampling self.bma_samples networks,
# and perform inference with each network on all samples in loader.
per_model_sample_predictions = []
for _ in tqdm.trange(swag.bma_samples, desc="Performing Bayesian model averaging"):
    # TODO(1): Sample new parameters for self.network from the SWAG approximate posterior
    swag.sample_parameters()
    # raise NotImplementedError("Sample network parameters")

    # TODO(1): Perform inference for all samples in `loader` using current model sample,
    #  and add the predictions to per_model_sample_predictions
    predictions = []
    for (batch_xs,) in loader:
        predictions.append(swag.network(batch_xs))
    
    predictions = torch.cat(predictions)
    
    per_model_sample_predictions.append(predictions)
    # raise NotImplementedError("Perform inference using current model")

assert len(per_model_sample_predictions) == swag.bma_samples
assert all(
    isinstance(model_sample_predictions, torch.Tensor)
    and model_sample_predictions.dim() == 2  # N x C
    and model_sample_predictions.size(1) == 6
    for model_sample_predictions in per_model_sample_predictions
)

per_model_sample_predictions = [p.unsqueeze(0) for p in per_model_sample_predictions]

# TODO(1): Average predictions from different model samples into bma_probabilities
# raise NotImplementedError("Aggregate predictions from model samples")
bma_logits = torch.cat(per_model_sample_predictions, 0)
bma_logits = torch.mean(bma_logits, 0)
bma_probabilities = torch.softmax(bma_logits, dim=-1)

assert bma_probabilities.dim() == 2 and bma_probabilities.size(1) == 6  # N x C
# return bma_probabilities

Performing Bayesian model averaging: 100%|███████| 2/2 [00:29<00:00, 14.80s/it]


In [117]:
len(per_model_sample_predictions)

2

In [119]:
per_model_sample_predictions[0].shape

torch.Size([1, 1800, 6])

In [115]:
assert all(
    isinstance(model_sample_predictions, torch.Tensor)
    and model_sample_predictions.dim() == 2  # N x C
    and model_sample_predictions.size(1) == 6
    for model_sample_predictions in per_model_sample_predictions
)

AssertionError: 

In [116]:
model_sample_predictions

NameError: name 'model_sample_predictions' is not defined

In [100]:
per_model_sample_predictions[0].shape

torch.Size([1, 1800, 6])

In [103]:
bma_logits = torch.cat(per_model_sample_predictions, 0)

In [104]:
bma_logits = torch.mean(bma_logits, 0)

In [105]:
bma_logits.shape

torch.Size([1800, 6])

In [121]:
torch.softmax(bma_logits, dim=-1).shape

torch.Size([1800, 6])

In [122]:
torch.softmax(bma_logits, dim=-1)

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward0>)

In [110]:
bma_probabilities.shape

NameError: name 'bma_probabilities' is not defined

In [67]:
current_mean

tensor([0., 0., 0., 0., 0., 0.])

In [68]:
current_std

tensor([0., 0., 0., 0., 0., 0.])

In [71]:
# TODO(1): Don't forget to update batch normalization statistics using self._update_batchnorm()
#  in the appropriate place!
net = swag.network.state_dict()

In [73]:
net.keys()

odict_keys(['layer0.0.weight', 'layer0.0.bias', 'layer0.1.weight', 'layer0.1.bias', 'layer0.1.running_mean', 'layer0.1.running_var', 'layer0.1.num_batches_tracked', 'layer1.0.weight', 'layer1.0.bias', 'layer1.1.weight', 'layer1.1.bias', 'layer1.1.running_mean', 'layer1.1.running_var', 'layer1.1.num_batches_tracked', 'layer2.0.weight', 'layer2.0.bias', 'layer2.1.weight', 'layer2.1.bias', 'layer2.1.running_mean', 'layer2.1.running_var', 'layer2.1.num_batches_tracked', 'layer3.0.weight', 'layer3.0.bias', 'layer3.1.weight', 'layer3.1.bias', 'layer3.1.running_mean', 'layer3.1.running_var', 'layer3.1.num_batches_tracked', 'layer4.0.weight', 'layer4.0.bias', 'layer4.1.weight', 'layer4.1.bias', 'layer4.1.running_mean', 'layer4.1.running_var', 'layer4.1.num_batches_tracked', 'layer5.0.weight', 'layer5.0.bias', 'linear.weight', 'linear.bias'])

In [74]:
sampled_param

tensor([0., 0., 0., 0., 0., 0.])

In [76]:
sampled_param.shape

torch.Size([6])

In [77]:
net

OrderedDict([('layer0.0.weight',
              tensor([[[[0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.]],
              
                       [[0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.]],
              
                       [[0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.]]],
              
              
                      [[[0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0.],
                        [0., 

In [75]:
net.update(sampled_param)

TypeError: iteration over a 0-d tensor

In [None]:


swag.network.load_state_dict(net)
swag._update_batchnorm()

In [123]:
3/4

0.75

In [124]:
4/2

2.0

In [128]:
with tqdm.trange(swag.swag_epochs, desc="Running gradient descent for SWA") as pbar:
    pbar_dict = {}
    for epoch in pbar:
        break

Running gradient descent for SWA:   0%|                 | 0/30 [00:00<?, ?it/s]


In [129]:
epoch

0

In [130]:
pbar

<tqdm.std.tqdm at 0x7f9e4398baf0>