In [1]:
from data.dataset import GWF_HP_Dataset
from data.dataloader import DataLoader
from data.transforms import NormalizeTransform, ComposeTransform, ReduceTo2DTransform, PowerOfTwoTransform
from visualization.visualize_data import plot_datapoint
from data.utils_save import save_pickle
from networks.unet_leiterrl import TurbNetG, UNet, weights_init

import os
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import h5py
from tqdm.auto import tqdm

COLOR = 'white'
mpl.rcParams['text.color'] = COLOR
mpl.rcParams['axes.labelcolor'] = COLOR
mpl.rcParams['xtick.color'] = COLOR
mpl.rcParams['ytick.color'] = COLOR

%load_ext autoreload
%autoreload 2

In [2]:
# TODO not necessary to cut of edges - exponential behaviour at edges ?! fix it otherwise!
#transforms = ComposeTransform([NormalizeTransform(), CutOffEdgesTransform()])

# DATASET = GWF_HP_Dataset(dataset_name ="dataset_HDF5_testtest", transform = NormalizeTransform(), 
#                  input_vars=["Liquid Y-Velocity [m_per_y]", "Liquid Z-Velocity [m_per_y]", 
#                  "Liquid_Pressure [Pa]", "Material_ID", "Temperature [C]"],
#                  output_vars=["Liquid_Pressure [Pa]", "Temperature [C]"])
# DATALOADER = DataLoader(DATASET, batch_size=2, shuffle=True, drop_last=False)
# 
# print(f"Dataset size: {len(DATASET)}")
# print(f"Dataloader size: {len(DATALOADER)}")
# 
# save_pickle({"dataset": DATASET, "dataloader" : DATALOADER}, "dataset_HDF5_testtest_and_dataloader.p")
# 
# DATASET[0]['x'][0,:,:,:].shape
# plot_datapoint(DATASET, run_id=1, view="side_hp")

# TODO Data Augmentation

In [3]:
# TODO:  data augmentation

# Splitting Data

In [4]:
# # split dataset into train, val, test
# datasets = {}
# for mode in ['train', 'val', 'test']:
#     temp_dataset = GWF_HP_Dataset(
#         dataset_name ="dataset_HDF5_testtest", transform = NormalizeTransform(), 
#         input_vars=["Liquid Y-Velocity [m_per_y]", "Liquid Z-Velocity [m_per_y]", 
#         "Liquid_Pressure [Pa]", "Material_ID", "Temperature [C]"],
#         output_vars=["Liquid_Pressure [Pa]", "Temperature [C]"],
#         mode=mode, split={'train': 0.6, 'val': 0.2, 'test': 0.2}
#     )
#     datasets[mode] = temp_dataset
# 
# # Create a dataloader for each split.
# dataloaders = {}
# for mode in ['train', 'val', 'test']:
#     temp_dataloader = DataLoader(
#         dataset=datasets[mode],
#         batch_size=2,
#         shuffle=True,
#         drop_last=False,
#     )
#     dataloaders[mode] = temp_dataloader

## test splitting data

In [5]:
# print(datasets["train"].runs)
# print(datasets["train"][0].keys())
# print(len(dataloaders["train"]))
# 
# for batch in dataloaders["train"]:
#     print(batch['x'].shape)
#     print(batch['y'].shape)
#     break

## test visualization of data

In [6]:
# plot_datapoint(datasets["train"], run_id=0, view="side_hp")


# simplest test NN (linear)

# test TurbNetG (from Rapha, from somebody else) 2D testcase on my data

In [7]:
# split dataset into train, val, test

def init_data(reduce_to_2D = True):
    
    datasets = {}
    if reduce_to_2D:
        transforms = ComposeTransform([NormalizeTransform(), PowerOfTwoTransform(), ReduceTo2DTransform()])
    else:
        transforms = ComposeTransform([NormalizeTransform()]) # PowerOfTwoTransform()

    for mode in ['train', 'val', 'test']:
        temp_dataset = GWF_HP_Dataset(
            dataset_name ="dataset_HDF5_testtest", transform = transforms,
            input_vars=["Liquid Y-Velocity [m_per_y]", "Liquid Z-Velocity [m_per_y]", 
            "Liquid_Pressure [Pa]", "Material_ID", "Temperature [C]"],
            output_vars=["Liquid_Pressure [Pa]", "Temperature [C]"],
            mode=mode, split={'train': 0.6, 'val': 0.2, 'test': 0.2}
        )
        datasets[mode] = temp_dataset

    # Create a dataloader for each split.
    dataloaders = {}
    for mode in ['train', 'val', 'test']:
        temp_dataloader = DataLoader(
            dataset=datasets[mode],
            batch_size=2,
            shuffle=True,
            drop_last=False,
        )
        dataloaders[mode] = temp_dataloader

    # Assert if data is not 2D
    def assertion_error_2d(datasets):
        for dataset in datasets["train"]:
            shape_data = len(dataset['x'].shape)
            break
        assert shape_data == 3, "Data is not 2D"

    assertion_error_2d(datasets)

    return datasets, dataloaders

datasets_2D, dataloaders_2D = init_data(reduce_to_2D=True)

Directory of currently used dataset is: /home/pelzerja/Development/simulation_groundtruth_pflotran/Phd_simulation_groundtruth/approach2_dataset_generation_simplified/dataset_HDF5_testtest
Directory of currently used dataset is: /home/pelzerja/Development/simulation_groundtruth_pflotran/Phd_simulation_groundtruth/approach2_dataset_generation_simplified/dataset_HDF5_testtest
Directory of currently used dataset is: /home/pelzerja/Development/simulation_groundtruth_pflotran/Phd_simulation_groundtruth/approach2_dataset_generation_simplified/dataset_HDF5_testtest


In [10]:
from torch.optim import Adam
from torch.nn import MSELoss
from torch.utils.tensorboard import SummaryWriter

In [12]:
loss_fn = MSELoss()
n_epochs = 100 #60000

#model = TurbNetG(channelExponent=4, in_channels=5, out_channels=2)
model = UNet(in_channels=5, out_channels=2)
optimizer = Adam(model.parameters(), lr=0.0001) #0.0004
# model.to(device)
model.apply(weights_init)
loss_hist = []
epochs = tqdm(range(n_epochs), desc = "epochs")
for epoch in epochs:
    for batch_idx, data_point in enumerate(dataloaders_2D["train"]):
        # TODO in welchem Format das input und target angegeben werden muss
        x = data_point["x"]
        y = data_point["y"]

        model.zero_grad()
        optimizer.zero_grad()

        y_out = model(x)
        mse_loss = loss_fn(y_out, y)
        loss = mse_loss
        
        loss.backward()
        optimizer.step()
        epochs.set_postfix_str(f"loss: {loss.item():.4f}")

        loss_hist.append(loss.item())

epochs: 100%|██████████| 100/100 [00:23<00:00,  4.25it/s, loss: 1.7626]
