In [None]:
pip install pytorch_wavelets

In [2]:
from IPython import get_ipython
get_ipython().magic('reset -sf')

import pywt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from timeit import default_timer
# from utilities3 import *
from pytorch_wavelets import DWT1D, IDWT1D

torch.manual_seed(0)
np.random.seed(0)

In [None]:
import h5py
import numpy as np
from sklearn.model_selection import train_test_split

# Load the HDF5 dataset
file_path = "/content/Config_2_sample_obs_1pc.hdf5"
with h5py.File(file_path, "r") as hdf:
    # Load input and output data
    x_data = np.array(hdf["input"])
    y_data = np.array(hdf["output"])

# print(x_data.shape)
# print(y_data.shape)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LastConditionalBlock(nn.Module):
    def __init__(self, in_channels):
        super(LastConditionalBlock, self).__init__()
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # Flatten except batch dimension

        # Calculate flattened size dynamically
        flattened_size = x.size(1)

        # Define fully connected layer with dynamically calculated input size
        fc = nn.Linear(flattened_size, 128)
        x = fc(x)
        return x

class ConditionalNetwork(nn.Module):
    def __init__(self):
        super(ConditionalNetwork, self).__init__()
        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(1, 64, kernel_size=3, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.last_block = LastConditionalBlock(in_channels=16)  # Adjust input channels

    def forward(self, x):
        out1 = self.block1(x)
        c1 = out1  # Store the original c1
        out2 = self.block2(out1)
        c2 = out2  # Store the original c2
        out3 = self.block3(out2)
        c3 = out3  # Store the original c3
        out4 = self.last_block(out3)
        c4 = out4  # Store the original c4

        # Flatten the outputs to 2D
        out1 = out1.view(out1.size(0), -1)
        out2 = out2.view(out2.size(0), -1)
        out3 = out3.view(out3.size(0), -1)
        out4 = out4.view(out4.size(0), -1)

        return out1, out2, out3, out4

# # Initialize and forward pass through the conditional network
conditional_net = ConditionalNetwork()
y_data_tensor = torch.from_numpy(y_data).unsqueeze(1).float()
c1, c2, c3, c4 = conditional_net(y_data_tensor)

# print("c1 shape:", c1.shape)
# print("c2 shape:", c2.shape)
# print("c3 shape:", c3.shape)
# print("c4 shape:", c4.shape)


In [5]:
import torch
import numpy as np
import scipy.io
import h5py
import torch.nn as nn

import operator
from functools import reduce
from functools import partial

"""
This code is taken from the repo: https://github.com/zongyi-li/fourier_neural_operator

The associated article is Fourier Neural Operator
"""

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# reading data
class MatReader(object):
    def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True):
        super(MatReader, self).__init__()

        self.to_torch = to_torch
        self.to_cuda = to_cuda
        self.to_float = to_float

        self.file_path = file_path

        self.data = None
        self.old_mat = None
        self._load_file()

    def _load_file(self):
        try:
            self.data = scipy.io.loadmat(self.file_path)
            self.old_mat = True
        except:
            self.data = h5py.File(self.file_path)
            self.old_mat = False

    def load_file(self, file_path):
        self.file_path = file_path
        self._load_file()

    def read_field(self, field):
        x = self.data[field]

        if not self.old_mat:
            x = x[()]
            x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1))

        if self.to_float:
            x = x.astype(np.float32)

        if self.to_torch:
            x = torch.from_numpy(x)

            if self.to_cuda:
                x = x.cuda()

        return x

    def set_cuda(self, to_cuda):
        self.to_cuda = to_cuda

    def set_torch(self, to_torch):
        self.to_torch = to_torch

    def set_float(self, to_float):
        self.to_float = to_float

# normalization, pointwise gaussian
class UnitGaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(UnitGaussianNormalizer, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = torch.mean(x, 0)
        self.std = torch.std(x, 0)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        if sample_idx is None:
            std = self.std + self.eps # n
            mean = self.mean
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx] + self.eps  # batch*n
                mean = self.mean[sample_idx]
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:,sample_idx]+ self.eps # T*batch*n
                mean = self.mean[:,sample_idx]

        # x is in shape of batch*n or T*batch*n
        x = (x * std) + mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()

