In [None]:
import goes2go
import pandas as pd
import numpy as np
import xarray as xr
from goes2go import GOES
from datetime import datetime, timedelta
import matplotlib.pyplot as plt

import warnings
warnings.simplefilter("ignore")

import fsspec
fs = fsspec.filesystem('s3', anon=True)
fsspec_caching = {
    "cache_type": "blockcache",  # block cache stores blocks of fixed size and uses eviction using a LRU strategy.
    "block_size": 8
    * 1024
    * 1024,  # size in bytes per block, adjust depends on the file size but the recommended size is in the MB
}

#### Find all files between launch of satellite and now

We need to write a loop around this, since the goes2go searching across long time periods is pretty slow. I also noticed that `goes2go.goes_nearesttime` takes way longer than searching for a timerange.

In [None]:
abi_files = goes2go.goes_timerange(
    start=datetime(2018, 1, 1, 00, 00), 
    end=datetime(2018, 1, 1, 00, 30),
    download=False,
    domain='F',
    product="ABI-L2-MCMIP",
)

In [None]:
abi_files

#### Other Data Products

In [None]:
# Cloud height at 2 km

acha_files = goes2go.goes_timerange(
    start=datetime(2018, 4, 1, 10, 00), 
    end=datetime(2018, 4, 1, 11, 00),
    download=False,
    # domain='F',
    product="ABI-L2-ACHA2KMF",
)

In [None]:
# Cloud pressure at 2 km

achp_files = goes2go.goes_timerange(
    start=datetime(2023, 4, 1, 10, 00), 
    end=datetime(2023, 4, 1, 11, 00),
    download=False,
    # domain='F',
    product="ABI-L2-ACHP2KMF",
)

In [None]:
# Cloud optical depth at 2 km

cod_files = goes2go.goes_timerange(
    start=datetime(2023, 4, 1, 10, 00), 
    end=datetime(2023, 4, 1, 11, 00),
    download=False,
    # domain='F',
    product="ABI-L2-COD2KMF",
)

In [None]:
# Cloud temperature at 2 km

acht_files = goes2go.goes_timerange(
    start=datetime(2020, 3, 1, 10, 00), 
    end=datetime(2020, 3, 1, 11, 00),
    download=False,
    # domain='F',
    product="ABI-L2-ACHTF",
)

In [None]:
# Clear sky mask at 2 km

acm_files = goes2go.goes_timerange(
    start=datetime(2022, 4, 1, 10, 00), 
    end=datetime(2022, 4, 1, 11, 00),
    download=False,
    # domain='F',
    product="ABI-L2-ACMF",
)

In [None]:
# Cloud phase at 2 km

actp_files = goes2go.goes_timerange(
    start=datetime(2022, 4, 1, 10, 00), 
    end=datetime(2022, 4, 1, 11, 00),
    download=False,
    # domain='F',
    product="ABI-L2-ACTPF",
)

In [None]:
# Cloud particle size at 2 km

cps_files = goes2go.goes_timerange(
    start=datetime(2022, 3, 1, 10, 00), 
    end=datetime(2022, 3, 1, 11, 00),
    download=False,
    # domain='F',
    product="ABI-L2-CPSF",
)

#### Open example file

In [None]:
index = 0
goes_data = xr.open_dataset(fs.open(abi_files['file'][index], **fsspec_caching), engine="h5netcdf")

In [None]:
goes_data

In [None]:
goes_data.CMI_C14.plot()

#### Create sub-patch from full-disk

- Mask percentage of full disk
- Sample with bias towards the center

In [None]:
ds = goes_data

In [None]:
def create_fov_mask(shape, fov_radius):
    """
    Function to create mask for specified field of view.
    """
    # Create coordinate grids
    y, x = np.ogrid[:shape[0], :shape[1]]
    # Calculate center points
    center_y, center_x = shape[0] // 2, shape[1] // 2
    # Calculate distance from center for each point
    dist_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2)
    # Normalize distances by max possible distance (corner to center)
    max_dist = np.sqrt((center_x)**2 + (center_y)**2)
    normalized_dist = dist_from_center / max_dist
    # Create mask for specified field of view
    mask = normalized_dist <= fov_radius
    return mask

In [None]:
def check_quality_flags(ds):
    """
    Function to check quality flags in the dataset.
    0 --> good pixel quality
    1 --> conditionally usable pixel quality
    2 --> out of range pixel quality
    3 --> no value pixel quality
    4 --> focal plane temperature threshold exceeded pixel quality
    """
    # Check each channel individually - exit early if bad quality found
    for i in range(1, 17):
        if (ds[f'DQF_C{i:02d}'] > 0).any().item():
            return False
    return True

In [None]:
class CenterWeightedCropDatasetEditor():
    def __init__(self, patch_shape, fov_radius=0.6, max_attempts=10):
        self.patch_shape = patch_shape
        self.fov_radius = fov_radius
        self.max_attempts = max_attempts
    def __call__(self, ds):
        assert ds['x'].shape[0] >= self.patch_shape[0], 'Invalid dataset shape: %s' % str(ds['x'].shape)
        assert ds['y'].shape[0] >= self.patch_shape[1], 'Invalid dataset shape: %s' % str(ds['y'].shape)

        # get x/y grid
        x_grid, y_grid = np.meshgrid(np.arange(0, ds.x.shape[0], 1), np.arange(0, ds.y.shape[0], 1))

        # create mask for valid coordinates within desired field of view
        # NOTE: This masks from the center to the image edge, rather than disk edge
        valid_mask = create_fov_mask(shape=(ds.x.shape[0], ds.y.shape[0]), fov_radius=self.fov_radius)

        # get coordinate pairs for valid points
        coords_on_disk = np.column_stack((x_grid[valid_mask], y_grid[valid_mask]))
        del x_grid, y_grid

        attempts = 0
        while attempts <= self.max_attempts:
            # pick random x/y index
            random_idx = np.random.randint(0, len(coords_on_disk))
            x, y = tuple(coords_on_disk[random_idx])
            # define patch boundaries
            xmin = x - self.patch_shape[0] // 2
            ymin = y - self.patch_shape[1] // 2
            xmax = x + self.patch_shape[0] // 2
            ymax = y + self.patch_shape[1] // 2

            # crop patch
            patch_ds = ds.sel({'x': slice(ds['x'][xmin], ds['x'][xmax - 1]),
                                'y': slice(ds['y'][ymin], ds['y'][ymax - 1])})
            # check data quality flags
            if check_quality_flags(patch_ds) == False:
                print('Found patch with bad quality flags, trying again...')
                # try new set of indices
                attempts += 1
                continue   
            else:
                # exit loop and return patch
                return patch_ds, xmin, ymin

        print('Could not find patch without bad quality flags after %d cropping attempts' % self.max_attempts)

In [None]:
patch_size = 1024

crop = CenterWeightedCropDatasetEditor(patch_shape=(patch_size, patch_size), fov_radius=0.6)
patch_ds, xmin, ymin = crop(ds)

In [None]:
patch_ds.CMI_C06.plot()