In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from einops.layers.torch import Rearrange
import numpy as np
import os
import matplotlib.pyplot as plt
from torch.utils.data.dataset import Dataset
import sys
from astropy.io import fits
from torch.utils.data.dataset import Subset
import random
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Subset
from torchvision.transforms.functional import normalize
from tqdm import tqdm
import h5py
from IPython.display import HTML
from neuralop.losses.data_losses import LpLoss, H1Loss
from neuralop.utils import count_model_params
import time
from torchvision.transforms import Resize
import torch.nn.functional as F

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

In [2]:
class CustomDataset(Dataset):
    def __init__(self, hdf5_file_path):
        self.hdf5_file_path = hdf5_file_path
        self.hdf5_file = h5py.File(hdf5_file_path, 'r')
        self.dataset_length = len(self.hdf5_file)

    def __len__(self):
        return self.dataset_length

    def __getitem__(self, idx):
        sample = {
            'x': torch.from_numpy(self.hdf5_file[f'sample_{idx}/x'][:]),
            'y': torch.from_numpy(self.hdf5_file[f'sample_{idx}/y'][:])
        }
        return sample
    

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=1, patch_size=16, emb_size=768):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, emb_size, H/P, W/P]
        x = x.flatten(2)  # [B, emb_size, N]
        x = x.transpose(1, 2)  # [B, N, emb_size]
        return x

class ViT(nn.Module):
    def __init__(self, in_channels=1, patch_size=16, emb_size=768, depth=12, num_heads=12, output_channels=99):
        super().__init__()
        self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size)
        self.pos_embed = nn.Parameter(torch.randn(1, (504//patch_size)**2, emb_size))
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=emb_size, nhead=num_heads), num_layers=depth)
        self.to_output_shape = nn.Sequential(
            nn.Linear(emb_size, patch_size*patch_size*output_channels),
            Rearrange('b n (c h w) -> b c (h n) w', h=patch_size, w=patch_size),
            Resize((504, 504), antialias=True)  
        )

    def forward(self, x):
        x = self.patch_embed(x)
        x += self.pos_embed
        x = self.transformer(x)
        x = self.to_output_shape(x)
        return x

In [3]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size=504, patch_size=8, output_channels=99, d_model=128, nhead=8, num_layers=2):
        super(VisionTransformer, self).__init__()
        self.image_size = image_size
        self.output_channels = output_channels
        self.patch_embed = nn.Sequential(
            nn.Conv2d(1, d_model, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c')
        )
        num_patches = (image_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=d_model * 2),
            num_layers=num_layers
        )
        #self.decoder = nn.Linear(d_model, output_channels * (image_size) * (image_size))
        self.decoder = nn.Linear(d_model, output_channels * (image_size//2) * (image_size//2))

    def forward(self, x):
        B, _, _, _ = x.shape
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embed
        x = Rearrange('b n d -> n b d')(x)
        x = self.transformer(x)
        x = x.mean(dim=0)  # Aggregate patches
        x = self.decoder(x)
        #x = x.view(B, self.output_channels, self.image_size, self.image_size)

        x = x.view(B, self.output_channels, self.image_size//2, self.image_size//2)
        x = F.interpolate(x, size=(self.image_size, self.image_size), mode='bilinear', align_corners=False)

        return x

In [4]:
processed_data_path = "/data.h5"
# Set seed for reproducibility
seed_value = 1
random.seed(seed_value)
np.random.seed(seed_value)

# Create the custom dataset
custom_dataset = CustomDataset(hdf5_file_path=processed_data_path)

# Define the size for the test set
test_set_size = 20

# Generate random indices for the test set and corresponding training set
all_indices = list(range(len(custom_dataset)))
random.shuffle(all_indices)

test_indices = all_indices[:test_set_size]
train_indices = all_indices[test_set_size:]

# Create training dataset using Subset
train_dataset = Subset(custom_dataset, train_indices)

# Create testing dataset using Subset
test_dataset = Subset(custom_dataset, test_indices)

# Example usage in a DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0, persistent_workers=False)
test_loaders = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=0, persistent_workers=False)


In [None]:
# Initialize the model, loss function, and optimizer
model = VisionTransformer()  
model = model.to(device)

n_params = count_model_params(model)
print(f'\nYour model has {n_params} parameters.')
sys.stdout.flush()  # flush the stdout buffer

h1loss = H1Loss(d=2)
test_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=2, verbose=True)



In [None]:
train_losses = []
test_losses = []
epoch_durations = []

num_epochs = 300
print_frequency = 40
save_checkpoint_interval = 300
checkpoint_dir = "/result/"

for epoch in range(num_epochs):
    start_time = time.time()
    
    # training process
    model.train()
    epoch_train_loss = 0.0
    for batch_idx, sample in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = sample['x'].to(device), sample['y'].to(device)
        output = model(data)
        loss = h1loss(output, target)
        loss = loss.mean()
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()

        if batch_idx % print_frequency == 0:
            print(f'Train Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}/{len(train_loader)}, Loss: {loss.item()}')
    
    avg_train_loss = epoch_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # testing process
    model.eval()
    epoch_test_loss = 0.0
    with torch.no_grad():
        for sample in test_loaders:
            data, target = sample['x'].to(device), sample['y'].to(device)
            output = model(data)
            loss = test_loss(output, target)
            loss = loss.mean()
            epoch_test_loss += loss.item()
    
    avg_test_loss = epoch_test_loss / len(test_loaders)
    test_losses.append(avg_test_loss)

    end_time = time.time()
    epoch_duration = end_time - start_time
    epoch_durations.append(epoch_duration)

    # Updated learning rate
    scheduler.step(avg_train_loss)
    
    print(f'Epoch: {epoch+1}/{num_epochs}, Duration: {epoch_duration:.2f}s, Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}')

    # save checkpoint
    if (epoch + 1) % save_checkpoint_interval == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch+1}: {checkpoint_path}")

