In [1]:
from p_drought_indices.functions.function_clns import load_config, cut_file, subsetting_pipeline
from p_drought_indices.functions.ndvi_functions import downsample, clean_ndvi, compute_ndvi, clean_outliers
from p_drought_indices.vegetation.cloudmask_cleaning import extract_apply_cloudmask, plot_cloud_correction, compute_difference, compute_correlation
import xarray as xr 
import pandas as pd
import yaml
from datetime import datetime, timedelta
import shutil
from shapely.geometry import Polygon, mapping
import geopandas as gpd
import matplotlib.pyplot as plt
from glob import glob
import os
#import datetime as datetime
import time
import numpy as np
import re
from p_drought_indices.vegetation.NDVI_indices import compute_svi, compute_vci
from p_drought_indices.analysis.metrics_table import MetricTable

CONFIG_PATH = r"../config.yaml"


### Processing and smoothing NDVI

In [2]:
chunks = {"lat": -1, "lon": -1, "time": 12}
## Uncomment for AVHRR data
#path = r'D:\shareVM\MSG\AVHRR\processed\*.nc'
#data = xr.open_mfdataset(path, chunks=chunks)
#data = data.rename({'longitude':'lon', 'latitude':'lat'})

ndvi_dir = r'D:\shareVM\MSG\msg_data\batch_2\processed'
list_files = [os.path.join(ndvi_dir,file) for file in os.listdir(ndvi_dir) if re.match('HRSEVIRI_200\d+.*', file)]
xr_df = xr.open_mfdataset(list_files, chunks=chunks)

xr_df = clean_outliers(xr_df)

CONFIG_PATH = r"../config.yaml"

config = load_config(CONFIG_PATH)
shapefile_path = config['SHAPE']['africa']
gdf = gpd.read_file(shapefile_path)

countries = ['Ethiopia','Kenya','Somalia']

subset = gdf[gdf.ADM0_NAME.isin(countries)]
#ds_avh = cut_file(data, subset)

base_dir = r'D:\shareVM\MSG\cloudmask\processed_clouds\batch_2\nc_files\new\ndvi_mask.nc'
cl_df = xr.open_mfdataset(base_dir, chunks=chunks)
cl_df = cl_df.sel(time=slice(cl_df['time'].min(), '2009-12-31'))

ds = cut_file(xr_df, subset)
ds_cl = cut_file(cl_df, subset)
mask_clouds, res_xr = extract_apply_cloudmask(ds, ds_cl)

In [5]:
from p_drought_indices.vegetation.cloudmask_cleaning import apply_whittaker

result = apply_whittaker(mask_clouds['ndvi'])
result.to_netcdf(r'D:\shareVM\MSG\msg_data\processed\smoothed_ndvi.nc')

In [9]:
veg_df = xr.open_dataset(r'D:\shareVM\MSG\msg_data\processed\smoothed_ndvi.nc')

In [10]:
vci = compute_vci(veg_df['ndvi'])
#res = compute_svi(res_xr)
vci.to_netcdf(r'D:\shareVM\MSG\msg_data\processed\vci_1D.nc')
#res.to_netcdf(r'D:\shareVM\MSG\msg_data\processed\svi_1D.nc')



### Loop to apply tablemetric for vci to each product 

In [11]:
config = load_config(CONFIG_PATH)
product_directory =  r"D:\shareVM\MSG\msg_data\processed"
var = 'vci'
table_metrics = pd.DataFrame()
for country in ['Ethiopia','Somalia','Kenya']:
    for product_dir in [config['SPI']['IMERG']['path'], config['SPI']['GPCC']['path'], config['SPI']['CHIRPS']['path'], config['SPI']['ERA5']['path']]:
        for late in [30, 60, 90, 180]:
            var_target = f"spi_gamma_{late}"
            spi = MetricTable(product_directory, product_dir, var, var_target, CONFIG_PATH, countries=[country])
            spi.compute_metrics_soil(freq="daily")
            table_df = spi.df_cover
            table_metrics = pd.concat([table_metrics, table_df],ignore_index=True)

table_metrics.to_csv(r'../data/spi_vci/spi_vci_daily.csv')

In [None]:
import seaborn as sns
def plot_data(data, metric, precp_idx, country, product):
    y = data[metric].loc[(data['precp_idx'] == precp_idx) & (data['country'] == country)\
        & (data['product'] == product)]

    #y.dropna().plot(kind='hist', bins=30, kde=True)
    sns.displot(y.dropna(), kde=True)
    plt.title(f"metric {metric} for {precp_idx} {product} in {country}")
    plt.show()

def group_plot(data, country, product):
    grouped = data.loc[(data['country'] == country)\
                & (data['product'] == product)]
    grouped = grouped.groupby(['precp_idx'])
    fig, ax = plt.subplots()
    grouped.plot(kind='hist', y='far', ax=ax, legend=True)
    plt.show()

def plot_data_spi(data, metric, country, product):
    spi_list = table_metrics['precp_idx'].unique()
    fig, ax = plt.subplots(1,4)

    for i, spi in enumerate(spi_list):
        y = data[metric].loc[(data['country'] == country)\
            & (data['product'] == product)]
        y= y.dropna()

        sns.displot(y, kde=True)
    #plt.title(f"metric {metric} for {product} in {country}")
    #plt.show()

### Load data and plot 

In [None]:
data = pd.read_csv(r'../data/spi_vci/spi_vci_daily.csv').iloc[:, 1:]

In [None]:
metric ='pod' 
precp_idx= 'spi_gamma_180'
country='Ethiopia' 
product="IMERG"

for metric in ['pod','far','accuracy']:
    print(f"Printing {metric}")
    for country in ['Ethiopia','Kenya','Somalia']:
        y = plot_data(table_metrics, metric, precp_idx, country, product)