In [None]:
import ee

try:
        ee.Initialize()
except Exception as e:
        ee.Initialize()
        ee.Authenticate()

In [None]:
import os
os.environ['USE_PYGEOS'] = '0'
import geopandas

import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt 
import geopandas as gpd
# import folium
import json
import os
import csv

from datetime import datetime, timedelta

from rasterio.plot import show
import rasterio
from rasterio.merge import merge
from rasterio.transform import from_origin
from rasterio.enums import Resampling
from rasterio.io import MemoryFile

In [None]:
##### Load some prelim files
ee_afg=ee.FeatureCollection("FAO/GAUL/2015/level0").filter(ee.Filter.eq('ADM0_NAME', 'Afghanistan'))
afg_shp=gpd.read_file('/data/afg_satellite/shp/district398_clean/district398_clean.shp')
tiles_25 = gpd.read_file('/data/afg_satellite/Grids/AFG_025_grid.shp')
helmand_shp = afg_shp[afg_shp.PROV_34_NA == 'Hilmand']

In [None]:
##afg_shp[afg_shp.DISTID.isin(districts_to_check)]

In [None]:
def shp_to_fc(file)->ee.FeatureCollection:
    """
    convert simple polygon/ shapefile to earth engine feature collection
    Note: doesn't seem to work very well with disjoint polygons/ nested polygons
    """
    g = [i for i in file.geometry]
    features=[]
    for i in range(len(g)):
        g = [i for i in file.geometry]
        x,y = g[i].exterior.coords.xy
        cords = np.dstack((x,y)).tolist()

        g=ee.Geometry.Polygon(cords)
        feature = ee.Feature(g)
        features.append(feature)
#     print("done")

    ee_object = ee.FeatureCollection(features)
    return ee_object

def addDate(image):
    date = ee.Date(image.get('system:time_start')).format('YYYYMMdd')
    dateBand = ee.Image(ee.Number.parse(date)).toInt().rename('date')
    return image.addBands(dateBand)

# This function adds a band representing the image timestamp.
def add_time(image):
  return image.addBands(image.getNumber('system:time_start'))

def cleanimg(image):
    return image.unmask(-9999)

In [None]:
start_date = "2015-01-01"
end_date = "2015-06-30"
modis_dates = pd.date_range(start_date, end_date, freq = '16D')

