In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import os
import matplotlib.pyplot as plt
from torch.utils.data.dataset import Dataset
import sys
from astropy.io import fits
from torch.utils.data.dataset import Subset
import random
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Subset
from torchvision.transforms.functional import normalize
from tqdm import tqdm
import h5py
import torch.nn.functional as F
from neuralop.losses.data_losses import LpLoss, H1Loss
from neuralop.utils import count_model_params
import time


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

In [2]:
class CustomDataset(Dataset):
    def __init__(self, hdf5_file_path):
        self.hdf5_file_path = hdf5_file_path
        self.hdf5_file = h5py.File(hdf5_file_path, 'r')
        self.dataset_length = len(self.hdf5_file)

    def __len__(self):
        return self.dataset_length

    def __getitem__(self, idx):
        sample = {
            'x': torch.from_numpy(self.hdf5_file[f'sample_{idx}/x'][:]),
            'y': torch.from_numpy(self.hdf5_file[f'sample_{idx}/y'][:])
        }
        return sample
    

In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.relu(self.conv3(x))
        x = self.flatten(x)
        return x
"""
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.rnn = nn.GRUCell(256, 256)
        self.linear = nn.Linear(256, 504 * 504)

    def forward(self, x, seq_length=99):
        outputs = []
        batch_size = x.size(0)
        h_t = torch.zeros(batch_size, 256).to(x.device)  # Initial hidden state

        for _ in range(seq_length):
            x = h_t  # Use the previous hidden state as input
            h_t = self.rnn(x, h_t)  # Update the hidden state
            output_t = self.linear(h_t)
            outputs.append(output_t.view(-1, 1, 504, 504))

        return torch.cat(outputs, dim=1)

"""
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        # Use LSTMCell for the recurrent network part
        self.lstm = nn.LSTMCell(256, 256)
        # Linear layer to map the hidden state output of the LSTM to the desired output size
        self.linear = nn.Linear(256, 504 * 504)

    def forward(self, x, seq_length=99):
        outputs = []
        batch_size = x.size(0)
        # Initialize hidden and cell states for LSTM
        h_t = torch.zeros(batch_size, 256).to(x.device)
        c_t = torch.zeros(batch_size, 256).to(x.device)

        for _ in range(seq_length):
            x = h_t  # Use the previous hidden state as input
            # Update the hidden and cell states
            h_t, c_t = self.lstm(x, (h_t, c_t))
            # Map the hidden state to the output size and reshape to image dimensions
            output_t = self.linear(h_t)
            outputs.append(output_t.view(-1, 1, 504, 504))

        # Concatenate the outputs along the sequence dimension
        return torch.cat(outputs, dim=1)


class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [4]:
processed_data_path = "/data.h5"

# Set seed for reproducibility
seed_value = 1
random.seed(seed_value)
np.random.seed(seed_value)

# Create the custom dataset
custom_dataset = CustomDataset(hdf5_file_path=processed_data_path)

# Define the size for the test set
test_set_size = 20

# Generate random indices for the test set and corresponding training set
all_indices = list(range(len(custom_dataset)))
random.shuffle(all_indices)

test_indices = all_indices[:test_set_size]
train_indices = all_indices[test_set_size:]

# Create training dataset using Subset
train_dataset = Subset(custom_dataset, train_indices)

# Create testing dataset using Subset
test_dataset = Subset(custom_dataset, test_indices)

# Example usage in a DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0, persistent_workers=False)
test_loaders = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=0, persistent_workers=False)

In [None]:
model = EncoderDecoder()
model = model.to(device)

n_params = count_model_params(model)
print(f'\nYour model has {n_params} parameters.')
sys.stdout.flush()  # flush the stdout buffer

h1loss = H1Loss(d=2)
test_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=2, verbose=True)





In [None]:
train_losses = []
test_losses = []
epoch_durations = []

num_epochs = 300
print_frequency = 40
save_checkpoint_interval = 300
checkpoint_dir = "/result"

