In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch import nn
import random, os

from cloudcasting.dataset import SatelliteDataset


def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.mps.manual_seed(seed)
    
seed_everything(42)



In [26]:
from diffusers import DDPMScheduler, UNet2DModel
from cloudcasting.constants import NUM_CHANNELS, IMAGE_SIZE_TUPLE, NUM_FORECAST_STEPS


class ConditionedUnet(nn.Module):
    history_steps: int
    def __init__(self, image_size, history_steps = 1):
        super().__init__()

        self.history_steps = history_steps

        # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (previous timesteps)
        self.model = UNet2DModel(
            sample_size=image_size,  # the target image resolution
            in_channels=NUM_CHANNELS + (history_steps - 1) * NUM_CHANNELS,  # Additional input channels for previous timesteps
            out_channels=NUM_CHANNELS,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=(128, 256, 512),
            down_block_types=(
                "DownBlock2D",  # a regular ResNet downsampling block
                "DownBlock2D",  # a regular ResNet downsampling block
                "DownBlock2D",  # a regular ResNet downsampling block

                # "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                # "AttnDownBlock2D",
            ),
            up_block_types=(
                # "AttnUpBlock2D",
                # "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                "UpBlock2D",  # a regular ResNet upsampling block
                "UpBlock2D",  # a regular ResNet upsampling block
                "UpBlock2D",  # a regular ResNet upsampling block
            ),
            add_attention = False,
        )

    # Our forward method now takes the class labels as an additional argument
    def forward(self, x, t, rollout_steps = NUM_FORECAST_STEPS):
        
        sample_shape = x.shape[-4:] 
        print(f"{sample_shape=}")
        assert sample_shape == (NUM_CHANNELS, self.history_steps, *IMAGE_SIZE_TUPLE), f"Input shape {x.shape} does not match expected shape (channels, history_steps, height, width)"

        # Calculate required crop for the input dims to be divisible by 16
        x_cropped_shape = [(size // 16) * 16 for size in x.shape[-2:]]
        x_cropped = x[..., :x_cropped_shape[0], :x_cropped_shape[1]]
        print(f"{x_cropped.shape=}")
        
        # Reshape input images (batch, channels, time, height, width) to (batch, channels*time, height, width)
        # use negative indexing to allow for non-batched input
        net_input = x_cropped.reshape(-1, x_cropped.shape[-4] * x_cropped.shape[-3], *x_cropped.shape[-2:])

        return self.model(net_input, t).sample
    


In [27]:
from cloudcasting.constants import DATA_INTERVAL_SPACING_MINUTES

TRAINING_DATA_PATH = "/Users/nsimpson/code/climetrend/cloudcast/2020_training_nonhrv.zarr"
HISTORY_STEPS = 1

# Instantiate the torch dataset object
dataset = SatelliteDataset(
    zarr_path=TRAINING_DATA_PATH,
    start_time=None,
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=180,
    sample_freq_mins=15,
    nan_to_num=True,
)

In [None]:
# Number of complete samples in the dataset
# - this includes overlapping periods, not completely distinct periods
n_samples = len(dataset)
print(n_samples)

# nan percentage
print(f"NaN percentage: {np.mean(np.isnan(dataset.ds)).compute()}")

In [29]:
# def seed_worker(worker_id):
#     worker_seed = torch.initial_seed() % 2**32
#     np.random.seed(worker_seed)
#     random.seed(worker_seed)

# g = torch.Generator()
# g.manual_seed(0)

batch_size = 2
num_workers = 0

# dataloader = DataLoader(
#     dataset=dataset,
#     batch_size=batch_size,
#     num_workers=num_workers,
#     worker_init_fn=seed_worker,
#     generator=g,
# )

dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    num_workers=num_workers,
)


In [None]:
X, y = next(iter(dataloader))

X.shape


In [31]:
# Instantiate the model
# Calculate required crop for the input dims to be divisible by 16
x_cropped_shape = [(size // 16) * 16 for size in X.shape[-2:]]
model = ConditionedUnet(x_cropped_shape, history_steps=HISTORY_STEPS)

In [32]:
X = X.to("mps")
model = model.to("mps")

In [None]:
model(X, 3)