In [1]:
# Basic setting for Jupyter_notebook to import utils
import os
import sys

notebook_path = os.path.abspath("")
project_root = os.path.abspath(os.path.join(notebook_path, "../../"))

sys.path.append(project_root)

In [2]:
import sys
import os
import xarray as xr
import numpy as np
import pandas as pd
from tqdm import tqdm
from utils import folder_utils

In [3]:
# Example usage

country = "GB"
data_folder = "data"
data_test_category = "test_data"
data_read_category = "raw_data"
data_save_category = "processed_data"
output_folder = "ERA5_DATA"
ddeg_out_lat = 0.25
ddeg_out_lon = 0.125


**The training dataset is [1979,2020];**

**The validation dataset is 2021;**

**The test dataset is 2022.**

# Step 1 Extract data and calculate mean and std
This step we take the hourly T850 data from 1979 to 2020 to compute the mean and std for normalization part.nce.

In [4]:
def extract_T850_compute_mean_std(country, data_folder, data_category, output_folder,start_year=1979, end_year=2020):
    # era5_pressure_level_2022_regrid_850.nc
    input_folder_path = folder_utils.find_folder(
        country, data_folder, data_category, output_folder
    )
    nc_files = [
        os.path.join(input_folder_path, f)
        for f in os.listdir(input_folder_path)
        if f.endswith(".nc") and "regrid_850" in f
        and start_year <= int(f.split('_')[3]) <= end_year
    ]
    ds = xr.open_mfdataset(nc_files, combine="by_coords")

    # Extract t2m data
    t2m_data = ds['t']

    # Compute mean and std in chunks
    mean_list = []
    std_list = []
    for chunk in tqdm(t2m_data):
        chunk_flatten = chunk.values.flatten()
        mean_list.append(np.nanmean(chunk_flatten))
        std_list.append(np.nanstd(chunk_flatten))

    mean_t2m = np.mean(mean_list)
    std_t2m = np.mean(std_list)

    return mean_t2m, std_t2m


In [5]:
def extract_T850_compute_mean_std_tt(country, data_folder, data_category, output_folder, start_year=1979, end_year=2020):
    input_folder_path = folder_utils.find_folder(country, data_folder, data_category, output_folder)
    nc_files = [
        os.path.join(input_folder_path, f)
        for f in os.listdir(input_folder_path)
        if f.endswith(".nc") and "regrid_850" in f and start_year <= int(f.split('_')[3]) <= end_year
    ]
    ds = xr.open_mfdataset(nc_files, combine="by_coords")

    t2m_data = ds['t']

    mean_list = []
    std_list = []

    for chunk in tqdm(t2m_data):
        chunk_flatten = chunk.values.flatten()

        # Exclude NaN and zero values from calculation
        valid_values = chunk_flatten[~np.isnan(chunk_flatten) & (chunk_flatten != 0)]

        if len(valid_values) > 0:
            mean_list.append(np.nanmean(valid_values))
            std_list.append(np.nanstd(valid_values))

    mean_t2m = np.mean(mean_list)
    std_t2m = np.mean(std_list)

    return mean_t2m, std_t2m

In [7]:
mean_t2m_1, std_t2m_1 = extract_T850_compute_mean_std(country, data_folder, data_save_category, output_folder, start_year=2021, end_year=2022)

100%|███████████████████████████████████████████████████████████████████████████| 17520/17520 [00:33<00:00, 521.18it/s]


In [9]:
print(mean_t2m_1)
print(std_t2m_1)

270.0407
34.181057


In [10]:
mean_t2m_2, std_t2m_2 = extract_T850_compute_mean_std_tt(country, data_folder, data_save_category, output_folder, start_year=2021, end_year=2022)

100%|███████████████████████████████████████████████████████████████████████████| 17520/17520 [00:31<00:00, 552.10it/s]


In [11]:
print(mean_t2m_2)
print(std_t2m_2)

274.32727
2.5188115


### Test example

In [29]:
ds

