In [None]:
# Prototype a ResNet
import numpy as np
from nn_models import simple_dnn
import vaex
import vaex.ml
from abundances import *
from typing import Dict, List


# Input: abundances, gas density, temperature (10 quantities for chem1)
# Output: equilibrium number densities (8 quantities for chem1)
"""
TODO
  - Write architectures for
    - ResNet
  - Check normal ML algs
    - SVMs
    - Decision Trees/Random Forests
  - Functions to standardise input
"""

"""
Pipeline overview:
  - Read in .parquet files
  - Generate dataset (extract density, temperature, EQ number densities)
  - Determine and calculate abundances based on keys
  - Sort input abundance array and output EQ array alphabetically
  - Pass into model with goal:
    - from density, temperature and abundance, map to output EQ number densities
  - Loss function analysis
"""

def load_parquet(directory: str, pattern="chem1"):
  return vaex.open(f"{directory}/*{pattern}*.parquet")


def abundance_dict_to_array(abu_dict: Dict, species: List):
  # Get abundances from dictionary and return an array indexed as species
  abu_arr = np.zeros(shape=len(species))
  for i, s in enumerate(species):
    abu_arr[i] = abu_dict[s]

  return abu_arr

def scatter_abu(abu_dict: Dict, scatter_threshold=0.1,
                scatter_species=["H2", "CO", "CH", "OH"]):
  # Scatter abundances randomly around their initial values by up to
  # 'scatter_threshold'
  # Changes 'abu_dict'!
  for s in abu_dict.keys():
    offset = abu_dict[s] * scatter_threshold
    low, high = abu_dict[s] - offset, abu_dict[s] + offset
    abu_dict[s] = np.random.uniform(low, high)

  return abu_dict

def create_dataset(df):
  # Takes vaex DataFrame and formats it to be a dataset for TF inputs
  eq_labels = [l for l in list(df.columns) if l.endswith("_EQ")]
  input_labels = [*[l.replace("_EQ", "") for l in eq_labels], "density", "temperature"]
  output_labels = eq_labels
