In [1]:
import torch
from torch import nn
import numpy as np

from typing import Dict

from torch.nn.utils.spectral_norm import SpectralNorm

from helper import assert_shape
from shared import FullyConnectedLayers, ResidualBlock
from vit_utils import make_vit_backbone, forward_vit

import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import RandomCrop, Normalize

In [2]:
class SpectralConv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        SpectralNorm.apply(self, name = "weight", n_power_iterations=1, dim = 0, eps = 1e-12)

# Dummy input: [batch, channels, length]
x = torch.randn(5, 64, 100)

conv1d_modified = SpectralConv1d(in_channels=64, out_channels=64, kernel_size=7, stride=1, padding = 3)
modified_out = conv1d_modified(x)
modified_out.shape

torch.Size([5, 64, 100])

In [3]:
class LocalBatchNorm(nn.Module):
    # When using large batch sizes, the variance across the batch can be very high, especially in early training.
    # It may cause the normalization to overreact, resulting in instability in the discriminator’s learning.
    # So we use virtual_bs for smaller batch size to normalize it through.
    def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 8, eps: float = 1e-8):
        super().__init__()

        self.num_features = num_features
        self.affine = affine             # learn weight and biases?
        self.virtual_bs = virtual_bs
        self.eps = eps

        if self.affine:
            self.weights = nn.Parameter(torch.ones(num_features))
            self.bias    = nn.Parameter(torch.zeros(num_features))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = x.shape

        G = np.ceil(x.shape[0] / self.virtual_bs).astype(int) # G = B / 8 
        x = x.view(G, -1, x.shape[1], x.shape[2])             # x: [G, -1, N, C]
        # Normalizing per group, per channel
        mean = x.mean([1, 3], keepdim=True)                   # mean: [G, 1, N, 1] 
        var = x.var([1, 3], keepdim=True)                     # var : [G, 1, N, 1] 

        x = (x - mean) / (torch.sqrt(var) + self.eps)            # x: [G, -1, N, C]

        if self.affine:
            x = x * self.weights[None, :, None]               # weight: [1, N, 1]
                                                              # x     : [G, -1, N, C]

            x = x + self.bias[None, :, None]                  # bias  : [1, N, 1]
                                                              # x     : [G, -1, N, C]
        return x.view(shape)

x = torch.randn(18, 64, 100)

lbn = LocalBatchNorm(num_features = 64, affine=True)
out = lbn(x)
out.shape

torch.Size([18, 64, 100])

In [4]:
x = torch.randn(18, 64, 100)

shape = x.shape

num_features=64
virtual_bs=8
eps = 1e-8
affine = True

weight = nn.Parameter(torch.ones(num_features))
bias   = nn.Parameter(torch.zeros(num_features))

G = np.ceil(x.shape[0] / virtual_bs).astype(int) # G = 20 / 8 = 3

x = x.view(G, -1, x.shape[1], x.shape[2])  # x: [G, -1, N, C]

# Normalizing per group, per channel
mean = x.mean([1, 3], keepdim=True)        # mean: [G, 1, N, 1] 
var = x.var([1, 3], keepdim=True)          # var : [G, 1, N, 1] 

x = (x - mean) / (torch.sqrt(var) + eps)      # x: [G, -1, N, C]
if affine: 
    x = x * weight[None, :, None]          # weight: [1, N, 1]
                                           # x     : [G, -1, N, C]

    x = x + bias[None, :, None]            # bias  : [1, N, 1]
                                           # x     : [G, -1, N, C]
x = x.view(shape)
x.shape

torch.Size([18, 64, 100])

In [5]:
def make_block(channels: int, kernel_size: int):
    return nn.Sequential(
        SpectralConv1d(in_channels  = channels, 
                       out_channels = channels, 
                       kernel_size  = kernel_size, 
                       padding      = kernel_size//2, 
                       padding_mode = "circular"),
        
        LocalBatchNorm(num_features = channels),
        nn.LeakyReLU(0.2, True)
    )
x = torch.rand(5, 64, 100)
block = make_block(64, 7)
out = block(x)
out.shape

torch.Size([5, 64, 100])

In [6]:
x = torch.rand(4, 3)
f = FullyConnectedLayers(3, 10)
f(x).shape

torch.Size([4, 10])

In [7]:
channels = 384 # DINO ViT-S output
c_dim    = 512 # Text embedding from CLIP
cmap_dim = 64  # Projection space for conditional score

x = torch.rand(5, channels, 64)
c = torch.rand(5, c_dim)

main = nn.Sequential(
    make_block(channels = channels, kernel_size = 1),
    ResidualBlock(make_block(channels = channels, kernel_size = 9))
)   # x shape will remain same as long as kernel is odd

cmapper = FullyConnectedLayers(in_features = c_dim, out_features = cmap_dim)
cls = SpectralConv1d(in_channels = channels, out_channels = cmap_dim, kernel_size = 1, padding = 0)

h = main(x)                         # h: [B, channels, dim]
out = cls(h)                        # h: [B, cmap_dim, dim]

cmap = cmapper(c)                   # cmap: [B, cmap_dim]
cmap = cmap.unsqueeze(-1)           # cmap: [B, cmap_dim, 1]

