## `spatial_corrmap` function and supporting functions

In [9]:
import time
from matplotlib import pylab as plt
import xarray as xr
import numpy as np
from numpy import zeros,arange
from scipy.optimize import leastsq
import numpy as np
import sys
import os
import tempfile
import earthaccess


def spatial_corrmap(granule_ssh, lat_halfwin, lon_halfwin, lats=None, lons=None, f_notnull=0.5):
    """
    Get a 2D map of SSH-SST spatial correlation coefficients. The SSH dataset is 
    shortname SEA_SURFACE_HEIGHT_ALT_GRIDS_L4_2SATS_5DAY_6THDEG_V_JPL1812 
    and the SST dataset is shortname MW_OI-REMSS-L4-GLOB-v5.0. 

    At each gridpoint, the spatial correlation is computed over a lat, lon window
    of size 2*lat_halfwin x 2*lon_halfwin. This is done for each gridpoint in
    the datasets. Spatial correlation is computed from the SSH, SST anomalies, 
    which are computed in turn as the deviations from a fitted 2D surface over 
    the window (a new 2D surface is fitted for each window around each gridpoint).
    
    Inputs
    ------
    granule_ssh: earthaccess.results.DataGranule
        Granule info for theSSH file, including download path and S3 location.  
    lat_halfwin, lon_halfwin: floats
        Half window size in degrees for latitude and longitude dimensions, respectively.
    lats, lons: None or 1D array-like
        These make up the latitude, longitude grid on which to compute correlations. 
        If None, will default to using the grid of the SSH product. Note that regardless
        of the lats, lons passed to this function, it will still use the gridpoints 
        of the SSH product for the actual computations.
    f_notnull: float (default = 50)
        Fraction of elements in a window that have to be non-nan. Percentage is computed
        as number of null elements divided by total number expected to be in the window. So
        for edge cases, 'ghost' elements at the edges are considered nan.

    Returns
    ------
    coef: 2D numpy array
        Spatial correlation coefficients.
    
    lats, lons: 1D numpy arrays.
        Latitudes and longitudes creating the 2D grid that 'coef' was calculated on.
    """    
    # Load datafiles, convert SST lon to (0,360) bounds, and interpolate SST to SSH grid:    
    ssh,sst = load_sst_ssh(granule_ssh)
    sst = sst.roll(lon=len(sst['lon'])//2)
    sst['lon'] = sst['lon']+180
    sst = sst.interp(lon=ssh['Longitude']).interp(lat=ssh['Latitude'])

    
    # Compute windows size and threshold number of non-nan points:
    dlat = (ssh['Latitude'][1]-ssh['Latitude'][0]).item()
    dlon = (ssh['Longitude'][1]-ssh['Longitude'][0]).item()
    nx_win = 2*round(lon_halfwin/dlon)
    ny_win = 2*round(lat_halfwin/dlat)
    n_thresh = nx_win*ny_win*f_notnull


    # Map of booleans for sst*ssh==np.nan. Will be used to determine if there are 
    # enough non-nan values to compute the correlation for a given window:
    notnul = (sst*ssh).notnull()

    
    # Compute spatial correlations over whole map:
    coef = []
    
    if lats is None:
        lats = ssh['Latitude'].data
        lons = ssh['Longitude'].data
    
    for lat_cen in lats:
        for lon_cen in lons:

            # Create window for both sst and ssh with xr.sel:
            lat_bottom = lat_cen - lat_halfwin
            lat_top = lat_cen + lat_halfwin
            lon_left = lon_cen - lon_halfwin
            lon_right = lon_cen + lon_halfwin
            ssh_win = ssh.sel(Longitude=slice(lon_left, lon_right), Latitude=slice(lat_bottom, lat_top))
            sst_win = sst.sel(Longitude=slice(lon_left, lon_right), Latitude=slice(lat_bottom, lat_top))
    
            # If number of non-nan values in sst*ssh window is less than threshold 
            # value, append np.nan, else compute anomalies and append their correlation coefficient:
            notnul_win = notnul.sel(Longitude=slice(lon_left, lon_right), Latitude=slice(lat_bottom, lat_top))
            n_notnul = notnul_win.sum().item()
            if n_notnul < n_thresh:
                coef.append(np.nan)
            else:
                # Compute anomalies:
                ssha,_=anomaly(ssh_win['Longitude'], ssh_win['Latitude'], ssh_win.data)
                ssta,_=anomaly(sst_win['Longitude'], sst_win['Latitude'], sst_win.data)
                
                # Compute correlation coefficient:
                a, b = ssta.flatten(), ssha.flatten()
                if ( np.nansum(abs(a))==0 ) or ( np.nansum(abs(b))==0 ): # There are some cases where all anomalies for one var are 0.
                    coef.append(0) # In this case, correlation should be 0. Numpy will compute this correctly, but will also throw a lot of warnings.
                else:
                    c = np.nanmean(a*b)/np.sqrt(np.nanvar(a) * np.nanvar(b))
                    coef.append(c)
        
            
    return np.array(coef).reshape((len(lats), len(lons))), np.array(lats), np.array(lons)


def load_sst_ssh(granule_ssh):
    """
    Return data for a single file each of SSH and SST on the same day. 
    Input arg is SSH granule info (earthaccess.results.DataGranule object) 
    for a file from the SEA_SURFACE_HEIGHT_ALT_GRIDS_L4_2SATS_5DAY_6THDEG_V_JPL1812 
    collection. Output is SLA data from the SSH granule along with SST data from the 
    MW_OI-REMSS-L4-GLOB-v5.0 collection, at timestamp noon UTC of the same day. 
    Returns ssh, sst as xarray.DataArray's.    
    """  
    # Get SLA variable from file, loaded fully into local memory:
    ssh = xr.load_dataset(earthaccess.open([granule_ssh])[0])['SLA'][0,...]
    
    # Get SST variable from SST file on same day as ssh_fn and at noon. Again, fully loaded into local memory:
    date = granule_ssh['umm']['GranuleUR'].split('_')[-1][:8]
    s3path_sst = 's3://podaac-ops-cumulus-protected/MW_OI-REMSS-L4-GLOB-v5.0/%s120000-REMSS-L4_GHRSST-SSTfnd-MW_OI-GLOB-v02.0-fv05.0.nc'%date
    sst = xr.load_dataset(earthaccess.open([s3path_sst], provider='POCLOUD')[0])['analysed_sst'][0,...]
    
    return ssh, sst


def anomaly(lon, lat, p):
    """
    Get anomalies for a variable over a 2D map. Done by fitting a 2D surface 
    to the data and taking the anomaly as the difference between each data point 
    and the surface. 
    
    This is mostly a wrapper for fit2Dsurf() which does the anomaly calculation. 
    This wrapper could be extended e.g. to take inputs with various shapes and 
    reformat them to work with fit2Dsurf(), but curretly has basic functionality.

    Inputs
    ------
    lon, lat: 1D array-like
        Longitude and latitude arrays, or more generally, the x, y coordinates (don't 
        need to have units of degrees e.g.).
    p: 2D array-like
        Variable to get anomalies for. Should have same shape as (lat, lon). 

    Returns
    -------
    va, vm: 2D NumPy arrays
        Anomalies (va) and mean surface fit (vm).

    Import requirements
    -------------------
    numpy    
    """
    x1, y1 = np.meshgrid(lon, lat)
    va, vm = fit2Dsurf(x1, y1, p)
    return va,vm


def fit2Dsurf(x, y, p, kind='linear'):
    """
    Get anomalies for a variable over a 2D map. Done by fitting a 2D surface 
    to the data and taking the anomaly as the difference between each data point 
    and the surface. Surface can either be a linear or quadratic function.
    
    Inputs
    ------
    x, y, p: 2D array-like, all same size.
        Variables to use to fit the function p(x, y). x, y are the dependent vars
        and p is the dependent var.
    kind: str
        (Default = 'linear'). Either 'linear' or 'quadratic' to specify the 
        functional form of the fit surface.
    
    Returns
    ------
    va, vm: 2D NumPy arrays
        Anomalies (va) and mean surface fit (vm).

    Import requirements
    -------------------
    from scipy.optimize import leastsq
    numpy
    """
    # Depending on fit fxn chosen, define functions to output a 2D surface (surface()) 
    # and the difference between 2D data and the computed surface (err()). Each 
    # fxn takes independent vars and polynomial coefficients; the err() fxn's in 
    # addition take data for the dependent var.
    if kind=='linear':
        def err(c,x0,y0,p):
            a,b,c=c
            return p - (a + b*x0 + c*y0 )

        def surface(c,x0,y0):
            a,b,c=c
            return a + b*x0 + c*y0

    if kind=='quadratic':
        def err(c,x0,y0,p):
            a,b,c,d,e,f=c
            return p - (a + b*x0 + c*y0 + d*x0**2 + e*y0**2 + f*x0*y0)
        
        def surface(c,x0,y0):
            a,b,c,d,e,f=c
            return a + b*x0 + c*y0 + d*x0**2 + e*y0**2 + f*x0*y0


    # Prep arrays and remove NAN's:
    xf=x.flatten()
    yf=y.flatten()
    pf=p.flatten()

    msk=~np.isnan(pf)
    pf=pf[msk]
    xf=xf[msk]
    yf=yf[msk]

    
    # Initial values of polynomial coefficients to start fitting algorithm off with:
    dpdx=(pf.max()-pf.min())/(xf.max()-xf.min())
    dpdy=(pf.max()-pf.min())/(yf.max()-yf.min())
    if kind=='linear':
        c = [pf.mean(),dpdx,dpdy]
    if kind=='quadratic':
        c = [pf.mean(),dpdx,dpdy,1e-22,1e-22,1e-22]


    # Fit:
    coef = leastsq(err,c,args=(xf,yf,pf))[0]
    vm = surface(coef,x,y) #mean surface
    va = p - vm #anomaly
    return va,vm

### Copy the following block of code to a .py file *"ssh_sst_correlation_test.py"*, with the `@profile` decorator on the `spatial_corrmap()` function
The code includes all the functions along with a single call to the `spatial_corrmap` function with specified args.

import time
from matplotlib import pylab as plt
import xarray as xr
import numpy as np
from numpy import zeros,arange
from scipy.optimize import leastsq
import numpy as np
import sys
import os
import tempfile
import earthaccess


@profile
def spatial_corrmap(granule_ssh, lat_halfwin, lon_halfwin, lats=None, lons=None, f_notnull=0.5):
    """
    Get a 2D map of SSH-SST spatial correlation coefficients. The SSH dataset is 
    shortname SEA_SURFACE_HEIGHT_ALT_GRIDS_L4_2SATS_5DAY_6THDEG_V_JPL1812 
    and the SST dataset is shortname MW_OI-REMSS-L4-GLOB-v5.0. 

    At each gridpoint, the spatial correlation is computed over a lat, lon window
    of size 2*lat_halfwin x 2*lon_halfwin. This is done for each gridpoint in
    the datasets. Spatial correlation is computed from the SSH, SST anomalies, 
    which are computed in turn as the deviations from a fitted 2D surface over 
    the window (a new 2D surface is fitted for each window around each gridpoint).
    
    Inputs
    ------
    granule_ssh: earthaccess.results.DataGranule
        Granule info for theSSH file, including download path and S3 location.  
    lat_halfwin, lon_halfwin: floats
        Half window size in degrees for latitude and longitude dimensions, respectively.
    lats, lons: None or 1D array-like
        These make up the latitude, longitude grid on which to compute correlations. 
        If None, will default to using the grid of the SSH product. Note that regardless
        of the lats, lons passed to this function, it will still use the gridpoints 
        of the SSH product for the actual computations.
    f_notnull: float (default = 50)
        Fraction of elements in a window that have to be non-nan. Percentage is computed
        as number of null elements divided by total number expected to be in the window. So
        for edge cases, 'ghost' elements at the edges are considered nan.

    Returns
    ------
    coef: 2D numpy array
        Spatial correlation coefficients.
    
    lats, lons: 1D numpy arrays.
        Latitudes and longitudes creating the 2D grid that 'coef' was calculated on.
    """    
    # Load datafiles, convert SST lon to (0,360) bounds, and interpolate SST to SSH grid:    
    ssh,sst = load_sst_ssh(granule_ssh)
    sst = sst.roll(lon=len(sst['lon'])//2)
    sst['lon'] = sst['lon']+180
    sst = sst.interp(lon=ssh['Longitude']).interp(lat=ssh['Latitude'])

    
    # Compute windows size and threshold number of non-nan points:
    dlat = (ssh['Latitude'][1]-ssh['Latitude'][0]).item()
    dlon = (ssh['Longitude'][1]-ssh['Longitude'][0]).item()
    nx_win = 2*round(lon_halfwin/dlon)
    ny_win = 2*round(lat_halfwin/dlat)
    n_thresh = nx_win*ny_win*f_notnull


    # Map of booleans for sst*ssh==np.nan. Will be used to determine if there are 
    # enough non-nan values to compute the correlation for a given window:
    notnul = (sst*ssh).notnull()

    
    # Compute spatial correlations over whole map:
    coef = []
    
    if lats is None:
        lats = ssh['Latitude'].data
        lons = ssh['Longitude'].data
    
    for lat_cen in lats:
        for lon_cen in lons:

            # Create window for both sst and ssh with xr.sel:
            lat_bottom = lat_cen - lat_halfwin
            lat_top = lat_cen + lat_halfwin
            lon_left = lon_cen - lon_halfwin
            lon_right = lon_cen + lon_halfwin
            ssh_win = ssh.sel(Longitude=slice(lon_left, lon_right), Latitude=slice(lat_bottom, lat_top))
            sst_win = sst.sel(Longitude=slice(lon_left, lon_right), Latitude=slice(lat_bottom, lat_top))
    
            # If number of non-nan values in sst*ssh window is less than threshold 
            # value, append np.nan, else compute anomalies and append their correlation coefficient:
            notnul_win = notnul.sel(Longitude=slice(lon_left, lon_right), Latitude=slice(lat_bottom, lat_top))
            n_notnul = notnul_win.sum().item()
            if n_notnul < n_thresh:
                coef.append(np.nan)
            else:
                # Compute anomalies:
                ssha,_=anomaly(ssh_win['Longitude'], ssh_win['Latitude'], ssh_win.data)
                ssta,_=anomaly(sst_win['Longitude'], sst_win['Latitude'], sst_win.data)
                
                # Compute correlation coefficient:
                a, b = ssta.flatten(), ssha.flatten()
                if ( np.nansum(abs(a))==0 ) or ( np.nansum(abs(b))==0 ): # There are some cases where all anomalies for one var are 0.
                    coef.append(0) # In this case, correlation should be 0. Numpy will compute this correctly, but will also throw a lot of warnings.
                else:
                    c = np.nanmean(a*b)/np.sqrt(np.nanvar(a) * np.nanvar(b))
                    coef.append(c)
        
            
    return np.array(coef).reshape((len(lats), len(lons))), np.array(lats), np.array(lons)


def load_sst_ssh(granule_ssh):
    """
    Return data for a single file each of SSH and SST on the same day. 
    Input arg is SSH granule info (earthaccess.results.DataGranule object) 
    for a file from the SEA_SURFACE_HEIGHT_ALT_GRIDS_L4_2SATS_5DAY_6THDEG_V_JPL1812 
    collection. Output is SLA data from the SSH granule along with SST data from the 
    MW_OI-REMSS-L4-GLOB-v5.0 collection, at timestamp noon UTC of the same day. 
    Returns ssh, sst as xarray.DataArray's.    
    """  
    # Get SLA variable from file, loaded fully into local memory:
    ssh = xr.load_dataset(earthaccess.open([granule_ssh])[0])['SLA'][0,...]
    
    # Get SST variable from SST file on same day as ssh_fn and at noon. Again, fully loaded into local memory:
    date = granule_ssh['umm']['GranuleUR'].split('_')[-1][:8]
    s3path_sst = 's3://podaac-ops-cumulus-protected/MW_OI-REMSS-L4-GLOB-v5.0/%s120000-REMSS-L4_GHRSST-SSTfnd-MW_OI-GLOB-v02.0-fv05.0.nc'%date
    sst = xr.load_dataset(earthaccess.open([s3path_sst], provider='POCLOUD')[0])['analysed_sst'][0,...]
    
    return ssh, sst


def anomaly(lon, lat, p):
    """
    Get anomalies for a variable over a 2D map. Done by fitting a 2D surface 
    to the data and taking the anomaly as the difference between each data point 
    and the surface. 
    
    This is mostly a wrapper for fit2Dsurf() which does the anomaly calculation. 
    This wrapper could be extended e.g. to take inputs with various shapes and 
    reformat them to work with fit2Dsurf(), but curretly has basic functionality.

    Inputs
    ------
    lon, lat: 1D array-like
        Longitude and latitude arrays, or more generally, the x, y coordinates (don't 
        need to have units of degrees e.g.).
    p: 2D array-like
        Variable to get anomalies for. Should have same shape as (lat, lon). 

    Returns
    -------
    va, vm: 2D NumPy arrays
        Anomalies (va) and mean surface fit (vm).

    Import requirements
    -------------------
    numpy    
    """
    x1, y1 = np.meshgrid(lon, lat)
    va, vm = fit2Dsurf(x1, y1, p)
    return va,vm


def fit2Dsurf(x, y, p, kind='linear'):
    """
    Get anomalies for a variable over a 2D map. Done by fitting a 2D surface 
    to the data and taking the anomaly as the difference between each data point 
    and the surface. Surface can either be a linear or quadratic function.
    
    Inputs
    ------
    x, y, p: 2D array-like, all same size.
        Variables to use to fit the function p(x, y). x, y are the dependent vars
        and p is the dependent var.
    kind: str
        (Default = 'linear'). Either 'linear' or 'quadratic' to specify the 
        functional form of the fit surface.
    
    Returns
    ------
    va, vm: 2D NumPy arrays
        Anomalies (va) and mean surface fit (vm).

    Import requirements
    -------------------
    from scipy.optimize import leastsq
    numpy
    """
    # Depending on fit fxn chosen, define functions to output a 2D surface (surface()) 
    # and the difference between 2D data and the computed surface (err()). Each 
    # fxn takes independent vars and polynomial coefficients; the err() fxn's in 
    # addition take data for the dependent var.
    if kind=='linear':
        def err(c,x0,y0,p):
            a,b,c=c
            return p - (a + b*x0 + c*y0 )

        def surface(c,x0,y0):
            a,b,c=c
            return a + b*x0 + c*y0

    if kind=='quadratic':
        def err(c,x0,y0,p):
            a,b,c,d,e,f=c
            return p - (a + b*x0 + c*y0 + d*x0**2 + e*y0**2 + f*x0*y0)
        
        def surface(c,x0,y0):
            a,b,c,d,e,f=c
            return a + b*x0 + c*y0 + d*x0**2 + e*y0**2 + f*x0*y0


    # Prep arrays and remove NAN's:
    xf=x.flatten()
    yf=y.flatten()
    pf=p.flatten()

    msk=~np.isnan(pf)
    pf=pf[msk]
    xf=xf[msk]
    yf=yf[msk]

    
    # Initial values of polynomial coefficients to start fitting algorithm off with:
    dpdx=(pf.max()-pf.min())/(xf.max()-xf.min())
    dpdy=(pf.max()-pf.min())/(yf.max()-yf.min())
    if kind=='linear':
        c = [pf.mean(),dpdx,dpdy]
    if kind=='quadratic':
        c = [pf.mean(),dpdx,dpdy,1e-22,1e-22,1e-22]


    # Fit:
    coef = leastsq(err,c,args=(xf,yf,pf))[0]
    vm = surface(coef,x,y) #mean surface
    va = p - vm #anomaly
    return va,vm


if __name__=="__main__":
    earthaccess.login()
    granules_ssh = earthaccess.search_data(
        short_name="SEA_SURFACE_HEIGHT_ALT_GRIDS_L4_2SATS_5DAY_6THDEG_V_JPL2205",
        temporal=("2010-01-01", "2011-12-31"),
        )
    lats = np.arange(-80, 80, 2)
    lons = np.arange(0, 359, 2)
    coef, lats, lons = spatial_corrmap(granules_ssh[-1], 3, 3, lats=lats, lons=lons, f_notnull=0.5)

### Install the `line_profiler` tool and run it on the .py file

In [10]:
!pip install line_profiler



In [11]:
!kernprof -lv ssh_sst_correlation_test.py

Granules found: 146
 Opening 1 granules, approx size: 0.01 GB
using endpoint: https://archive.podaac.earthdata.nasa.gov/s3credentials
QUEUEING TASKS | : 1it [00:00, 776.87it/s]
PROCESSING TASKS | : 100%|████████████████████████| 1/1 [00:00<00:00,  4.62it/s]
COLLECTING RESULTS | : 100%|███████████████████| 1/1 [00:00<00:00, 20971.52it/s]
QUEUEING TASKS | : 1it [00:00, 2487.72it/s]
PROCESSING TASKS | : 100%|████████████████████████| 1/1 [00:00<00:00, 66.34it/s]
COLLECTING RESULTS | : 100%|███████████████████| 1/1 [00:00<00:00, 24105.20it/s]
Wrote profile results to ssh_sst_correlation_test.py.lprof
Timer unit: 1e-06 s

Total time: 48.5762 s
File: ssh_sst_correlation_test.py
Function: spatial_corrmap at line 142

Line #      Hits         Time  Per Hit   % Time  Line Contents
   142                                           @profile
   143                                           def spatial_corrmap(granule_ssh, lat_halfwin, lon_halfwin, lats=None, lons=None, f_notnull=0.5):
   144       

### Line profiler results summarized
* The three `.sel()` function calls (from the `Xarray` package) each take ~15% of the computation time, for a total of almost 50% of total comp time.
* The two calls to `anomaly()` each take ~15% of the computation time, for a total of almost 30% of total comp time.

## Testing if 3 .sel() calls on separate DataArrays is slower than combining all the arrays into one dataset then using one .sel() call

In [14]:
earthaccess.login()
granules_ssh = earthaccess.search_data(
    short_name="SEA_SURFACE_HEIGHT_ALT_GRIDS_L4_2SATS_5DAY_6THDEG_V_JPL2205",
    temporal=("2010-01-01", "2011-12-31"),
    )

Granules found: 146


In [15]:
ssh, sst = load_sst_ssh(granules_ssh[0])

 Opening 1 granules, approx size: 0.01 GB
using endpoint: https://archive.podaac.earthdata.nasa.gov/s3credentials


QUEUEING TASKS | : 0it [00:00, ?it/s]

PROCESSING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/1 [00:00<?, ?it/s]

QUEUEING TASKS | : 0it [00:00, ?it/s]

PROCESSING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/1 [00:00<?, ?it/s]

In [16]:
sst = sst.roll(lon=len(sst['lon'])//2)
sst['lon'] = sst['lon']+180
sst = sst.interp(lon=ssh['Longitude']).interp(lat=ssh['Latitude'])

# Map of booleans for sst*ssh==np.nan. Will be used to determine if there are 
# enough non-nan values to compute the correlation for a given window:
notnul = (sst*ssh).notnull()

In [17]:
notnul = notnul.rename("notnul_sst*ssh")
merged = xr.merge([ssh, sst, notnul], compat="equals")

In [18]:
%%time

ssh_win = ssh.sel(Longitude=slice(0, 6), Latitude=slice(-3, 3))
sst_win = sst.sel(Longitude=slice(0, 6), Latitude=slice(-3, 3))
notnul_win = notnul.sel(Longitude=slice(0, 6), Latitude=slice(-3, 3))

CPU times: user 2.6 ms, sys: 0 ns, total: 2.6 ms
Wall time: 2.1 ms


In [19]:
%%time
merged_win = merged.sel(Longitude=slice(0, 6), Latitude=slice(-3, 3))

CPU times: user 720 µs, sys: 0 ns, total: 720 µs
Wall time: 708 µs


Looks like we can cut computation time for the `.sel()` function calls by ~67% !
Since these lines of code take ~50% of the running time for the entire function, this translates to an expected ~33% reduction in total computation time.

## Revise `spatial_corrmap()` function and test for computation improvements

In [20]:
def spatial_corrmap_new(granule_ssh, lat_halfwin, lon_halfwin, lats=None, lons=None, f_notnull=0.5):
    """
    Get a 2D map of SSH-SST spatial correlation coefficients. The SSH dataset is 
    shortname SEA_SURFACE_HEIGHT_ALT_GRIDS_L4_2SATS_5DAY_6THDEG_V_JPL1812 
    and the SST dataset is shortname MW_OI-REMSS-L4-GLOB-v5.0. 

    At each gridpoint, the spatial correlation is computed over a lat, lon window
    of size 2*lat_halfwin x 2*lon_halfwin. This is done for each gridpoint in
    the datasets. Spatial correlation is computed from the SSH, SST anomalies, 
    which are computed in turn as the deviations from a fitted 2D surface over 
    the window (a new 2D surface is fitted for each window around each gridpoint).
    
    Inputs
    ------
    granule_ssh: earthaccess.results.DataGranule
        Granule info for theSSH file, including download path and S3 location.  
    lat_halfwin, lon_halfwin: floats
        Half window size in degrees for latitude and longitude dimensions, respectively.
    lats, lons: None or 1D array-like
        These make up the latitude, longitude grid on which to compute correlations. 
        If None, will default to using the grid of the SSH product. Note that regardless
        of the lats, lons passed to this function, it will still use the gridpoints 
        of the SSH product for the actual computations.
    f_notnull: float (default = 50)
        Fraction of elements in a window that have to be non-nan. Percentage is computed
        as number of null elements divided by total number expected to be in the window. So
        for edge cases, 'ghost' elements at the edges are considered nan.

    Returns
    ------
    coef: 2D numpy array
        Spatial correlation coefficients.
    
    lats, lons: 1D numpy arrays.
        Latitudes and longitudes creating the 2D grid that 'coef' was calculated on.
    """    
    # Load datafiles, convert SST longitude to (0,360), and interpolate SST to SSH grid:    
    ssh,sst = load_sst_ssh(granule_ssh)
    sst = sst.roll(lon=len(sst['lon'])//2)
    sst['lon'] = sst['lon']+180
    sst = sst.interp(lon=ssh['Longitude']).interp(lat=ssh['Latitude'])

    
    # Compute windows size and threshold number of non-nan points:
    dlat = (ssh['Latitude'][1]-ssh['Latitude'][0]).item()
    dlon = (ssh['Longitude'][1]-ssh['Longitude'][0]).item()
    nx_win = 2*round(lon_halfwin/dlon)
    ny_win = 2*round(lat_halfwin/dlat)
    n_thresh = nx_win*ny_win*f_notnull


    # Map of booleans for sst*ssh==np.nan. Will be used to determine if there are 
    # enough non-nan values to compute the correlation for a given window:
    notnul = (sst*ssh).notnull()

    
    # Combine all needed DataArrays into a single Dataset for more efficient indexing:
    ######################## Updated code ########################
    notnul = notnul.rename("notnul") # Needs a name to merge
    mergeddata = xr.merge([ssh, sst, notnul], compat="equals")
    ######################## Updated code ########################

    # Compute spatial correlations over whole map:
    coef = []
    
    if lats is None:
        lats = ssh['Latitude'].data
        lons = ssh['Longitude'].data
    
    for lat_cen in lats:
        for lon_cen in lons:

            # Create window for both sst and ssh with xr.sel:
            lat_bottom = lat_cen - lat_halfwin
            lat_top = lat_cen + lat_halfwin
            lon_left = lon_cen - lon_halfwin
            lon_right = lon_cen + lon_halfwin
            ######################## Updated code ########################
            data_win = mergeddata.sel(Longitude=slice(lon_left, lon_right), Latitude=slice(lat_bottom, lat_top))
            ######################## Updated code ########################
            
            # If number of non-nan values in sst*ssh window is less than threshold 
            # value, append np.nan, else compute anomalies and append their correlation coefficient:
            n_notnul = data_win["notnul"].sum().item()
            if n_notnul < n_thresh:
                coef.append(np.nan)
            else:
                # Compute anomalies:
                ######################## Updated code ########################
                ssha,_=anomaly(data_win['Longitude'], data_win['Latitude'], data_win['SLA'].data)
                ssta,_=anomaly(data_win['Longitude'], data_win['Latitude'], data_win['analysed_sst'].data)
                ######################## Updated code ########################
                
                # Compute correlation coefficient:
                a, b = ssta.flatten(), ssha.flatten()
                if ( np.nansum(abs(a))==0 ) or ( np.nansum(abs(b))==0 ): # There are some cases where all anomalies for one var are 0.
                    coef.append(0) # In this case, correlation should be 0. Numpy will compute this correctly, but will also throw a lot of warnings.
                else:
                    c = np.nanmean(a*b)/np.sqrt(np.nanvar(a) * np.nanvar(b))
                    coef.append(c)
        
            
    return np.array(coef).reshape((len(lats), len(lons))), np.array(lats), np.array(lons)

**Run old and new functions**

In [21]:
lats = np.arange(-80, 80, 2)
lons = np.arange(0, 359, 2)

In [22]:
%%time
coef, lats, lons = spatial_corrmap(granules_ssh[-1], 3, 3, lats=lats, lons=lons, f_notnull=0.5)

 Opening 1 granules, approx size: 0.01 GB
using endpoint: https://archive.podaac.earthdata.nasa.gov/s3credentials


QUEUEING TASKS | : 0it [00:00, ?it/s]

PROCESSING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/1 [00:00<?, ?it/s]

QUEUEING TASKS | : 0it [00:00, ?it/s]

PROCESSING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/1 [00:00<?, ?it/s]

CPU times: user 28.1 s, sys: 8.44 ms, total: 28.1 s
Wall time: 29.1 s


In [23]:
%%time
coef, lats, lons = spatial_corrmap_new(granules_ssh[-1], 3, 3, lats=lats, lons=lons, f_notnull=0.5)

 Opening 1 granules, approx size: 0.01 GB
using endpoint: https://archive.podaac.earthdata.nasa.gov/s3credentials


QUEUEING TASKS | : 0it [00:00, ?it/s]

PROCESSING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/1 [00:00<?, ?it/s]

QUEUEING TASKS | : 0it [00:00, ?it/s]

PROCESSING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/1 [00:00<?, ?it/s]

CPU times: user 20.3 s, sys: 53.4 ms, total: 20.3 s
Wall time: 21.2 s
