In [1]:
### Perform lagged regression (w/ the possibility of normal regression for lag = 0) on two datasets x and y.
def lagged_regress(x,y,lag): 
    def trend_func(y):
        # Align to remove nans in X
        i = np.isfinite(x_lag)
        a_lag, b_stack = xr.align(x_lag[i], y_stack[i])

        # Perform linear regression
        reg = LinearRegression().fit(a_lag.values.reshape(-1,1),b_stack)

        # Place betas in xarray, unstack, and return 
        reg_da = xr.DataArray(data=reg.coef_,
                              coords={'dim_0':y_stack.allpoints.values})

        reg_rename = reg_da.rename({'dim_0':'allpoints'})
        reg_xr = reg_rename.reindex_like(y_stack)
        reg_unstack = reg_xr.unstack('allpoints')

        return reg_unstack
    
    def err_func(y):                    # currently have the conf. int. calc separate
        if np.isnan(y).any():
            return xr.DataArray(np.nan)
        else:
            fitted = sm.OLS(y[i].values,x_lag[i].values).fit()
            unc = fitted.conf_int(alpha=0.05, cols=None)
            #unc = st.t.interval(0.95, len(y)-1, loc=y.mean(), scale=st.sem(y))
            
            return xr.DataArray(unc)
    
    if y.ndim == 3:
        x_lag = x.shift(time=lag)
        
        y_stack = y.stack(allpoints=['latitude','longitude'])    # stack lat, lon into new coords named allpoints
        trend = trend_func(y_stack)   # apply trend function on stacked coords
        return trend                # return unstacked, linearized coords
    
    elif y.ndim == 2:
        arr = []
        for n in lag:
            x_lag = x.shift(time=n)
            i = np.isfinite(x_lag)
            
            trend = y.groupby('latitude').apply(trend_func)
            arr.append(trend)
        
        trend_xr = xr.concat(arr,dim='lag').assign_coords({'lag':lag})
        return trend_xr
     
    elif y.ndim == 1:
        arr = []
        err_arr = []
        for n in lag:
            x_lag = x.shift(time=n)
            i = np.isfinite(x_lag)
            
            trend = trend_func(y)
            err = err_func(y)
            arr.append(trend)
            err_arr.append(err)
        
        trend_xr = xr.concat(arr,dim='lag').assign_coords({'lag':lag})
        err_xr = xr.concat(err_arr,dim='lag').assign_coords({'lag':lag}).squeeze(dim='dim_0')
        return trend_xr, err_xr