In [1]:
from typing import Tuple
import torch
from torch import nn

class TemporalEncoder(nn.Module):
    def __init__(
        self, input_size: Tuple[int, int], num_images: int, device: str
        ) -> None:
        super().__init__()
        # Set the input size of the image.
        self.input_size = input_size
        # Set the size of the flattened image.
        self.flatten_size = input_size[0] * input_size[1]
        # Set a list of GRUs, one for each image.
        self.gru = nn.GRU(
            self.flatten_size, self.flatten_size, num_layers=num_images,
            batch_first=True)
        #self.grus = nn.ModuleList(
        #    [nn.GRU(self.flatten_size, self.flatten_size)
        #     for _ in range(num_images)])
        # Set the device used for the computations.
        self.to(device)
        self.device = device

    def to(self, device: str) -> None:
        super().to(device)
        self.device = device

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        batch_size = x.shape[0]
        n_channels = x.shape[1]
        # Set the initial hidden states 
        initial_hidden_state = torch.zeros(
            batch_size, n_channels, self.flatten_size, dtype=torch.float32,
            device=self.device)

        _, out = self.gru(x.flatten(start_dim=2), initial_hidden_state)
        # Iterate over the images and pass them through the GRUs.
        '''for i, gru in enumerate(self.grus):
            # Flatten the image.
            img = x[:, i].flatten(start_dim=1)
            # If it is the first image, use the initial hidden state.
            if i == 0:
                h = initial_hidden_state
            # Get the forward pass of the GRU.
            h, _ = gru(img, h)''';
        
        # Turn the hidden state to the original shape.
        out = out.view(batch_size, n_channels, self.input_size[0],
                       self.input_size[1])
        return out

In [2]:
import numpy as np

img_ex = np.load('./artifacts/np_dataset-v0/B07_2022_06_02.npy')

In [3]:
img_ex.shape

(96, 1, 446, 780)

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TemporalEncoder(input_size=(64, 64), num_images=3, device=device)

In [5]:
from pathlib import Path

def ls(path: Path): 
    "Return files on Path, sorted"
    return sorted(list(path.iterdir()))

In [6]:
from pathlib import Path

import numpy as np
import torch
import torchvision.transforms as T
import wandb
from tqdm import tqdm as progress_bar

# from cloud_diffusion.utils import ls

PROJECT_NAME = "ddpm_clouds"
DATASET_ARTIFACT = 'capecape/gtc/np_dataset:v0'

class DummyNextFrameDataset:
    "Dataset that returns random images"
    def __init__(self, num_frames=4, img_size=64, N=1000):
        self.img_size = img_size
        self.num_frames = num_frames
        self.N = N

    def __getitem__(self, idx):
        return torch.randn(self.num_frames, self.img_size, self.img_size)
    
    def __len__(self):
        return self.N


class CloudDataset:
    """Dataset for cloud images
    It loads numpy files from wandb artifact and stacks them into a single array
    It also applies some transformations to the images
    """
    def __init__(self, 
                 files, # list of numpy files to load (they come from the artifact)
                 num_frames=4, # how many consecutive frames to stack
                 scale=True, # if we images to interval [-0.5, 0.5]
                 img_size=64, # resize dim, original images are big (446, 780)
                 valid=False, # if True, transforms are deterministic
                ):
        
        tfms = [T.Resize((img_size, int(img_size*1.7)))] if img_size is not None else []
        tfms += [T.RandomCrop(img_size)] if not valid else [T.CenterCrop(img_size)]
        self.tfms = T.Compose(tfms)
        self.load_data(files, num_frames, scale)
        
    def load_day(self, file, scale=True):
        one_day = np.load(file)
        if scale:
            one_day = 0.5 - self._scale(one_day)
        return one_day

    def load_data(self, files, num_frames, scale):
        "Loads all data into a single array self.data"
        data = []
        # TODO: download all files
        for file in progress_bar(files[0:2], leave=False):
            one_day = self.load_day(file, scale)
            wds = np.lib.stride_tricks.sliding_window_view(
                one_day.squeeze(),
                num_frames,
                axis=0).transpose((0,3,1,2))
            data.append(wds)
            # pbar.comment = f"Creating CloudDataset from {file}"
        self.data = np.concatenate(data, axis=0)

    def shuffle(self):
        """Shuffles the dataset, useful for getting 
        interesting samples on the validation dataset"""
        idxs = torch.randperm(len(self.data))
        self.data = self.data[idxs]
        return self

    @staticmethod
    def _scale(arr):
        "Scales values of array in [0,1]"
        m, M = arr.min(), arr.max()
        return (arr - m) / (M - m)
    
    def __getitem__(self, idx):
        return self.tfms(torch.from_numpy(self.data[idx]))
    
    def __len__(self): return len(self.data)

    def save(self, fname="cloud_frames.npy"):
        np.save(fname, self.data)


class CloudDatasetInference(CloudDataset):
     def load_data(self, files, num_frames=None, scale=None):
        "Loads all data into a single array self.data"
        data = []
        max_length = 100
        # TODO: download everything
        for file in files[0:2]:
            one_day = self.load_day(file, scale)
            data.append(one_day)
            max_length = min(max_length, len(one_day))
        self.data = np.stack([d[:max_length] for d in data], axis=0).squeeze()


def download_dataset(at_name, project_name):
    "Downloads dataset from wandb artifact"
    def _get_dataset(run):
        artifact = run.use_artifact(at_name, type='dataset')
        return artifact.download()

    if wandb.run is not None:
        run = wandb.run
        artifact_dir = _get_dataset(run)
    else:
        run = wandb.init(project=project_name, job_type="download_dataset")
        artifact_dir = _get_dataset(run)
        run.finish()

    files = ls(Path(artifact_dir))
    return files


In [7]:
files = download_dataset(DATASET_ARTIFACT, project_name=PROJECT_NAME)
train_ds = CloudDataset(files)
print(f"Let's grab 5 samples: {train_ds[0:5].shape}")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mriccardo-spolaor94[0m ([33mai-industry[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact np_dataset:v0, 3816.62MB. 30 files... Done. 0:0:0.1


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

                                             

Let's grab 5 samples: torch.Size([5, 4, 64, 64])




In [8]:
train_ds[0:3][:, :3].shape

torch.Size([3, 3, 64, 64])

In [10]:
out = model(train_ds[0:3][:,:3].to(device))



In [11]:
out.shape

torch.Size([3, 3, 64, 64])