### Modis REALTIME Download and Processing
Script designed to process Modis images and save them one by one.

1. Find vtiles/htiles/dates for each data point
2. define windows around them, cut image
3. save images 


In [1]:
import pandas as pd
import geojson as gsn
from pyproj import Proj
from osgeo import gdal
from osgeo import gdalconst

import tempfile
import wget
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import os
import pickle
from collections import defaultdict
from datetime import datetime, timedelta

import xarray as xr
import rioxarray as rxr
from azure.storage.blob import ContainerClient

modis_account_name = 'modissa'
modis_container_name = 'modis-006'
modis_account_url = 'https://' + modis_account_name + '.blob.core.windows.net/'
modis_blob_root = modis_account_url + modis_container_name + '/'

# This file is provided by NASA; it indicates the lat/lon extents of each
# NOTE: this was from tutorial, not actually helpful because unprojected?

modis_tile_extents_url = modis_blob_root + 'sn_bound_10deg.txt'

temp_dir = os.path.join(tempfile.gettempdir(),'modis_snow')
os.makedirs(temp_dir,exist_ok=True)
fn = os.path.join(temp_dir,modis_tile_extents_url.split('/')[-1])
# wget.download(modis_tile_extents_url, fn)


modis_container_client = ContainerClient(account_url=modis_account_url, 
                                         container_name=modis_container_name,
                                                  credential=None)

#### Modis/azure helpers


In [2]:
def lat_lon_to_modis_tile(lat,lon):
    '''converts lat lon to modis tiles but reconstructing grid and its projection'''
    
    CELLS = 2400
    VERTICAL_TILES = 18
    HORIZONTAL_TILES = 36
    EARTH_RADIUS = 6371007.181
    EARTH_WIDTH = 2 * math.pi * EARTH_RADIUS

    TILE_WIDTH = EARTH_WIDTH / HORIZONTAL_TILES
    TILE_HEIGHT = TILE_WIDTH
    CELL_SIZE = TILE_WIDTH / CELLS
    
    MODIS_GRID = Proj(f'+proj=sinu +R={EARTH_RADIUS} +nadgrids=@null +wktext')
    
    x, y = MODIS_GRID(lon, lat)
    h = (EARTH_WIDTH * .5 + x) / TILE_WIDTH
    v = -(EARTH_WIDTH * .25 + y - (VERTICAL_TILES - 0) * TILE_HEIGHT) / TILE_HEIGHT
    
    return int(h), int(v)


def list_blobs_in_folder(container_name,folder_name):
    """
    List all blobs in a virtual folder in an Azure blob container
    """
    
    files = []
    generator = modis_container_client.list_blobs(name_starts_with=folder_name)
    for blob in generator:
        files.append(blob.name)
    return files
        
    
def list_hdf_blobs_in_folder(container_name,folder_name):
    """"
    List .hdf files in a folder
    """
    
    files = list_blobs_in_folder(container_name,folder_name)
    files = [fn for fn in files if fn.endswith('.hdf')]
    return files

# daynum = '2014236'
def daynum_gen(date_time):
    '''converts date time objects to filename'''
    doy = date_time.timetuple().tm_yday
    year = date_time.year
    return str(year) + '{:03d}'.format(doy)

