In [1]:
from types import SimpleNamespace

import torch

from cloud_diffusion.dataset import download_dataset, CloudDataset
from cloud_diffusion.ddpm import noisify_ddpm


PROJECT_NAME = "ddpm_clouds"
DATASET_ARTIFACT = "capecape/gtc/np_dataset:v0"

config = SimpleNamespace(
    epochs=50,  # number of epochs
    model_name="unet_small",  # model name to save [unet_small, unet_big]
    strategy="ddpm",  # strategy to use ddpm
    noise_steps=1000,  # number of noise steps on the diffusion process
    sampler_steps=333,  # number of sampler steps on the diffusion process
    seed=42,  # random seed
    batch_size=128,  # batch size
    img_size=64,  # image size
    device="cuda",  # device
    num_workers=8,  # number of workers for dataloader
    num_frames=4,  # number of frames to use as input
    lr=5e-4,  # learning rate
    validation_days=3,  # number of days to use for validation
    log_every_epoch=5,  # log every n epochs to wandb
    n_preds=8,  # number of predictions to make
)

# downlaod the dataset from the wandb.Artifact
files = download_dataset(DATASET_ARTIFACT, PROJECT_NAME)
train_days, valid_days = files[: -config.validation_days], files[-config.validation_days :]
train_ds = CloudDataset(files=train_days, num_frames=config.num_frames, img_size=config.img_size)
valid_ds = CloudDataset(files=valid_days, num_frames=config.num_frames, img_size=config.img_size).shuffle()

