In [None]:
#Modify the following variables before each experiment
# exptname          output file name
# filename_climo    input file 64/256   3hr/12hr
# filename_truth    input file
# nobs              number of observation nobs/64*64 nobs/256*256
# pickarctan        0 is fully linear, nobs is fully nonlinear

from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
from netCDF4 import Dataset
import sys, time, os
from sqgturb import SQG, rfft2, irfft2, cartdist,enkf_update,gaspcohn, bulk_ensrf

from EnSF_Sparse_obs_dct import EnSF
from skimage.restoration import inpaint
import cv2
from scipy.stats import ortho_group
from sklearn import decomposition
import cvxpy as cp
#plot
import scipy.fft
# from scipy.fftpack import fft, dct, idct
import torch
import torch_dct as dct
from joblib import Parallel, delayed


def l1_solve(components_, masked_img, mask, lambda_reg):
    """
    Solve for vector v = (v1, ..., vN) with L1-regularization on v:
        minimize ||A1*v1 + A2*v2 + ... + AN*vN - B||_F^2 + lambda * ||v||_1
    """
    
    A = components_.T

    # Define the optimization variables
    v = cp.Variable(np.shape(A)[1])

    B = masked_img.ravel()
    mask_flatten = mask.ravel()

    # Define the objective function
    objective = cp.Minimize(
        cp.norm(cp.multiply(mask_flatten,(A @ v - B)), 2)**2 + lambda_reg * cp.norm(v, 1)
    )

    # Solve the optimization problem
    problem = cp.Problem(objective)
    problem.solve()

    return v.value

def random_indices_from_flatten(shape, num_samples):
    # Generate random flat indices
    flat_indices = torch.randperm(shape * shape)[:num_samples]
    
    # Convert to 2D indices
    rows = flat_indices // shape
    cols = flat_indices % shape
    return rows, cols

def unused_indices(shape,data_rows,data_cols):

    if len(data_rows) == shape **2:
        remaining_rows = torch.tensor([0])
        remaining_cols = torch.tensor([0])
    else:
        all_indices = torch.arange(shape * shape)  
        selected_indices = torch.cat([data_rows * shape + data_cols])
        selected_indices = data_rows * shape + data_cols
        remaining_indices = torch.tensor(list(set(all_indices.tolist()) - set(selected_indices.tolist())))
        remaining_rows = remaining_indices // shape
        remaining_cols = remaining_indices % shape

    return remaining_rows, remaining_cols
    
def k_largest_indices(tensor: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Find the indices of the k largest entries in a 2D tensor.
    
    Args:
        tensor: Input 2D tensor
        k: Number of largest entries to find
        
    Returns:
        tuple of (row indices, column indices) for the k largest entries
    """
    # Flatten the tensor and get indices of k largest values
    flat_indices = torch.topk(tensor.flatten(), k).indices
    
    # Convert flat indices back to 2D coordinates
    num_cols = tensor.size(1)
    row_indices = flat_indices // num_cols
    col_indices = flat_indices % num_cols
    #print(flat_indices)
    return row_indices, col_indices

def get_indices_with_sum_less_than_k(rows: int, cols: int, k: int) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Generate row and column indices where i + j < k for a rows × cols array
    
    Args:
        rows (int): Number of rows
        cols (int): Number of columns
        k (int): Threshold value for sum of indices
        
    Returns:
        tuple[torch.Tensor, torch.Tensor]: Row indices and column indices
    """
    # Create meshgrid of indices
    row_indices = torch.arange(rows).repeat_interleave(cols)
    col_indices = torch.arange(cols).repeat(rows)
    
    # Create mask where sum of indices is less than k
    mask = row_indices + col_indices < k
    
    # Return row and column indices where condition is met
    return row_indices[mask], col_indices[mask]

def get_ensemble_covariance(num_ensemble,ensemble):
    ensemble_mean = ensemble.mean(dim=0)
    ensemble_perturbations = ensemble - ensemble_mean  
    cov_matrix = (ensemble_perturbations.T @ ensemble_perturbations) / (num_ensemble - 1)
    return cov_matrix

def total_variation_loss(x, height=64, width=64):
    # x is expected to be of shape [batch_size, height, width]
    x = x.view(-1, 1, height, width)  # Reshape to [batch_size, 1, height, width]
    
    # Compute TV loss for each image in the batch
    tv_h = torch.mean(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]), dim=[1, 2, 3])
    tv_w = torch.mean(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]), dim=[1, 2, 3])
    
    # Return the total variation for each image
    return tv_h + tv_w

