In [None]:
from functools import partialmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from neuralop.layers.spectral_convolution import SpectralConv
from neuralop.layers.padding import DomainPadding
from neuralop.layers.fno_block import FNOBlocks
from neuralop.layers.mlp import MLP
from neuralop.models.base_model import BaseModel
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from neuralop.models.fno import TFNO
from neuralop.models.fno import TFNO3d
from neuralop.training.trainer import Trainer
from neuralop.utils import count_model_params
from neuralop.losses.data_losses import LpLoss, H1Loss
import pdb
import sys
from torch.utils.data.dataset import Subset
import random
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Subset
from torchvision.transforms.functional import normalize
from tqdm import tqdm
from torch.utils.data import random_split
import h5py
import time

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

In [2]:
class CustomDataset(Dataset):
    def __init__(self, hdf5_file_path):
        self.hdf5_file_path = hdf5_file_path
        self.hdf5_file = h5py.File(hdf5_file_path, 'r')
        self.dataset_length = len(self.hdf5_file)

    def __len__(self):
        return self.dataset_length

    def __getitem__(self, idx):
        input_data = torch.from_numpy(self.hdf5_file[f'sample_{idx}/input'][:]).float()
        output_data = torch.from_numpy(self.hdf5_file[f'sample_{idx}/output'][:]).float()
        
        input_data = input_data.unsqueeze(0).unsqueeze(0)  
        output_data = output_data.unsqueeze(0)  
        
        
        input_data_downsampled = F.interpolate(input_data, scale_factor=0.5, mode='bilinear', align_corners=False)
        output_data_downsampled = F.interpolate(output_data, scale_factor=0.5, mode='bilinear', align_corners=False)
        
        input_data_downsampled = input_data_downsampled.squeeze(0)  
        output_data_downsampled = output_data_downsampled.squeeze(0)  
        
        sample = {
            'input':  input_data_downsampled,
            'output': output_data_downsampled
        }
        return sample

processed_data_path = "/data.h5"

In [3]:
# Set seed for reproducibility
seed_value = 1
random.seed(seed_value)
np.random.seed(seed_value)

# Create the custom dataset
custom_dataset = CustomDataset(hdf5_file_path=processed_data_path)

# Define the size for the test set
test_set_size = 20

# Generate random indices for the test set and corresponding training set
all_indices = list(range(len(custom_dataset)))
random.shuffle(all_indices)

test_indices = all_indices[:test_set_size]
train_indices = all_indices[test_set_size:]

# Create training dataset using Subset
train_dataset = Subset(custom_dataset, train_indices)

# Create testing dataset using Subset
test_dataset = Subset(custom_dataset, test_indices)

# Example usage in a DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0, persistent_workers=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=0, persistent_workers=False)

In [None]:
for batch in test_loader:
    x_shape = batch['input'].shape
    y_shape = batch['output'].shape
    print(f'Batch X shape: {x_shape}')
    print(f'Batch Y shape: {y_shape}')
    break  