Unnamed: 0,Array,Chunk
Bytes,136.88 MiB,68.44 MiB
Shape,"(17520, 32, 64)","(8760, 32, 64)"
Dask graph,2 chunks in 5 graph layers,2 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 136.88 MiB 68.44 MiB Shape (17520, 32, 64) (8760, 32, 64) Dask graph 2 chunks in 5 graph layers Data type float32 numpy.ndarray",64  32  17520,

Unnamed: 0,Array,Chunk
Bytes,136.88 MiB,68.44 MiB
Shape,"(17520, 32, 64)","(8760, 32, 64)"
Dask graph,2 chunks in 5 graph layers,2 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [31]:
t2m_data

Unnamed: 0,Array,Chunk
Bytes,136.88 MiB,68.44 MiB
Shape,"(17520, 32, 64)","(8760, 32, 64)"
Dask graph,2 chunks in 5 graph layers,2 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 136.88 MiB 68.44 MiB Shape (17520, 32, 64) (8760, 32, 64) Dask graph 2 chunks in 5 graph layers Data type float32 numpy.ndarray",64  32  17520,

Unnamed: 0,Array,Chunk
Bytes,136.88 MiB,68.44 MiB
Shape,"(17520, 32, 64)","(8760, 32, 64)"
Dask graph,2 chunks in 5 graph layers,2 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


### Processing

In [47]:
mean_t2m, std_t2m = extract_T850_compute_mean_std(country, data_folder, data_save_category, output_folder, start_year=1979, end_year=2020)

100%|█████████████████████████████████████████████████████████████████████████| 368184/368184 [13:57<00:00, 439.80it/s]


In [6]:
mean_t2m_true, std_t2m_true = extract_T850_compute_mean_std_tt(country, data_folder, data_save_category, output_folder, start_year=1979, end_year=2020)

100%|█████████████████████████████████████████████████████████████████████████| 368184/368184 [14:33<00:00, 421.40it/s]


In [48]:
mean_t2m

269.50018

In [49]:
std_t2m

34.130245

In [7]:
mean_t2m_true

273.77817

In [8]:
std_t2m_true

2.5819736

In [6]:
mean_t2m = 269.50018
std_t2m = 34.130245

In [5]:
filelist_train = []
filelist_validation =[]
filelist_test=[]
input_folder_path = folder_utils.find_folder(
    country, data_folder, data_save_category, output_folder
)

for year in range (1979,2021):
    file_path =  os.path.join(input_folder_path, f"era5_pressure_level_{year}_regrid_850.nc")
    filelist_train.append (file_path)
    
filelist_validation.append(os.path.join(input_folder_path, f"era5_pressure_level_2021_regrid_850.nc"))
filelist_test.append(os.path.join(input_folder_path, f"era5_pressure_level_2022_regrid_850.nc"))


# Step 2 load model setting

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from utils import model_utils
from bilinear_interpolation import BilinearInterpolation

### ttest

In [5]:
class STN(nn.Module):
    def __init__(self, input_shape=(1, 32, 64), sampling_size=(8, 16), num_classes=10):
        super(STN, self).__init__()
        self.input_shape = input_shape
        self.sampling_size = sampling_size
        self.num_classes = num_classes

        # Note: PyTorch uses B, C, H, W ordering while TensorFlow uses B, H, W, C
        self.conv1 = nn.Sequential(
            nn.Conv2d(self.input_shape[0], 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.locnet = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * self.sampling_size[0] * self.sampling_size[1], 500),
            nn.ReLU(),
            nn.Linear(500, 200),
            nn.ReLU(),
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU(),
            # Initialize weights here if necessary_
            nn.Linear(50, 6),
        )  # The six-dimensional torch is the radial transformation parameter

        # Initialize the weights of the last Linear layer
        (
            self.locnet[-1].weight.data,
            self.locnet[-1].bias.data,
        ) = model_utils.get_initial_weights_torch(50)

        self.bilinear_interpolation = BilinearInterpolation(self.sampling_size)

        self.upconv1 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=2, padding=1), nn.ReLU()  # up6
        )

        self.conv6 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.upconv2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=2, padding=1), nn.ReLU()  # up7
        )

        self.conv7 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.conv10 = nn.Conv2d(32, 1, kernel_size=5, padding=2)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x5 = self.conv5(x3)

        # Apply locnet to flattened x5
        theta = self.locnet(x5)

        # Use bilinear interpolation on x with `theta`
        x_transformed = self.bilinear_interpolation(
            x, theta
        )  # x is the input image, theta is the transformation parameter

        up6 = F.interpolate(x_transformed, scale_factor=2, mode="nearest")
        up6 = self.upconv1(up6)
        up6 = torch.cat([up6, x2], 1)
        x6 = self.conv6(up6)

        up7 = F.interpolate(x6, scale_factor=2, mode="nearest")
        up7 = self.upconv2(up7)
        up7 = torch.cat([up7, x1], 1)
        x7 = self.conv7(up7)

        x10 = self.conv10(x7)
        return x10

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = STN()
model.to(device)
print(model)