In [None]:
for dist in districts_to_check:
    dist_tiles = tiles_25.sjoin(afg_shp[afg_shp.DISTID == dist])
   
    print(dist_tiles.shape[0])
    dist_tiles.plot()
    plt.show()
    for year in [2015]:
        start_date = f"{year}-01-01"
        end_date = f"{year}-06-30"
    
        print(year, datetime.now())
        ## Prepare Datasets

        ### landcover
        agDataset = ee.Image(f'COPERNICUS/Landcover/100m/Proba-V-C3/Global/{year}').select('discrete_classification');
        agMask = agDataset.select(['discrete_classification']).eq(40)

        ## modis
        mod_dataset = ee.ImageCollection("MODIS/061/MOD13Q1").filterDate(start_date, end_date).filterBounds(ee_afg)
        mod_dataset = mod_dataset.map(lambda x: x.unmask(-9999).select('NDVI'))
        mod_ndvi_date = mod_dataset.map(addDate).map(lambda x: x.updateMask(agMask).multiply(0.0001))


        ### landsat
        l8_dataset = ee.ImageCollection("LANDSAT/COMPOSITES/C02/T1_L2_8DAY_NDVI").filterDate(start_date, end_date).filterBounds(ee_afg)
        l8_dataset = l8_dataset.select(['NDVI']).map(addDate)
        l8_ndvi_date = l8_dataset.map(lambda x: x.updateMask(agMask))
        
        l8_maxNdvi = l8_ndvi_date.select('NDVI')
        l8_maxNdvi2 = l8_ndvi_date.qualityMosaic('NDVI')
        l8_acquisitionDate = l8_maxNdvi2.select('date')
        
        mod_maxNdvi = mod_ndvi_date.select('NDVI')
        mod_maxNdvi2 = mod_ndvi_date.qualityMosaic('NDVI')
        mod_acquisitionDate = mod_maxNdvi2.select('date')
    
        ## Export Imagery
        for img, bname, scale, outname in [(mod_maxNdvi2, 'NDVI', 250, 'modis'), 
                                           (mod_acquisitionDate, 'date', 250, 'modis'), 
                                           (l8_maxNdvi2, 'NDVI', 30, 'landsat'), 
                                           (l8_acquisitionDate, 'date', 30, 'landsat')]:
    
            print(outname)
            chips = []
            pth = f"/data/afg_satellite/revisions/bestDates/{outname}"
            assert os.path.exists(pth)
            # if not os.path.exists(pth):
            #     os.makedirs(pth)
                
            
            # if not os.path.exists(pth + f"/{start_date}.tif"):
            print("Downloading Tiles", start_date, datetime.now())
            for tile_id in dist_tiles['id'].unique():
                ee_dist = shp_to_fc(dist_tiles[dist_tiles.id == tile_id])
                mod_img = img.clipToBoundsAndScale(
                                                geometry=ee_dist, 
                                                scale = scale
                                              )
        
                if outname == 'modis':
                    mod_img = mod_img.reproject(crs = 'EPSG:4326', scale = scale )
            
                mod_img_ndvi = ee.data.computePixels({
                                "bandIds": [bname],
                                "expression": mod_img,
                                "fileFormat": 'GEOTIFF',
                            #'grid': {'crsCode': tile.crs} this was causing weird issues
                            })
                chips.append(mod_img_ndvi)
        
            print("Merging Tiles", datetime.now())
            openchips = [MemoryFile(c).open() for c in chips]
            outimg, out_trans = merge(openchips)
            
            with rasterio.open(pth + "/" + f"{outname}_{bname}_{dist}_{year}_jan_june.tif", 'w', 
                       driver='GTiff', 
                       height=outimg.shape[1], 
                       width=outimg.shape[2], 
                       count=outimg.shape[0],
                       dtype=outimg.dtype,  
                       transform=out_trans) as dest:
                dest.write(outimg)

In [None]:
dfs = {'landsat':[], 
      'modis':[]}
for distid in districts_to_check:
    print(distid, datetime.now())
    for y in range(2015, 2020):
        print(y)
        for img_type in ['landsat', 'modis']:
            df = rxr.open_rasterio(f"/data/afg_satellite/revisions/bestDates/{img_type}/{img_type}_date_{distid}_{y}_jan_june.tif")
            # Covert our xarray.DataArray into a xarray.Dataset
            df = df.to_dataset('band')
            
            # Rename the variable to a more useful name
            df = df.rename({1: 'date'})
            df_long = df.to_dataframe()
            
            if img_type == 'landsat':
                df_long['date'] = df_long['date'].replace(-2147483648, np.nan)
                
            df_long = df_long.dropna()
            df_long = df_long[df_long['date'] != 0].copy()
            df_long = df_long.drop(columns = 'spatial_ref')
            df_long = df_long.reset_index()
            df_long = df_long.rename(columns = {'y':'latitude', 'x':'longitude'})
            
            
            geometry = [Point(xy) for xy in zip(df_long.longitude, df_long.latitude)]
            df_long = df_long.drop(columns = ['latitude', 'longitude'])
            gdf = gpd.GeoDataFrame(df_long, crs="EPSG:4326", geometry=geometry)
            
            gdf_dists = gdf.sjoin(afg_shp[afg_shp.DISTID.isin([distid])])
            best = gdf_dists.groupby(['PROV_34_NA', 'DIST_34_NA', 'DISTID', 'PROVID']).date.agg(pd.Series.mode)
            best = best.reset_index()
            best['year'] = y
            dfs[img_type].append(best)


