# Incremental Unified Framework (IUF) Testing

We want to test each component and module in isolation to make sure it's working properly

_TODO:_
1. Go through ViT code and understand unique implementation
2. Update to include modules for creating Discriminator, Encoder, & Decoder
3. Check for paper/code for architecture or hyperparams match our implementation (num_heads, dim, etc)
4. Latent Space Regularization
5. Gradient Update Regularization


In [1]:
import einops as ein
import torch
import torch.nn as nn
import torch.nn.functional as F
from Methods import BaseAnomalyDetector
from Methods.IUF.iuf import IUF_Model, IUF_Loss
from Methods.IUF.ViT import MultiHeadSelfAttention, ViTBlock, ViT
from Methods.IUF.utils.discriminator import Discriminator
from Methods.IUF.utils.encoder import Encoder
from Methods.IUF.utils.decoder import Decoder

In [136]:
# Starting to put together an IUF pipeline for testing
# TODO:
# - Create train_one_epoch() and eval_one_epoch(), based on BaseAnomalyDetector
# - Set up training and eval experiments
#    - Check whether the recon_loss works better with .sum() or .mean()
# - Change results to testing all tasks after final task training



iuf = IUF()
dummy = torch.rand(8, 3, 224, 224)
x_recon, discrim_out, s_vals = iuf(dummy)

loss = IUF_Loss(x=dummy,
                x_recon=x_recon,
                singular_vals=s_vals,
                t=3,
                discrim_output=discrim_out,
                task_idx=torch.ones(8, dtype=torch.long))
iuf.update_grad(loss)

Using cuda device
Reconstruction error:  tensor(624796.8125, grad_fn=<SumBackward0>)
Discriminator error:  tensor(21.9315, grad_fn=<NllLossBackward0>)
Singular Value error:  tensor(1.4867, grad_fn=<SumBackward0>)


In [133]:
iuf.get_parameter('decoder.conv1.weight').grad

In [32]:
loss.backward()

In [46]:
for p in iuf.named_parameters():
    print(p[0], p[1].shape, sep="\n")
    print()

discriminator.pos_embedding
torch.Size([1, 196, 64])

discriminator.conv1.weight
torch.Size([64, 3, 16, 16])

discriminator.conv1.bias
torch.Size([64])

discriminator.ln1.weight
torch.Size([64])

discriminator.ln1.bias
torch.Size([64])

discriminator.vit_blocks.0.norm1.weight
torch.Size([64])

discriminator.vit_blocks.0.norm1.bias
torch.Size([64])

discriminator.vit_blocks.0.MHSA.query.weight
torch.Size([64, 64])

discriminator.vit_blocks.0.MHSA.query.bias
torch.Size([64])

discriminator.vit_blocks.0.MHSA.key.weight
torch.Size([64, 64])

discriminator.vit_blocks.0.MHSA.key.bias
torch.Size([64])

discriminator.vit_blocks.0.MHSA.value.weight
torch.Size([64, 64])

discriminator.vit_blocks.0.MHSA.value.bias
torch.Size([64])

discriminator.vit_blocks.0.MHSA.out_proj.weight
torch.Size([64, 64])

discriminator.vit_blocks.0.MHSA.out_proj.bias
torch.Size([64])

discriminator.vit_blocks.0.norm2.weight
torch.Size([64])

discriminator.vit_blocks.0.norm2.bias
torch.Size([64])

discriminator.vit_blo

In [114]:

print(p.grad.shape, omega.shape, v.shape, t3.shape, sep="\n")

torch.Size([64, 64])
torch.Size([64])
torch.Size([64, 64])
torch.Size([64, 64])


In [115]:
t2.shape

torch.Size([64, 64])

In [117]:
v.inverse() @ t2

tensor([[ 0.0000e+00,  1.4461e-01, -4.9813e+00,  ...,  1.9231e+00,
          2.8157e+01,  9.8665e-01],
        [ 0.0000e+00, -2.4274e-01,  1.1241e+01,  ..., -3.7909e+00,
         -7.2293e+01, -6.6403e-01],
        [ 0.0000e+00,  1.6958e-01, -8.1958e+00,  ...,  2.6518e+00,
          5.3161e+01,  2.9080e-01],
        ...,
        [ 0.0000e+00, -6.8746e-02,  4.1377e-01,  ...,  3.0917e-01,
          1.5057e+00, -1.6652e-01],
        [ 0.0000e+00, -1.7945e-01,  6.3121e-01,  ..., -2.6570e-01,
          3.1411e+00, -3.0540e+00],
        [ 0.0000e+00,  8.1898e-02,  3.7128e-04,  ...,  3.7983e-01,
         -2.4713e+00,  2.1231e+00]], grad_fn=<MmBackward0>)

In [90]:
dict(iuf.named_parameters())

215

In [91]:
iuf.named_parameters().data

AttributeError: 'generator' object has no attribute 'data'

In [134]:
x_recon

tensor([[[[ 8.6983e-02, -1.0990e-01, -3.8556e-02,  ...,  2.4592e-02,
           -1.4347e-01, -1.9034e-01],
          [ 4.8430e-02,  4.0400e-03, -1.0452e-02,  ..., -1.1076e-01,
            6.4735e-02, -6.2962e-02],
          [ 6.8348e-02,  2.5883e-01,  1.2144e-02,  ...,  1.5799e-01,
           -1.1598e-01, -1.6044e-01],
          ...,
          [-3.7563e-01, -8.6864e-02,  1.4916e-01,  ...,  6.1944e-02,
            1.4293e-01,  2.1511e-01],
          [ 6.8932e-02,  4.8663e-02,  3.7446e-01,  ..., -4.6764e-01,
           -4.1933e-02,  3.2577e-02],
          [ 7.7945e-02, -4.3533e-02,  2.0552e-01,  ..., -1.4802e-01,
           -3.5046e-02, -1.3873e-01]],

         [[ 2.8014e-02, -9.2312e-02,  2.0312e-01,  ...,  4.9354e-02,
            1.5624e-02,  2.7589e-02],
          [-8.9889e-02, -4.1114e-02,  9.0667e-02,  ...,  1.3049e-02,
            7.8257e-02,  1.2341e-01],
          [ 1.4991e-01,  4.8009e-02, -3.8246e-02,  ..., -3.7881e-01,
           -1.5205e-02, -1.5243e-01],
          ...,
     