out = out * cmap                    # out:  [B, cmap_dim, cmap_dim]
out = out.sum(1, keepdim=True)      # out:  [B, 1,        cmap_dim]
out = out * np.sqrt(1 / cmap_dim)   # out:  [B, 1,        cmap_dim]
out.shape

torch.Size([5, 1, 64])

In [8]:
class DiscHead(nn.Module):
    def __init__(self, channels: int, c_dim: int, cmap_dim: int):
        super().__init__()

        self.channels = channels # DINO ViT-S output
        self.c_dim = c_dim       # Text embedding from CLIP
        self.cmap_dim = cmap_dim # Projection space for conditional score

        self.main = nn.Sequential(
            make_block(channels = channels, kernel_size = 1),
            ResidualBlock(make_block(channels = channels, kernel_size = 9))
        )   # x shape will remain same as long as kernel is odd

        if self.c_dim > 0:
            self.cmapper = FullyConnectedLayers(in_features = c_dim, out_features = cmap_dim)
            self.cls = SpectralConv1d(in_channels = channels, out_channels = cmap_dim, kernel_size = 1, padding = 0)
        else:
            self.cls = SpectralConv1d(in_channels = channels, out_channels = 1, kernel_size=1, padding=0)
        
    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        # x: [B, channels, dim]
        # c: [B, c_dim]

        h = main(x)                         # h  : [B, channels, dim]
        out = cls(h)                        # out: [B, cmap_dim, dim]
                                            #         or
                                            # out: [B, 1,        dim]
        
        if self.c_dim > 0:
            cmap = cmapper(c)                   # cmap: [B, cmap_dim]
            cmap = cmap.unsqueeze(-1)           # cmap: [B, cmap_dim, 1]
            out = out * cmap                    # out:  [B, cmap_dim, cmap_dim]
            out = out.sum(1, keepdim=True)      # out:  [B, 1,        cmap_dim]
            out = out * np.sqrt(1 / cmap_dim)   # out:  [B, 1,        cmap_dim]
        
        return out

channels = 384 # DINO ViT-S output
c_dim    = 512 # Text embedding from CLIP
cmap_dim = 64  # Projection space for conditional score

x = torch.rand(5, channels, 64)
c = torch.rand(5, c_dim)

dh = DiscHead(channels = 384, c_dim = 512, cmap_dim = 64)
out = dh(x, c)
out.shape


torch.Size([5, 1, 64])

In [9]:
hooks = [2,5,8,11]
hook_patch = True

n_hooks = len(hooks) + int(hook_patch) # n_hooks: 5

model = make_vit_backbone(
            timm.create_model('vit_small_patch16_224_dino', pretrained=True),
            patch_size=[16,16], hooks=hooks, hook_patch=hook_patch,
        )

model = model.model.eval().requires_grad_(False)
img_res = model.model.patch_embed.img_size[0]  # img_res  : 224
embed_dim = model.model.embed_dim              # embed_dim: 384
norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)

x = torch.rand(5, 3, 200, 200)
x = nn.functional.interpolate(x, (img_res, img_res), mode='area')  # x: [B, channel, 224, 224]
x = norm(x)
x = forward_vit(model, x)

for i in range(len(x)-1):
    print(f"Block {hooks[i]} Shape:", list(x[f'{i}'].shape))
if hook_patch:
    print(f"Dropout Block Shape:", list(x['4'].shape))

  model = create_fn(


Block 2 Shape: [5, 384, 196]
Block 5 Shape: [5, 384, 196]
Block 8 Shape: [5, 384, 196]
Block 11 Shape: [5, 384, 196]
Dropout Block Shape: [5, 384, 196]


In [10]:
class DINO(nn.Module):
    def __init__(self, hooks: int = [2, 5, 8, 11], hook_patch: bool = True):
        super().__init__()
        
        self.n_hooks = len(hooks) + int(hook_patch) # n_hooks: 5

        self.model = make_vit_backbone(
                    timm.create_model('vit_small_patch16_224_dino', pretrained=True),
                    patch_size=[16,16], hooks=hooks, hook_patch=hook_patch,
                )
        self.model = self.model.model.eval().requires_grad_(False)
        self.img_res = model.model.patch_embed.img_size[0]                 # img_res  : 224
        self.embed_dim = model.model.embed_dim                             # embed_dim: 384
        self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
    
    def forward(self, x: torch.Tensor) -> Dict: # FIX
        x = nn.functional.interpolate(x, self.img_res, mode='area')  # x: [B, channel, 224, 224]
        x = norm(x)
        features = forward_vit(self.model, x)
        return features
    
x = torch.rand(5, 3, 200, 200)

dino = DINO()
out = dino(x)

for i in range(len(out)-1):
    print(f"Block {hooks[i]} Shape:", list(out[f'{i}'].shape))
if hook_patch:
    print(f"Dropout Block Shape:", list(out['4'].shape))

Block 2 Shape: [5, 384, 196]
Block 5 Shape: [5, 384, 196]
Block 8 Shape: [5, 384, 196]
Block 11 Shape: [5, 384, 196]
Dropout Block Shape: [5, 384, 196]


In [11]:
dino.embed_dim

384