def inpaint_task(i):
    return inpaint.inpaint_biharmonic(masked_img[i], 1.0 - img_mask)

# Function for advancing models
def advance_task(nanal):
    return models[nanal].advance(pvens[nanal])

def process_image(i):
    return cv2.inpaint(cv2_raw_img[i], cv2_mask, 2, cv2.INPAINT_NS) / 255

# horizontal covariance localization length scale in meters.
hcovlocal_scale = 1 #float(sys.argv[1])

covinflate1 = 1
covinflate2 = -1
exptname = os.getenv('exptname','Arctan_EnSF_12hrly_64_cv2_25per')   #  Sparse  Arctan
threads = int(os.getenv('OMP_NUM_THREADS','1'))

diff_efold = None # use diffusion from climo file

profile = False # turn on profiling?

use_letkf = True  # use LETKF
global_enkf = False # global EnSRF solve
read_restart = False
# if savedata not None, netcdf filename will be defined by env var 'exptname'
# if savedata = 'restart', only last time is saved (so expt can be restarted)
#savedata = True 
#savedata = 'restart'
savedata = None
#nassim = 101 
#nassim_spinup = 1
nassim = 100 # assimilation times to run
nassim_spinup = 100

direct_insertion = False 
if direct_insertion: print('# direct insertion!')

nanals = 20 # ensemble members

oberrstdev = 1. # ob error standard deviation in K
oberrstdev_arctan = 0.01

# nature run created using sqg_run.py.
filename_climo = 'sqg_N64_12hrly.nc' # file name for forecast model climo
# perfect model
filename_truth = 'sqg_N64_12hrly.nc' # file name for nature run to draw obs

print('# filename_modelclimo=%s' % filename_climo)
print('# filename_truth=%s' % filename_truth)

# fix random seed for reproducibility.
rsobs = np.random.RandomState(42) # fixed seed for observations
rsics = np.random.RandomState() # varying seed for initial conditions
rsarctan = np.random.RandomState(98) # fixed seed for observations
rsjump = np.random.RandomState(10) # fixed seed for observations

# get model info
nc_climo = Dataset(filename_climo)
# parameter used to scale PV to temperature units.
scalefact = nc_climo.f*nc_climo.theta0/nc_climo.g
# initialize qg model instances for each ensemble member.
x = nc_climo.variables['x'][:]
y = nc_climo.variables['y'][:]
x, y = np.meshgrid(x, y)
nx = len(x); ny = len(y)
dt = nc_climo.dt
if diff_efold == None: diff_efold=nc_climo.diff_efold
pvens = np.empty((nanals,2,ny,nx),np.float32)
if not read_restart:
    pv_climo = nc_climo.variables['pv']
    indxran = rsics.choice(pv_climo.shape[0],size=nanals,replace=False)
else:
    ncinit = Dataset('%s_restart.nc' % exptname, mode='r', format='NETCDF4_CLASSIC')
    ncinit.set_auto_mask(False)
    pvens[:] = ncinit.variables['pv_b'][-1,...]/scalefact
    tstart = ncinit.variables['t'][-1]
    #for nanal in range(nanals):
    #    print(nanal, pvens[nanal].min(), pvens[nanal].max())