In [5]:
class CustomFNO(BaseModel, name='CustomFNO'):
    def __init__(
        self,
        n_modes,
        hidden_channels,
        in_channels=3,
        out_channels=1,
        lifting_channels=256,
        projection_channels=256,
        n_layers=4,
        output_scaling_factor=None,
        max_n_modes=None,
        fno_block_precision="full",
        use_mlp=False,
        mlp_dropout=0,
        mlp_expansion=0.5,
        non_linearity=F.gelu,
        stabilizer=None,
        norm=None,
        preactivation=False,
        fno_skip="linear",
        mlp_skip="soft-gating",
        separable=False,
        factorization=None,
        rank=1.0,
        joint_factorization=False,
        fixed_rank_modes=False,
        implementation="factorized",
        decomposition_kwargs=dict(),
        domain_padding=None,
        domain_padding_mode="one-sided",
        fft_norm="forward",
        SpectralConv=SpectralConv,
        **kwargs
    ):
        super().__init__()
        self.n_dim = len(n_modes)
        self.n_layers = n_layers

        # Define the lifting layers
        self.lifting1 = MLP(
            in_channels=in_channels,
            out_channels=hidden_channels,
            hidden_channels=lifting_channels,
            n_layers=2,
            n_dim=self.n_dim,
        )

        # Define the FNO blocks
        self.fno_blocks = FNOBlocks(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            n_modes=n_modes,
            output_scaling_factor=output_scaling_factor,
            use_mlp=use_mlp,
            mlp_dropout=mlp_dropout,
            mlp_expansion=mlp_expansion,
            non_linearity=non_linearity,
            stabilizer=stabilizer,
            norm=norm,
            preactivation=preactivation,
            fno_skip=fno_skip,
            mlp_skip=mlp_skip,
            max_n_modes=max_n_modes,
            fno_block_precision=fno_block_precision,
            rank=rank,
            fft_norm=fft_norm,
            fixed_rank_modes=fixed_rank_modes,
            implementation=implementation,
            separable=separable,
            factorization=factorization,
            decomposition_kwargs=decomposition_kwargs,
            joint_factorization=joint_factorization,
            SpectralConv=SpectralConv,
            n_layers=n_layers,
            **kwargs
        )

        # Define the projection layer
        self.projection = MLP(
            in_channels=hidden_channels,
            out_channels=out_channels,
            hidden_channels=projection_channels,
            n_layers=2,
            n_dim=self.n_dim,
            non_linearity=non_linearity,
        )

    def forward(self, x1):
        """Forward pass for the Custom FNO model.

        Parameters
        ----------
        x1 : tensor
            Input tensor of shape [batch_size, 1, 504, 504]
        """
        
        # Path 2: Downsample, process, then upsample
        x1_downsampled = self.lifting1(x1)  # Apply lifting

        for layer_idx in range(self.n_layers):
            x1_downsampled = self.fno_blocks(x1_downsampled, layer_idx)

        x = self.projection(x1_downsampled)



        return x


In [None]:
model = CustomFNO(n_modes=(64,64), in_channels=1, out_channels=99, hidden_channels=128, projection_channels=256, use_mlp=True,
                              factorization='tucker', rank=0.5)
model = model.to(device)

n_params = count_model_params(model)
print(f'\nYour model has {n_params} parameters.')
sys.stdout.flush()  # flush the stdout buffer
h1loss = H1Loss(d=2)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3, weight_decay=1e-4) 
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3, verbose=True)

In [None]:
import csv
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim

num_epochs = 300
print_frequency = 40
save_checkpoint_start_epoch = 100
save_checkpoint_interval = 50
checkpoint_dir = "/result/"
results_file = os.path.join(checkpoint_dir, "result.csv")



train_losses = []
test_losses = []
epoch_durations = []

results_file = os.path.join(checkpoint_dir, "results.csv")
with open(results_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['epoch', 'r2', 'relative_error', 'mse', 'mae', 'psnr', 'ssim'])