In [3]:
def images_downloader(tiles, centroids, out_dataset, prod_name, verbose = False):
    """"""
    cell_ids = []
    i = 0
    for date_tile in tqdm(tiles.keys()):
        print("\n",i)

        date = date_tile[0]
        daynum = daynum_gen(date)
        daynum_og = daynum #to save later
        tile_num = (date_tile[1],date_tile[2])

      
        folder = prod_name + '/' + '{:0>2d}/{:0>2d}'.format(date_tile[1],date_tile[2]) + '/' + daynum

        # Find all HDF files from this tile on this day
        filenames = list_hdf_blobs_in_folder(modis_container_name,folder)
        print('Found {} matching file(s):'.format(len(filenames)))
        for fn in filenames:
            print(fn)
        file_root = filenames.copy()
        
        if len(file_root) > 1: #images may come in multiples
            print("multiple files found: ", len(file_root))
            blob_name1 = filenames[0]
            blob_name2 = filenames[1]
            
            # Download to a temporary file
            url1 = modis_blob_root + blob_name1
            url2 = modis_blob_root + blob_name2

            filename = os.path.join(temp_dir,blob_name1.replace('/','_'))
            if not os.path.isfile(filename):
                wget.download(url1,filename)
                
            filename = os.path.join(temp_dir,blob_name2.replace('/','_'))
            if not os.path.isfile(filename):
                wget.download(url2,filename)
            rds1 = rxr.open_rasterio(filename)
            rds2 = rxr.open_rasterio(filename)
            
            #find highest quality image
            rds1_quality = ((rds1.NDSI_Snow_Cover_Basic_QA.values >0) | (rds1.NDSI_Snow_Cover_Basic_QA.values < 2)).sum()
            rds2_quality = ((rds2.NDSI_Snow_Cover_Basic_QA.values >0) | (rds2.NDSI_Snow_Cover_Basic_QA.values < 2)).sum()
            
            rds = rds1 if rds1_quality >= rds2_quality else rds2 
                
        else:
            # Work with the first returned URL
            file_found = False
            breaker = 1
            while not file_found and breaker <= 5:
                try:
                    blob_name = filenames[0]
                    file_found = True
                except IndexError:
                    print("No file found: tile {} date {}".format(tile_num,daynum))
                    date -= timedelta(days=1)
                    daynum = daynum_gen(date) 

                    breaker +=1 
                    print("trying:", daynum)
            if breaker == 5:
                raise ValueError("Image", tile_num, daynum, "not found")


            # Download to a temporary file
            url = modis_blob_root + blob_name
            filename = os.path.join(temp_dir,blob_name.replace('/','_'))
            if not os.path.isfile(filename):
                wget.download(url,filename)

            rds = rxr.open_rasterio(filename)

        #####reproject#####
        image = rds.rio.reproject(dst_crs="EPSG:4326")
        for var in image.data_vars:
            image[var]=image[var].astype(image[var].dtype,keep_attrs = False) 


        #####create blocks around centroids#####    
        cells = tiles[date_tile]
        for cell in cells:
            center = centroids[cell]


            x_idx = np.nanargmin(np.abs(image.x.values - center[0]))
            y_idx = np.nanargmin(np.abs(image.y.values - center[1]))

            #subset 21x21 square
            xmin, xmin_actual, xmax = max(x_idx -10, 0) , x_idx -10, x_idx + 11 
            ymin, ymin_actual, ymax = max(y_idx -10, 0) , y_idx -10, y_idx + 11

            sub_image = image[dict(x= slice(xmin,xmax), y= slice(ymin,ymax))]

            try: # in case we're against boundary
                sub_image = sub_image.squeeze().to_array().to_numpy()
                out_dataset[i] = sub_image
            except ValueError as e:                
                #flip and reflip before saving because coding's hard
                sub_image = np.swapaxes(sub_image, 1,2)
                
                image_shape = tuple(image.dims[d] for d in ['x', 'y'])
                simage_shape = sub_image.shape
                if verbose:
                    print(e)
                    print("Out of bounds error, padding with 0 for day/grid:", daynum_og, cell)

                    print("input shape: ", image_shape, "output shape", simage_shape)
                    print("max/min", xmax, ymax, xmin, ymin)
                    
                #pad with necessary columns
                if xmin_actual < 0:                    
                    fill = np.zeros((out_dataset.shape[1],
                                     0-xmin_actual, simage_shape[1]))
                    sub_image = np.concatenate((fill, sub_image), axis= 1)
                    simage_shape = sub_image.shape
                    if verbose:
                        print("off left")
                        print("updated simage_shape", simage_shape)
                    
                elif xmax > image_shape[0]:
                    fill = np.zeros((out_dataset.shape[1],
                                    xmax- image_shape[0], simage_shape[1]))
                    sub_image = np.concatenate((sub_image, fill), axis=1)
                    simage_shape = sub_image.shape
                    print("off right")
                    print("updated simage_shape", simage_shape)
                
                if ymin_actual < 0 :
                    fill = np.zeros((out_dataset.shape[1],
                                   21, 0-ymin_actual ))
                    sub_image = np.concatenate((fill, sub_image), axis=2)
                    simage_shape = sub_image.shape
                    if verbose:
                        print("off up")
                        print("updated simage_shape", simage_shape)

                elif ymax > image_shape[1]:
                    fill = np.zeros((out_dataset.shape[1],
                                     21, ymax - image_shape[1] ))
                    sub_image = np.concatenate((sub_image,fill), axis=2)
                    simage_shape = sub_image.shape
                    if verbose:
                        print("off down")
                        print("updated simage_shape", simage_shape)
                    
                sub_image = np.swapaxes(sub_image, 1,2)
                out_dataset[i] = sub_image
                
                
                
            cell_ids.append((cell, daynum_og)) 

            i+=1
        
        
        
    return cell_ids, out_dataset


Ingest training + testing geodata and timestamps

