In [4]:
import xarray as xr 
import numpy as np 
import matplotlib.pyplot as plt
import scipy.stats as stats
from scipy.signal import windows
import torch
import torch.nn as nn

In [5]:
ds = xr.open_dataset("SST_ABSO_075.00E_085.00E_10.00S_00.00S_20240101_20240229_region.nc")

In [6]:
def power_spec(image):    
    npix = image.shape[0]

    fourier_image = np.fft.fftn(image)
    fourier_amplitudes = np.abs(fourier_image)**2
    kfreq = np.fft.fftfreq(npix) * npix
    kfreq2D = np.meshgrid(kfreq, kfreq)
    knrm = np.sqrt(kfreq2D[0]**2 + kfreq2D[1]**2)

    knrm = knrm.flatten()
    fourier_amplitudes = fourier_amplitudes.flatten()

    kbins = np.arange(0.5, npix//2+1, 1.)
    kvals = 0.5 * (kbins[1:] + kbins[:-1])
    Abins, _, _ = stats.binned_statistic(knrm, fourier_amplitudes,
                                         statistic = "mean",
                                         bins = kbins)
    Abins *= np.pi * (kbins[1:]**2 - kbins[:-1]**2)
    return kvals, Abins

def resize(image, split = 48):
    ## Resize image so that integer number of split x split sub-grids fit
    x_dim = image.shape[0] - image.shape[0]%split
    y_dim = image.shape[1] - image.shape[1]%split
    image = image[0:x_dim, 0:y_dim]
    return image

def differentiate(image): 
    
    return x_grad, y_grad

def subgrid(image, split = 48):
    if image.shape[0]%split != 0:
        resize(image, split)
    x_subgrids, y_subgrids = int(image.shape[0]/split), int(image.shape[1]/split)
    subgridded = np.zeros((x_subgrids, y_subgrids, split, split))
    for i in range(x_subgrids):
        for j in range(y_subgrids):
            subgridded[i,j] = image[i*split:(i+1)*split, j*split:(j+1)*split]
    return subgridded

def reconstruct_subgrids(subgridded):
    x_subgrids, y_subgrids, split_x, split_y = subgridded.shape
    reconstructed_image = np.zeros((x_subgrids * split_x, y_subgrids * split_y))
    
    for i in range(x_subgrids):
        for j in range(y_subgrids):
            reconstructed_image[i * split_x:(i + 1) * split_x, j * split_y:(j + 1) * split_y] = subgridded[i, j]
    return reconstructed_image

def apply_model(grad_x, grad_y, model, split=48): 
    subgridded_x, subgridded_y = subgrid(grad_x,split), subgrid(grad_y,split)
    x_grids, y_grids = subgridded_x.shape[0], subgridded_x.shape[1]
    stacked = np.stack((subgridded_x, subgridded_y), axis = 0)
    gR = np.zeros((2,x_grids,y_grids,subgridded_x.shape[2],subgridded_x.shape[3])) # empty gR array
    for i in range(x_grids): 
        for j in range(y_grids): 
            subgrid_tensor = torch.from_numpy(stacked[:,i,j])
            gR[:,i,j] = model(subgrid_tensor).detach().cpu().numpy()
    grad_x_reconstructed = reconstruct_subgrids(gR[0])
    grad_y_reconstructed = reconstruct_subgrids(gR[1])
    return grad_x_reconstructed, grad_y_reconstructed 

def absolute_reconstruction(G_x, G_y, l4): 
    ## Re-construct absolute domain from x and y gradient channels
    ## Requires the low-res L4 image as a reference
    
    dim = G_x.shape[0] # dim x dim image
    N = dim**2
    
    # grad_x matrix operator
    grad_x = np.zeros((N,N))
    for i in range(N):
        if (i+1)%dim != 0:
            grad_x[i,i] = -1
            grad_x[i,i+1] = 1
    
    # grad_y matrix operator
    grad_y = np.zeros((N,N))
    for i in range(N):
        if (i+1)//(N-(dim-1)) == 0: 
            grad_y[i,i] = -1 
            grad_y[i,i+dim] = 1
    
    # now we set-up the equation Ax = b
    A = np.matmul(grad_x.transpose(),grad_x) + np.matmul(grad_y.transpose(),grad_y)
    
    G_x_vec = np.ndarray.flatten(G_x)
    G_y_vec = np.ndarray.flatten(G_y)
    b = np.matmul(grad_x.transpose(), G_x_vec) + np.matmul(grad_y.transpose(), G_y_vec)
    
    # solve system of linear equations
    solution = np.linalg.solve(A,b)
    # re-shape flattened vector to (dim x dim) matrix
    reconstructed = np.reshape(solution,(dim,dim))
    # scale values
    l4_min, l4_max = np.min(l4), np.max(l4)
    recon_min, recon_max = np.min(reconstructed), np.max(reconstructed)
    normalized_recon = (reconstructed - recon_min) / (recon_max - recon_min)
    rescaled_reconstructed = normalized_recon * (l4_max - l4_min) + l4_min
    return rescaled_reconstructed

In [7]:
class ERes_Block(nn.Module):
    # enhanced residual block
    def __init__(self, input_channels = 64):
        # define layers
        super(ERes_Block, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size = 3, stride = 1, padding = 'same')
        # must return same no. of channels as inputted
        self.conv2 = nn.Conv2d(64, input_channels, kernel_size = 3, stride = 1, padding = 'same')
        self.relu = nn.ReLU()
    def forward(self, x):
        x_skip = x # save for skip connection
        x = self.relu(self.conv1(x))
        x = self.conv2(x) * 0.1
        x = torch.add(x,x_skip) # add copy to x
        return x
        
class EDSR(nn.Module): 
    def __init__(self, input_channels = 1, output_channels = 1, N_blocks=16):
        # define layers
        super(EDSR, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size = 3, dilation=1, padding = 'same') # dilation=1
        self.conv2 = nn.Conv2d(64, 64, kernel_size = 3, padding = 'same')
        self.conv3 = nn.Conv2d(64, output_channels, kernel_size = 3, padding = 'same')
        self.erb = ERes_Block(64)
        self.N_blocks = N_blocks
        self.relu = nn.ReLU()
    def forward(self, x):
        x=x.float()
        x=self.relu(self.conv1(x))
        x_skip = x
        ## 16 residual blocks
        for _ in range(self.N_blocks):
            x = self.erb(x)
        x = self.relu(self.conv2(x))
        x = torch.add(x,x_skip)
        x = self.conv3(x)
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = 'EDSR_Filtered_Data'
model = EDSR(2,2).to(device)
model.load_state_dict(torch.load(model_path,map_location=device))

<All keys matched successfully>

In [3]:
image = ds.sst.isel(time=3).data
ds['sst_gradx'] = ds.sst.differentiate('lat')
ds['sst_grady'] = ds.sst.differentiate('lon')
sx, sy = ds.sst_gradx.isel(time=0).data,ds.sst_grady.isel(time=0).data
gx, gy = apply_model(sx, sy, model)
abs_reconstruction = absolute_reconstruction(gx,gy,image)

fig,ax=plt.subplots(1,2,figsize=(10,10))
ax[0].imshow(image)
ax[1].imshow(abs_reconstruction)

NameError: name 'ds' is not defined