for epoch in range(num_epochs):
    start_time = time.time()

    # Train
    model.train()
    epoch_loss = 0.0
    train_loader_tqdm = tqdm(train_loader, total=len(train_loader), desc=f'Train Epoch {epoch+1}/{num_epochs}')
    for batch_idx, batch in enumerate(train_loader_tqdm):
        optimizer.zero_grad()
        data, target = batch['input'].to(device), batch['output'].to(device)

        output = model(data) 

        loss = h1loss(output, target).mean()

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if batch_idx % print_frequency == 0:
            train_loader_tqdm.set_postfix(loss=loss.item())

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Test
    model.eval()
    test_loader_tqdm = tqdm(test_loader, total=len(test_loader), desc=f'Test Epoch {epoch+1}/{num_epochs}')
    r2_scores = []


    with torch.no_grad():
        for batch in test_loader_tqdm:
            data, target = batch['input'].to(device), batch['output'].to(device)
            output = model(data).cpu().numpy() 
            
            for true_sample, pred_sample in zip(target.cpu().numpy(), output):
                r2 = r2_score(true_sample.flatten(), pred_sample.flatten())
                r2_scores.append(r2)


    avg_r2 = np.mean(r2_scores)

    test_losses.append(avg_r2)

    # Update learning rate
    scheduler.step(avg_train_loss)

    # End of timer
    end_time = time.time()
    epoch_duration = end_time - start_time
    epoch_durations.append(epoch_duration)

    # Calculate remaining time
    avg_epoch_duration = np.mean(epoch_durations)
    remaining_epochs = num_epochs - (epoch + 1)
    remaining_time = remaining_epochs * avg_epoch_duration
    hours, rem = divmod(remaining_time, 3600)
    minutes, seconds = divmod(rem, 60)

    print(f'Epoch: {epoch+1}/{num_epochs}, Duration: {epoch_duration:.2f}s, '
          f'Train Loss: {avg_train_loss:.4f}, Test R2: {avg_r2:.4f}')
    print(f'Estimated Remaining Time: {int(hours)}h {int(minutes)}m {int(seconds)}s')


    # Save model checkpoint
    if (epoch + 1) >= save_checkpoint_start_epoch and (epoch + 1) % save_checkpoint_interval == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch+1}: {checkpoint_path}")

        model.eval()
        sample_r2_scores, sample_relative_errors, sample_mses, sample_maes, sample_psnr_values, sample_ssim_values = [], [], [], [], [], []
        
        min_val = float('inf')
        max_val = float('-inf')

        for sample in test_loader:
            y_true_batch = sample['output'].cpu().numpy()
            batch_min = y_true_batch.min()
            batch_max = y_true_batch.max()
            min_val = min(min_val, batch_min)
            max_val = max(max_val, batch_max)

        data_range = max_val - min_val

        with torch.no_grad():
            for sample in test_loader:
                x = sample['input'].to(device)
                y_true_batch = sample['output'].cpu().numpy()
                y_pred_batch = model(x).cpu().numpy()

                for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
                    true_sample_component = true_sample
                    pred_sample_component = pred_sample
                    
                    true_sample_flat = true_sample.flatten()
                    pred_sample_flat = pred_sample.flatten()

                    # R-squared
                    r2 = r2_score(true_sample_flat, pred_sample_flat)
                    sample_r2_scores.append(r2)

                    # Relative error
                    absolute_error = np.linalg.norm(true_sample_flat - pred_sample_flat, 2)
                    relative_error = absolute_error / np.linalg.norm(true_sample_flat, 2)
                    sample_relative_errors.append(relative_error)

                    # MSE
                    mse = mean_squared_error(true_sample_flat, pred_sample_flat)
                    sample_mses.append(mse)

                    # MAE
                    mae = mean_absolute_error(true_sample_flat, pred_sample_flat)
                    sample_maes.append(mae)
                    
                    # PSNR 
                    psnr_value = psnr(true_sample_component, pred_sample_component, data_range=data_range)
                    sample_psnr_values.append(psnr_value)

                    # SSIM 
                    ssim_value = ssim(true_sample_component, pred_sample_component, data_range=data_range)
                    sample_ssim_values.append(ssim_value)

            avg_r2 = np.mean(sample_r2_scores)
            avg_relative_error = np.mean(sample_relative_errors)
            avg_mse = np.mean(sample_mses)
            avg_mae = np.mean(sample_maes)
            avg_psnr = np.mean(sample_psnr_values)
            avg_ssim = np.mean(sample_ssim_values)

        with open(results_file, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch+1, avg_r2, avg_relative_error, avg_mse, avg_mae, avg_psnr, avg_ssim])

# Save final losses
np.save(os.path.join(checkpoint_dir, 'train_losses.npy'), np.array(train_losses))
np.save(os.path.join(checkpoint_dir, 'test_losses.npy'), np.array(test_losses))

# Average Epoch Duration
average_epoch_duration = np.mean(epoch_durations)
print(f'Average Epoch Duration: {average_epoch_duration:.2f}s')

# Total Training Time
total_training_time = np.sum(epoch_durations)
print(f'Total Training Time: {total_training_time:.2f}s')


# Plot Training Loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.show()

# Plot Test R2 Score
plt.figure(figsize=(10, 5))
plt.plot(test_losses, label='Test R2 Score')
plt.xlabel('Epoch')
plt.ylabel('R2 Score')
plt.title('Test R2 Score Over Epochs')
plt.legend()
plt.show()