In [None]:
# Cell 1 - Imports
import numpy as np
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from helper.read_GEOSldas import read_ObsFcstAna, read_tilecoord, read_obs_param

# Cell 2 - Function Definition
def compute_monthly_stats(expdir,expid,domain,this_month,tc,obs_param,var_list):
    """
    Compute monthly statistics for GEOS LDAS observations, forecasts, and analysis.
    
    Parameters:
    -----------
    expdir : str
        Experiment directory path
    expid : str
        Experiment ID
    domain : str
        Domain name
    this_month : datetime
        Month to process
    tc : dict
        Tile coordinates
    obs_param : list
        Observation parameters
    var_list : list
        List of variables to process
    
    Returns:
    --------
    tuple
        N_data, data_sum, data2_sum, oxf_sum, oxa_sum, fxa_sum
    """
    n_tile = tc['N_tile']
    n_spec = len(obs_param)
    
    # Initialize time range
    start_time = this_month.replace(day=1,hour=3) 
    end_time = start_time + relativedelta(months=1)
    
    # Initialize arrays
    N_data = np.zeros((n_tile, n_spec))
    oxf_sum = np.zeros((n_tile, n_spec))
    oxa_sum = np.zeros((n_tile, n_spec))
    fxa_sum = np.zeros((n_tile, n_spec))
    
    data_sum = {}
    data2_sum = {}
    for var in var_list:
        data_sum[var] = np.zeros((n_tile, n_spec))
        data2_sum[var] = np.zeros((n_tile, n_spec))
    
    # Process each timestep
    date_time = start_time
    while date_time < end_time:
        fname = expdir+expid+'/output/'+domain+'/ana/ens_avg/Y'+ \
                date_time.strftime('%Y') + '/M' + \
                date_time.strftime('%m') + '/'  + \
                expid+'.ens_avg.ldas_ObsFcstAna.' + \
                date_time.strftime('%Y%m%d_%H%M') +'z.bin'
        
        print("Processing file:", fname)
        
        # Read and process data
        OFA = read_ObsFcstAna(fname)
        
        if len(OFA['obs_tilenum']) > 0:
            # Process valid data
            data_tile = process_timestep(OFA, tc, obs_param, var_list, n_tile, n_spec)
            
            # Update sums
            is_valid = ~np.isnan(data_tile['obs_obs'])
            update_sums(is_valid, data_tile, N_data, oxf_sum, oxa_sum, fxa_sum, data_sum, data2_sum)
        
        date_time = date_time + timedelta(seconds=10800)
    
    return N_data, data_sum, data2_sum, oxf_sum, oxa_sum, fxa_sum

# Cell 3 - Helper Functions
def process_timestep(OFA, tc, obs_param, var_list, n_tile, n_spec):
    """Process a single timestep of data."""
    data_tile = {var: np.zeros((n_tile, n_spec)) + np.nan for var in var_list}
    
    for ispec in np.arange(n_spec):
        this_species = int(obs_param[ispec]['species'])
        process_species(OFA, tc, obs_param[ispec], this_species, ispec, var_list, data_tile)
    
    return data_tile

def process_species(OFA, tc, obs_param_spec, this_species, ispec, var_list, data_tile):
    """Process data for a single species."""
    # Create masks
    species_mask = OFA['obs_species'] == this_species
    assim_mask = OFA['obs_assim'] == 1
    
    # Apply appropriate mask
    if obs_param_spec['assim'] == 'T':
        mask = np.logical_and(species_mask, assim_mask)
    else:
        mask = species_mask
    
    # Get masked data and tile numbers
    masked_tilenum = OFA['obs_tilenum'][mask]
    masked_data = {var: OFA[var][mask] for var in var_list}
    
    if len(masked_tilenum) > 0:
        # Convert to 0-based indices for tc array
        tile_indices = masked_tilenum - 1
        
        # Get corresponding tc tile IDs
        tc_tile_ids = tc['tile_id'][tile_indices]
        
        # Find matching indices in tc array
        # tile_idx = np.where(np.isin(tc['tile_id'], tc_tile_ids))[0]
        tile_idx = OFA['obs_tilenum'][mask] - 1

        print("masked_tilenum: ", (OFA['obs_tilenum'][mask] - 1)[:5])
        print("tile_indices: ", tile_indices[:5])
        print("tile_idx: ", tile_idx[:5]) 

        if tile_idx.size > 0:
            for var in var_list:
                data_tile[var][tile_idx, ispec] = masked_data[var]
        else:
            print(f"Warning: No matching tiles found for species {this_species}")
    else:
        print(f"Warning: No data for species {this_species}")

def update_sums(is_valid, data_tile, N_data, oxf_sum, oxa_sum, fxa_sum, data_sum, data2_sum):
    """Update running sums with new data."""
    N_data[is_valid] += 1
    oxf_sum[is_valid] += data_tile['obs_obs'][is_valid] * data_tile['obs_fcst'][is_valid]
    oxa_sum[is_valid] += data_tile['obs_obs'][is_valid] * data_tile['obs_ana'][is_valid]
    fxa_sum[is_valid] += data_tile['obs_fcst'][is_valid] * data_tile['obs_ana'][is_valid]
    
    for var in data_sum:
        data_sum[var][is_valid] += data_tile[var][is_valid]
        data2_sum[var][is_valid] += data_tile[var][is_valid] **2

# Cell 4 - Example Usage
if __name__ == '__main__':
    # Set parameters
    date_time = datetime(2018,8,6,9)
    expdir = '/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/CYGNSS_Experiments/DAv8_M36_Aus/'
    expid = 'DAv8_M36_Aus'
    domain = 'SMAP_EASEv2_M36_GLOBAL'
    var_list = ['obs_obs', 'obs_obsvar', 'obs_fcst', 'obs_fcstvar', 'obs_ana', 'obs_anavar']
    
    # Read coordinate and parameter files
    ftc = expdir+expid+'/output/'+domain+'/rc_out/'+expid+'.ldas_tilecoord.bin'
    tc = read_tilecoord(ftc)
    
    fop = expdir+expid+'/output/'+domain+'/rc_out/Y2018/M08/'+expid+'.ldas_obsparam.20180801_0000z.txt'
    obs_param = read_obs_param(fop)
    
    # Compute statistics
    N_data, data_sum, data2_sum, oxf_sum, oxa_sum, fxa_sum = \
           compute_monthly_stats(expdir,expid,domain,date_time,tc,obs_param,var_list)

# Cell 5 - Visualization (optional)
# Add visualization code here if needed

In [None]:
print('N_data:', N_data)
print('Max:', np.max(N_data))
print('Min:', np.min(N_data))
print('Mean:', np.mean(N_data))
print('Length:', N_data.shape)
print('tc[com_lon] length:', len(tc['com_lon']))
print('tc[com_lat] length:', len(tc['com_lat']))

import matplotlib.pyplot as plt

# Plot the observation data with obs_tilenum as color
plt.figure(figsize=(10, 6))
sc = plt.scatter(tc['com_lon'], tc['com_lat'], c=N_data[:,9], s=1, cmap='plasma', alpha=0.5)
plt.colorbar(sc, label="Observation Tile Number")
plt.title("Observation Data (Longitude vs Latitude) with Tile Number as Color")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.grid(True)
plt.show()