Note: paths are currently absolute, but happy to make them work on both machines

In [4]:
path = "C:/Users/Matt/Documents/Python Scripts/SnowComp/dat/grid_cells_2b.geojson"
with open(path) as f:
    gj = gsn.load(f)
print(len(gj['features']))

20759


Estimate centroids for lat_lon calculations by taking mean of points (not actual centroid because of projection and great circle distance?)

In [5]:
centroids = {} #cellid : centroid

for cell in range(len(gj['features'])):
    assert len(gj['features'][cell]['geometry']['coordinates'][0]) == 5 #coordinates have repeat on fifth, make sure this is universal
    
    cell_id =gj['features'][cell]['properties']['cell_id']
    centroid = list(np.mean(
        gj['features'][cell]['geometry']['coordinates'][0][0:4],
        axis = 0)) #lazy centroid calculation
    centroids[cell_id] = centroid

1. Ingest training, testing, submission datasets
2. Find what tiles (time, h,v) each image are stored in
3. store by cell_id, recall later for centroids

In [6]:
submission = pd.read_csv("C:/Users/Matt/Documents/Python Scripts/SnowComp/dat/submission_format_2b.csv")

submission.rename({"Unnamed: 0":"cell_id"}, axis=1, inplace=True)
DATE = "2022-02-10"
# submission

## Download relevant images

Process and save smaller images one by one


## Submission images

In [None]:
cell_ids = submission['cell_id']

# create dictionary tiles_sub (DATE, lat, lon) : [cell_ids]
counter_sub = 0 
tiles_sub = defaultdict(list)
for cell in tqdm(cell_ids):
    modis_tile = lat_lon_to_modis_tile(centroids[cell][1], centroids[cell][0])
    tiles_sub[(datetime.fromisoformat(DATE),) + modis_tile].append(cell)
    counter_sub += 1
    
print("total squares:", counter_sub)    

  0%|          | 0/20759 [00:00<?, ?it/s]

Load Terra Submission Data

In [None]:
product = 'MOD10A1'

#initialize empty array
dataset_sub_t = np.empty((counter_sub, 7, 21, 21)) #(image, band, row, column)

# download dataset
cell_ids_sub, dataset_sub_t = images_downloader(tiles_sub, centroids, dataset_sub_t, product)
    
#####save output#####
# output_path = "C:/Users/Matt/Dropbox/SnowComp/realtimeData/"+ "Modis_subT_"+str(DATE)+".npy"
# np.save(output_path,dataset_sub_t)

path_ids = "C:/Users/Matt/Dropbox/SnowComp/realtimeData/"+ "cell_snow_ids_sub"+str(DATE)+".pkl"
with open(path_ids, 'wb') as handle:
    pickle.dump(cell_ids_sub, handle)

Load Aqua Submission Data

In [None]:
product = 'MYD10A1'

#initialize empty array
dataset_sub_a = np.empty((counter_sub, 7, 21, 21)) #(image, band, row, column)

# download dataset
cell_ids_sub, dataset_sub_a = images_downloader(tiles_sub, centroids, dataset_sub_a, product)
    
#####save output#####
# output_path = "C:/Users/Matt/Dropbox/SnowComp/realtimeData/"+ "Modis_subA_"+str(DATE)+".npy"
# np.save(output_path,dataset_sub_t)

## Recombine images, save

In [None]:
sub_dataset = np.concatenate((dataset_sub_t[:,0:1,:,:],dataset_sub_a[:,0:1,:,:]), axis = 1)
sub_dataset = sub_dataset/255

output_path = "C:/Users/Matt/Dropbox/SnowComp/realtimeData/"+ "Modis_sub_"+str(DATE)+".npy"
np.save(output_path, sub_dataset)



### Sanity Checks

In [None]:
# check how many filled with missing or empty
def data_quality_checker(dataset):
    which_idx = np.all(dataset[:,0,:,:]>100, axis = (1,2))
    bad_images = dataset[np.all(dataset[:,0,:,:]>100, axis = (1,2))] 
    print("all missing:", np.sum(np.all(dataset[:,0,:,:]>100, axis = (1,2))),
         "of:", dataset.shape[0])
    
    
    return bad_images, which_idx

def random_bad_plot(bad_images):
    idx = random.randrange(bad_images.shape[0])
    plt.imshow(bad_images[idx,0,:,:])
    print(idx)
    
bad_images, which_idx_a = data_quality_checker(dataset_a)
bad_images, which_idx_t = data_quality_checker(dataset_t)

print("overlapping:",  np.sum(which_idx_a & which_idx_t))