In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from neuralop.models import TFNO, UNO
from neuralop.training.trainer import Trainer
from neuralop.utils import count_model_params
from neuralop.losses.data_losses import LpLoss, H1Loss
import pdb
import torch.nn as nn
import sys
from astropy.io import fits
from torch.utils.data.dataset import Subset
import random
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Subset
from torchvision.transforms.functional import normalize
from tqdm import tqdm
from torch.utils.data import random_split
import h5py

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

In [10]:
import torch.nn.functional as F

from neuralop.layers.mlp import MLP
from neuralop.layers.normalization_layers import AdaIN
from neuralop.layers.skip_connections import skip_connection
from neuralop.layers.spectral_convolution import SpectralConv
from neuralop.utils import validate_scaling_factor

class U_net(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, dropout_rate):
        super(U_net, self).__init__()
        self.input_channels = input_channels
        self.conv1 = self.conv(input_channels, output_channels, kernel_size=kernel_size, stride=2, dropout_rate=dropout_rate)
        self.conv2 = self.conv(output_channels, output_channels, kernel_size=kernel_size, stride=2, dropout_rate=dropout_rate)
        self.conv2_1 = self.conv(output_channels, output_channels, kernel_size=kernel_size, stride=1, dropout_rate=dropout_rate)
        self.conv3 = self.conv(output_channels, output_channels, kernel_size=kernel_size, stride=2, dropout_rate=dropout_rate)
        self.conv3_1 = self.conv(output_channels, output_channels, kernel_size=kernel_size, stride=1, dropout_rate=dropout_rate)
        
        self.deconv2 = self.deconv(output_channels, output_channels)
        self.deconv1 = self.deconv(output_channels*2, output_channels)
        self.deconv0 = self.deconv(output_channels*2, output_channels)
    
        self.output_layer = self.output(output_channels*2, output_channels, kernel_size=kernel_size, stride=1, dropout_rate=dropout_rate)

    def forward(self, x):
        out_conv1 = self.conv1(x)
        out_conv2 = self.conv2_1(self.conv2(out_conv1))
        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_deconv2 = self.deconv2(out_conv3)
        concat2 = torch.cat((out_conv2, out_deconv2), 1)
        out_deconv1 = self.deconv1(concat2)
        concat1 = torch.cat((out_conv1, out_deconv1), 1)
        out_deconv0 = self.deconv0(concat1)
        concat0 = torch.cat((x, out_deconv0), 1)
        out = self.output_layer(concat0)
        return out

    def conv(self, in_planes, output_channels, kernel_size, stride, dropout_rate):
        return nn.Sequential(
            nn.Conv2d(in_planes, output_channels, kernel_size=kernel_size,
                      stride=stride, padding=(kernel_size - 1) // 2, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(dropout_rate)
        )

    def deconv(self, input_channels, output_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(input_channels, output_channels, kernel_size=4,
                               stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True)
        )

    def output(self, input_channels, output_channels, kernel_size, stride, dropout_rate):
        return nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size,
                         stride=stride, padding=(kernel_size - 1) // 2)

class FNOBlockswithUnet(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        n_modes,
        output_scaling_factor=None,
        n_layers=1,
        max_n_modes=None,
        fno_block_precision="full",
        use_mlp=False,
        mlp_dropout=0,
        mlp_expansion=0.5,
        non_linearity=F.gelu,
        stabilizer=None,
        norm=None,
        ada_in_features=None,
        preactivation=False,
        fno_skip="linear",
        mlp_skip="soft-gating",
        separable=False,
        factorization=None,
        rank=1.0,
        SpectralConv=SpectralConv,
        joint_factorization=False,
        fixed_rank_modes=False,
        implementation="factorized",
        decomposition_kwargs=dict(),
        fft_norm="forward",
        unet_kernel_size=3,
        unet_dropout_rate=0.1,
        **kwargs,
    ):
        super().__init__()
        if isinstance(n_modes, int):
            n_modes = [n_modes]
        self._n_modes = n_modes
        self.n_dim = len(n_modes)

        self.output_scaling_factor = validate_scaling_factor(output_scaling_factor, self.n_dim, n_layers)

        self.max_n_modes = max_n_modes
        self.fno_block_precision = fno_block_precision
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_layers = n_layers
        self.joint_factorization = joint_factorization
        self.non_linearity = non_linearity
        self.stabilizer = stabilizer
        self.rank = rank
        self.factorization = factorization
        self.fixed_rank_modes = fixed_rank_modes
        self.decomposition_kwargs = decomposition_kwargs
        self.fno_skip = fno_skip
        self.mlp_skip = mlp_skip
        self.use_mlp = use_mlp
        self.mlp_expansion = mlp_expansion
        self.mlp_dropout = mlp_dropout
        self.fft_norm = fft_norm
        self.implementation = implementation
        self.separable = separable
        self.preactivation = preactivation
        self.ada_in_features = ada_in_features

        self.convs = SpectralConv(
            self.in_channels,
            self.out_channels,
            self.n_modes,
            output_scaling_factor=output_scaling_factor,
            max_n_modes=max_n_modes,
            rank=rank,
            fixed_rank_modes=fixed_rank_modes,
            implementation=implementation,
            separable=separable,
            factorization=factorization,
            decomposition_kwargs=decomposition_kwargs,
            joint_factorization=joint_factorization,
            n_layers=n_layers,
        )

        self.fno_skips = nn.ModuleList(
            [
                skip_connection(
                    self.in_channels,
                    self.out_channels,
                    skip_type=fno_skip,
                    n_dim=self.n_dim,
                )
                for _ in range(n_layers)
            ]
        )

        if use_mlp:
            self.mlp = nn.ModuleList(
                [
                    MLP(
                        in_channels=self.out_channels,
                        hidden_channels=round(self.out_channels * mlp_expansion),
                        dropout=mlp_dropout,
                        n_dim=self.n_dim,
                    )
                    for _ in range(n_layers)
                ]
            )
            self.mlp_skips = nn.ModuleList(
                [
                    skip_connection(
                        self.in_channels,
                        self.out_channels,
                        skip_type=mlp_skip,
                        n_dim=self.n_dim,
                    )
                    for _ in range(n_layers)
                ]
            )
        else:
            self.mlp = None

        # Add U-Net
        self.unet = U_net(self.in_channels, self.out_channels, unet_kernel_size, unet_dropout_rate)

        # Each block will have 2 norms if we also use an MLP
        self.n_norms = 1 if self.mlp is None else 2
        if norm is None:
            self.norm = None
        elif norm == "instance_norm":
            self.norm = nn.ModuleList(
                [
                    getattr(nn, f"InstanceNorm{self.n_dim}d")(
                        num_features=self.out_channels
                    )
                    for _ in range(n_layers * self.n_norms)
                ]
            )
        elif norm == "group_norm":
            self.norm = nn.ModuleList(
                [
                    nn.GroupNorm(num_groups=1, num_channels=self.out_channels)
                    for _ in range(n_layers * self.n_norms)
                ]
            )
        elif norm == "ada_in":
            self.norm = nn.ModuleList(
                [
                    AdaIN(ada_in_features, out_channels)
                    for _ in range(n_layers * self.n_norms)
                ]
            )
        else:
            raise ValueError(
                f"Got norm={norm} but expected None or one of "
                "[instance_norm, group_norm, ada_in]"
            )

    def forward(self, x, index=0, output_shape=None):
        if self.preactivation:
            return self.forward_with_preactivation(x, index, output_shape)
        else:
            return self.forward_with_postactivation(x, index, output_shape)

    def forward_with_postactivation(self, x, index=0, output_shape=None):
        x_skip_fno = self.fno_skips[index](x)
        x_skip_fno = self.convs[index].transform(x_skip_fno, output_shape=output_shape)

        if self.mlp is not None:
            x_skip_mlp = self.mlp_skips[index](x)
            x_skip_mlp = self.convs[index].transform(x_skip_mlp, output_shape=output_shape)

        if self.stabilizer == "tanh":
            x = torch.tanh(x)

        x_fno = self.convs(x, index, output_shape=output_shape)
        x_unet = self.unet(x)

        if self.norm is not None:
            x_fno = self.norm[self.n_norms * index](x_fno)

        x = x_fno + x_unet + x_skip_fno

        if (self.mlp is not None) or (index < (self.n_layers - 1)):
            x = self.non_linearity(x)

        if self.mlp is not None:
            x = self.mlp[index](x) + x_skip_mlp

            if self.norm is not None:
                x = self.norm[self.n_norms * index + 1](x)

            if index < (self.n_layers - 1):
                x = self.non_linearity(x)

        return x

    def forward_with_preactivation(self, x, index=0, output_shape=None):
        x = self.non_linearity(x)

        if self.norm is not None:
            x = self.norm[self.n_norms * index](x)

        x_skip_fno = self.fno_skips[index](x)
        x_skip_fno = self.convs[index].transform(x_skip_fno, output_shape=output_shape)

        if self.mlp is not None:
            x_skip_mlp = self.mlp_skips[index](x)
            x_skip_mlp = self.convs[index].transform(x_skip_mlp, output_shape=output_shape)

        if self.stabilizer == "tanh":
            x = torch.tanh(x)

        x_fno = self.convs(x, index, output_shape=output_shape)
        x_unet = self.unet(x)
        x = x_fno + x_unet + x_skip_fno

        if self.mlp is not None:
            if index < (self.n_layers - 1):
                x = self.non_linearity(x)

            if self.norm is not None:
                x = self.norm[self.n_norms * index + 1](x)

            x = self.mlp[index](x) + x_skip_mlp

        return x

    @property
    def n_modes(self):
        return self._n_modes

    @n_modes.setter
    def n_modes(self, n_modes):
        self.convs.n_modes = n_modes
        self._n_modes = n_modes

    def get_block(self, indices):
        if self.n_layers == 1:
            raise ValueError(
                "A single layer is parametrized, directly use the main class."
            )
        return SubModule(self, indices)

    def __getitem__(self, indices):
        return self.get_block(indices)

class SubModule(nn.Module):
    def __init__(self, main_module, indices):
        super().__init__()
        self.main_module = main_module
        self.indices = indices

    def forward(self, x):
        return self.main_module.forward(x, self.indices)

In [11]:
from neuralop.layers.spectral_convolution import SpectralConv
from neuralop.layers.padding import DomainPadding
from neuralop.layers.fno_block import FNOBlocks
from neuralop.layers.mlp import MLP
from neuralop.models.base_model import BaseModel

class CustomFNO(BaseModel, name='CustomFNO'):
    def __init__(
        self,
        n_modes,
        hidden_channels,
        in_channels=3,
        out_channels=1,
        lifting_channels=256,
        projection_channels=256,
        n_layers=4,
        output_scaling_factor=None,
        max_n_modes=None,
        fno_block_precision="full",
        use_mlp=False,
        mlp_dropout=0,
        mlp_expansion=0.5,
        non_linearity=F.gelu,
        stabilizer=None,
        norm=None,
        preactivation=False,
        fno_skip="linear",
        mlp_skip="soft-gating",
        separable=False,
        factorization=None,
        rank=1.0,
        joint_factorization=False,
        fixed_rank_modes=False,
        implementation="factorized",
        decomposition_kwargs=dict(),
        domain_padding=None,
        domain_padding_mode="one-sided",
        fft_norm="forward",
        SpectralConv=SpectralConv,
        unet_kernel_size=3,
        unet_dropout_rate=0.1,
        **kwargs
    ):
        super().__init__()
        self.n_dim = len(n_modes)
        self.n_layers = n_layers

        # Define the lifting layers
        self.lifting1 = MLP(
            in_channels=in_channels,
            out_channels=hidden_channels,
            hidden_channels=lifting_channels,
            n_layers=2,
            n_dim=self.n_dim,
        )

        # Define the FNO blocks
        self.fno_blocks = FNOBlocks(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            n_modes=n_modes,
            output_scaling_factor=output_scaling_factor,
            use_mlp=use_mlp,
            mlp_dropout=mlp_dropout,
            mlp_expansion=mlp_expansion,
            non_linearity=non_linearity,
            stabilizer=stabilizer,
            norm=norm,
            preactivation=preactivation,
            fno_skip=fno_skip,
            mlp_skip=mlp_skip,
            max_n_modes=max_n_modes,
            fno_block_precision=fno_block_precision,
            rank=rank,
            fft_norm=fft_norm,
            fixed_rank_modes=fixed_rank_modes,
            implementation=implementation,
            separable=separable,
            factorization=factorization,
            decomposition_kwargs=decomposition_kwargs,
            joint_factorization=joint_factorization,
            SpectralConv=SpectralConv,
            n_layers=n_layers // 2,  # Use half the layers for FNOBlocks
            **kwargs
        )
        
        self.fno_blocks_unet = FNOBlockswithUnet(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            n_modes=n_modes,
            output_scaling_factor=output_scaling_factor,
            use_mlp=use_mlp,
            mlp_dropout=mlp_dropout,
            mlp_expansion=mlp_expansion,
            non_linearity=non_linearity,
            stabilizer=stabilizer,
            norm=norm,
            preactivation=preactivation,
            fno_skip=fno_skip,
            mlp_skip=mlp_skip,
            max_n_modes=max_n_modes,
            fno_block_precision=fno_block_precision,
            rank=rank,
            fft_norm=fft_norm,
            fixed_rank_modes=fixed_rank_modes,
            implementation=implementation,
            separable=separable,
            factorization=factorization,
            decomposition_kwargs=decomposition_kwargs,
            joint_factorization=joint_factorization,
            SpectralConv=SpectralConv,
            n_layers=n_layers - (n_layers // 2),  # Use remaining layers for FNOBlockswithUnet
            unet_kernel_size=unet_kernel_size,
            unet_dropout_rate=unet_dropout_rate,
            **kwargs
        )

        # Define the projection layer
        self.projection = MLP(
            in_channels=hidden_channels,
            out_channels=out_channels,
            hidden_channels=projection_channels,
            n_layers=2,
            n_dim=self.n_dim,
            non_linearity=non_linearity,
        )

    def forward(self, x):
        """Forward pass for the Custom FNO model.

        Parameters
        ----------
        x : tensor
            Input tensor of shape [batch_size, in_channels, 504, 504]
        """
        
        x = self.lifting1(x)  # Result shape: [batch_size, hidden_channels, 504, 504]

        # Apply FNO blocks for the first half of layers
        for layer_idx in range(self.n_layers // 2):
            x = self.fno_blocks(x, layer_idx)
            
        # Apply FNOBlockswithUnet for the second half of layers
        for layer_idx in range(self.n_layers - (self.n_layers // 2)):
            x = self.fno_blocks_unet(x, layer_idx)

        x = self.projection(x)

        return x


In [12]:
class CustomDataset(Dataset):
    def __init__(self, hdf5_file_path):
        self.hdf5_file_path = hdf5_file_path
        self.hdf5_file = h5py.File(hdf5_file_path, 'r')
        self.dataset_length = len(self.hdf5_file)

    def __len__(self):
        return self.dataset_length

    def __getitem__(self, idx):
        input_data = torch.from_numpy(self.hdf5_file[f'sample_{idx}/input'][:]).float()
        output_data = torch.from_numpy(self.hdf5_file[f'sample_{idx}/output'][:]).float()
        
        input_data = input_data.unsqueeze(0)  
        
        sample = {
            'input':  input_data,
            'output': output_data
        }
        return sample

processed_data_path = "/data.h5"

In [13]:
# Set seed for reproducibility
seed_value = 1
random.seed(seed_value)
np.random.seed(seed_value)

# Create the custom dataset
custom_dataset = CustomDataset(hdf5_file_path=processed_data_path)

# Define the size for the test set
test_set_size = 20

# Generate random indices for the test set and corresponding training set
all_indices = list(range(len(custom_dataset)))
random.shuffle(all_indices)

test_indices = all_indices[:test_set_size]
train_indices = all_indices[test_set_size:]

# Create training dataset using Subset
train_dataset = Subset(custom_dataset, train_indices)

# Create testing dataset using Subset
test_dataset = Subset(custom_dataset, test_indices)

# Example usage in a DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0, persistent_workers=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=0, persistent_workers=False)

In [None]:
for batch in train_loader:
    print(batch['input'].shape)
    print(batch['output'].shape)
    
    break

In [None]:
model = CustomFNO(n_modes=(64,64), in_channels=1, out_channels=99, hidden_channels=128, use_mlp=True, n_layers=4,
                              factorization='tucker', rank=0.5)

model = model.to(device)
n_params = count_model_params(model)
print(f'\nYour model has {n_params} parameters.')
sys.stdout.flush()  # flush the stdout buffer

h1loss = H1Loss(d=2)
test_loss = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3, verbose=True)

In [None]:
print(model)

In [None]:
import csv
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim


num_epochs = 300
print_frequency = 40
save_checkpoint_start_epoch = 100
save_checkpoint_interval = 50
checkpoint_dir = "/result/"
results_file = os.path.join(checkpoint_dir, "result.csv")



train_losses = []
test_r2_losses = []
test_mse_losses = []
epoch_durations = []

results_file = os.path.join(checkpoint_dir, "results.csv")
with open(results_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['epoch', 'r2', 'relative_error', 'mse', 'mae', 'psnr', 'ssim'])


for epoch in range(num_epochs):
    start_time = time.time()

    # Train
    model.train()
    epoch_loss = 0.0
    train_loader_tqdm = tqdm(train_loader, total=len(train_loader), desc=f'Train Epoch {epoch+1}/{num_epochs}')
    for batch_idx, batch in enumerate(train_loader_tqdm):
        optimizer.zero_grad()
        data, target = batch['input'].to(device), batch['output'].to(device)

        output = model(data) 

        loss = h1loss(output, target).mean()

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if batch_idx % print_frequency == 0:
            train_loader_tqdm.set_postfix(loss=loss.item())

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Test
    model.eval()
    test_loader_tqdm = tqdm(test_loader, total=len(test_loader), desc=f'Test Epoch {epoch+1}/{num_epochs}')
    r2_scores = []
    mse_losses = []


    with torch.no_grad():
        for batch in test_loader_tqdm:
            data, target = batch['input'].to(device), batch['output'].to(device)
            output = model(data)
            
            output_np = output.cpu().numpy()
            target_np = target.cpu().numpy()
            
            for true_sample, pred_sample in zip(target_np, output_np):
                r2 = r2_score(true_sample.flatten(), pred_sample.flatten())
                r2_scores.append(r2)
            
            mse_loss = test_loss(output, target).item()
            mse_losses.append(mse_loss)
            


    avg_r2 = np.mean(r2_scores)
    avg_mse = np.mean(mse_losses)

    test_r2_losses.append(avg_r2)
    test_mse_losses.append(avg_mse)

    # Update learning rate
    scheduler.step(avg_train_loss)

    # End of timer
    end_time = time.time()
    epoch_duration = end_time - start_time
    epoch_durations.append(epoch_duration)

    # Calculate remaining time
    avg_epoch_duration = np.mean(epoch_durations)
    remaining_epochs = num_epochs - (epoch + 1)
    remaining_time = remaining_epochs * avg_epoch_duration
    hours, rem = divmod(remaining_time, 3600)
    minutes, seconds = divmod(rem, 60)

    print(f'Epoch: {epoch+1}/{num_epochs}, Duration: {epoch_duration:.2f}s, '
          f'Train Loss: {avg_train_loss:.4f}, Test R2: {avg_r2:.4f}, Test MSE: {avg_mse:.4f}')
    print(f'Estimated Remaining Time: {int(hours)}h {int(minutes)}m {int(seconds)}s')


    # Save model checkpoint
    if (epoch + 1) >= save_checkpoint_start_epoch and (epoch + 1) % save_checkpoint_interval == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch+1}: {checkpoint_path}")

        model.eval()
        sample_r2_scores, sample_relative_errors, sample_mses, sample_maes, sample_psnr_values, sample_ssim_values = [], [], [], [], [], []
        
        min_val = float('inf')
        max_val = float('-inf')

        for sample in test_loader:
            y_true_batch = sample['output'].cpu().numpy()
            batch_min = y_true_batch.min()
            batch_max = y_true_batch.max()
            min_val = min(min_val, batch_min)
            max_val = max(max_val, batch_max)

        data_range = max_val - min_val

        with torch.no_grad():
            for sample in test_loader:
                x = sample['input'].to(device)
                y_true_batch = sample['output'].cpu().numpy()
                y_pred_batch = model(x).cpu().numpy()

                for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
                    true_sample_component = true_sample
                    pred_sample_component = pred_sample
                    
                    true_sample_flat = true_sample.flatten()
                    pred_sample_flat = pred_sample.flatten()

                    # R-squared
                    r2 = r2_score(true_sample_flat, pred_sample_flat)
                    sample_r2_scores.append(r2)

                    # Relative error
                    absolute_error = np.linalg.norm(true_sample_flat - pred_sample_flat, 2)
                    relative_error = absolute_error / np.linalg.norm(true_sample_flat, 2)
                    sample_relative_errors.append(relative_error)

                    # MSE
                    mse = mean_squared_error(true_sample_flat, pred_sample_flat)
                    sample_mses.append(mse)

                    # MAE
                    mae = mean_absolute_error(true_sample_flat, pred_sample_flat)
                    sample_maes.append(mae)
                    
                    # PSNR 
                    psnr_value = psnr(true_sample_component, pred_sample_component, data_range=data_range)
                    sample_psnr_values.append(psnr_value)

                    # SSIM 
                    ssim_value = ssim(true_sample_component, pred_sample_component, data_range=data_range)
                    sample_ssim_values.append(ssim_value)

            avg_r2 = np.mean(sample_r2_scores)
            avg_relative_error = np.mean(sample_relative_errors)
            avg_mse = np.mean(sample_mses)
            avg_mae = np.mean(sample_maes)
            avg_psnr = np.mean(sample_psnr_values)
            avg_ssim = np.mean(sample_ssim_values)

        with open(results_file, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch+1, avg_r2, avg_relative_error, avg_mse, avg_mae, avg_psnr, avg_ssim])

# Save final losses
np.save(os.path.join(checkpoint_dir, 'train_losses.npy'), np.array(train_losses))
np.save(os.path.join(checkpoint_dir, 'test_losses.npy'), np.array(test_r2_losses))
np.save(os.path.join(checkpoint_dir, 'test_mse_losses.npy'), np.array(test_mse_losses))

# Average Epoch Duration
average_epoch_duration = np.mean(epoch_durations)
print(f'Average Epoch Duration: {average_epoch_duration:.2f}s')

# Total Training Time
total_training_time = np.sum(epoch_durations)
print(f'Total Training Time: {total_training_time:.2f}s')


# Plot Training Loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.show()

# Plot Test R2 Score
plt.figure(figsize=(10, 5))
plt.plot(test_r2_losses, label='Test R2 Score')
plt.xlabel('Epoch')
plt.ylabel('R2 Score')
plt.title('Test R2 Score Over Epochs')
plt.legend()
plt.show()

# Plot Test MSE Loss
plt.figure(figsize=(10, 5))
plt.plot(test_mse_losses, label='Test MSE Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Test MSE Loss Over Epochs')
plt.legend()
plt.show()