In [5]:
import torch
from torch import nn
from videogamegen.models.modules import AdaptiveGroupNorm
from einops import rearrange

In [3]:
class ResBlock2d(nn.Module):
    def __init__(self, in_channels, out_channels, adaptive=False):
        super().__init__()
        if in_channels != out_channels:
            self.identity = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            )
        else:
            self.identity = nn.Identity()

        if adaptive:
            self.beta = nn.Parameter(torch.ones(1, out_channels, 1, 1))
            self.gamma = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
        self.adaptive = adaptive

        self.block = nn.Sequential(
            AdaptiveGroupNorm(in_channels),
            nn.SiLU(),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1
            ),
            AdaptiveGroupNorm(out_channels),
            nn.SiLU(),
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1
            )
        )

    def forward(self, x):
        out = self.block(x)
        
        if self.adaptive:
            out = out * self.beta + self.gamma
        
        return out + self.identity(x)

In [4]:
x = torch.rand((1, 256, 4, 4))

In [7]:
mean, logvar = torch.chunk(x, 2, dim=1)

In [11]:
mean.shape

torch.Size([1, 128, 4, 4])

In [12]:
logvar.shape

torch.Size([1, 128, 4, 4])

In [13]:
k = 1 + logvar - mean.pow(2) - logvar.exp()

In [14]:
k.shape

torch.Size([1, 128, 4, 4])

In [8]:
kl_div = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())

In [10]:
kl_div

tensor(545.6481)