# normalization, Gaussian
class GaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(GaussianNormalizer, self).__init__()

        self.mean = torch.mean(x)
        self.std = torch.std(x)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        x = (x * (self.std + self.eps)) + self.mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()


# normalization, scaling by range
class RangeNormalizer(object):
    def __init__(self, x, low=0.0, high=1.0):
        super(RangeNormalizer, self).__init__()
        mymin = torch.min(x, 0)[0].view(-1)
        mymax = torch.max(x, 0)[0].view(-1)

        self.a = (high - low)/(mymax - mymin)
        self.b = -self.a*mymax + high

    def encode(self, x):
        s = x.size()
        x = x.view(s[0], -1)
        x = self.a*x + self.b
        x = x.view(s)
        return x

    def decode(self, x):
        s = x.size()
        x = x.view(s[0], -1)
        x = (x - self.b)/self.a
        x = x.view(s)
        return x

#loss function with rel/abs Lp loss
class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)

# Sobolev norm (HS norm)
# where we also compare the numerical derivatives between the output and target
class HsLoss(object):
    def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True):
        super(HsLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.k = k
        self.balanced = group
        self.reduction = reduction
        self.size_average = size_average

        if a == None:
            a = [1,] * k
        self.a = a

    def rel(self, x, y):
        num_examples = x.size()[0]
        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)
        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)
        return diff_norms/y_norms

    def __call__(self, x, y, a=None):
        nx = x.size()[1]
        ny = x.size()[2]
        k = self.k
        balanced = self.balanced
        a = self.a
        x = x.view(x.shape[0], nx, ny, -1)
        y = y.view(y.shape[0], nx, ny, -1)

        k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny)
        k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1)
        k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device)
        k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device)

        x = torch.fft.fftn(x, dim=[1, 2])
        y = torch.fft.fftn(y, dim=[1, 2])

        if balanced==False:
            weight = 1
            if k >= 1:
                weight += a[0]**2 * (k_x**2 + k_y**2)
            if k >= 2:
                weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4)
            weight = torch.sqrt(weight)
            loss = self.rel(x*weight, y*weight)
        else:
            loss = self.rel(x, y)
            if k >= 1:
                weight = a[0] * torch.sqrt(k_x**2 + k_y**2)
                loss += self.rel(x*weight, y*weight)
            if k >= 2:
                weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4)
                loss += self.rel(x*weight, y*weight)
            loss = loss / (k+1)

        return loss

# print the number of parameters
def count_params(model):
    c = 0
    for p in list(model.parameters()):
        c += reduce(operator.mul,
                    list(p.size()+(2,) if p.is_complex() else p.size()))
    return c

In [6]:
# %%
""" Def: 1d Wavelet layer """
class WaveConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, level, dummy):
        super(WaveConv1d, self).__init__()

        """
        1D Wavelet layer. It does Wavelet Transform, linear transform, and
        Inverse Wavelet Transform.
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.dwt_ = DWT1D(wave='db6', J=self.level, mode='symmetric').to(dummy.device)
        self.mode_data, _ = self.dwt_(dummy)
        self.modes1 = self.mode_data.shape[-1]

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))

    # Convolution
    def mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        # Compute single tree Discrete Wavelet coefficients using some wavelet
        dwt = DWT1D(wave='db6', J=self.level, mode='symmetric').to(device)
        x_ft, x_coeff = dwt(x)

        # Multiply the final low pass and high pass coefficients
        out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-1],  device=x.device)
        out_ft = self.mul1d(x_ft, self.weights1)
        x_coeff[-1] = self.mul1d(x_coeff[-1], self.weights2)

        # Reconstruct the signal
        idwt = IDWT1D(wave='db6', mode='symmetric').to(device)
        x = idwt((out_ft, x_coeff))
        return x

""" The forward operation """
class WNO1d(nn.Module):
    def __init__(self,in_features, hidden_dim=256, width=64, level=4, x_train=None):
        super(WNO1d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.

        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """
        self.in_features = in_features
        self.level = level
        self.width = width
        self.dummy_data = dummy_data
        self.padding = 2 # pad the domain when required
        self.fc0 = nn.Linear(in_features+1, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.dummy_data)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.dummy_data)
        self.conv2 = WaveConv1d(self.width, self.width, self.level, self.dummy_data)
        self.conv3 = WaveConv1d(self.width, self.width, self.level, self.dummy_data)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        # x = F.pad(x, [0,self.padding])

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        # x = x[..., :-self.padding]
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)


