## This notebook has functions to 
- open a .tif file as a raster dataset
- read the raster dataset and map spectral bands to a common bands 

In [26]:
def open_file(base, band, band_name, cmr=False, coords=None, start_date=None, end_date=None):
    '''
    Function to read each file from S3 bucket url or fetch from CMR. If fetch from CMR, download fiiles to current working 
    directory (e.g., grive and read from there)
    '''
    
    ## Open raster and apply scale and offset
    print(f"Reading {base}.{band}.tif")
    da = rxr.open_rasterio(f"{base}.{band}.tif", mask_and_scale=True)
    da.name = band_name
    
    return da 


def read_file_as_array(sat_id,run_id=None,title_id=None,granule=None,sr_key=None):
    '''
    Function to read all bands included in the dataset 
    
    Example:
        
    '''
    
    
    # Map Spectral Bands
    # Ignore Coastal aerosol and Cirrus bands
    if 'L30' in sat_id:
        sr_bands = ["B02","B03","B04","B05","B06","B07", "Fmask"]
    else:
        sr_bands = ["B02","B03","B04","B8A","B11","B12", "Fmask"]
        

    common_bands = ["B","G","R","NIR","SWIR1", "SWIR2", 'Fmask']
    
    sr_bands_common = dict(zip(common_bands,sr_bands))

    
    print("Reading "+sat_id+" raster data")
    
    sr_das = [open_file(sr_key, 
                            band, 
                            band_name) for (band,
                                            band_name) in zip(sr_bands, 
                                                              common_bands)]
                                                              
    sr_ds  = xr.merge(sr_das, 
                          combine_attrs="drop_conflicts") 
    
    
    return sr_ds

    
def search_stac_for_HLS(pt, dt_min, dt_max, cloudcover_max=50, lim=100, url='https://cmr.earthdata.nasa.gov/cloudstac/LPCLOUD', 
                        collections=['HLSL30.v2.0', 'HLSS30.v2.0']):
    # open the catalog
    catalog = Client.open(f'{url}')
    
    # perform the search
    search = catalog.search(
        collections=collections,
        intersects=pt,
        datetime=dt_min + '/' + dt_max,
        limit=lim
    )

    links = []

    if search.matched() == 0:
        print('No granules found at point', pt, 'from', dt_min, 'to', dt_max)
    else:
        print('Found', search.matched(), 'granules at point', pt, 'from', dt_min, 'to', dt_max)
        item_collection = search.get_all_items()
        
        for i in item_collection:
            for a in i.assets:
                    links.append(i.assets[a].href)

    return(links)


def fix_links(src_link, src_dirs, dst_dir, meta_dir, add_tile_dir=True):
    dst_link = src_link

    if '.xml' in dst_link:
        dst_link2 = os.path.join(meta_dir, os.path.basename(dst_link))
    else:
        for src_dir in src_dirs:
            dst_link = dst_link.replace(src_dir, dst_dir)
        dst_splits = dst_link.split('/')
        dst_link2 = '/'.join(dst_splits[0:2]) + \
            '/' + dst_splits[3].split('.')[2] + \
            '/' + '/'.join(dst_splits[3:])
    
    return(dst_link2)

def get_temp_creds(url,user,password):
    #url = 'https://data.lpdaac.earthdatacloud.nasa.gov/s3credentials'
    url = requests.get(url, allow_redirects=False).headers['Location']
    creds = requests.get(url, auth=(user, password)).json()
    return creds



def make_dirs(dst_links):
    try:
        for dst_link in dst_links:
            os.makedirs(os.path.dirname(dst_link), exist_ok=True)
    except Exception as e:
        print(e)
        
def download_data(s3_links, local_links, s3_session):
    s3_links = [l.replace('s3://', '') for l in s3_links]
    for i in range(0, len(s3_links)):
        s3_link = s3_links[i]
        s3_bucket = s3_link.split('/')[0]
        s3_link = s3_link.replace(s3_bucket +'/', '')   
        #print(s3_link)
        local_link = local_links[i]
        prefix = ''
        delimiter = '/'
        
        # ignore XML files for now, figure out how to get them later because they contain useful information
        if not '.xml' in local_link:
            if not '.jpg' in local_link:
                try:
                    with open(local_link, 'wb') as f:
                        print(i, s3_bucket, s3_link, local_link)
                        #s3.download_fileobj(s3_bucket, s3_link, f)
                        bucket = 'lp-prod-protected'
                        prefix = ''
                        delimiter = '/'
                        s3_session.download_file(Bucket=bucket, 
                                         Key=s3_link, 
                                         Filename=local_link)

                  
        

                    
                except Exception as e:
                    print(e)
                    #print('Errors with file'+s3_link)

    

