In [None]:
import numpy as np
import matplotlib.pyplot as plt 
import h5py
from mayavi import mlab
import os
import matplotlib.cm as cm

# torch stuff
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as ts
generator1 = torch.Generator().manual_seed(42)

# for visualisation loop:
from matplotlib import colors
from matplotlib.animation import ArtistAnimation
from IPython.display import HTML

plt.rcParams["image.origin"] = "lower"
plt.rcParams["image.cmap"] = "viridis"

# import pyvista as pv
# mlab.init_notebook()

from mpl_toolkits.mplot3d import Axes3D

DataPath = os.path.abspath("").replace("Summer-Sandbox23/ptpg", "NbodySimulation/gevolution-1.2/output/")
newtonPath = DataPath+"newton/"
grPath = DataPath+"gr/"


In [None]:
# List files in output
newtonFiles = os.listdir(newtonPath)
grFiles = os.listdir(grPath)

print(newtonFiles)
print(grFiles)



In [None]:
# Extract datasets with .h5 formats
def Extracth5Specifics(filename:str) -> np.ndarray:
    h5File = h5py.File(filename, "r")
    dataset = h5File["data"][()]
    h5File.close()
    return dataset

def Extracth5Data(abspath:str) -> np.ndarray or list[np.ndarray]:
    Files = os.listdir(abspath)
    h5Files = [abspath+name for name in Files if ".h5" in name]
    if len(h5Files) > 1:
        datasets = []
        for filename in h5Files:
            datasets.append(Extracth5Specifics(filename))
        return datasets
    else:
        return Extracth5Specifics(h5Files[0])

newtonCube = Extracth5Data(newtonPath)
grCube = Extracth5Data(grPath)
grCube = grCube + 1000000 * grCube
print(f"Dimension of newton cube: {np.shape(newtonCube)}")
print(f"Dimension of gr cube: {np.shape(grCube)}")


In [None]:
def SliceDataCube(data:np.ndarray, axis:int, index:int or list or tuple) -> np.ndarray:
    slices = [slice(None)] * data.ndim
    if isinstance(index, int):
        slices[axis] = index
    else:
        slices[axis] = slice(index[0], index[1])
    return data[tuple(slices)]

In [None]:
print(np.shape(SliceDataCube(newtonCube, axis=1, index=1)))

In [None]:
def VisualiseSlice(cube:np.ndarray, name:str=None, axis:int=None, difference:bool=False):
    fig, ax = plt.subplots(figsize=(7.5,7.5))
    ims = []
    axis = axis if axis is not None else 0
    for i in range(cube.shape[axis]):
        image = SliceDataCube(cube, axis, i)
        temp_min = np.percentile(image, 1)
        temp_max = np.percentile(image, 99)
        # if not difference:
        #     temp_norm = colors.TwoSlopeNorm(vmin=temp_min, vcenter=0, vmax=temp_max)
        #     temp = [ax.imshow(image, norm=temp_norm)]
        # else:
        temp = [ax.imshow(image, vmin=temp_min, vmax=temp_max)]
        ims.append(temp)
    # name = f"{cube=}".split("=")[0]
    ax.set_title(f"Visualising {name if name is not None else ''} across axis: {axis}")
    norm = colors.Normalize(vmin=temp_min, vmax=temp_max)
    sm = cm.ScalarMappable(cmap='viridis', norm=norm)
    sm.set_array([])
    fig.colorbar(sm, ax=ax)
    fig.tight_layout()
    anim = ArtistAnimation(fig, ims, interval=50, blit=True)
    plt.close(fig)
    return anim

def VisualiseDifference(cube1:np.ndarray, cube2:np.ndarray, axis:int=None, absolute:bool=True):
    cube = np.abs(cube1-cube2) if absolute else cube1-cube2
    return VisualiseSlice(cube, name="difference", axis=axis, difference=True)


In [None]:
newtonAnim = VisualiseSlice(newtonCube, name="gr", axis=2)
HTML(newtonAnim.to_jshtml())


In [None]:
diffAnim = VisualiseDifference(newtonCube, grCube, axis=0)
HTML(diffAnim.to_jshtml())

In [None]:
# Custom class