In [7]:
class RealNVPBlock(nn.Module):
    def __init__(self, in_features, hidden_dim=256, width=64, level=4, x_train=None):
        super(RealNVPBlock, self).__init__()
        self.in_features = in_features
        self.hidden_dim = hidden_dim

        # Define the scale and translation networks
        self.scale_net = WNO1d(in_features, hidden_dim, width, level, x_train.permute(0, 2, 1))
        self.translation_net = WNO1d(in_features, hidden_dim,width, level, x_train.permute(0, 2, 1))

    def forward(self, x1, x2, c1):
        # Ensure x1 has the correct shape
        x1 = x1.unsqueeze(1)  # Add a channel dimension
        x1_concat_c1 = torch.cat((x1, c1.unsqueeze(1)), dim=-1)
        # Compute scale and translation parameters
        s = self.scale_net(x1)
        t = self.translation_net(x1)

        # Apply scale and translation to x2
        x2_transformed = x2 * torch.exp(s) + t

        # Return transformed x1 and x2
        return x1.squeeze(1), x2_transformed

    def inverse(self, y1, y2, c4):
        # Ensure the sizes of tensors match for concatenation
        if y1.size(-1) != c4.size(-1):
            # Expand c4 to match the shape of y1
            c4_expanded = c4.unsqueeze(1).expand(-1, y1.size(1), -1)
        else:
            c4_expanded = c4.unsqueeze(1)

        # Concatenate y1 with c4
        y1_concat_c4 = torch.cat((y1, c4_expanded), dim=-1)

        # Define the scale and translation networks for the inverse pass
        scale_net_inv = WNO1d(y1_concat_c4.size(-1), self.hidden_dim, width=64, level=4, x_train=y1_concat_c4)
        translation_net_inv = WNO1d(y1_concat_c4.size(-1), self.hidden_dim, width=64, level=4, x_train=y1_concat_c4)

        # Compute inverse scale and translation parameters
        inv_s = scale_net_inv(y1_concat_c4)
        inv_t = translation_net_inv(y1_concat_c4)

        # Apply inverse scale and translation to y2
        x2_inverse = (y2 - inv_t) * torch.exp(-inv_s)

        # Return transformed x1 and x2
        return y1.squeeze(1), x2_inverse


In [8]:
import torch
import torch.nn as nn

