to-do
- create a dataset of images and tag them as is_big_flare - true or false
  - run the current notebook to get the hourly images for 2015 (no consideration for flare/no-flare)
  - make this code a module to get a function that can pull desired images from SDO-dataset
    - function schema: 
      get_sdo_solar_images_from_aws(
          wanted_times, # list of datetimes for which images are desired
          save_folder_path, # folder path to save .png files
      ) -> fetched_image_paths # list of paths of image .png files saved
  - use the function to pull images pertaining to all big-flare events in 2015 and add to the images dataset
  - move on to making a flare classifier

In [None]:
# %matplotlib inline

import os
from typing import Union
import zarr

# import gcsfs
import s3fs
import sunpy.map

import dask.array as da
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
import sunpy.visualization.colormaps as cm

from astropy.time import Time
from dask.distributed import Client, LocalCluster
from sunpy.visualization import axis_labels_from_ctype, wcsaxes_compat

from matplotlib import animation
from IPython.display import HTML

In [None]:
# fix dask cache size

from dask.cache import Cache
cache = Cache(1e9)  # Leverage 1 GB of memory
cache.register()    # Turn cache on globally

In [None]:
# AWS_ZARR_ROOT = (
#     "s3://gov-nasa-hdrl-data1/contrib/fdl-sdoml/fdl-sdoml-v2/sdomlv2_small.zarr/"
# )

AWS_ZARR_ROOT = (
    "s3://gov-nasa-hdrl-data1/contrib/fdl-sdoml/fdl-sdoml-v2/sdomlv2.zarr/2015/"
)

def s3_connection(path_to_zarr: os.path) -> s3fs.S3Map:
    """
    Instantiate connection to aws for a given path `path_to_zarr`
    """
    return s3fs.S3Map(
        root=path_to_zarr,
        s3=s3fs.S3FileSystem(anon=True),
        # anonymous access requires no credentials
        check=False,
    )


def load_single_aws_zarr(
    path_to_zarr: os.path,
    cache_max_single_size: int = None,
) -> Union[zarr.Array, zarr.Group]:
    """
    load zarr from s3 using LRU cache
    """
    return zarr.open(
        zarr.LRUStoreCache(
            store=s3_connection(path_to_zarr),
            max_size=cache_max_single_size,
        ),
        mode="r",
    )

## 2. Reading and loading the AIA data


The SDO ML dataset is stored in the Zarr format, a format for the storage of chunked, compressed, N-dimensional arrays with Numpy dtype. For an in-depth overview, see https://zarr.readthedocs.io/en/stable/tutorial.html.

In [None]:
# first, we create a group with the store data located on GCP.
root = load_single_aws_zarr(
    path_to_zarr=AWS_ZARR_ROOT,
)

In [None]:
# Using `root.tree()`, we are able to display the hierarchy (of `loc`).
# print(root.tree())

As shown in the tree, the heirachy consists of groups, each shown with their respective shape, and data type. In this example, we will primarily look at the 171 Å channel from 2010. This consists of 6135 512x512 images, stored as float32, and can be accessed as follows:

In [None]:
images_171a_zarray = root["171A"]

We could have alternatively accessed the 2010 data as:

```
loc = 'fdl-sdoml-v2/sdomlv2_small.zarr/2010'
```

which becomes increasingly useful in the full dataset (where the heirachy contains years 2010 - present).

**Loading with Dask**

We can then load this data into an array using dask.

In [None]:
all_image = da.from_array(images_171a_zarray)
all_image

As shown above, the data has the shape (6135, 512, 512), and is split into 52 chunks of (120, 512, 512), each of 125.83 MB; this is further visualised on the right. The data is now in a form to be manipulated like a Numpy array.

We can load and display one image now:

In [None]:
image=all_image[0,:,:]
plt.figure(figsize=(5,5))
colormap = plt.get_cmap('sdoaia171')
plt.imshow(image.compute(),origin='lower',vmin=10,vmax=1000,cmap=colormap)

Depending on the use-case, we may wish to extract a subset of this data in various ways. In the following sections we step through a number of potential operations that we may wish to make with the data.

### 2a. Selecting images based on header information

The new data includes all fits header information with the same keywords. To find out the AIA keyword definition, one can refer to the following online document:
http://jsoc.stanford.edu/~jsoc/keywords/AIA/AIA02840_K_AIA-SDO_FITS_Keyword_Document.pdf

And one can list all the AIA keywords included:

We can extract the exposure (and observation) time from the data attributes (the header information), and downsample our data based upon that information.

