## Generate LCSIF clear-inst and clear-daily, MODIS Period
### Jianing Fang (jf3423@columbia.edu)

In [1]:
import numpy as np
import xarray as xr
import os
import itertools
import matplotlib.pyplot as plt
import multiprocessing as mp
import torch
from torch import nn
from datetime import date, time, timedelta
import datetime
from sklearn.preprocessing import StandardScaler
from matplotlib.colors import LogNorm
import astral
from astral import sun
import skyfield
from skyfield.api import load, wgs84
import pandas as pd
from scipy import interpolate
import netCDF4 as nc

### Predict OCO-2 Local Overpass time

In [None]:
# TLE data obtained from https://www.space-track.org/

stations_url = '../../data/OCO2_TLE.txt'

satellites = load.tle_file(stations_url)
ts = load.timescale()
print('Loaded', len(satellites), 'satellites')
epoches=np.array([sat.epoch.utc_datetime() for sat in satellites])
epoch_duration=np.diff(epoches)

observe_time=[]
lons=[]
lats=[]

for i, epoch in enumerate(epoches):
    if epoch_duration[i].total_seconds() > 1:
        epoch_start = ts.from_datetime(epoch)
        times=pd.date_range(epoch, periods=int(epoch_duration[i].total_seconds()/10), freq='10S').tolist()
        t=ts.from_datetimes(times)
        satellite = satellites[i]
        geocentric = satellite.at(t)
        observe_time.append(t.utc_datetime())
        subpoint = wgs84.subpoint(geocentric)
        lons.append(subpoint.longitude.degrees)
        lats.append(subpoint.latitude.degrees)
        
lon_array=np.concatenate(lons, axis=0)
lat_array=np.concatenate(lats, axis=0)
time_array=np.concatenate(observe_time, axis=0)
date_array=np.array([t.date() for t in time_array])

np.save("../../data/processed/OCO2_Track.npy", np.array([lat_array, lon_array, time_array]).T)

date_of_interest=date(2014,9,6)
date_sel=date_array == date_of_interest
lon_sel=lon_array[date_sel]
lat_sel=lat_array[date_sel]
time_sel=time_array[date_sel]
lon_sel=lon_array[date_sel]
samples=np.array([lat_sel, lon_sel, time_sel]).T
noon_time=datetime.datetime.combine(date_of_interest, datetime.time(12,0,0), tzinfo=datetime.timezone.utc) + np.array([datetime.timedelta(hours=h) for h in (-lon_sel / 360 * 24)])
time_diff=np.array([t.total_seconds()/3600 for t in (time_sel-noon_time)]) - 1.6

low_range=(time_diff < -5.5) & (time_diff > -6.5)
low_range_idx=np.arange(0, lat_sel.shape[0])[low_range]
max_local_idx=np.argmax(lat_sel[low_range])
low_time=time_diff[low_range_idx[max_local_idx]]

high_range=(time_diff > 5.5) & (time_diff < 6.5)
high_range_idx=np.arange(0, lat_sel.shape[0])[high_range]
min_local_idx=np.argmin(lat_sel[high_range])
high_time=time_diff[high_range_idx[min_local_idx]]
valid_time = (time_diff > low_time) & (time_diff < high_time)
f=interpolate.interp1d(lat_sel[valid_time], time_diff[valid_time], kind="linear")

np.save("../../data/processed/latitude_time_diff_sep_6_2014.npy", np.array([lat_sel[valid_time], time_diff[valid_time]]))

In [2]:
MODIS_GAPFILLED_DIR="MCD43C4.006_16day/"
PROCESSED_DIR="../../data/processed/"
SIF_MODIS_OUT_DIR_V5="../../data/processed/SIF_MODIS_16day_v5"
sif_processed_dir = os.path.join("../../data/", "processed")
fig_dir="./figs"

In [4]:
os.mkdir(SIF_MODIS_OUT_DIR_V5)
for i in range(2000, 2022, 1):
    os.mkdir(os.path.join(SIF_MODIS_OUT_DIR_V5, str(i)))

In [3]:
latitude_time_diff=np.load("../../data/processed/latitude_time_diff_sep_6_2014.npy")
f=interpolate.interp1d(latitude_time_diff[0], latitude_time_diff[1], kind="linear")

In [4]:
def compute_cos_sza_for_fitted_overpass(latitude, date_of_interest):
    overpass_time_diff=float(f(latitude))
    overpass_time_delta=datetime.timedelta(hours=overpass_time_diff)
    cos_sza=np.cos(astral.sun.zenith(astral.LocationInfo(latitude=latitude, longitude=0).observer,
                             dateandtime=datetime.datetime.combine(date_of_interest, datetime.time(13,36)) + overpass_time_delta,
                             with_refraction = True) / 180 * np.pi)
    return cos_sza

