In [None]:
!rm -rf sngp_wrapper

This is custom implementation of spectral normalization and gaussian process by laplace approximization.




In [None]:
!git clone https://github.com/iamownt/sngp_wrapper.git
!mv sngp_wrapper/sngp_wrapper/* sngp_wrapper/
!rm -r sngp_wrapper/sngp_wrapper
! pip install timm

Cloning into 'sngp_wrapper'...
remote: Enumerating objects: 27, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 27 (delta 5), reused 24 (delta 5), pack-reused 0 (from 0)[K
Receiving objects: 100% (27/27), 18.56 KiB | 633.00 KiB/s, done.
Resolving deltas: 100% (5/5), done.


In [None]:
import timm
import torch.nn as nn
import torch
import torchvision
from sngp_wrapper.covert_utils import convert_to_sn_my, replace_layer_with_gaussian


class ConvNextTinyGP(nn.Module): # hate sn
    def __init__(self, num_classes: int):
        super(ConvNextTinyGP, self).__init__()
        feature_extractor = torchvision.models.convnext_tiny(weights="ConvNeXt_Tiny_Weights.IMAGENET1K_V1")
        feature_extractor.classifier = nn.Identity()
        self.feature_extractor = feature_extractor
        self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
        self.classifier = nn.Linear(768, num_classes) # please determine 768 by the classifier/head of the model

    def forward(self, x, kwargs):
        features = self.flatten(self.feature_extractor(x))
        output = self.classifier(features, **kwargs)
        return output

model = ConvNextTinyGP(num_classes=1000)
print("parameters before conversion", sum(p.numel() for p in model.parameters()))
sigma_reparam_model = convert_to_sn_my(model, spec_norm_replace_list=["Linear", "Conv2D"], spec_norm_bound=2.)
print("parameters after conversion", sum(p.numel() for p in sigma_reparam_model.parameters()))
# print(sigma_reparam_model)

parameters before conversion 28587592
parameters after conversion 28587592


In [None]:
GP_KWARGS = {
    'num_inducing': 2048,
    'gp_scale': 1.0,
    'gp_bias': 0.,
    'gp_kernel_type': 'gaussian', # 'linear'
    'gp_input_normalization': True,
    'gp_cov_discount_factor': -1,
    'gp_cov_ridge_penalty': 1.,
    'gp_output_bias_trainable': False,
    'gp_scale_random_features': False,
    'gp_use_custom_random_features': True,
    'gp_random_feature_type': 'orf',
    'gp_output_imagenet_initializer': True,
    'num_classes': 1000,
}
replace_layer_with_gaussian(container=sigma_reparam_model, signature="classifier", **GP_KWARGS)

Model is equipped with gaussian process (laplace approximation)

In [None]:
print(model)

ConvNextTinyGP(
  (feature_extractor): ConvNeXt(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
      )
      (1): Sequential(
        (0): CNBlock(
          (block): Sequential(
            (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (1): Permute()
            (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
            (3): ParametrizedLinear(
              in_features=96, out_features=384, bias=True
              (parametrizations): ModuleDict(
                (weight): ParametrizationList(
                  (0): _SpectralNorm()
                )
              )
            )
            (4): GELU(approximate='none')
            (5): ParametrizedLinear(
              in_features=384, out_features=96, bias=True
              (parametrizations): ModuleDict(
                (weight): P

In [None]:
kwargs = {"return_random_features": False, "return_covariance": False,
          "update_precision_matrix": False, "update_covariance_matrix": False}
output = sigma_reparam_model(torch.randn(10, 3, 224, 224), kwargs)
print(output)

tensor([[-0.1015, -0.0750,  0.0922,  ..., -0.0538, -0.2242,  0.3168],
        [-0.1016, -0.0750,  0.0922,  ..., -0.0538, -0.2243,  0.3166],
        [-0.1046, -0.0723,  0.0917,  ..., -0.0573, -0.2240,  0.3168],
        ...,
        [-0.1093, -0.0801,  0.1043,  ..., -0.0615, -0.2098,  0.3107],
        [-0.0904, -0.0695,  0.0800,  ..., -0.0533, -0.2270,  0.3220],
        [-0.1017, -0.0745,  0.0919,  ..., -0.0537, -0.2243,  0.3167]],
       grad_fn=<AddBackward0>)


**Simple Example**

In [None]:
ind_data = torch.randn(10, 3, 224, 224)
ood_data = torch.randn(10, 3, 224, 224) + 1

for _ in range(10):
    sigma_reparam_model(ind_data, {"update_precision_matrix": True}) # we remember the in-domain data
sigma_reparam_model.classifier.update_covariance_matrix()

ind_output = sigma_reparam_model(ind_data, {"update_precision_matrix": False, "return_covariance": True,})
ood_output = sigma_reparam_model(ood_data, {"update_precision_matrix": False, "return_covariance": True,})
ind_prob, ind_cov = ind_output
ood_prob, ood_cov = ood_output


we see significant difference of uncertainty mean value between ind and ood data

In [None]:
ind_uncertainty = torch.diagonal(ind_cov, 0)
ood_uncertainty = torch.diagonal(ood_cov, 0)
print("ind_uncertainty", ind_uncertainty, "ind mean", torch.mean(ind_uncertainty))
print("ood_uncertainty", ood_uncertainty, "ood mean", torch.mean(ood_uncertainty))

ind_uncertainty tensor([0.0072, 0.0085, 0.0122, 0.0073, 0.1083, 0.0086, 0.0073, 0.0476, 0.0072,
        0.0071], grad_fn=<DiagonalBackward0>) ind mean tensor(0.0221, grad_fn=<MeanBackward0>)
ood_uncertainty tensor([0.0148, 0.0148, 0.0147, 0.0148, 0.0151, 0.0147, 0.0506, 0.0147, 0.0525,
        0.0224], grad_fn=<DiagonalBackward0>) ood mean tensor(0.0229, grad_fn=<MeanBackward0>)


**Important Notes for users:**

1. the rff-gp is implemented based on tfm.nlp.layers.RandomFeatureGaussianProcess, i have test several foundation models across 0.5B, 1B (or inception architecture), things work well.
2. in the forward process, you can set update_precision_matrix to update the precision matrix (default True)
3. remember you should use model.classifier.update_covariance_matrix() **once** when you want to evaluate the model with uncertainty quantification ability.
4. remember you should use model.classifier.reset_covariance_matrix() **at the beginning of each epoch**
5. when you want to test ood ability, set return_covariance=True and you will get the covariance matrix of each input, then you can take the diagnal point as the uncertainty value.
6. when you struggle to tune the hyperparameter, please note that when you set higher spectral norm value, the bound gets loose the may not affect the model. And you can safely change the gaussian kernel to linear kernel, the model will recover to the original model.
7. we can absorb the sn module in the inference stage to accelerate the inference speed, but i haven't implemented in this repository (maybe in others)