### 2b. Selecting images based on indices

While the data is not currently ordered by observation time, we can simple index the array to extract a number of observations

In [None]:
# We are going to choose data in 1 hour intervals


# df_time = pd.DataFrame(t_obs, index=np.arange(np.shape(t_obs)[0]), columns=["Time"])
# df_time["Time"] = pd.to_datetime(df_time["Time"])

# select times at a frequency of 60 minutes
wanted_times = pd.date_range(
    start="2015-01-01 00:00:00", end="2015-12-31 23:59:59", freq="60T", tz="UTC"
)

In [None]:
# [NEW] select subset of images based on time

# get the indices of the images that are closest to the wanted times: images_zry_wanted_idxs
images_zry_wanted_idxs = []
images_zry_times = pd.to_datetime(np.array(images_171a_zarray.attrs["T_OBS"]))
for selected_time in wanted_times[None:None]:
    images_zry_wanted_idxs.append(np.argmin(abs(images_zry_times - selected_time)))
images_zry_wanted_idxs = sorted(set(images_zry_wanted_idxs))

# get wanted times
images_wanted_times = images_zry_times[images_zry_wanted_idxs]

print("images_zry_wanted_idxs = ", images_zry_wanted_idxs)
print("len(images_zry_wanted_idxs) = ", len(images_zry_wanted_idxs))

In [None]:
# get one image from aws


def get_single_solar_image(image_idx, images_171a_zarray=images_171a_zarray):
    images_wanted_drry = da.from_array(images_171a_zarray)
    image = np.array(images_wanted_drry[image_idx, :, :])
    return image


image = get_single_solar_image(images_zry_wanted_idxs[4002])

# downsample the image to 256 by sampling every other pixel
downsampled_pxl_posns = np.arange(0, image.shape[0], 2)
image_downsmpd = image[downsampled_pxl_posns, :][:, downsampled_pxl_posns]

plt.figure(figsize=(5, 5))
plt.imshow(image, origin="lower", vmin=10, vmax=1000, cmap=plt.get_cmap("sdoaia171"))
plt.figure(figsize=(5, 5))
plt.imshow(
    image_downsmpd, origin="lower", vmin=10, vmax=1000, cmap=plt.get_cmap("sdoaia171")
)

In [None]:
# ASH is TESTINg.....

# TODO: change the code to track images_processed_times and don't use images_df
#       - also stop using csvs
#       - save each image as a PNG and name it with the datetime in format
#         dt=yyyy-mm-dd_hhmmss.png e.g. dt=2015-01-01_170400.png

import pandas as pd
import glob
import os
import re
import sys

# get the images that have been processed already: images_processed_times
images_png_folder = "/Users/aishsk6/gd_to_be_archived_big_files/sdo_image_data"
images_processed_paths = glob.glob(os.path.join(images_png_folder, "*.png"))
images_processed_times = [
    pd.to_datetime(re.sub(".png", "", os.path.basename(path)))
    for path in images_processed_paths
]

for image_time in images_wanted_times[None:None]:
    current_img_time = image_time
    # print(f"current_img_time: {current_img_time}")

    # get the position of image_time in images_wanted_times
    image_time_idx = list(images_wanted_times).index(image_time)
    # print(f"image_time_idx: {image_time_idx}")

    # check if the images_processed_times contains the row currently being processed and skip iter if true
    if current_img_time in images_processed_times:
        # print('current_img_time:', current_img_time, 'images_processed_times:', images_processed_times)
        print(
            f"Skipping image_time_idx {image_time_idx} as it has been processed already."
        )
        continue

    # get current image
    image_arr = get_single_solar_image(images_zry_wanted_idxs[image_time_idx])[
        downsampled_pxl_posns, :
    ][:, downsampled_pxl_posns]
    print(sys.getsizeof(image_arr) / (1024 * 1024))


    # # Processing image
    # fig = plt.figure(figsize=(5, 5))
    # plt.imshow(image_arr, origin="lower", vmin=10, vmax=1000, cmap="sdoaia171")
    # plt.savefig(f"{images_png_folder}/{current_img_time}.png")
    # plt.close("all")  # Close the figure manually to release resources

    # # Explicit memory management
    # del image_arr
    # gc.collect()

    print(f"image_time_idx: {image_time_idx} of {len(images_wanted_times)}")

In [None]:
import sys

sys.getsizeof(images_zry_times) / (1024 ** 2)

In [None]:
images_processed_times

In [None]:
current_img_time