def compute_daily_sza_for_fitted_overpass(latitude, date_of_interest):
    overpass_time_diff=float(f(latitude))
    overpass_time_delta=datetime.timedelta(hours=overpass_time_diff)
    eval_points=np.arange(-0.5,0.501, 1/(6*24))
    daily_sza_points=np.array([np.cos(astral.sun.zenith(astral.LocationInfo(latitude=latitude, longitude=0).observer,
                             dateandtime=datetime.datetime.combine(date_of_interest, datetime.time(13,36)) + overpass_time_delta + datetime.timedelta(days=p),
                             with_refraction = True) / 180 * np.pi) for p in eval_points])
    daily_sza_points[daily_sza_points < 0] = 0
    return np.mean(daily_sza_points)

In [6]:
def try_gpu(i=0): 
    """Return gpu(i) if exists, otherwise return cpu()."""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

# feedforward model construct function
def construct_model(input_dim, hidden_dim, n_hidden_layers, drop_out=None):
    layers=[]
    layers.append(nn.Linear(input_dim, hidden_dim))
    layers.append(nn.ReLU())
    if drop_out:
        layers.append(nn.Dropout(p=0.2))
    for i in range(n_hidden_layers - 1):
        layers.append(nn.Linear(hidden_dim,hidden_dim))
        layers.append(nn.ReLU())
        if drop_out:
            layers.append(nn.Dropout(p=drop_out))
    layers.append(nn.Linear(hidden_dim, 1))
    return nn.Sequential(*layers).to(device=try_gpu())

In [7]:
hidden_dim=64
n_hidden_layers=2
net= construct_model(3, hidden_dim, n_hidden_layers)
model_name="layer_2_neuron_64_08-16-2022_16-22-24_lr0.001_batchsize1024"
model_dir="./models"
net.load_state_dict(torch.load(os.path.join(model_dir, model_name), map_location=torch.device('cpu')))
net.eval();

In [8]:
train_val_ds = xr.open_dataset(os.path.join(sif_processed_dir, "train_val.nc"))
XY = np.stack([train_val_ds.Nadir_Reflectance_Band1.values,
train_val_ds.Nadir_Reflectance_Band2.values,
np.cos(train_val_ds.SZA.values / 180 * np.pi), train_val_ds.SIF_757nm.values]).T
scaler = StandardScaler()
scaler.fit(XY[:,0:3])

In [9]:
water_mask=np.load("../../data/processed/water_mask.npy")

In [29]:
def generate_sif_prediction_for_a_year(year_list):
    if len(os.listdir(os.path.join(SIF_AVHRR_OUT_DIR_V5, year_list[0].split("/")[6]))) < 24:
        for file in year_list:
            generate_sif_prediction(file)
    print(year_list[0].split("/")[6] + " finished!")