# get OMP_NUM_THREADS (threads to use) from environment.
models = []
for nanal in range(nanals):
    if not read_restart:
        pvens[nanal] = pv_climo[indxran[nanal]]
        #print(nanal, pvens[nanal].min(), pvens[nanal].max())
    pvens[nanal] = pv_climo[0] + np.random.normal(0,1000,size=(2,ny,nx))
    models.append(\
    SQG(pvens[nanal],
    nsq=nc_climo.nsq,f=nc_climo.f,dt=dt,U=nc_climo.U,H=nc_climo.H,\
    r=nc_climo.r,tdiab=nc_climo.tdiab,symmetric=nc_climo.symmetric,\
    diff_order=nc_climo.diff_order,diff_efold=diff_efold,threads=threads))
if read_restart: ncinit.close()


# vertical localization scale
Lr = np.sqrt(models[0].nsq)*models[0].H/models[0].f
vcovlocal_fact = gaspcohn(np.array(Lr/hcovlocal_scale))
#vcovlocal_fact = 0.0 # no increment at opposite boundary
#vcovlocal_fact = 1.0 # no vertical localization

print('# use_letkf=%s global_enkf=%s' % (use_letkf,global_enkf))
print("# hcovlocal=%g vcovlocal=%s diff_efold=%s covinf1=%s covinf2=%s nanals=%s" %\
     (hcovlocal_scale/1000.,vcovlocal_fact,diff_efold,covinflate1,covinflate2,nanals))

# if nobs > 0, each ob time nobs ob locations are randomly sampled (without
# replacement) from the model grid
# if nobs < 0, fixed network of every Nth grid point used (N = -nobs)
nobs = -1024                                                         #2048  #nx*ny//16 # number of obs to assimilate (randomly distributed)
#nobs = -1 # fixed network, every -nobs grid points. nobs=-1 obs at all pts.
print('Obs Precentage is ', nobs/(nx*ny))
# nature run
nc_truth = Dataset(filename_truth)
pv_truth = nc_truth.variables['pv']
# set up arrays for obs and localization function
if nobs < 0:
    nskip = -nobs
    # if (nx*ny)%nobs != 0:
    #     raise ValueError('nx*ny must be divisible by nobs')
    #nobs = (nx*ny)//nskip**2
    fixed = True
    nobs = nskip
    print('# fixed network nobs = %s' % nobs)
else:
    fixed = False
    print('# random network nobs = %s' % nobs)
#if nobs == nx*ny//2: fixed=True # used fixed network for obs every other grid point
print('fixed is', fixed)
oberrvar = oberrstdev**2*np.ones(nobs,float)
pvob = np.empty((2,nobs),float)
covlocal = np.empty((ny,nx),float)
covlocal_tmp = np.empty((nobs,nx*ny),float)
xens = np.empty((nanals,2,nx*ny),float)
if not use_letkf:
    obcovlocal = np.empty((nobs,nobs),float)
else:
    obcovlocal = None

if global_enkf: # model-space localization matrix
    n = 0
    covlocal_modelspace = np.empty((nx*ny,nx*ny),float)
    x1 = x.reshape(nx*ny); y1 = y.reshape(nx*ny)
    for n in range(nx*ny):
        dist = cartdist(x1[n],y1[n],x1,y1,nc_climo.L,nc_climo.L)
        covlocal_modelspace[n,:] = gaspcohn(dist/hcovlocal_scale)

obtimes = nc_truth.variables['t'][:]
if read_restart:
    timeslist = obtimes.tolist()
    ntstart = timeslist.index(tstart)
    print('# restarting from %s.nc ntstart = %s' % (exptname,ntstart))
else:
    ntstart = 0
assim_interval = obtimes[1]-obtimes[0]
assim_timesteps = int(np.round(assim_interval/models[0].dt))
print('# assim interval = %s secs (%s time steps)' % (assim_interval,assim_timesteps))
print('# ntime,pverr_a,pvsprd_a,pverr_b,pvsprd_b,obinc_b,osprd_b,obinc_a,obsprd_a,omaomb/oberr,obbias_b,inflation,tr(P^a)/tr(P^b)')

