In [7]:
from holodecml.data import num_particles_dict, load_raw_datasets, scale_images, unet_bin_xy
from holodecml.losses import unet_loss, unet_loss_xy
from holodecml.models import custom_unet, custom_jnet, custom_jnet_full
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import xarray as xr
import os
from os.path import join
import yaml
from tensorflow.keras.optimizers import Adam


In [2]:
path_data = "/glade/p/cisl/aiml/ai4ess_hackathon/holodec/"
num_particles = [1,2,3,4,5,6,7,8,9,10,"12-25"]
output_cols = ["x", "y", "hid"]
subset = [500,500,500,500,500,500,500,500,500,500,500*13]
scaler_out = MinMaxScaler()
bin_factor = 10
h = 0


In [8]:
train_inputs_list = []
train_outputs_list = []
valid_inputs_list = []
valid_outputs_list = []
for num,sub in zip(num_particles, subset):
    print(num,sub)
    train_inputs,\
    train_outputs = load_raw_datasets(path_data, num, 'train',
                                      output_cols, sub)
    valid_inputs,\
    valid_outputs = load_raw_datasets(path_data, num, 'valid',
                                      output_cols, sub//10)
    print(f"num_particle: {num}")
    print(f"train shapes: input - {train_inputs.shape}\t output - {train_outputs.shape}")
    print(f"valid shapes: input - {valid_inputs.shape}\t output - {valid_outputs.shape}")
    
    train_inputs, scaler_in = scale_images(train_inputs)
    valid_inputs, _ = scale_images(valid_inputs, scaler_in)

    train_outputs = unet_bin_xy(train_inputs, train_outputs, bin_factor)
    valid_outputs = unet_bin_xy(valid_inputs, valid_outputs, bin_factor)
        
    train_inputs_list.append(train_inputs)
    train_outputs_list.append(train_outputs)
    valid_inputs_list.append(valid_inputs)
    valid_outputs_list.append(valid_outputs)


1 500
num_particle: 1
train shapes: input - (500, 600, 400)	 output - (500, 3)
valid shapes: input - (50, 600, 400)	 output - (50, 3)
2 500
num_particle: 2
train shapes: input - (500, 600, 400)	 output - (1000, 3)
valid shapes: input - (50, 600, 400)	 output - (100, 3)
3 500
num_particle: 3
train shapes: input - (500, 600, 400)	 output - (1500, 3)
valid shapes: input - (50, 600, 400)	 output - (150, 3)
4 500
num_particle: 4
train shapes: input - (500, 600, 400)	 output - (2000, 3)
valid shapes: input - (50, 600, 400)	 output - (200, 3)
5 500
num_particle: 5
train shapes: input - (500, 600, 400)	 output - (2500, 3)
valid shapes: input - (50, 600, 400)	 output - (250, 3)
6 500
num_particle: 6
train shapes: input - (500, 600, 400)	 output - (3000, 3)
valid shapes: input - (50, 600, 400)	 output - (300, 3)
7 500
num_particle: 7
train shapes: input - (500, 600, 400)	 output - (3500, 3)
valid shapes: input - (50, 600, 400)	 output - (350, 3)
8 500
num_particle: 8
train shapes: input - (500, 

In [10]:
train_inputs = np.vstack(train_inputs_list)
train_outputs = np.vstack(train_outputs_list)
valid_inputs = np.vstack(valid_inputs_list)
valid_outputs = np.vstack(valid_outputs_list)


In [None]:
valid_outputs

In [None]:
train_inputs.shape

In [5]:
def load_jnet_datasets_xy_1to25(path_data, num_particles, output_cols,
                                subset=False, bin_factor=False, input_col="image"):

    train_inputs_list = []
    train_outputs_list = []
    valid_inputs_list = []
    valid_outputs_list = []
    for num,sub in zip(num_particles, subset):
        train_inputs,\
        train_outputs = load_raw_datasets(path_data, num, 'train',
                                          output_cols, sub)
        valid_inputs,\
        valid_outputs = load_raw_datasets(path_data, num, 'valid',
                                          output_cols, sub//10)
        train_inputs_list.append(train_inputs)
        train_outputs_list.append(train_outputs)
        valid_inputs_list.append(valid_inputs)
        valid_outputs_list.append(valid_outputs)

    train_inputs = np.vstack(train_inputs_list)
    train_outputs = np.vstack(train_outputs_list)
    valid_inputs = np.vstack(valid_inputs_list)
    valid_outputs = np.vstack(valid_outputs_list)
    
    return train_inputs, train_outputs, valid_inputs, valid_outputs