In [None]:
""" THIS FUNCTION CALCULATES A LINEAR REGRESSION OF 2 VARIABLES WITH MORE THAN 1 DIMENSION

Code written by: Hrishi (https://hrishichandanpurkar.blogspot.com/2017/09/vectorized-functions-for-correlation.html)

Uploaded by Paloma Trascasa-Castro

"""


# First import some modules to make this function work

import xarray as xr 
from  scipy import stats, signal #Required for detrending data and computing regression

# Define the function "lag_linregress_3D"

def lag_linregress_3D(x, y, lagx=0, lagy=0):
    """
    Input: Two xr.Datarrays of any dimensions with the first dim being time. 
    Thus the input data could be a 1D time series, or for example, have three dimensions (time,lat,lon). 
    Datasets can be provied in any order, but note that the regression slope and intercept will be calculated
    for y with respect to x.
    Output: Covariance, correlation, regression slope and intercept, p-value, and standard error on regression
    between the two datasets along their aligned time dimension.  
    Lag values can be assigned to either of the data, with lagx shifting x, and lagy shifting y, with the specified lag amount. 
    """ 
    #1. Ensure that the data are properly alinged to each other. 
    
    x,y = xr.align(x,y)
    
    
    #2. Add lag information if any, and shift the data accordingly. If not needed, jump to step 3
    if lagx!=0:
        #If x lags y by 1, x must be shifted 1 step backwards. 
        #But as the 'zero-th' value is nonexistant, xr assigns it as invalid (nan). Hence it needs to be dropped
        x   = x.shift(time = -lagx).dropna(dim='time')
        #Next important step is to re-align the two datasets so that y adjusts to the changed coordinates of x
        x,y = xr.align(x,y)

    if lagy!=0:
        y   = y.shift(time = -lagy).dropna(dim='time')
        x,y = xr.align(x,y)
 
    #3. Compute data length, mean and standard deviation along time axis for further use: 
    n     = x.shape[0]
    xmean = x.mean(axis=0)
    ymean = y.mean(axis=0)
    xstd  = x.std(axis=0)
    ystd  = y.std(axis=0)
    
    #4. Compute covariance along time axis
    cov   =  np.sum((x - xmean)*(y - ymean), axis=0)/(n)
    
    #5. Compute correlation along time axis
    cor   = cov/(xstd*ystd)
    
    #6. Compute regression slope and intercept:
    slope     = cov/(xstd**2)
    intercept = ymean - xmean*slope  
    
    #7. Compute P-value and standard error
    #Compute t-statistics
    tstats = cor*np.sqrt(n-2)/np.sqrt(1-cor**2)
    stderr = slope/tstats
    
    from scipy.stats import t
    pval   = t.sf(tstats, n-2)*2
    pval   = xr.DataArray(pval, dims=cor.dims, coords=cor.coords)

    return cov,cor,slope,intercept,pval,stderr

In [None]:
# Now all you need to do is to prepare your data and run the funcion:

# Convert your data to an xr.Darray. For example I used 2 3-D cubes (time,latitude,longitude)
# if I want to regress wind onto sst, then:

y=xr.DataArray(wind.data)
x=xr.DataArray(sst.data)

# Calculate covariance, correlation, slope, intercept, pvalue and standard error
# of your linear regression

cov,cor,slope,intercept,pval,stderr = lag_linregress_3D(x,y)

print (slope)