# initialize model clock
for nanal in range(nanals):
    models[nanal].t = obtimes[ntstart]
    models[nanal].timesteps = assim_timesteps

# initialize output file.
if savedata is not None:
   nc = Dataset('%s.nc' % exptname, mode='w', format='NETCDF4_CLASSIC')
   nc.r = models[0].r
   nc.f = models[0].f
   nc.U = models[0].U
   nc.L = models[0].L
   nc.H = models[0].H
   nc.nanals = nanals
   nc.hcovlocal_scale = hcovlocal_scale
   nc.vcovlocal_fact = vcovlocal_fact
   nc.oberrstdev = oberrstdev
   nc.g = nc_climo.g; nc.theta0 = nc_climo.theta0
   nc.nsq = models[0].nsq
   nc.tdiab = models[0].tdiab
   nc.dt = models[0].dt
   nc.diff_efold = models[0].diff_efold
   nc.diff_order = models[0].diff_order
   nc.filename_climo = filename_climo
   nc.filename_truth = filename_truth
   nc.symmetric = models[0].symmetric
   xdim = nc.createDimension('x',models[0].N)
   ydim = nc.createDimension('y',models[0].N)
   z = nc.createDimension('z',2)
   t = nc.createDimension('t',None)
   obs = nc.createDimension('obs',nobs)
   ens = nc.createDimension('ens',nanals)
   pv_t =\
   nc.createVariable('pv_t',np.float32,('t','z','y','x'),zlib=True)
   pv_c =\
   nc.createVariable('pv_c',np.float32,('t','ens','z','y','x'),zlib=True)
   pv_b =\
   nc.createVariable('pv_b',np.float32,('t','ens','z','y','x'),zlib=True)
   pv_a =\
   nc.createVariable('pv_a',np.float32,('t','ens','z','y','x'),zlib=True)
   pv_a.units = 'K'
   pv_b.units = 'K'
   pv_c.units = 'K'
   inf = nc.createVariable('inflation',np.float32,('t','z','y','x'),zlib=True)
   pv_obs = nc.createVariable('obs',np.float32,('t','obs'))
   x_obs = nc.createVariable('x_obs',np.float32,('t','obs'))
   y_obs = nc.createVariable('y_obs',np.float32,('t','obs'))
   # eady pv scaled by g/(f*theta0) so du/dz = d(pv)/dy
   xvar = nc.createVariable('x',np.float32,('x',))
   xvar.units = 'meters'
   yvar = nc.createVariable('y',np.float32,('y',))
   yvar.units = 'meters'
   zvar = nc.createVariable('z',np.float32,('z',))
   zvar.units = 'meters'
   tvar = nc.createVariable('t',np.float32,('t',))
   tvar.units = 'seconds'
   ensvar = nc.createVariable('ens',np.int32,('ens',))
   ensvar.units = 'dimensionless'
   xvar[:] = np.arange(0,models[0].L,models[0].L/models[0].N)
   yvar[:] = np.arange(0,models[0].L,models[0].L/models[0].N)
   zvar[0] = 0; zvar[1] = models[0].H
   ensvar[:] = np.arange(1,nanals+1)

# initialize kinetic energy error/spread spectra
kespec_errmean = None; kespec_sprdmean = None

ncount = 0
nanals2 = 4 # ensemble members used for kespec spread

init_std_x_state = (pvens.reshape(nanals,2*nx*ny)).std(axis=0)

#Jump
jump = 3000
jumpnoise = rsjump.normal(0,900,size=(1,2,ny,nx)) 
jumpnoise_reshape =jumpnoise.reshape(2,ny*nx) 


indxob = np.sort(rsobs.choice(nx*ny,nobs,replace=False))
indx_unob = np.setdiff1d(np.arange(nx*ny), indxob)
obs_save = np.zeros((100,nobs))