# average time of epoch duration
avg_epoch_duration = np.mean(epoch_durations)
print(f'Average Epoch Duration: {avg_epoch_duration:.2f}s')

# save loss
np.save(os.path.join(checkpoint_dir, 'train_losses.npy'), np.array(train_losses))
np.save(os.path.join(checkpoint_dir, 'test_losses.npy'), np.array(test_losses))

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(test_losses, label='Testing Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Testing Loss')
plt.legend()
plt.show()

In [None]:
from sklearn.metrics import r2_score

sample_r2 = []

with torch.no_grad():
    for sample in test_loaders:
        x = sample['x'].to(device)
        y_true_batch = sample['y'].numpy()
        y_pred_batch = model(x).cpu().numpy()

        # Calculate R-squared for each sample in the batch
        for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
            sample_r2_scores = r2_score(true_sample.flatten(), pred_sample.flatten())
            sample_r2.append(sample_r2_scores)

import numpy as np

# Initialize an empty list to store all sample relative errors
sample_relative_errors = []

with torch.no_grad():
    for sample in test_loaders:
        x = sample['x'].to(device)
        y_true_batch = sample['y'].numpy()
        y_pred_batch = model(x).cpu().numpy()

        # Calculate Relative Error for each sample in the batch
        for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
            absolute_error = np.linalg.norm(true_sample.flatten() - pred_sample.flatten(), 2)
            relative_error = absolute_error / (np.linalg.norm(true_sample.flatten(), 2)) 
            #print(relative_error)
            sample_relative_errors.append(relative_error)
            
import numpy as np
from sklearn.metrics import mean_squared_error

min_mse_sample = None
sample_mses = []  # Initialize outside the loop

with torch.no_grad():
    for sample_idx, sample in enumerate(test_loaders):
        x = sample['x'].to(device)
        y_true_batch = sample['y'].numpy()  
        y_pred_batch = model(x).cpu().numpy()

        # Iterate through each ground truth-prediction pair in the sample
        for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
            mse_sample = mean_squared_error(true_sample.flatten(), pred_sample.flatten())
            sample_mses.append(mse_sample)

# Calculate overall Mean Squared Error
overall_mse = np.mean(sample_mses)
print(f'Overall Mean Squared Error: {overall_mse}')

# Find the minimum MSE value and its index
min_sample_mse = min(sample_mses)
min_index = np.argmin(sample_mses)

#print(f'Index of Minimum MSE: {min_index}')

import numpy as np

# Initialize an empty list to store all sample absolute errors
sample_absolute_errors = []

with torch.no_grad():
    for sample in test_loaders:
        x = sample['x'].to(device)
        y_true_batch = sample['y'].numpy()
        y_pred_batch = model(x).cpu().numpy()

        # Calculate Absolute Error for each sample in the batch
        for true_sample, pred_sample in zip(y_true_batch, y_pred_batch):
            absolute_error = np.mean(np.abs(true_sample.flatten() - pred_sample.flatten()))
            sample_absolute_errors.append(absolute_error)

# Calculate Mean Absolute Error (MAE)
mae = np.mean(sample_absolute_errors)

print(f"Mean Absolute Error (MAE): {mae}")


# Calculate overall Relative Error
overall_relative_error = np.mean(sample_relative_errors)
print(f'Overall Relative Error: {overall_relative_error}')

# Calculate overall R-squared
overall_r2 = np.mean(sample_r2)
print(f'Overall R-squared: {overall_r2}')