class RealNVPBlock_fnn(nn.Module):
    def __init__(self, in_features, c4_features, hidden_dim=256):
        super(RealNVPBlock_fnn, self).__init__()
        self.in_features = in_features
        self.c4_features = c4_features
        self.hidden_dim = hidden_dim

        # Define the scale and translation networks
        self.scale_net = nn.Sequential(
            nn.Linear(in_features + c4_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, in_features)
        )

        self.translation_net = nn.Sequential(
            nn.Linear(in_features + c4_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, in_features)
        )

    def forward(self, x1, x2, c4):
        # Ensure x1 has the correct shape
        x1 = x1.unsqueeze(1)  # Add a channel dimension

        # Concatenate x1 with c4
        x1_concat_c4 = torch.cat((x1, c4.unsqueeze(1)), dim=-1)

        # Compute scale and translation parameters
        s = self.scale_net(x1_concat_c4)
        t = self.translation_net(x1_concat_c4)

        # Apply scale and translation to x2
        x2_transformed = x2 * torch.exp(s) + t

        # Return transformed x1 and x2
        return x1.squeeze(1), x2_transformed

    def inverse(self, y_data, c4):
        # Ensure y_data has the correct shape
        c4_reshaped = c4.repeat(1, y_data.shape[1], 1)

        # Concatenate y_data with c4
        y_data_concat_c4 = torch.cat((y_data, c4_reshaped), dim=-1)

        # Flatten the concatenated tensor
        y_data_concat_flat = y_data_concat_c4.view(y_data_concat_c4.size(0), -1)

        # Define the scale and translation networks for the inverse pass
        scale_net_inv = nn.Sequential(
            nn.Linear(y_data_concat_flat.shape[1], self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 256)  # Output size should match the last dimension of y_data
        )

        translation_net_inv = nn.Sequential(
            nn.Linear(y_data_concat_flat.shape[1], self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 256)  # Output size should match the last dimension of y_data
        )

        # Compute inverse scale and translation parameters
        inv_s = scale_net_inv(y_data_concat_flat)
        inv_t = translation_net_inv(y_data_concat_flat)

        # print(inv_s.shape)
        # Reshape inv_s and inv_t to match the shape of y_data
        inv_s = inv_s.view(y_data.shape[0], y_data.shape[1], -1)
        inv_t = inv_t.view(y_data.shape[0], y_data.shape[1], -1)
        # print("inv_s",inv_s)
        # print("Shapes - y_data:", y_data.shape, "inv_s:", inv_s.shape, "inv_t:", inv_t.shape)

        # Apply inverse scale and translation to y_data
        x2_inverse = (y_data - inv_t) * torch.exp(-inv_s)

        # Return transformed x1 and x2
        return y_data.squeeze(1), x2_inverse


In [None]:
# Training loop with inverse pass and customized loss calculation
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Convert NumPy arrays to PyTorch tensors
x_data = torch.tensor(x_data, dtype=torch.float).unsqueeze(-1)
x_data = x_data.view(1, 4096, 1)
y_data = torch.tensor(y_data, dtype=torch.float)
# print(x_data.shape)
# Reshape the tensor from (1, 4096, 1) to (1, 4096)
x_data_reshaped = x_data.view(1, 4096)

# Split the tensor into two tensors of shapes (1, 2048) each
x1 = x_data_reshaped[:, :2048]
x2 = x_data_reshaped[:, 2048:]

# Printing shapes for verification
# print("Shape of x1:", x1.shape)
# print("Shape of x2:", x2.shape)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F

# Convert NumPy arrays to PyTorch tensors
x_data = torch.tensor(x_data, dtype=torch.float).unsqueeze(-1)
y_data = torch.tensor(y_data, dtype=torch.float)


x_data_res = x_data.squeeze(-1)
dummy_data = x_data_res
train_dataset = TensorDataset(x_data, y_data)
batch_size = 16  # Adjust batch size as needed
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Define your RealNVP blocks
block1 = RealNVPBlock(in_features=2048, x_train=x_data_res)
block2 = RealNVPBlock(in_features=2048, x_train=x_data_res)
block3 = RealNVPBlock(in_features=2048, x_train=x_data_res)
block4 = RealNVPBlock(in_features=2048, x_train=x_data_res)
block4_inv = RealNVPBlock_fnn(in_features=768, c4_features=128)

# Define optimizer and loss function
optimizer1 = optim.Adam(block1.parameters(), lr=1e-4)  # Adjust learning rate as needed
optimizer2 = optim.Adam(block2.parameters(), lr=1e-4)
optimizer3 = optim.Adam(block3.parameters(), lr=1e-4)
optimizer4 = optim.Adam(block4_inv.parameters(), lr=1e-4)
loss_function = nn.SmoothL1Loss()

# Learning rate scheduler
scheduler1 = ReduceLROnPlateau(optimizer1, mode='min', patience=5, factor=0.1, verbose=True)
scheduler2 = ReduceLROnPlateau(optimizer2, mode='min', patience=5, factor=0.1, verbose=True)
scheduler3 = ReduceLROnPlateau(optimizer3, mode='min', patience=5, factor=0.1, verbose=True)
scheduler4 = ReduceLROnPlateau(optimizer4, mode='min', patience=5, factor=0.1, verbose=True)

