# Incremental Unified Framework (IUF) Testing

We want to test each component and module in isolation to make sure it's working properly

_TODO:_
1. Go through ViT code and understand unique implementation
2. Update to include modules for creating Discriminator, Encoder, & Decoder
3. Check for paper/code for architecture or hyperparams match our implementation (num_heads, dim, etc)
4. Latent Space Regularization
5. Gradient Update Regularization


In [1]:
import einops as ein
import torch
import torch.nn as nn
import torch.nn.functional as F
from Methods.IUF.ViT import MultiHeadSelfAttention, ViTBlock, ViT
from Methods.IUF.utils.discriminator import Discriminator
from Methods.IUF.utils.encoder import Encoder

In [7]:
# Starting to put together an IUF pipeline for testing
class IUF(nn.Module):
    def __init__(self):
        super().__init__()

        self.discriminator = Discriminator()

        self.encoder = Encoder()

        return

    def forward(self, x):

        # List of length num_layers,
        # where each item is a tensor of size (B x L x E)
        oasa_features = self.discriminator(x, return_features=True)

        latent_features = self.encoder(x, oasa_features)

        return oasa_features

iuf = IUF()
dummy = torch.rand(8, 3, 224, 224)
out = iuf(dummy)

In [8]:
out[0].shape

torch.Size([8, 196, 64])

In [4]:
out[0][0].sum()

tensor(1., grad_fn=<SumBackward0>)

In [4]:
q = torch.rand(8, 3, 224, 224)
k = q * q
k

tensor([[[[2.4207e-03, 9.8799e-01, 6.8344e-02,  ..., 2.9395e-01,
           9.1263e-01, 1.0908e-01],
          [2.1685e-02, 1.7980e-01, 1.0857e-01,  ..., 2.0768e-01,
           7.3983e-02, 4.8949e-01],
          [5.1859e-02, 3.1769e-02, 1.7874e-01,  ..., 7.9489e-02,
           7.8932e-01, 7.5410e-02],
          ...,
          [7.2196e-01, 9.7465e-01, 6.6747e-01,  ..., 5.2817e-01,
           2.6647e-01, 1.6131e-01],
          [5.0014e-01, 7.9827e-01, 2.0470e-01,  ..., 2.6861e-01,
           3.9106e-01, 1.8346e-01],
          [4.1077e-02, 8.2152e-01, 2.8943e-01,  ..., 2.0944e-02,
           9.7998e-01, 8.5060e-01]],

         [[2.6021e-01, 5.2151e-01, 1.7537e-02,  ..., 7.3854e-02,
           4.8579e-02, 1.0855e-02],
          [9.3027e-01, 5.6885e-05, 9.8884e-01,  ..., 1.7535e-01,
           7.4032e-02, 7.9196e-01],
          [9.8599e-01, 1.4638e-02, 2.6416e-02,  ..., 6.2555e-02,
           4.8376e-02, 7.1530e-01],
          ...,
          [7.4061e-03, 5.4547e-01, 2.0823e-01,  ..., 2.3168

In [5]:
q

tensor([[[[4.9201e-02, 9.9398e-01, 2.6143e-01,  ..., 5.4217e-01,
           9.5532e-01, 3.3028e-01],
          [1.4726e-01, 4.2403e-01, 3.2949e-01,  ..., 4.5572e-01,
           2.7200e-01, 6.9963e-01],
          [2.2772e-01, 1.7824e-01, 4.2277e-01,  ..., 2.8194e-01,
           8.8843e-01, 2.7461e-01],
          ...,
          [8.4968e-01, 9.8724e-01, 8.1699e-01,  ..., 7.2675e-01,
           5.1621e-01, 4.0163e-01],
          [7.0720e-01, 8.9346e-01, 4.5243e-01,  ..., 5.1827e-01,
           6.2535e-01, 4.2832e-01],
          [2.0267e-01, 9.0638e-01, 5.3799e-01,  ..., 1.4472e-01,
           9.8994e-01, 9.2228e-01]],

         [[5.1011e-01, 7.2216e-01, 1.3243e-01,  ..., 2.7176e-01,
           2.2041e-01, 1.0419e-01],
          [9.6450e-01, 7.5422e-03, 9.9440e-01,  ..., 4.1875e-01,
           2.7209e-01, 8.8992e-01],
          [9.9297e-01, 1.2099e-01, 1.6253e-01,  ..., 2.5011e-01,
           2.1994e-01, 8.4575e-01],
          ...,
          [8.6059e-02, 7.3856e-01, 4.5633e-01,  ..., 4.8134

In [6]:
.049201**2

0.002420738401