In [None]:
import ast
import sunpy.map
import torch

import matplotlib.pyplot as plt

from sdoml import SDOMLDataset
from timeit import default_timer as timer

First, we will instantiate the ``SDOMLDataset`` class, to load one month of 
the six optically-thin SDO/AIA channels (94A/131A/171A/193A/211A/335A) alongside the 3 components of the HMI magnetograms (Bx, By, Bz)
from ``fdl-sdoml-v2/sdomlv2_small.zarr`` and ``fdl-sdoml-v2/sdomlv2_hmi_small.zarr/``. 

**For an in-depth overview of accessing data, see the AIA example notebook!**

In [None]:
sdomlds = SDOMLDataset(
    cache_max_size=1 * 512 * 512 * 4096,
    years=[
        "2010",
    ],
    data_to_load={
        "AIA": {
            "storage_location": "gcs",
            "root": "/Users/pwright/Documents/work/spaceml/data/sdomlv2_small.zarr",
            "channels": ["94A", "131A", "171A", "193A", "211A", "335A"],
        },
        "HMI": {
            "storage_location": "gcs",
            "root": "/Users/pwright/Documents/work/spaceml/data/sdomlv2_hmi_small.zarr",
            "channels": ["Bx", "By", "Bz"],
        },
    },
)

Let's use the ``torch.utils.data.DataLoader`` iterator with a ``batch_size`` of 1, and no shuffling of the data.

As will be evident, the first data access for a given chunk is relatively slow (it is retrieved from remote store on Google Cloud Storage), however the second data access is faster, as this uses cache. For more information see https://zarr.readthedocs.io/en/stable/api/storage.html#zarr.storage.LRUStoreCache

In [None]:
dataloader = torch.utils.data.DataLoader(
    sdomlds,
    batch_size=1,
    shuffle=False,
)

In [None]:
data = next(iter(dataloader))

### Plotting one set of images

For the one set of images returned from the dataloader, the following code block creates the set of ``sunpy.map`` from the ``images`` and ``metadata``.

In [None]:
plt.figure(figsize=(20, 20))

i = 0
# iterate through instruments (here there is only AIA and HMI)
for inst in data["data"]:
    # iterate through the channels for a give batch_index
    for img_index in range(data["data"][inst][0, :, 0, 0].shape[0]):
        # Create a sunpy map with the data
        selected_image = data["data"][inst][0, img_index, :, :]
        selected_headr = {
            keys: values[img_index]
            for keys, values in ast.literal_eval(data["meta"][inst][0]).items()
        }
        my_map = sunpy.map.Map(selected_image.numpy(), selected_headr)

        # set the index and plot the sunpy.map
        ax = plt.subplot(3, 3, i + 1, projection=my_map)

        if my_map.meta["instrume"][0:3] == "HMI":
            my_map.plot_settings["cmap"] = "hmimag"
            my_map.plot_settings["norm"] = plt.Normalize(-1000.0, 1000.0)

        my_map.plot()

        i += 1

---