In [2]:
import random, logging
from types import SimpleNamespace

import torch, wandb
from fastprogress import progress_bar

from cloud_diffusion.dataset import download_dataset, CloudDataset
from cloud_diffusion.ddpm import ddim_sampler
from cloud_diffusion.models import UNet2D, get_unet_params
from cloud_diffusion.utils import parse_args, set_seed
from cloud_diffusion.wandb import to_video, vhtile

## Download cloud_diffusion dataset from weights and biases

[notes on how to download their datasets](https://wandb.ai/capecape/ddpm_clouds/reports/GTC-Diffusion-on-the-Clouds--VmlldzozNzQ1OTkz)

In [15]:
#"Downloads dataset from wandb artifact"
DATASET_ARTIFACT_SMALL = 'capecape/gtc/np_dataset:v0' #single band
DATASET_ARTIFACT_LARGE = 'capecape/gtc/np_dataset:v1' #twice the sequences by combining 2 bands.
DATASET_ARTIFACT = DATASET_ARTIFACT_SMALL

In [None]:
with wandb.init(job_type="download_dataset"):
    artifact = wandb.use_artifact(DATASET_ARTIFACT, type='dataset')
    artifact_dir = artifact.download()

## Load downloaded small cloud_diffusion dataset

In [None]:
import os
import numpy as np

In [None]:
def load_numpy_arrays_from_folder(folder_path):
    numpy_arrays = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".npy"):
            file_path = os.path.join(folder_path, filename)
            numpy_arrays.append(np.load(file_path))
    return numpy_arrays

In [None]:
folder_path = "/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0"
dataset = load_numpy_arrays_from_folder(folder_path)

## Inspect small cloud_diffusion dataset

In [5]:
file = np.load('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_01.npy')

In [6]:
file.shape

(96, 1, 446, 780)

In [None]:
dataset[0].shape

### Each numpy file is an array with shape (96, 1, 446, 780). This represents a batch of 96 items, where each item is a single channel image of size 446x780

In [None]:
import matplotlib.pyplot as plt

In [10]:
data_slice = file[0, 0, :, :]

In [None]:
# Plot the data slice
plt.imshow(data_slice, cmap='gray')
plt.title('Data Slice from Numpy Array')
plt.show()

In [14]:
PROJECT_NAME = "ddpm_clouds"

In [16]:
config = SimpleNamespace(
    model_name="unet_small", # model name to save [unet_small, unet_big]
    sampler_steps=333, # number of sampler steps on the diffusion process
    num_frames=4, # number of frames to use as input,
    img_size=64, # image size to use
    num_random_experiments = 2, # we will perform inference multiple times on the same inputs
    seed=42,
    device="cuda" if torch.cuda.is_available() else "cpu",
    # device="mps",
    sampler="ddim",
    future_frames=10,  # number of future frames
    bs=8, # how many samples
)

## Examine data preparation steps for [inference](https://github.com/tcapelle/cloud_diffusion/blob/master/inference.py)
### from Line 56 'def prepare_data(self)':

### use cloud_diffusion download dataset function

In [17]:
files = download_dataset(DATASET_ARTIFACT, PROJECT_NAME)