df_landsat = pd.concat(dfs['landsat'])
df_modis = pd.concat(dfs['modis'])
dfs = None
gc.collect()

df_final = df_modis.merge(df_landsat[['DISTID', 'year', 'date']], on = ['DISTID', 'year'], suffixes = ('_modis', '_landsat'))
df_final.to_csv("/data/afg_satellite/revisions/bestDates/landsat_modis_top10.csv", index = False)

In [None]:
#### Load and Clean Data

df_final = pd.read_csv("/data/afg_satellite/revisions/bestDates/landsat_modis_top10.csv")
df_final['date_modis'] = df_final['date_modis']*10000
landsat_dates = np.concatenate([pd.date_range(f"{y}-01-01", f"{y}-06-30", freq = '8D') for y in range(2015, 2020)])
modis_dates = np.concatenate([pd.date_range(f"{y}-01-01", f"{y}-06-30", freq = '16D') for y in range(2015, 2020)])

landsat_date_labels = {dt:i+1 for i, dt in enumerate(landsat_dates)}
modis_date_labels = {dt:i+1 for i, dt in enumerate(modis_dates)}

df_final['date_modis'] = pd.to_datetime(df_final.date_modis.apply(lambda x: str(int(x))[:4] + "-" + str(int(x))[4:6] + "-"+ str(int(x))[6:] ))
df_final['date_landsat'] = pd.to_datetime(df_final.date_landsat.apply(lambda x: str(int(x))[:4] + "-" + str(int(x))[4:6] + "-"+ str(int(x))[6:] ))

df_final['month_modis'] = df_final['date_modis'].dt.month
df_final['month_landsat'] = df_final['date_landsat'].dt.month
df_final['day_diff'] = abs((df_final['date_modis'] - df_final['date_landsat']).dt.days).astype(int)

df_final['landsat_to_modis'] = pd.to_datetime(pd.cut(df_final.date_landsat, modis_dates, right = False).apply(lambda x: str(x)[1:11]))
df_final['landsat_eq_modis'] = (df_final.landsat_to_modis == df_final.date_modis)*1
df_final['landsat_to_modis_label'] = df_final['landsat_to_modis'].map(modis_date_labels)
df_final['modis_label'] = df_final['date_modis'].map(modis_date_labels)
df_final['label_diff'] = df_final['modis_label'] - df_final['landsat_to_modis_label']
df_final['modis_landsat_diff_weeks'] = np.abs((df_final['date_modis'] - df_final['date_landsat']).dt.days)/7

In [None]:
## check for a single year
df_final[df_final.year == 2015][['PROV_34_NA', 'DIST_34_NA', 'year', 'date_modis', 'date_landsat']].sort_values(['PROV_34_NA', 'DIST_34_NA'])

In [None]:
## construct comparison grid
tmp = df_final[['PROV_34_NA', 'DIST_34_NA', 'year', 
                'date_modis', 'date_landsat']].sort_values(['PROV_34_NA', 'DIST_34_NA']).pivot(index = ['DIST_34_NA'], 
                                                                                               columns = ['year'], 
                                                                                              values = ['date_modis', 'date_landsat'])

tmp.columns = [f'modis_{y}' for y in range(2015, 2020) ] + [f'landsat_{y}' for y in range(2015, 2020) ]
cols = list(zip([f'modis_{y}' for y in range(2015, 2020) ], [f'landsat_{y}' for y in range(2015, 2020) ]))
cols = list(sum(cols, ()))

tmp = tmp[cols]

for c in tmp.columns:
    if 'modis' in c:
        tmp[c] = tmp[c].apply(lambda x: f"{x.date()} - {(x + pd.Timedelta(15, 'd')).date()}")
        
    if 'landsat' in c:
        tmp[c] = tmp[c].apply(lambda x: f"{x.date()} - {(x + pd.Timedelta(7, 'd')).date()}")

In [None]:
tmp