for ntime in range(100): #nassim

    # check model clock
    # if models[0].t != obtimes[ntime+ntstart]:
    #     raise ValueError('model/ob time mismatch %s vs %s' %\
    #     (models[0].t, obtimes[ntime+ntstart]))

    t1 = time.time()
    if not fixed:
        if ntime == 0:
            print('RRRRRRR')
        # randomly choose points from model grid
        if nobs == nx*ny:
            indxob = np.arange(nx*ny)
            indx_unob = np.setdiff1d(np.arange(nx*ny), indxob)
        else:
            indxob = np.sort(rsobs.choice(nx*ny,nobs,replace=False))
            indx_unob = np.setdiff1d(np.arange(nx*ny), indxob)
    else:
        if ntime == 0:
            print("not Random")
    obs_save[ntime] = indxob


    indxob_ensf = np.concatenate((indxob, indxob+nx*ny), axis=None)
    indxobs_rows = torch.tensor(indxob // (nx),device='cuda')
    indxobs_cols = torch.tensor(indxob % (ny),device='cuda')
    indx_unob_rows, indx_unob_cols = unused_indices(nx,indxobs_rows,indxobs_cols)
    indx_unob_ensf = np.concatenate((indx_unob, indx_unob+nx*ny), axis=None)
    
    pickarctan =  nobs             #  0#                                                              ##############################################################
    if ntime == 0:
        print('ARCTAN  percetage is',pickarctan/nobs*100)
    arctan_index =  np.sort(rsarctan.choice(nobs, pickarctan, replace=False)) #2048 1024 4096
    arctan_index_ensf = np.concatenate((arctan_index, arctan_index+nobs), axis=None) 
    #print(ntime+ntstart)
    for k in range(2):
        # surface temp obs
        if ntime >= jump and ntime <= jump+2:
            pvob[k] = scalefact*(pv_truth[ntime+ntstart,k,:,:].ravel()[indxob] + jumpnoise_reshape[k,indxob])
        else:
            pvob[k] = scalefact*pv_truth[ntime+ntstart,k,:,:].ravel()[indxob]
        pvob[k,arctan_index] = np.arctan(pvob[k,arctan_index]) + rsobs.normal(scale=oberrstdev_arctan,size=int(pickarctan)) # add ob errors int(nobs/2)
        pvob[k,~np.isin(np.arange(len(pvob[k])), arctan_index)] += rsobs.normal(scale=oberrstdev,size=int(nobs - pickarctan)) # add ob errors
    
  
    pvensmean_b = pvens.mean(axis=0).copy()

    if ntime >= jump and ntime <= jump+2:
        pverr_b = (scalefact*(pvensmean_b-pv_truth[ntime+ntstart]+jumpnoise))**2
    else:
        pverr_b = (scalefact*(pvensmean_b-pv_truth[ntime+ntstart]))**2
    pvsprd_b = ((scalefact*(pvensmean_b-pvens))**2).sum(axis=0)/(nanals-1)

    if savedata is not None:
        if savedata == 'restart' and ntime != nassim-1:
            pass
        else:
            pv_t[ntime] = pv_truth[ntime+ntstart]
            pv_b[ntime,:,:,:] = scalefact*pvens
            #pv_obs[ntime] = pvob
            #x_obs[ntime] = xob
            #y_obs[ntime] = yob

    # EnKF update
    EnSF_Update_obs = EnSF(n_dim = nobs*2, ensemble_size = nanals ,eps_alpha=0.05, device= 'cuda' ,\
                   obs_sigma = oberrstdev, euler_steps = 1000, scalefact = nc_climo.f*nc_climo.theta0/nc_climo.g, init_std_x_state = init_std_x_state[indxob_ensf],  ISarctan=False)
    EnSF_Update_unobs = EnSF(n_dim = nx*ny*2 - nobs*2, ensemble_size = nanals ,eps_alpha=0.05, device= 'cuda' ,\
                obs_sigma = oberrstdev, euler_steps = 1000, scalefact = nc_climo.f*nc_climo.theta0/nc_climo.g, init_std_x_state = init_std_x_state[indx_unob_ensf],  ISarctan=False)


    # create 1d state vector.
    xens = pvens.reshape(nanals,2*nx*ny).copy()
    # update state vector.
    xens_obs =\
    EnSF_Update_obs.state_update_normalized(x_input = xens[:,indxob_ensf],obs_input = pvob.reshape(2*nobs), arcindex=arctan_index_ensf)
    xens[:,indxob_ensf] = xens_obs.cpu().numpy()
    xens = xens.reshape(nanals,2,nx,ny)
    
    
    #
    img_mask = np.zeros((ny, nx))
    img_mask[indxobs_rows.cpu().numpy(),indxobs_cols.cpu().numpy()] = 1.0
    masked_img = xens.reshape(40,nx,ny)* scalefact
    masked_img_min = np.min(masked_img,axis = (1,2))[:, None, None] #
    masked_img_max = np.max(masked_img,axis = (1,2))[:, None, None]  #
    masked_img = (masked_img - masked_img_min)/ (masked_img_max - masked_img_min)
    masked_img = masked_img * img_mask
    recovered_img = np.zeros((40,nx,ny))
    #cv2
    cv2_raw_img = (masked_img * 255.0).astype(np.uint8)
    cv2_mask = ((1-img_mask) * 255.0).astype(np.uint8)

    results_img = Parallel(n_jobs=-1)(delayed(process_image)(i) for i in range(40))
    recovered_img[:] = np.array(results_img)

    recovered_img = recovered_img * (masked_img_max - masked_img_min) + masked_img_min

    xens = xens.reshape(nanals,2*nx*ny)
    unobs = recovered_img.reshape(nanals,2*nx*ny)[:,indx_unob_ensf] #* scalefact
    xens_unobs =\
    EnSF_Update_unobs.un_state_update_normalized(x_input = xens[:,indx_unob_ensf],obs_input = unobs, arcindex=np.array([]))
    xens[:,indx_unob_ensf] = xens_unobs.cpu().numpy()
    # back to 3d state vector
    pvens = xens.reshape((nanals,2,ny,nx)).copy()

    t2 = time.time()
    #print('cpu time for EnKF update',t2-t1)

    if savedata is not None:
        if savedata == 'restart' and ntime != nassim-1:
            pass
        else:
            pv_c[ntime,:,:,:] = scalefact*pvens
    pvensmean_a = pvens.mean(axis=0)
    # print out analysis error, spread and innov stats for background
    if ntime >= jump and ntime <= jump+2:
        print('jumpppp')
        pverr_a = (scalefact*(pvensmean_a-pv_truth[ntime+ntstart]-jumpnoise))**2
        print(ntstart)
    else:
        pverr_a = (scalefact*(pvensmean_a-pv_truth[ntime+ntstart]))**2
    pvsprd_a = ((scalefact*(pvensmean_a-pvens))**2).sum(axis=0)/(nanals-1)
    print("%s %g %g %g %g " %\
    (ntime+ntstart,np.sqrt(pverr_a.mean()),np.sqrt(pvsprd_a.mean()),np.sqrt(pverr_b.mean()),np.sqrt(pvsprd_b.mean())))

    # save data.
    if savedata is not None:
        if savedata == 'restart' and ntime != nassim-1:
            pass
        else:
            pv_a[ntime,:,:,:] = scalefact*pvens
            tvar[ntime] = obtimes[ntime+ntstart]
            #inf[ntime] = inflation_factor
            nc.sync()

    # run forecast ensemble to next analysis time
    t1 = time.time()
    # for nanal in range(nanals):
    #     pvens[nanal] = models[nanal].advance(pvens[nanal])
    pvens_updated = Parallel(n_jobs=-1)(delayed(advance_task)(nanal) for nanal in range(nanals))
    for nanal, updated_pven in enumerate(pvens_updated):
        pvens[nanal] = updated_pven
    t2 = time.time()
    if ntime < 10:
        print('cpu time for ens forecast',t2-t1)
   
if savedata: nc.close()