for epoch in range(num_epochs):
    start_time = time.time()
    #training process
    model.train()
    epoch_train_loss = 0.0
    for batch_idx, sample in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = sample['x'].to(device), sample['y'].to(device)
        output = model(data)
        loss = h1loss(output, target)
        loss = loss.mean()
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()

        if batch_idx % print_frequency == 0:
            print(f'Train Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}/{len(train_loader)}, Loss: {loss.item()}')
    
    avg_train_loss = epoch_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # testing process
    model.eval()
    epoch_test_loss = 0.0
    with torch.no_grad():
        for sample in test_loaders:
            data, target = sample['x'].to(device), sample['y'].to(device)
            output = model(data)
            loss = test_loss(output, target)
            loss = loss.mean()
            epoch_test_loss += loss.item()
    
    avg_test_loss = epoch_test_loss / len(test_loaders)
    test_losses.append(avg_test_loss)

    end_time = time.time()
    epoch_duration = end_time - start_time
    epoch_durations.append(epoch_duration)

    # Updated learning rate
    scheduler.step(avg_train_loss)
    
    print(f'Epoch: {epoch+1}/{num_epochs}, Duration: {epoch_duration:.2f}s, Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}')

    # save checkpoint
    if (epoch + 1) % save_checkpoint_interval == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch+1}: {checkpoint_path}")

# average time of epoch duration
avg_epoch_duration = np.mean(epoch_durations)
print(f'Average Epoch Duration: {avg_epoch_duration:.2f}s')

# save loss
np.save(os.path.join(checkpoint_dir, 'train_losses.npy'), np.array(train_losses))
np.save(os.path.join(checkpoint_dir, 'test_losses.npy'), np.array(test_losses))

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(test_losses, label='Testing Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Testing Loss')
plt.legend()
plt.show()

In [None]:
from sklearn.metrics import r2_score

sample_r2 = []

with torch.no_grad():
    for sample in test_loaders:
        x = sample['x'].to(device)
        y_true_batch = sample['y'].numpy()
        y_pred_batch = model(x).cpu().numpy()

        # Calculate R-squared for each sample in the batch
        for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
            sample_r2_scores = r2_score(true_sample.flatten(), pred_sample.flatten())
            sample_r2.append(sample_r2_scores)

import numpy as np

# Initialize an empty list to store all sample relative errors
sample_relative_errors = []

with torch.no_grad():
    for sample in test_loaders:
        x = sample['x'].to(device)
        y_true_batch = sample['y'].numpy()
        y_pred_batch = model(x).cpu().numpy()

        # Calculate Relative Error for each sample in the batch
        for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
            absolute_error = np.linalg.norm(true_sample.flatten() - pred_sample.flatten(), 2)
            relative_error = absolute_error / (np.linalg.norm(true_sample.flatten(), 2)) 
            #print(relative_error)
            sample_relative_errors.append(relative_error)
            
import numpy as np
from sklearn.metrics import mean_squared_error

min_mse_sample = None
sample_mses = []  # Initialize outside the loop

with torch.no_grad():
    for sample_idx, sample in enumerate(test_loaders):
        x = sample['x'].to(device)
        y_true_batch = sample['y'].numpy()  
        y_pred_batch = model(x).cpu().numpy()

        # Iterate through each ground truth-prediction pair in the sample
        for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
            mse_sample = mean_squared_error(true_sample.flatten(), pred_sample.flatten())
            sample_mses.append(mse_sample)

# Calculate overall Mean Squared Error
overall_mse = np.mean(sample_mses)
print(f'Overall Mean Squared Error: {overall_mse}')

# Find the minimum MSE value and its index
min_sample_mse = min(sample_mses)
min_index = np.argmin(sample_mses)

#print(f'Index of Minimum MSE: {min_index}')

import numpy as np

# Initialize an empty list to store all sample absolute errors
sample_absolute_errors = []

with torch.no_grad():
    for sample in test_loaders:
        x = sample['x'].to(device)
        y_true_batch = sample['y'].numpy()
        y_pred_batch = model(x).cpu().numpy()

        # Calculate Absolute Error for each sample in the batch
        for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
            absolute_error = np.mean(np.abs(true_sample.flatten() - pred_sample.flatten()))
            sample_absolute_errors.append(absolute_error)

# Calculate Mean Absolute Error (MAE)
mae = np.mean(sample_absolute_errors)

print(f"Mean Absolute Error (MAE): {mae}")


# Calculate overall Relative Error
overall_relative_error = np.mean(sample_relative_errors)
print(f'Overall Relative Error: {overall_relative_error}')

# Calculate overall R-squared
overall_r2 = np.mean(sample_r2)
print(f'Overall R-squared: {overall_r2}')