[Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)

[Diffusers Quicktour](https://huggingface.co/docs/diffusers/en/quicktour)

In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import os

In [5]:
import gcsfs
fs = gcsfs.GCSFileSystem()

## Dataloader

In [16]:
batch_size = 32
shuffle_data = True

In [6]:
user_path = "gs://leap-scratch/sammyagrawal"
fs.ls(user_path)

['leap-scratch/sammyagrawal/aquaplanet_in_1',
 'leap-scratch/sammyagrawal/aquaplanet_in_1.zarr',
 'leap-scratch/sammyagrawal/aquaplanet_in_2.zarr',
 'leap-scratch/sammyagrawal/aquaplanet_in_3.zarr',
 'leap-scratch/sammyagrawal/aquaplanet_out_2.zarr',
 'leap-scratch/sammyagrawal/aquaplanet_out_3.zarr',
 'leap-scratch/sammyagrawal/input_climsim.npy',
 'leap-scratch/sammyagrawal/output_climsim.npy']

In [7]:
with fs.open(os.path.join(user_path, "input_climsim.npy"), 'rb') as f:
    X_npy = np.load(f)
with fs.open(os.path.join(user_path, "output_climsim.npy"), 'rb') as f:
    Y_npy = np.load(f)


In [None]:
from sklearn.model_selection import train_test_split
X_train, X_tst, y_train, y_tst = train_test_split(X_npy, Y_npy, test_size=0.3)
X_val, X_test, y_val, y_test = train_test_split(X_tst, y_tst, test_size=0.5)

In [None]:
X_npy

In [13]:
class ClimsimDataset(Dataset):
    def __init__(self, input_npy, output_npy):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.X = torch.tensor(input_npy, device=self.device, dtype=torch.float32)
        self.Y = torch.tensor(output_npy, device=self.device, dtype=torch.float32)
        assert self.X.shape[0] == self.Y.shape[0], "Number of samples does not match"

    def __len__(self):
        return(X.shape[0])

    def __getitem__(self, idx):
        return(self.X[idx], self.Y[idx])

In [14]:
train_ds = ClimsimDataset(X_npy, Y_npy)

In [15]:
%%time
train_ds[5]

CPU times: user 149 μs, sys: 38 μs, total: 187 μs
Wall time: 193 μs


(tensor([2.1850e+02, 2.3478e+02, 2.4224e+02, 2.5092e+02, 2.5916e+02, 2.6575e+02,
         2.6545e+02, 2.5751e+02, 2.4577e+02, 2.3765e+02, 2.3161e+02, 2.2734e+02,
         2.2404e+02, 2.2101e+02, 2.1757e+02, 2.1354e+02, 2.0922e+02, 2.0694e+02,
         2.0351e+02, 2.0360e+02, 2.0340e+02, 2.0598e+02, 2.0900e+02, 2.1271e+02,
         2.1670e+02, 2.2074e+02, 2.2472e+02, 2.2858e+02, 2.3228e+02, 2.3585e+02,
         2.3934e+02, 2.4267e+02, 2.4585e+02, 2.4889e+02, 2.5178e+02, 2.5449e+02,
         2.5729e+02, 2.5999e+02, 2.6255e+02, 2.6483e+02, 2.6693e+02, 2.6887e+02,
         2.7072e+02, 2.7240e+02, 2.7383e+02, 2.7517e+02, 2.7634e+02, 2.7724e+02,
         2.7825e+02, 2.7938e+02, 2.8025e+02, 2.8112e+02, 2.8201e+02, 2.8277e+02,
         2.8359e+02, 2.8443e+02, 2.8531e+02, 2.8613e+02, 2.8683e+02, 2.8782e+02,
         1.5264e-06, 1.5171e-06, 1.5060e-06, 1.5031e-06, 1.4913e-06, 1.4880e-06,
         1.5005e-06, 1.5142e-06, 1.5336e-06, 1.5127e-06, 1.3748e-06, 1.2566e-06,
         1.2185e-06, 1.2256e

In [None]:
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle_data) 