[34m[1mwandb[0m: Downloading large artifact np_dataset:v0, 3816.62MB. 30 files... 
[34m[1mwandb[0m:   30 of 30 files downloaded.  
Done. 0:0:2.6


### files = list of file paths to downloaded numpy arrays

In [18]:
files

[Path('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_01.npy'),
 Path('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_02.npy'),
 Path('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_03.npy'),
 Path('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_04.npy'),
 Path('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_05.npy'),
 Path('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_06.npy'),
 Path('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_07.npy'),
 Path('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_08.npy'),
 Path('/bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/np_dataset:v0/B07_2022_06_09

## extract three days worth of validation data from full dataset

In [19]:
valid_ds = CloudDataset(files=files[-3:], # 3 days of validation data 
                                num_frames=config.num_frames, img_size=config.img_size)

## validation dataset is list of tensors with shape [4, 64, 64]

In [20]:
valid_ds[0]

tensor([[[-0.2148, -0.2150, -0.2152,  ..., -0.1546, -0.1568, -0.1596],
         [-0.2171, -0.2172, -0.2176,  ..., -0.1672, -0.1678, -0.1722],
         [-0.2128, -0.2161, -0.2200,  ..., -0.1913, -0.1936, -0.1970],
         ...,
         [ 0.0137,  0.0527, -0.0090,  ..., -0.2838, -0.2498, -0.2445],
         [ 0.0307,  0.0268,  0.0024,  ..., -0.3208, -0.2935, -0.2796],
         [ 0.0470,  0.0357,  0.0557,  ..., -0.3338, -0.3256, -0.2947]],

        [[-0.2133, -0.2139, -0.2168,  ..., -0.1524, -0.1532, -0.1519],
         [-0.2143, -0.2176, -0.2187,  ..., -0.1621, -0.1595, -0.1598],
         [-0.2139, -0.2173, -0.2168,  ..., -0.1801, -0.1788, -0.1849],
         ...,
         [ 0.0360,  0.1036,  0.0661,  ..., -0.2281, -0.3011, -0.3177],
         [ 0.1195,  0.1047,  0.0782,  ..., -0.2324, -0.2912, -0.3200],
         [ 0.0910,  0.0788,  0.0494,  ..., -0.2922, -0.2995, -0.3123]],

        [[-0.2136, -0.2103, -0.2142,  ..., -0.1571, -0.1563, -0.1579],
         [-0.2159, -0.2149, -0.2160,  ..., -0

In [21]:
valid_ds[0].shape

torch.Size([4, 64, 64])

In [30]:
valid_ds

<cloud_diffusion.dataset.CloudDataset at 0x1503e5842920>

## randomly select some samples from the validation dataset

In [72]:
idxs = random.choices(range(len(valid_ds) - config.future_frames), k=config.bs)  # select some samples

In [73]:
idxs

[113, 7, 58, 135, 7, 53, 174, 146]

In [None]:
# fix the batch to the same samples for reproducibility
batch = valid_ds[idxs].to(config.device)
batch

In [26]:
batch.shape

torch.Size([8, 4, 64, 64])

In [76]:
vid_sequence = valid_ds[idxs[0]:idxs[0]+4+config.future_frames,0,...]
vid_sequence.shape

torch.Size([14, 64, 64])

## Prepare Our Data

In [8]:
from cloudcasting.dataset import SatelliteDataset
from cloudcasting.constants import DATA_INTERVAL_SPACING_MINUTES

TRAINING_DATA_PATH = "/bask/projects/v/vjgo8416-climate/shared/data/eumetsat/training/2022_training_nonhrv.zarr"
HISTORY_STEPS = 3

# Instantiate the torch dataset object
dataset = SatelliteDataset(
    zarr_path=TRAINING_DATA_PATH,
    start_time="2022-01-31",
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=True,
)

In [9]:
import torch
import numpy as np
import torchvision.transforms as T

class SatelliteDataset2(SatelliteDataset):
    def __init__(self, *args, img_size=64, **kwargs):
        super().__init__(*args, **kwargs)
        tfms = [T.Resize((img_size, int(img_size*614/372)))]
        tfms += [T.RandomCrop(img_size)]
        self.tfms = T.Compose(tfms)

    def __getitem__(self, idx):
        concat_data = np.concatenate(super().__getitem__(idx), axis=-3)[-3]
        return 0.5-self.tfms(torch.from_numpy(concat_data))

In [10]:
dataset2 = SatelliteDataset2(
    zarr_path=TRAINING_DATA_PATH,
    start_time="2022-01-31",
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=True,
)

In [31]:
dataset2[0]

tensor([[[0.4760, 0.4760, 0.4760,  ..., 0.4763, 0.4764, 0.4763],
         [0.4760, 0.4760, 0.4760,  ..., 0.4761, 0.4763, 0.4761],
         [0.4760, 0.4760, 0.4760,  ..., 0.4764, 0.4764, 0.4763],
         ...,
         [0.4760, 0.4760, 0.4760,  ..., 0.4762, 0.4764, 0.4765],
         [0.4760, 0.4760, 0.4760,  ..., 0.4773, 0.4973, 0.5646],
         [0.5945, 0.5945, 0.5945,  ..., 0.8918, 1.0996, 1.2834]],

        [[0.4765, 0.4765, 0.4765,  ..., 0.4765, 0.4765, 0.4765],
         [0.4765, 0.4765, 0.4765,  ..., 0.4765, 0.4765, 0.4765],
         [0.4764, 0.4764, 0.4765,  ..., 0.4765, 0.4765, 0.4765],
         ...,
         [0.4765, 0.4765, 0.4765,  ..., 0.4765, 0.4765, 0.4767],
         [0.4764, 0.4765, 0.4764,  ..., 0.4779, 0.4974, 0.5647],
         [0.5947, 0.5949, 0.5948,  ..., 0.8923, 1.0997, 1.2834]],

        [[0.4765, 0.4765, 0.4765,  ..., 0.4765, 0.4765, 0.4765],
         [0.4765, 0.4765, 0.4765,  ..., 0.4765, 0.4765, 0.4765],
         [0.4765, 0.4765, 0.4765,  ..., 0.4765, 0.4765, 0.

In [28]:
dataset2[0].shape

torch.Size([4, 64, 64])

In [77]:
idxs = random.choices(range(len(dataset2) - config.future_frames), k=config.bs)  # select some samples
idxs

[2683, 7173, 9853, 79, 9809, 8498, 4141, 1892]

In [None]:
# fix the batch to the same samples for reproducibility
batch = torch.stack([dataset2[idxs[0]].to(config.device), dataset2[idxs[1]].to(config.device), 
                     dataset2[idxs[2]].to(config.device), dataset2[idxs[3]].to(config.device), 
                     dataset2[idxs[4]].to(config.device), dataset2[idxs[5]].to(config.device),
                    dataset2[idxs[6]].to(config.device), dataset2[idxs[7]].to(config.device)])
batch

In [53]:
batch.shape

torch.Size([8, 4, 64, 64])

In [63]:
dataset2[idxs[0]][0]

tensor([[0.4765, 0.4765, 0.4765,  ..., 0.4765, 0.4765, 0.4765],
        [0.4765, 0.4765, 0.4765,  ..., 0.4765, 0.4765, 0.4765],
        [0.4765, 0.4765, 0.4765,  ..., 0.4765, 0.4765, 0.4765],
        ...,
        [1.3331, 1.3331, 1.3331,  ..., 1.4513, 1.4955, 1.5000],
        [1.5000, 1.5000, 1.5000,  ..., 1.5000, 1.5000, 1.5000],
        [1.5000, 1.5000, 1.5000,  ..., 1.5000, 1.5000, 1.5000]])

In [78]:
idx = idxs[0]

In [79]:
vid_sequence2 = torch.stack([dataset2[idx][0], dataset2[idx+1][0], dataset2[idx+2][0], dataset2[idx+3][0],
                                           dataset2[idx+4][0], dataset2[idx+5][0], dataset2[idx+6][0], dataset2[idx+7][0],
                                           dataset2[idx+8][0], dataset2[idx+9][0], dataset2[idx+10][0], dataset2[idx+11][0],
                                           dataset2[idx+12][0], dataset2[idx+13][0]])
vid_sequence2.shape

torch.Size([14, 64, 64])

In [82]:
vid_sequence[0].shape

torch.Size([64, 64])

In [87]:
from pathlib import Path

import torch
import wandb
import numpy as np

## For Training

def to_wandb_image(img):
    "Convert a tensor to a wandb.Image"
    return wandb.Image(torch.cat(img.split(1), dim=-1).cpu().numpy())

def log_images(xt, samples):
    "Log sampled images to wandb"
    device = samples.device
    frames = torch.cat([xt[:, :-1,...].to(device), samples], dim=1)
    wandb.log({"sampled_images": [to_wandb_image(img) for img in frames]})

def save_model(model, model_name):
    "Save the model to wandb"
    model_name = f"{wandb.run.id}_{model_name}"
    models_folder = Path("models")
    if not models_folder.exists():
        models_folder.mkdir()
    torch.save(model.state_dict(), models_folder/f"{model_name}.pth")
    at = wandb.Artifact(model_name, type="model")
    at.add_file(f"models/{model_name}.pth")
    wandb.log_artifact(at)


## For Inference
def htile(img):
    "Horizontally tile a batch of images."
    return torch.cat(img.split(1), dim=-1)

def vtile(img):
    "Vertically tile a batch of images."
    return torch.cat(img.split(1), dim=-2)

def vhtile(*imgs):
    "Vertically and horizontally tile a batch of images."
    return vtile(torch.cat([htile(img) for img in imgs], dim=0))

def scale(arr):
    "Scales values of array in [0,1]"
    m, M = arr.min(), arr.max()
    return (arr - m) / (M - m)

def preprocess_frames(data):
    "Preprocess frames for wandb.Video"
    sdata = scale(data.squeeze())
    # print(sdata.shape)
    def tfm(frame):
        rframe = 255 * frame
        return rframe.cpu().numpy().astype(np.uint8)
    return [tfm(frame) for frame in sdata]

In [92]:
def make_video(data):
    frames = preprocess_frames(data)
    vid = np.repeat(np.stack(frames)[:,None,...],3,axis=1)
    return wandb.Video(vid)

In [93]:
import random, logging
from types import SimpleNamespace

import torch, wandb
from fastprogress import progress_bar

from cloud_diffusion.dataset import download_dataset, CloudDataset
from cloud_diffusion.ddpm import ddim_sampler
from cloud_diffusion.models import UNet2D, get_unet_params
from cloud_diffusion.utils import parse_args, set_seed
from cloud_diffusion.wandb import to_video, vhtile

logging.basicConfig(format='%(asctime)s - %(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

PROJECT_NAME = "ddpm_clouds"
JOB_TYPE = "inference"
MODEL_ARTIFACT = "capecape/ddpm_clouds/esezp3jh_unet_small:v0"  # small model

config = SimpleNamespace(
    model_name="unet_small", # model name to save [unet_small, unet_big]
    sampler_steps=333, # number of sampler steps on the diffusion process
    num_frames=4, # number of frames to use as input,
    img_size=64, # image size to use
    num_random_experiments = 2, # we will perform inference multiple times on the same inputs
    seed=42,
    device="cuda" if torch.cuda.is_available() else "cpu",
    # device="mps",
    sampler="ddim",
    future_frames=10,  # number of future frames
    bs=8, # how many samples
)

class Inference:

    def __init__(self, config):
        self.config = config
        set_seed(config.seed)

        # create a batch of data to use for inference
        self.prepare_data()
        
        # we default to ddim as it's faster and as good as ddpm
        self.sampler = ddim_sampler(config.sampler_steps)

        # create the Unet
        model_params = get_unet_params(config.model_name, config.num_frames)

        logger.info(f"Loading model {config.model_name} from artifact: {MODEL_ARTIFACT}")
        self.model = UNet2D.from_artifact(model_params, MODEL_ARTIFACT).to(config.device)

        self.model.eval()
    
    def prepare_data(self):
        "Generates a batch of data from the validation dataset"

        self.valid_ds = dataset2
        self.idxs = random.choices(range(len(self.valid_ds) - config.future_frames), k=config.bs)  # select some samples
        # fix the batch to the same samples for reproducibility
        #self.batch = self.valid_ds[self.idxs[0]].to(config.device)
        self.batch = torch.stack([self.valid_ds[idxs[0]].to(config.device), self.valid_ds[idxs[1]].to(config.device), 
                     self.valid_ds[idxs[2]].to(config.device), self.valid_ds[idxs[3]].to(config.device), 
                     self.valid_ds[idxs[4]].to(config.device), self.valid_ds[idxs[5]].to(config.device),
                    self.valid_ds[idxs[6]].to(config.device), self.valid_ds[idxs[7]].to(config.device)])

    def sample_more(self, frames, future_frames=1):
        "Autoregressive sampling, starting from `frames`. It is hardcoded to work with 3 frame inputs."
        for _ in progress_bar(range(future_frames), total=future_frames, leave=True):
            # compute new frame with previous 3 frames
            new_frame = self.sampler(self.model, frames[:,-3:,...])
            # add new frame to the sequence
            frames = torch.cat([frames, new_frame.to(frames.device)], dim=1)
        return frames.cpu()

    def forecast(self):
        "Perform inference on the batch of data."
        logger.info(f"Forecasting {self.batch.shape[0]} samples for {self.config.future_frames} future frames.")
        sequences = []
        for i in range(self.config.num_random_experiments):
            logger.info(f"Generating {i+1}/{self.config.num_random_experiments} futures.")
            frames = self.sample_more(self.batch, self.config.future_frames)
            sequences.append(frames)

        return sequences

    def log_to_wandb(self, sequences):
        "Create a table with the ground truth and the generated frames. Log it to wandb."
        table = wandb.Table(columns=["id", "gt", *[f"gen_{i}" for i in range(config.num_random_experiments)], "gt/gen"])
        for i, idx in enumerate(self.idxs):
            vid_obj = torch.stack([self.valid_ds[idx][0], self.valid_ds[idx+1][0], self.valid_ds[idx+2][0], self.valid_ds[idx+3][0],
                                           self.valid_ds[idx+4][0], self.valid_ds[idx+5][0], self.valid_ds[idx+6][0], self.valid_ds[idx+7][0],
                                           self.valid_ds[idx+8][0], self.valid_ds[idx+9][0], self.valid_ds[idx+10][0], self.valid_ds[idx+11][0],
                                           self.valid_ds[idx+12][0], self.valid_ds[idx+13][0]])
            #gt_vid = to_video(np.repeat(gt_vid.detach().cpu().numpy(), 3, axis=1) * 255)
            gt_vid = make_video(vid_obj)
            pred_vids = [make_video(frames[i]) for frames in sequences]
            gt_gen = wandb.Image(vhtile(torch.stack([self.valid_ds[idx][0], self.valid_ds[idx+1][0], self.valid_ds[idx+2][0], self.valid_ds[idx+3][0],
                                           self.valid_ds[idx+4][0], self.valid_ds[idx+5][0], self.valid_ds[idx+6][0], self.valid_ds[idx+7][0],
                                           self.valid_ds[idx+8][0], self.valid_ds[idx+9][0], self.valid_ds[idx+10][0], self.valid_ds[idx+11][0],
                                           self.valid_ds[idx+12][0], self.valid_ds[idx+13][0]]), *[frames[i] for frames in sequences]))
            table.add_data(idx, gt_vid, *pred_vids, gt_gen)
        logger.info("Logging results to wandb...")
        wandb.log({f"gen_table_{config.future_frames}_random":table})

In [94]:
set_seed(config.seed)

with wandb.init(project=PROJECT_NAME, job_type=JOB_TYPE, 
                config=config, tags=["ddpm", config.model_name]):
    infer = Inference(config)
    sequences = infer.forecast()
    infer.log_to_wandb(sequences)

2024-10-24 17:03:02,034 - INFO - Loading model unet_small from artifact: capecape/ddpm_clouds/esezp3jh_unet_small:v0
[34m[1mwandb[0m:   1 of 1 files downloaded.  
2024-10-24 17:03:03,546 - INFO - Forecasting 8 samples for 10 future frames.
2024-10-24 17:03:03,548 - INFO - Generating 1/2 futures.


Loading model from: /bask/projects/v/vjgo8416-climate/users/lwcf1795/diffusion/artifacts/esezp3jh_unet_small:v0/esezp3jh_unet_small.pth


2024-10-24 17:03:58,953 - INFO - Generating 2/2 futures.


2024-10-24 17:05:43,402 - INFO - Logging results to wandb...