STN(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): ReLU()
  )
  (conv5): Sequential(
    (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): ReL

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
input_data = input_data.to(device)

In [None]:
print(1)

In [None]:
summary(model, input_size=(1, 32, 64))


In [5]:
class SimpleModel3(nn.Module):
    def __init__(self, input_shape=(1, 32, 64)):
        super(SimpleModel3, self).__init__()
        self.input_shape = input_shape

        self.conv1 = nn.Sequential(
            nn.Conv2d(self.input_shape[0], 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x5 = self.conv5(x3)
        return x5

model3 = SimpleModel3().to(device)
summary(model3, input_size=(1, 32, 64), device=device)

NameError: name 'device' is not defined

In [50]:
def get_initial_weights_torch(output_size):
    b = torch.zeros(2, 3, dtype=torch.float32)
    # Identity transformation: set the main diagonal to 1
    b[0, 0] = 1
    b[1, 1] = 1

    # Initialize the weights to zero
    # W = np.zeros((output_size, 6), dtype="float32")
    W = np.zeros((6,output_size), dtype="float32")
    W = torch.tensor(W, dtype=torch.float)
    # b = torch.tensor(b.flatten(), dtype=torch.float)
    b = torch.as_tensor(b.flatten(), dtype=torch.float) # Don't compute gradient for b when initializing weights

    return (W, b)

(tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

### tt2

In [5]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [6]:
class SimpleModel3(nn.Module):
    def __init__(self, input_shape=(1, 32, 64), sampling_size=(8, 16)):
        super(SimpleModel3, self).__init__()
        self.input_shape = input_shape
        self.sampling_size = sampling_size  # Add this line

        self.conv1 = nn.Sequential(
            nn.Conv2d(self.input_shape[0], 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.locnet = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * self.sampling_size[0] * self.sampling_size[1], 500),
            nn.ReLU(),
            nn.Linear(500, 200),
            nn.ReLU(),
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50, 6),
        )

         # Initialize the weights of the last Linear layer
        (
            self.locnet[-1].weight.data,
            self.locnet[-1].bias.data,
        ) = model_utils.get_initial_weights_torch(50)

        self.bilinear_interpolation = BilinearInterpolation(self.sampling_size, device='cpu')

        self.upconv1 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=2, padding=1), nn.ReLU()  # up6
        )

        self.conv6 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x5 = self.conv5(x3)
        # print(x5.shape)
        theta = self.locnet(x5)
        # print(theta.shape)
        # Use bilinear interpolation on x with `theta`
        x_transformed = self.bilinear_interpolation(
            x, theta
        )  # x is the input image, theta is the transformation parameter
        # return x_transformed
        # up6 = F.interpolate(x_transformed, scale_factor=2, mode="nearest")
        # up6 = self.upconv1(up6)
        # up6 = torch.cat([up6, x2], 1)
        # x6 = self.conv6(up6)


        return x_transformed


        
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device('cpu')
# device_str = "cuda" if "cuda" in str(device) else "cpu"
model3 = SimpleModel3().to(device)
# summary(model3, input_size=(1, 32, 64), device=device_str )
summary(model3, input_size=(1, 32, 64), device='cpu')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


x0:

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  4,  4,  4,  4,  4,  4,  4,  4,  8,  8,
          8,  8,  8,  8,  8,  8, 12, 12, 12, 12, 12, 12, 12, 12, 17, 17, 17, 17,
         17, 17, 17, 17, 21, 21, 21, 21, 21, 21, 21, 21, 25, 25, 25, 25, 25, 25,
         25, 25, 29, 29, 29, 29, 29, 29, 29, 29, 34, 34, 34, 34, 34, 34, 34, 34,
         38, 38, 38, 38, 38, 38, 38, 38, 42, 42, 42, 42, 42, 42, 42, 42, 46, 46,
         46, 46, 46, 46, 46, 46, 51, 51, 51, 51, 51, 51, 51, 51, 55, 55, 55, 55,
         55, 55, 55, 55, 59, 59, 59, 59, 59, 59, 59, 59, 63, 63, 63, 63, 63, 63,
         63, 63],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  4,  4,  4,  4,  4,  4,  4,  4,  8,  8,
          8,  8,  8,  8,  8,  8, 12, 12, 12, 12, 12, 12, 12, 12, 17, 17, 17, 17,
         17, 17, 17, 17, 21, 21, 21, 21, 21, 21, 21, 21, 25, 25, 25, 25, 25, 25,
         25, 25, 29, 29, 29, 29, 29, 29, 29, 29, 34, 34, 34, 34, 34, 34, 34, 34,
         38, 38, 38, 38, 38, 38, 38, 38, 42, 42, 42, 42, 42, 42, 42, 42, 46, 46,
     