class TestCubes(Dataset):
    def __init__(self, newtonCube, grCube, stride=1, transform=None, additionalInfo=False):
        self.newtonCube = newtonCube
        self.grCube = grCube
        self.stride=stride
        self.length = self.__len__()
        self.halflength = int(self.length/2.)
        self.transform = transform
        self.additionalInfo = additionalInfo

    def __len__(self):
        newtonShape = self.newtonCube.shape
        grShape = self.grCube.shape
        newtonLength = 0
        grLength = 0
        for i in range(len(newtonShape)):
            newtonLength += newtonShape[i]
            grLength += grShape[i]
        return int((newtonLength+grLength)/self.stride)

    def _getSlice(self, data:np.ndarray, axis:int, index:int or list or tuple):
        slices = [slice(None)] * data.ndim
        if isinstance(index, int):
            slices[axis] = index
        else:
            slices[axis] = slice(index[0], index[1])
        return data[tuple(slices)]

    def __getitem__(self, idx):
        NEWTON = idx < self.halflength
        if not NEWTON:
            idx = idx-self.halflength
        axis = idx // self.newtonCube.shape[0] #only works for cubic cubes and stride 1 for now, should be easy to improve
        index = idx % self.newtonCube.shape[0]
        if NEWTON:
            slice_data = self._getSlice(self.newtonCube, axis, index)
            label = torch.tensor([0.0], dtype=torch.float32)
        else:
            slice_data = self._getSlice(self.grCube, axis, index)
            label = torch.tensor([1.0], dtype=torch.float32)
        if self.stride != 1:
            sample = {"image": torch.tensor(slice_data, dtype=torch.float32), "label": label}
        else:
            sample = {"image": torch.tensor(slice_data, dtype=torch.float32).unsqueeze(0), "label": label}

        if self.additionalInfo:
            sample["axis"] = axis
            sample["index"] = index
        if self.transform:
            # toBeNormalized = sample["image"]
            # Normalized = self.transform(toBeNormalized)
            # sample["image"] = Normalized
            sample["image"] = (sample["image"]-torch.mean(sample["image"]))/torch.std(sample["image"])

        return sample
    
    def __str__(self, idx=None):
        returnString = "Dataset info:\n----------------------\n"
        returnString += f"  Newton cube size: {self.newtonCube.shape}\n"
        returnString += f"  GR cube size: {self.grCube.shape}\n"
        returnString += f"  Stride: {self.stride}\n"
        return returnString

    def printImage(self, idx):
        returnString = ""
        sample = self.__getitem__(idx)
        image = sample["image"]
        returnString += f"Image info (Newton:0, GR:1):\n"
        for key, val in sample.items():
            if key != "image":
                returnString += f"  {key}: {val}\n"
        returnString += "Basic statistics:\n"
        basicStat = {
            "mean": torch.mean(image),
            "min": torch.min(image),
            "max": torch.max(image),
            "std": torch.std(image),
            "median": torch.median(image)
        }
        for key, val in basicStat.items():
            returnString += f"  {key}: {val}\n" 
        returnString += "\n"
        return returnString



In [None]:
L = TestCubes(newtonCube, grCube, transform=False, additionalInfo=True)
print(L.printImage(64*3))

In [None]:
[train, val, test] = random_split(L, [0.7,0.2,0.1], generator=generator1)

train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)
test_loader = DataLoader(test, batch_size=32)

In [None]:
# Convolutional Neural network

class SkeletonCNN(nn.Module):
    def __init__(self):
        super(SkeletonCNN, self).__init__()
        ### LAYER 1 (Convolutional) ### (1, 64, 64) -> (64, 64, 64)
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3,3), stride=1, padding=1),       # (1, 64, 64) -> (64, 64, 64)
            nn.ReLU(),                                                      # -
            nn.Dropout(0.25),                                                # -
        )
        
        ### LAYER 2 (Convolutional) ### (64, 64, 64) -> (32, 16, 16)
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=(3,3), stride=1, padding=1),      # (64, 64, 64) -> (32, 64, 64)
            nn.ReLU(),                                                      # - 
            nn.MaxPool2d(kernel_size=(4,4)),                                # (32, 64, 64) -> (32, 16, 16)
        )

        ### FLATTENING ###  (32, 16, 16) -> (8192)
        # self.flat = nn.Flatten()

        ### LAYER 3 (Fully connected) ###   (8192) -> (256)
        self.layer3 = nn.Sequential(
            nn.Flatten(),                                                   # (32, 16, 16) -> (8192)
            nn.Linear(int(32*16*16), 256),                                  # (8192) -> (256)
            nn.ReLU(),                                                      # -
            nn.Dropout(0.25),                                               # -
        )

        ### LAYER 4 (Fully connected) ###
        self.layer4 = nn.Sequential(
            nn.Linear(256, 16),                                             # (256) -> (16)
            nn.ReLU(),                                                      # -
            nn.Dropout(0.25),                                               # -
        )

        ### LAYER 5 (Output) ###
        self.output = nn.Sequential(
            nn.Linear(16, 1),                                               # (16) -> (2)
            nn.Sigmoid(),
        )

        # List of layers
        self.layers = [self.layer1, self.layer2, self.layer3, self.layer4]
    
    def forward(self, X):
        for layer in self.layers:
            X = layer(X)
        return self.output(X)
    
    def printSummary(self, input:tuple=(1,64,64)):
        return summary(self, input_size=input)

from torchsummary import summary



In [None]:
model = SkeletonCNN()
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
model.printSummary()

In [None]:


n_epochs = 5
for epoch in range(n_epochs):
    for object in train_loader:
        inputs = object["image"]
        labels = object["label"]
        y_pred = model(inputs)
        loss = loss_fn(y_pred, labels)
        # print(y_pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    acc = 0
    count = 0
    tol = 1e-7
    for object in val_loader:
        inputs = object["image"]
        labels = object["label"]
        y_pred = model(inputs)
        acc += (torch.abs(torch.round(y_pred) - labels)<tol).float().sum()
        count += len(labels)
    acc /= count
    # print(f"Epoch: {epoch}, Loss: {loss.item():.4f}, Acc: {acc*100:.2f} %")
    print(f"Epoch: [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}, Acc: {acc*100:.2f} %")

    


In [None]:
# Testing: 
correct = 0
count = 0
tol = 1e-7
for object in test_loader:
    inputs = object["image"]
    labels = object["label"]
    y_pred = model(inputs)
    correct += (torch.abs(torch.round(y_pred) - labels)<tol).float().sum()
    count += len(labels)
acc = correct / count

print(f"Correct: [{correct} / {count}] giving accuracy of {acc*100:.2f} %")