# Regularization
weight_decay = 1e-5
reg1 = torch.optim.lr_scheduler.LambdaLR(optimizer1, lambda epoch: 1 / (1 + weight_decay * epoch))
reg2 = torch.optim.lr_scheduler.LambdaLR(optimizer2, lambda epoch: 1 / (1 + weight_decay * epoch))
reg3 = torch.optim.lr_scheduler.LambdaLR(optimizer3, lambda epoch: 1 / (1 + weight_decay * epoch))
reg4 = torch.optim.lr_scheduler.LambdaLR(optimizer4, lambda epoch: 1 / (1 + weight_decay * epoch))

# Training loop with inverse pass and customized loss calculation
num_epochs = 500  # Adjust number of epochs as needed
log_interval = 100
output_losses = []

for epoch in range(num_epochs):
    block1.train()
    block2.train()
    block3.train()
    block4.train()
    total_output_loss = 0

    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
        # Forward pass
        x1_out, x2_out = block1(x1, x2, c1)
        x1_out_2, x2_out_2 = block2(x1_out, x2_out, c2)
        x1_out_3, x2_out_3 = block3(x1_out_2, x2_out_2, c3)
        x1_out_4, x2_out_4 = block4(x1_out_3, x2_out_3, c4)
        y1_inverse, y2_inverse = block4_inv.inverse(y_data, c4)
        y1_inverse_3, y2_inverse_3 = block3.inverse(y1_inverse, y2_inverse, c3)
        y1_inverse_2, y2_inverse_2 = block2.inverse(y1_inverse_3, y2_inverse_3, c2)
        y1_inverse_1, y2_inverse_1 = block1.inverse(y1_inverse_2, y2_inverse_2, c1)

        # Concatenate y1_inverse_1 and y2_inverse_1 along the first dimension
        concatenated_inverse = torch.cat((y1_inverse_1, y2_inverse_1), dim=1)

        # Compute output loss
        output_loss = loss_function(concatenated_inverse, x_batch[:, 2048:])
        total_output_loss += output_loss.item()

        # Backpropagation
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        optimizer3.zero_grad()
        optimizer4.zero_grad()
        output_loss.backward(retain_graph=True)

        # # Gradient clipping
        # clip_grad_norm_(block1.parameters(), max_norm=1)
        # clip_grad_norm_(block2.parameters(), max_norm=1)
        # clip_grad_norm_(block3.parameters(), max_norm=1)
        # clip_grad_norm_(block4.parameters(), max_norm=1)
        # clip_grad_norm_(block4_inv.parameters(), max_norm=1)

        optimizer1.step()
        optimizer2.step()
        optimizer3.step()
        optimizer4.step()
        plt.figure(figsize=(10,5))

        # Iterate over each batch
        for i in range(min(x_batch.size(0), concatenated_inverse.size(0))):
            m = x_batch[i, 2048:2048+512].detach().numpy().reshape(-1)
            n = (concatenated_inverse[i, :512].detach().numpy().reshape(-1))
            plt.plot(m[300:], label=f'Actual Input Batch {i}')
            plt.plot(n[300:], label=f'Concatenated Inverse Batch {i}')

        plt.title('Comparison between Actual Input and Concatenated Inverse')
        plt.xlabel('Index')
        plt.ylabel('Value')
        plt.legend()
        plt.grid(True)

        # plt.xlim(0, 512)  # Set the limits for the x-axis (index)
        # plt.ylim(-5, 5)

        plt.show()

    # Learning rate scheduling
    scheduler1.step(total_output_loss)
    scheduler2.step(total_output_loss)
    scheduler3.step(total_output_loss)
    scheduler4.step(total_output_loss)

    # Regularization
    reg1.step()
    reg2.step()
    reg3.step()
    reg4.step()



    output_losses.append(total_output_loss / len(train_loader))
    print('Epoch {} Average Output Loss: {:.6f}'.format(epoch+1, output_losses[-1]))
