In [6]:
from diffusers import DDIMScheduler
from pytorch_lightning import Trainer, seed_everything
from omegaconf import OmegaConf
from CasCast.networks.prediff.taming.autoencoder_kl import AutoencoderKL
from pipeline.modeldefinitions.dit import CasFormer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pipeline.utils import load_checkpoint_cascast
import torch
import torch.nn as nn
"""
B T C H W   
"""
class DiT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = CasFormer(arch='DiT-custom', config=config.Model)
        self.scheduler = DDIMScheduler(num_train_timesteps=config.timesteps)

    def forward(self, noisy: torch.Tensor, timesteps: torch.Tensor, cond: torch.Tensor):
        return self.model(noisy, timesteps, cond)

class Autoencoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.autoencoder = AutoencoderKL(**config)
        self.autoencoder.eval() 
        load_checkpoint_cascast("/home/vatsal/NWM/weather/pipeline/autoencoder_ckpt.pth", self.autoencoder)
        for param in self.autoencoder.parameters():
            param.requires_grad = False
        self.autoencoder.requires_grad_(False)

    @torch.no_grad()
    def encode(self, x):
        # x: (B, T, C, H, W) [0, 1]
        B, T, _, H, W = x.shape
        out = []
        for i in range(T):
            frame = x[:, i]  # [B, C, H, W]
            z = self.autoencoder.encode(frame).sample()
            out.append(z.unsqueeze(1))
        return torch.cat(out, dim=1)

    @torch.no_grad()
    def decode(self, x):
        # x: (B, T, latent_C, H, W) [0, 1]
        B, T, C, H, W = x.shape
        out = []
        for i in range(T):
            frame = x[:, i]
            dec = self.autoencoder.decode(frame)
            out.append(dec.unsqueeze(1))
        return torch.cat(out, dim=1)

In [23]:
config = OmegaConf.load("/home/vatsal/NWM/weather/pipeline/configs/models/vae.yaml")
autoencoder = Autoencoder(config).to("cuda")

['autoencoder_kl', 'lpipsWithDisc']
[32mloaded autoencoder_kl successfully the game is on[0m


In [31]:
path = "/home/vatsal/Dataserver/NWM/datasets/sevir/data/vil/2017"
import os
files = os.listdir(path)
file = files[0]
import h5py
data = h5py.File(os.path.join(path, file), 'r')
vil = data['vil'][10]
print(vil.shape)  # (T, H, W)

(384, 384, 49)


In [32]:
x = torch.tensor(vil).unsqueeze(0)  # (1, H, W, T)
x = x.permute(0, 3, 1, 2).unsqueeze(2)  # (1, T, C, H, W)
x = x.float() / 255.0  # Normalize to [0, 1]
enc = autoencoder.encode(x.to("cuda"))  # (1, T, latent_C, H', W')

In [33]:
print("min enc:", enc.min().item())
print("max enc:", enc.max().item())

min enc: -8.521828651428223
max enc: 20.1545352935791


In [None]:
import torch

# Original tensor
x = torch.randn((4, 64, 64))
tmp_x = x.clone()

# 1) Save reference slice
x0 = x[0:1]              # shape (1, 64, 64)

# 2) Subtract x0
x = x - x0               # now centered around zero


# 3) Normalize to [0, 1]
x = (x + 1) / 2

# send

# 4) Rescale back to [-1, 1]
x = x * 2 - 1

# 5) Add x0 back
x = x + x0

# Check
print(torch.isclose(tmp_x, x, atol=1e-5).all())  # Should print: True


tensor(True)