In [21]:
def generate_sif_prediction_modis(file):
    ds=xr.open_dataset(file)
    valid_flag_modis=(np.invert(np.isnan(ds.red_filled.values[0]))) & (np.invert(np.isnan(ds.nir_filled.values[0])))
    red_valid_modis = ds.red_filled.values[0][valid_flag_modis]
    nir_valid_modis = ds.nir_filled.values[0][valid_flag_modis]

    year=int(file.split("/")[-1].split(".")[1][0:4])
    year_str=file.split("/")[-1].split(".")[1][0:4]
    month=int(file.split("/")[-1].split(".")[1][4:6])
    month_str=file.split("/")[-1].split(".")[1][4:6]
    period=file.split("/")[-1].split(".")[1][6]
    if period=="a":
        day=8
        day_str="15"
        file_date=np.datetime64(year_str + "-" + month_str + "-" + day_str)
    else:
        day=24
        file_date=(np.datetime64(year_str + "-" + month_str) + np.timedelta64(1, 'M')).astype('datetime64[D]') - np.timedelta64(1, 'D')
        
    doi=date(year,month,day)
    
    # compute cos sza
    computed_cos_sza=np.full(3600, np.nan)
    computed_cos_sza[164:-164]=np.array([compute_cos_sza_for_fitted_overpass(l, doi) for l in ds.latitude.values[164:-164]])
    computed_cos_sza_array=np.tile(computed_cos_sza, (7200,1)).T

    computed_cos_daily_sza=np.full(3600, np.nan)
    computed_cos_daily_sza[164:-164]=np.array([compute_daily_sza_for_fitted_overpass(l, doi) for l in ds.latitude.values[164:-164]])
    computed_cos_daily_sza_array=np.tile(computed_cos_daily_sza, (7200,1)).T

    computed_cos_sza_valid_modis=computed_cos_sza_array[valid_flag_modis]
    computed_cos_daily_sza_valid_modis=computed_cos_daily_sza_array[valid_flag_modis]
    data_matrix_modis = np.array([red_valid_modis, nir_valid_modis, computed_cos_sza_valid_modis]).T
    
    scaled_data_matrix_modis = scaler.transform(data_matrix_modis)

    with torch.no_grad():
            predicted_modis=net(torch.tensor(scaled_data_matrix_modis).float().to(try_gpu())).cpu().numpy().squeeze()

    sif_modis=np.zeros((3600, 7200))
    sif_modis[valid_flag_modis]=predicted_modis
    sif_modis[np.invert(valid_flag_modis)]=np.nan
    sif_modis[water_mask==1] = np.nan

    cos_sza_modis=np.zeros((3600, 7200))
    cos_sza_modis[valid_flag_modis]=computed_cos_sza_valid_modis
    cos_sza_modis[np.invert(valid_flag_modis)]=np.nan
    cos_sza_modis[water_mask==1] = np.nan
    
    cos_daily_sza_modis=np.zeros((3600, 7200))
    cos_daily_sza_modis[valid_flag_modis]=computed_cos_daily_sza_valid_modis
    cos_daily_sza_modis[np.invert(valid_flag_modis)]=np.nan
    cos_daily_sza_modis[water_mask==1] = np.nan


    sif_array_modis=xr.DataArray(np.expand_dims(sif_modis, axis=0),
                     coords=[[file_date,], ds.latitude, ds.longitude],
                     dims=["time", "latitude", "longitude"])
    
    sif_daily_array_modis=xr.DataArray(np.expand_dims(sif_modis / cos_sza_modis * cos_daily_sza_modis, axis=0),
                     coords=[[file_date,], ds.latitude, ds.longitude],
                     dims=["time", "latitude", "longitude"])

    cos_sza_array_modis=xr.DataArray(np.expand_dims(cos_sza_modis, axis=0),
                     coords=[[file_date,], ds.latitude, ds.longitude],
                     dims=["time", "latitude", "longitude"])
    
    cos_daily_sza_array_modis=xr.DataArray(np.expand_dims(cos_daily_sza_modis, axis=0),
                     coords=[[file_date,], ds.latitude, ds.longitude],
                     dims=["time", "latitude", "longitude"])

    modis_sif_ds=xr.Dataset({"sif_modis_clear_inst":sif_array_modis,
                             "sif_modis_clear_daily":sif_daily_array_modis,
                             "cos_sza_modis":cos_sza_array_modis,
                             "cos_daily_sza_modis":cos_daily_sza_array_modis})
    modis_sif_ds.to_netcdf(os.path.join(SIF_MODIS_OUT_DIR_V5, "/".join(file.split("/")[6:])))
    modis_sif_ds.close()
    ds.close()

In [22]:
MODIS_FILE_LIST=[]
for MODIS_YEAR in sorted(os.listdir(MODIS_GAPFILLED_DIR)):
    for MODIS_FILE in sorted(os.listdir(os.path.join(MODIS_GAPFILLED_DIR, MODIS_YEAR))):
        file=os.path.join(MODIS_GAPFILLED_DIR, MODIS_YEAR, MODIS_FILE)
        MODIS_FILE_LIST.append(file)
        
MODIS_FILE_YEAR_LIST_AVAILABLE=[]
for MODIS_YEAR in sorted(os.listdir(MODIS_GAPFILLED_DIR)):
    year_list=[]
    for MODIS_FILE in sorted(os.listdir(os.path.join(MODIS_GAPFILLED_DIR, MODIS_YEAR))):
        file=os.path.join(MODIS_GAPFILLED_DIR, MODIS_YEAR,MODIS_FILE)
        year_list.append(file)
    MODIS_FILE_YEAR_LIST_AVAILABLE.append(year_list)

In [23]:
def generate_modis_prediction_for_a_year(year_list):
    if len(os.listdir(os.path.join(SIF_MODIS_OUT_DIR_V5, year_list[0].split("/")[6]))) < 24:
        for file in year_list:
            generate_sif_prediction_modis(file)
    print(year_list[0].split("/")[6] + " finished!")

In [24]:
pool = mp.Pool(6)
jobs = []

for year_list in MODIS_FILE_YEAR_LIST_AVAILABLE:
    if len(os.listdir(os.path.join(SIF_MODIS_OUT_DIR_V5, year_list[0].split("/")[6]))) < 24:
        job = pool.apply_async(generate_modis_prediction_for_a_year,(year_list,))
        jobs.append(job)

for job in jobs: 
    job.get()

    #now we are done, kill the listener
pool.close()
pool.join()


2002 finished!
2000 finished!
2001 finished!
2005 finished!
2003 finished!
2004 finished!
2007 finished!
2008 finished!
2006 finished!
2009 finished!
2011 finished!
2010 finished!
2013 finished!
2012 finished!
2014 finished!
2016 finished!
2015 finished!
2017 finished!
2019 finished!
2020 finished!
2018 finished!
2021 finished!