next(iter(valid_ds))

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mphinate[0m ([33mmanchester_prize[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact np_dataset:v0, 3816.62MB. 30 files... 
[34m[1mwandb[0m:   30 of 30 files downloaded.  
Done. 0:6:44.4


tensor([[[ 1.6665e-02,  6.3076e-02,  5.8489e-02,  ..., -2.7529e-01,
          -2.9290e-01, -3.0397e-01],
         [ 1.6778e-04,  5.4600e-02,  5.6731e-02,  ..., -2.9674e-01,
          -3.0536e-01, -3.1290e-01],
         [ 1.9444e-02,  4.4553e-02,  3.7216e-02,  ..., -3.0977e-01,
          -3.2065e-01, -3.2652e-01],
         ...,
         [-3.3933e-01, -3.2881e-01, -2.5199e-01,  ..., -3.1279e-01,
          -3.1407e-01, -3.1473e-01],
         [-3.3650e-01, -3.2294e-01, -2.7958e-01,  ..., -3.1458e-01,
          -3.1538e-01, -3.1643e-01],
         [-3.2120e-01, -2.8458e-01, -2.8682e-01,  ..., -3.1503e-01,
          -3.1591e-01, -3.1746e-01]],

        [[ 3.1216e-02,  2.7759e-02,  3.9924e-02,  ..., -2.5125e-01,
          -2.8785e-01, -2.9981e-01],
         [ 3.0683e-02,  3.7654e-02,  3.7798e-02,  ..., -2.2979e-01,
          -2.9156e-01, -3.0760e-01],
         [ 2.6864e-02,  3.3317e-02,  2.5960e-02,  ..., -1.8531e-01,
          -2.6132e-01, -3.0356e-01],
         ...,
         [-3.2504e-01, -3

In [4]:
train_ds = CloudDataset(files=train_days, num_frames=config.num_frames, img_size=446)

In [5]:
next(iter(train_ds)).shape

torch.Size([4, 446, 446])

In [None]:
import numpy as np
import torchvision.transforms as T
from cloudcasting.constants import IMAGE_SIZE_TUPLE


class CloudcastingDataset(SatelliteDataset):
    def __init__(self, img_size, valid=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        tfms = [T.Resize((img_size, int(img_size * (IMAGE_SIZE_TUPLE[1] / IMAGE_SIZE_TUPLE[0]))))] if img_size is not None else []
        tfms += [T.RandomCrop(img_size)] if not valid else [T.CenterCrop(img_size)]
        self.tfms = T.Compose(tfms)

    def __getitem__(self, idx: int):
        # concatenate future prediction and previous frames along time axis
        concat_data = np.concatenate(super().__getitem__(idx), axis=-3)
        # data is in [0,1] range, normalize to [-0.5, 0.5]
        # note that -1s could be NaNs, which are now at +1.5
        # output has shape (11, history_steps + forecast_horizon, height, width)
        return 0.5 - self.tfms(torch.from_numpy(concat_data))

In [None]:


class CloudDataset(SatelliteDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __getitem__(self, idx):
        return torch.cat(self.data[idx], dim=-3)

In [7]:


x = np.ones((2, 3, 4, 5, 6))

x[0][0].shape

(4, 5, 6)

In [8]:
x[0, 0].shape

(4, 5, 6)

In [9]:
x[0, 0, ...].shape

(4, 5, 6)

In [14]:
x[:, None, ...].shape

(2, 1, 3, 4, 5, 6)

In [61]:
dummy_data = torch.randn(10, 11, 4, 64, 64, requires_grad=False)

betamin, betamax, n_steps = 0.0001, 0.02, 1000
beta = torch.linspace(betamin, betamax, n_steps)
alpha = 1.0 - beta
alphabar = alpha.cumprod(dim=0)
sigma = beta.sqrt()


def noisify_ddpm(x0):
    "Noise by ddpm"
    device = x0.device
    n = len(x0)
    t = torch.randint(0, n_steps, (n,), dtype=torch.long)
    ε = torch.randn(x0.shape, device=device)
    ᾱ_t = alphabar[t].reshape(-1, 1, 1, 1).to(device)
    xt = ᾱ_t.sqrt() * x0 + (1 - ᾱ_t).sqrt() * ε
    return xt, t.to(device), ε


from torch import vmap


def noisify_last_frame_channels(frames, noise_func):
    "Noisify the last frame of a sequence. Inputs have shape (batch, channels, time, height, width)."
    past_frames = frames[:, :, :-1]
    last_frame = frames[:, :, -1:]

    # vmap over channels (dim=1) -- idk why output dim = 1 doesn't work, but this does!
    # our out dims are (channels, batch, time, height, width), so we rejig later
    # the None will just not vmap over the returned diffusion step counts (called t)
    channel_noisify = vmap(noise_func, in_dims=1, out_dims=(0, None, 0), randomness="same")
    noise, t, e = channel_noisify(last_frame)

    # reshape to (batch, channels, time, height, width) ready for diffusion model, both for noise and e
    # leave channels intact for now
    noise = torch.swapaxes(noise, 0, 1)
    history_and_noisy_target = torch.cat([past_frames, noise], dim=2)
    history_and_noisy_target = history_and_noisy_target.view(
        history_and_noisy_target.shape[0],
        history_and_noisy_target.shape[1] * history_and_noisy_target.shape[2],  # collapse channels and time
        history_and_noisy_target.shape[3],
        history_and_noisy_target.shape[4],
    )

    e = torch.swapaxes(e, 0, 1)
    e = e.view(e.shape[0], e.shape[1] * e.shape[2], e.shape[3], e.shape[4])

    return history_and_noisy_target, t, e

In [62]:
noisify_last_frame_channels(dummy_data, lambda x: (0 * x, 1, 0 * x))[0][0, :, 0, 0]

tensor([ 0.5776,  0.6119, -0.5367,  0.0000, -0.1512,  0.9278,  0.7324, -0.0000,
        -0.2882,  2.1622,  2.3457,  0.0000, -0.9894, -1.6108,  0.8110,  0.0000,
        -1.8311, -0.9710, -0.7868,  0.0000,  0.0486, -1.0543,  1.3018, -0.0000,
         0.9455, -0.1048, -1.1720,  0.0000, -0.3532, -0.3649, -1.0455, -0.0000,
         0.2589, -0.5421,  0.4593,  0.0000, -0.9825,  0.3658, -1.3455,  0.0000,
         1.2885, -0.3904, -0.8328,  0.0000])

In [36]:
dummy_data[:, -1:].shape

torch.Size([10, 1, 1, 64, 64])