RuntimeError: index 2048 is out of bounds for dimension 2 with size 2048

In [None]:
class STN(nn.Module):
    def __init__(self, input_shape=(1, 32, 64), sampling_size=(8, 16), num_classes=10):
        super(STN, self).__init__()
        self.input_shape = input_shape
        self.sampling_size = sampling_size
        self.num_classes = num_classes

        # Note: PyTorch uses B, C, H, W ordering while TensorFlow uses B, H, W, C
        self.conv1 = nn.Sequential(
            nn.Conv2d(self.input_shape[0], 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.locnet = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * self.sampling_size[0] * self.sampling_size[1], 500),
            nn.ReLU(),
            nn.Linear(500, 200),
            nn.ReLU(),
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU(),
            # Initialize weights here if necessary_
            nn.Linear(50, 6),
        )  # The six-dimensional torch is the radial transformation parameter

        # Initialize the weights of the last Linear layer
        (
            self.locnet[-1].weight.data,
            self.locnet[-1].bias.data,
        ) = model_utils.get_initial_weights_torch(50)

        self.bilinear_interpolation = BilinearInterpolation(self.sampling_size)

        self.upconv1 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=2, padding=1), nn.ReLU()  # up6
        )

        self.conv6 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.upconv2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=2, padding=1), nn.ReLU()  # up7
        )

        self.conv7 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
        )

        self.conv10 = nn.Conv2d(32, 1, kernel_size=5, padding=2)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x5 = self.conv5(x3)

        # Apply locnet to flattened x5
        theta = self.locnet(x5)

        # Use bilinear interpolation on x with `theta`
        x_transformed = self.bilinear_interpolation(
            x, theta
        )  # x is the input image, theta is the transformation parameter

        up6 = F.interpolate(x_transformed, scale_factor=2, mode="nearest")
        up6 = self.upconv1(up6)
        up6 = torch.cat([up6, x2], 1)
        x6 = self.conv6(up6)

        up7 = F.interpolate(x6, scale_factor=2, mode="nearest")
        up7 = self.upconv2(up7)
        up7 = torch.cat([up7, x1], 1)
        x7 = self.conv7(up7)

        x10 = self.conv10(x7)
        return x10

In [None]:
conda create --name datf python=3.10