# Load / Test Outputs of ```Forward Simulation.py```

In [None]:
import sys
import os

notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, '../../../'))
if project_root not in sys.path:
    sys.path.append(project_root)



import random
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import numpy as np
import torch

from Components import data_loader


### Load Data from root directory
Load all data files in a given root folder

In [None]:

root_dir = "../../Debug Data/Demo Data/Forward Simulation Outputs"

dataset = data_loader.SimDataset(root_dir = root_dir)
dataset_len = len(dataset)

print("Number of Output fields:", dataset_len)

data_sample = dataset[random.randint(0, dataset_len-1)]
print(data_sample)

### Load from list of files
Loads data files given their paths as a list

In [None]:
file_list = ["../../Debug Data/Demo Data/Forward Simulation Outputs/Sample_Data_000.pt",
            "../../Debug Data/Demo Data/Forward Simulation Outputs/Sample_Data_002.pt",
            "../../Debug Data/Demo Data/Forward Simulation Outputs/Sample_Data_004.pt"
            ]
    
dataset = data_loader.SimDataset(file_list=file_list)
dataset_len = len(dataset)

print("Number of Output fields:", dataset_len)

data_sample = dataset[random.randint(0, dataset_len-1)]
print(data_sample)

### Load Data from csv-file

In [None]:
csv_file = "../../Debug Data/Demo Data/Forward Simulation Outputs/absolute_sample_data.csv"
#create_csv.from_root_dir(root_dir, csv_file)
    
dataset = data_loader.SimDataset(csv_file = csv_file)
dataset_len = len(dataset)
    
print("Number of Output fields:", dataset_len)

data_sample = dataset[random.randint(0, dataset_len-1)]
print(data_sample)

### Access Data
What and how can you extract data from our Forward Simulation Outputs?

In [None]:
file_list = ["../../Debug Data/Demo Data/Forward Simulation Outputs/Sample_Data_000.pt"]


dataset = data_loader.SimDataset(file_list = file_list)

# first file in dataset
data_sample = dataset[0]

print(" data_sample data:")
print("  - Simulator Output Field")
field = data_sample.field

print("  - Amplitude/Phase data")
amp = data_sample.amp
phase = data_sample.phase

print("  - Simulation Space Spatial resolution")
spatial_resolution = data_sample.spatial_resolution

print("  - Simulation Space Grid Shape")
grid_shape = data_sample.grid_shape


print("  - Position/Offset + their unit")
position = data_sample.position
offset = data_sample.offset
unit = data_sample.unit

print("  - Axis/Angle (degree)")
axis = data_sample.axis
angle = data_sample.angle

print("  - Post-Processing transform functions")
transforms = data_sample.transforms
    

### Concatenate Datasets
Combine multiple different datasets into one

In [None]:
  
file_list1 = ["../../Debug Data/Demo Data/Forward Simulation Outputs/Sample_Data_000.pt",
            "../../Debug Data/Demo Data/Forward Simulation Outputs/Sample_Data_001.pt",
            "../../Debug Data/Demo Data/Forward Simulation Outputs/Sample_Data_002.pt"
            ]
dataset1 = data_loader.SimDataset(file_list = file_list1)

file_list2 = ["../../Debug Data/Demo Data/Forward Simulation Outputs/Sample_Data_003.pt",
            "../../Debug Data/D'emo Data/Forward Simulation Outputs/Sample_Data_004.pt"
                ]
dataset2 = data_loader.SimDataset(file_list = file_list2)

dataset = dataset1 + dataset2

dataset_len = len(dataset)

print("Number of Output fields:", dataset_len)

print(dataset)

### Simple Visualization

In [None]:
root_dir = "../../Debug Data/Demo Data/Forward Simulation Outputs"
    
dataset = data_loader.SimDataset(root_dir = root_dir)
dataset_len = len(dataset)

data_sample = dataset[random.randint(0, dataset_len-1)]

amp = data_sample.amp.detach().cpu().numpy()
phase = data_sample.phase.detach().cpu().numpy()


def base_plot(image, support, title, units, extent=None, cmap="jet", vmin=None, vmax=None, grid=False):
    """General Plotting for single 2D Images"""
    
    fig, ax = plt.subplots(figsize=(6.4,4.8), dpi=150)
    
    if extent is None:
        extent = [0, support[0] / units, 0, support[1] / units]
    
    im = ax.imshow(image.T, cmap=cmap, 
                   extent=extent, origin="lower",
                   vmin=vmin, vmax=vmax)
    
    # Colorbar
    cbar = fig.colorbar(im)
    cbar.formatter = ticker.FuncFormatter(lambda x, _: f"{x:.7}")
    cbar.update_ticks()

   
    # Plot Title
    ax.set_title(title)
    
    # Grid Lines
    if grid:
        plt.xticks(range(0, int(support[0]/units) + 1, 5))
        plt.yticks(range(0, int(support[1]/units) + 1, 5))
        plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
    
    return fig, ax 


spatial_support = [spatial_resolution[i]*data_sample.grid_shape[i] for i in range(3)]
fig, ax = base_plot(amp, spatial_support, "Amp", units=data_sample.unit, cmap="gray", vmin=None, vmax=None)
fig, ax = base_plot(phase, spatial_support, "Phase", units=1e-6, cmap="gray")