In [43]:
import argparse
import os
from pathlib import Path

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
from torch.utils.data import DataLoader

from dm_zoo.dff.PixelDiffusion import (
    PixelDiffusionConditional,
)
from WD.datasets import Conditional_Dataset_Zarr_Iterable
from WD.utils import create_dir
from WD.io import create_xr_output_variables
# from WD.io import load_config, write_config  # noqa F401
import pytorch_lightning as pl


In [44]:
model_name = "2023-09-01_14-19-31"  # we have to pass this to the bash file every time! (should contain a string of the date the run was started).
nens = 1  # we have to pass this to the bash file every time!

ds_config = OmegaConf.load(f"/data/compoundx/WeatherDiff/hydra_configs/rasp_thuerey_iterative/.hydra/config.yaml")
ml_config = OmegaConf.load(f"/data/compoundx/WeatherDiff/hydra_configs/training/rasp_thuerey_iterative/conditional_diffusion_6h_iterative/{model_name}/.hydra/config.yaml")


model_load_dir = Path(f"/data/compoundx/WeatherDiff/saved_model/rasp_thuerey_iterative/conditional_diffusion_6h_iterative/{model_name}/lightning_logs/version_0/checkpoints/")

test_ds_path = f"/data/compoundx/WeatherDiff/model_input/rasp_thuerey_iterative_test.zarr"


ds = Conditional_Dataset_Zarr_Iterable(test_ds_path, ds_config.template, shuffle_chunks=False, 
                                            shuffle_in_chunks=False)

model_ckpt = [x for x in model_load_dir.iterdir()][0]

conditioning_channels = ds.array_inputs.shape[1] * len(ds.conditioning_timesteps) + ds.array_constants.shape[0]
generated_channels = ds.array_targets.shape[1]

restored_model = PixelDiffusionConditional.load_from_checkpoint(
    model_ckpt,
    config=ml_config.experiment.pixel_diffusion,
    conditioning_channels=conditioning_channels,
    generated_channels=generated_channels,
    loss_fn=None,
    sampler=None,
)


Is Time embed used ?  True
Cyclical Padding ?  False


In [45]:
restored_model.device

device(type='cuda', index=0)

In [48]:
dl = DataLoader(ds, batch_size=3)
trainer = pl.Trainer()

n_steps = 2  # hard-code for now. Relax later.

constants = torch.tensor(ds.array_constants[:], dtype=torch.float).to(restored_model.device)




res = []
for i in range(nens):  # loop over ensemble members
    ts = []
    for i, b in enumerate(dl):  # loop over batches in test set
        print(i)
        input = b
        trajectories = torch.zeros(size=(b[1].shape[0], n_steps, *b[1].shape[1:]))
        for step in range(n_steps):
            restored_model.eval()
            with torch.no_grad():  
                out = restored_model.forward(input)  # is this a list of tensors or a tensor?
                trajectories[:,step,...] = out
                input = [torch.concatenate([out, constants.unsqueeze(0).expand(out.size(0), *constants.size())], dim=1), None]  # we don't need the true target here
        ts.append(trajectories)
        if i == 4:
            break
    res.append(torch.cat(ts, dim=0))
res = torch.stack(out, dim=0)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


0


In [None]:
len(ts)

2

In [20]:
ts[1].shape

torch.Size([2, 3, 38, 32, 64])

In [12]:
trajectories.shape

torch.Size([2, 3, 38, 32, 64])

In [38]:
model_name = "2023-09-01_14-19-31"  # we have to pass this to the bash file every time! (should contain a string of the date the run was started).
nens = 1  # we have to pass this to the bash file every time!

ds_config = OmegaConf.load(f"/data/compoundx/WeatherDiff/hydra_configs/rasp_thuerey_geopotential/.hydra/config.yaml")


model_load_dir = Path(f"/data/compoundx/WeatherDiff/saved_model/rasp_thuerey_geopotential/conditional_diffusion_6h_iterative/{model_name}/lightning_logs/version_0/checkpoints/")


test_ds_path = f"/data/compoundx/WeatherDiff/model_input/rasp_thuerey_geopotential_test.zarr"

ds = Conditional_Dataset_Zarr_Iterable(test_ds_path, ds_config.template, shuffle_chunks=False, 
                                            shuffle_in_chunks=False)

In [41]:
ds.array_inputs

<zarr.core.Array '/inputs/data' (2917, 38, 32, 64) float32 read-only>