In [None]:
import numpy as np
import matplotlib.pyplot as plt 
import h5py
from mayavi import mlab
import os

# torch stuff
import torch
from torch.utils.data import Dataset, DataLoader

# for visualisation loop:
from matplotlib.animation import ArtistAnimation
from matplotlib import colors
from IPython.display import HTML

plt.rcParams["image.origin"] = "lower" 
plt.rcParams["image.cmap"] = "viridis"

# import pyvista as pv
# mlab.init_notebook()

from mpl_toolkits.mplot3d import Axes3D

DataPath = os.path.abspath("").replace("Summer-Sandbox23/ptpg", "NbodySimulation/gevolution-1.2/output/")
newtonPath = DataPath+"newton/"
grPath = DataPath+"gr/"


In [None]:
# List files in output
newtonFiles = os.listdir(newtonPath)
grFiles = os.listdir(grPath)

print(newtonFiles)
print(grFiles)



In [None]:
# Extract datasets with .h5 formats
def Extracth5Specifics(filename:str) -> np.ndarray:
    h5File = h5py.File(filename, "r")
    dataset = h5File["data"][()]
    h5File.close()
    return dataset

def Extracth5Data(abspath:str) -> np.ndarray or list[np.ndarray]:
    Files = os.listdir(abspath)
    h5Files = [abspath+name for name in Files if ".h5" in name]
    if len(h5Files) > 1:
        datasets = []
        for filename in h5Files:
            datasets.append(Extracth5Specifics(filename))
        return datasets
    else:
        return Extracth5Specifics(h5Files[0])

newtonCube = Extracth5Data(newtonPath)
grCube = Extracth5Data(grPath)
print(f"Dimension of newton cube: {np.shape(newtonCube)}")
print(f"Dimension of gr cube: {np.shape(grCube)}")


In [None]:
def SliceDataCube(data:np.ndarray, axis:int, index:int or list or tuple) -> np.ndarray:
    slices = [slice(None)] * data.ndim
    if isinstance(index, int):
        slices[axis] = index
    else:
        slices[axis] = slice(index[0], index[1])
    return data[tuple(slices)]

In [None]:
print(np.shape(SliceDataCube(newtonCube, axis=1, index=1)))

In [None]:
def VisualiseSlice(cube:np.ndarray, name:str=None, axis:int=None, difference:bool=False):
    fig, ax = plt.subplots(figsize=(7,7))
    ims = []
    axis = axis if axis is not None else 0
    for i in range(cube.shape[axis]):
        image = SliceDataCube(cube, axis, i)
        temp_min = np.percentile(image, 1)
        temp_max = np.percentile(image, 99)
        if not difference:
            temp_norm = colors.TwoSlopeNorm(vmin=temp_min, vcenter=0, vmax=temp_max)
            temp = [ax.imshow(image, norm=temp_norm)]
        else:
            temp = [ax.imshow(image, vmin=temp_min, vmax=temp_max)]
        ims.append(temp)
    # name = f"{cube=}".split("=")[0]
    ax.set_title(f"Visualising {name if name is not None else ''} across axis: {axis}")
    anim = ArtistAnimation(fig, ims, interval=50, blit=True)
    plt.close(fig)
    return anim

def VisualiseDifference(cube1:np.ndarray, cube2:np.ndarray, axis:int=None, absolute:bool=True):
    cube = np.abs(cube1-cube2) if absolute else cube1-cube2
    return VisualiseSlice(cube, name="difference", axis=axis, difference=True)


In [None]:
newtonAnim = VisualiseSlice(grCube, name="gr", axis=0)
HTML(newtonAnim.to_jshtml())


In [None]:
diffAnim = VisualiseDifference(newtonCube, grCube, axis=0)
HTML(diffAnim.to_jshtml())

In [None]:
# Custom class

class TestCubes(Dataset):
    def __init__(self, newtonCube, grCube, stride=1):
        self.newtonCube = newtonCube
        self.grCube = grCube
        self.length = self.__len__()
        self.halflength = int(self.length/2.)
        self.stride=stride

    def __len__(self):
        newtonShape = self.newtonCube.shape
        grShape = self.grCube.shape
        newtonLength = 0
        grLength = 0
        for i in range(len(newtonShape)):
            newtonLength += newtonShape[i]
            grLength += grShape[i]
        return int((newtonLength+grLength)/self.stride)

    def getSlice(self, data:np.ndarray, axis:int, index:int or list or tuple):
        slices = [slice(None)] * data.ndim
        if isinstance(index, int):
            slices[axis] = index
        else:
            slices[axis] = slice(index[0], index[1])
        return data[tuple(slices)]

    def __getitem__(self, idx):
        NEWTON = idx < self.halflength
        idx = idx-self.halflength
        axis = idx // self.newtonCube.shape[0] #only works for cubic cubes and stride 1 for now, should be easy to improve
        index = idx % self.newtonCube.shape[0]
        if NEWTON:
            slice_data = self.getSlice(self, self.newtonCube, axis=axis, index=index)
            label = "newton"
        else:
            slice_data = self.getSlice(self, self.grCube, axis=axis, index=index)
            label = "gr"
        return torch.tensor(slice_data), label



In [None]:
# batch_size = 1
# train_dataloader = DataLoader()