In [1]:
import sys
import os
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
sys.path.append("../..")
from src.modules.wanvae import _video_vae

pretrained_path = "/Users/tosinkuye/ape/wanup/checkpoints/Wan2.1_VAE.pth"
vae = _video_vae(pretrained_path=pretrained_path, z_dim=16, device='cpu')

In [None]:
print(vae.encoder)

In [None]:

print(vae.decoder)

In [27]:
import torch
tensor = torch.randn(1, 3, 4, 16, 16)

In [28]:
from torch import nn
import torch.nn.functional as F
from src.modules.wanvae import ResidualBlock, CausalConv3d, RMS_norm,  Resample, AttentionBlock
from einops import rearrange
from natten import  NeighborhoodAttention2D, NeighborhoodAttention3D
CACHE_T = 2

class SpatialNeighbourhoodAttentionBlock(nn.Module):
    """
    Spatial neighbourhood attention  with a single head.
    """

    def __init__(self, dim, kernel_size=(3, 3), dilation=(1, 1)):
        super().__init__()
        self.dim = dim

        # layers
        self.norm = RMS_norm(dim)
        self.na = NeighborhoodAttention2D(
            dim,
            num_heads=1,
            kernel_size=kernel_size,
            dilation=dilation,
        )

    def forward(self, x):
        identity = x
        b, c, t, h, w = x.shape
        x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
        x = self.norm(x).view(b * t, h, w, c)
        x = self.na(x)
        x = x.view(b, c, t, h, w)
        return x + identity
    
    
class SpatiaTemporallNeighbourhoodAttentionBlock(nn.Module):
    """
    Spatial and temporal neighbourhood attention  with a single head.
    """

    def __init__(self, dim, kernel_size=(3, 3, 3), dilation=(1, 1, 1)):
        super().__init__()
        self.dim = dim

        # layers
        self.norm = RMS_norm(dim, images=False)
        self.na = NeighborhoodAttention3D(
            dim,
            num_heads=1,
            kernel_size=kernel_size,
            dilation=dilation,
        )

    def forward(self, x):
        identity = x
        b, c, t, h, w = x.shape
        x = self.norm(x).view(b, t, h, w, c)
        x = self.na(x)
        x = x.view(b, c, t, h, w)
        return x + identity

class Decoder(nn.Module):
    def __init__(self,
                 dim=128,
                 z_dim=4,
                 dim_mult=[1, 2, 4, 4],
                 num_res_blocks=2,
                 attn_scales=[1.0, 2.0],
                 temporal_upsample=[False, True, True],
                 dropout=0.0):
        super().__init__()
        self.z_dim = z_dim
        self.dim = dim
        self.dim_mult = dim_mult
        self.num_res_blocks = num_res_blocks
        self.attn_scales = attn_scales
        self.temporal_upsample = temporal_upsample
        self.dropout = dropout

        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
        scale = 1.0 / 2**(len(dim_mult) - 2)
        
        self.middle = nn.Sequential(
            ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
            ResidualBlock(dims[0], dims[0], dropout))
        
        self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)

        upsamples = []
        
        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
            # residual (+attention) blocks
            if i == 1 or i == 2 or i == 3:
                in_dim = in_dim // 2
            for _ in range(num_res_blocks + 1):
                upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
                if scale in attn_scales:
                    upsamples.append(SpatiaTemporallNeighbourhoodAttentionBlock(out_dim))
                in_dim = out_dim

            # upsample block
            if i != len(dim_mult) - 1:
                mode = 'upsample3d' if temporal_upsample[i] else 'upsample2d'
                upsamples.append(Resample(out_dim, mode=mode))
                scale *= 2.0
                
        # Here we add super sampling
        # get dims from the last two layers
        upsamples.append(Resample(out_dim, mode='upsample2d'))
        scale *= 2.0
        for i in range(2):
            # then we add more residual blocks
            if i == 0:
                out_dim = out_dim // 2
                in_dim = out_dim
            if i == 1:
                out_dim = int(out_dim / 1.5)
                in_dim = out_dim
            for _ in range(num_res_blocks + 1):
                upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
                if scale in attn_scales:
                    upsamples.append(SpatiaTemporallNeighbourhoodAttentionBlock(out_dim))
                in_dim = out_dim
            if i < 1:
                upsamples.append(Resample(out_dim, mode='upsample2d_15'))
                scale *= 1.5

        self.upsamples = nn.Sequential(*upsamples)
        
        self.head = nn.Sequential(
            RMS_norm(out_dim, images=False), 
            nn.SiLU(),
            CausalConv3d(out_dim, 3, 3, padding=1)
        )
    
    def forward(self, x, feat_cache=None, feat_idx=[0]):
        ## conv1
        if feat_cache is not None:
            idx = feat_idx[0]
            cache_x = x[:, :, -CACHE_T:, :, :].clone()
            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
                # cache last frame of last two chunk
                cache_x = torch.cat([
                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
                        cache_x.device), cache_x
                ],
                                    dim=2)
            x = self.conv1(x, feat_cache[idx])
            feat_cache[idx] = cache_x
            feat_idx[0] += 1
        else:
            x = self.conv1(x)

        ## middle
        for layer in self.middle:
            if isinstance(layer, ResidualBlock) and feat_cache is not None:
                x = layer(x, feat_cache, feat_idx)
            else:
                x = layer(x)

        ## upsamples
        for layer in self.upsamples:
            if feat_cache is not None:
                x = layer(x, feat_cache, feat_idx)
            else:
                x = layer(x)

        ## head
        for layer in self.head:
            if isinstance(layer, CausalConv3d) and feat_cache is not None:
                idx = feat_idx[0]
                cache_x = x[:, :, -CACHE_T:, :, :].clone()
                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
                    # cache last frame of last two chunk
                    cache_x = torch.cat([
                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
                            cache_x.device), cache_x
                    ],
                                        dim=2)
                x = layer(x, feat_cache[idx])
                feat_cache[idx] = cache_x
                feat_idx[0] += 1
            else:
                x = layer(x)
        return x
        

In [107]:
cfg = dict(
        dim=96,
        z_dim=16,
        dim_mult=[1, 2, 4, 4],
        num_res_blocks=2,
        attn_scales=[],
        temporal_upsample=[False, True, True],
        dropout=0.0)

In [None]:
decoder = Decoder(**cfg).to('mps')
decoder.train()

In [None]:
def print_model_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
print_model_parameters(vae.decoder)

In [None]:
print_model_parameters(decoder)

In [109]:
from torch.optim import Adam
from tqdm.auto import tqdm
optimizer = Adam(decoder.parameters(), lr=5e-4)
input_tensor = torch.randn(1, 3, 4, 16, 16).to('mps')
target = torch.ones(1, 3, 4, 48, 48).to('mps')

In [110]:
def loss_fn(pred, target):
    # add KL divergence loss as well
    #kl_loss = F.kl_div(pred, target, reduction='batchmean')
    mse_loss = F.mse_loss(pred, target)
    l1_loss = F.l1_loss(pred, target)
    return mse_loss + l1_loss

In [None]:
vae.to('mps')

In [None]:
for epoch in tqdm(range(100)):
    optimizer.zero_grad()
    with torch.no_grad():
        z, log_var = vae.encoder(input_tensor).chunk(2, dim=1)
    pred = decoder(z)
    loss = loss_fn(pred, target)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}")