In [1]:
import sherpa

# help function
from transfer_learning import NeuralNet_sherpa_optimize
from dataset_loader import (
    data_loader,
    all_filter,
    get_descriptors,
    one_filter,
    data_scaler,
)

# modules
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

import os, sys
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error

from tqdm import tqdm
from scipy.stats import pearsonr

import matplotlib.pyplot as plt

parameters = [
    sherpa.Continuous(name="lr", range=[0.0002, 0.1], scale="log"),
    # sherpa.Discrete(name='Epoch', range=[10,100]),
    sherpa.Discrete(name="H_l1", range=[10, 300]),
    sherpa.Choice(
        name="activate",
        range=["nn.Hardswish", "nn.PReLU", "nn.ReLU", "nn.Sigmoid", "nn.LeakyReLU"],
    ),
]
algorithm = sherpa.algorithms.RandomSearch(max_num_trials=10)
study = sherpa.Study(
    parameters=parameters,
    algorithm=algorithm,
    lower_is_better=False,
    disable_dashboard=True,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_path = os.getcwd()
file_name = "data/CrystGrowthDesign_SI.csv"

"""
Data description.

    Descriptors:
        'void fraction', 'Vol. S.A.', 'Grav. S.A.', 'Pore diameter Limiting', 'Pore diameter Largest'
    Source task:
        'H2@100 bar/243K (wt%)'
    Target tasks:
        'H2@100 bar/130K (wt%)' 'CH4@100 bar/298 K (mg/g)' '5 bar Xe mol/kg' '5 bar Kr mol/kg'
"""

descriptor_columns = [
    "void fraction",
    "Vol. S.A.",
    "Grav. S.A.",
    "Pore diameter Limiting",
    "Pore diameter Largest",
]
one_filter_columns = ["H2@100 bar/243K (wt%)"]
another_filter_columns = ["H2@100 bar/130K (wt%)"]

# load data
data = data_loader(base_path, file_name)

# extract descriptors and gas adsorptions
one_property = one_filter(data, one_filter_columns)
descriptors = get_descriptors(data, descriptor_columns)

# prepare training inputs and outputs
X = np.array(descriptors.values, dtype=np.float32)
y = np.array(one_property.values, dtype=np.float32).reshape(len(X),)
X = data_scaler(X)
y = data_scaler(y.reshape(-1, 1)).reshape(len(X),)

# makes transfer trials... more of a legacy code ---- function cannot be pulled out of .py bc of data dependencies
data_small = data.sample(n=100, random_state=1)

another_property = one_filter(data_small, another_filter_columns)
descriptors_small = get_descriptors(data_small, descriptor_columns)

X_small = np.array(descriptors_small.values, dtype=np.float32)
y_small = np.array(another_property.values, dtype=np.float32).reshape(
    len(X_small),
)

In [7]:
from Statistics_helper import stratified_cluster_sample
for i in range(10):
    t_1,t_2,y_1,y_2=stratified_cluster_sample(1,data,descriptor_columns,one_filter_columns[0],5)
    print(y_2.iloc[0])

1660     8.2
9662     4.2
1574     6.4
11429    6.9
66       NaN
4134     NaN
4985     NaN
8920     NaN
236      NaN
2776     NaN
7913     NaN
8980     NaN
1333     NaN
10494    NaN
10649    NaN
12198    NaN
4451     NaN
8014     NaN
8207     NaN
9327     NaN
Name: H2@100 bar/243K (wt%), dtype: float64
10649    16.9
12198     8.6
1333      4.6
10494    11.9
66        NaN
4134      NaN
4985      NaN
8920      NaN
236       NaN
2776      NaN
7913      NaN
8980      NaN
4451      NaN
8014      NaN
8207      NaN
9327      NaN
1574      NaN
1660      NaN
9662      NaN
11429     NaN
Name: H2@100 bar/243K (wt%), dtype: float64
1660     8.2
9662     4.2
1574     6.4
11429    6.9
66       NaN
4134     NaN
4985     NaN
8920     NaN
236      NaN
2776     NaN
7913     NaN
8980     NaN
1333     NaN
10494    NaN
10649    NaN
12198    NaN
4451     NaN
8014     NaN
8207     NaN
9327     NaN
Name: H2@100 bar/243K (wt%), dtype: float64
1660     8.2
9662     4.2
1574     6.4
11429    6.9
66       NaN
413

In [12]:
y_1

Unnamed: 0,10799,13338,1188,10395,10540,11616,10124,527,13439,9988,...,6182,6594,6654,6939,9606,11330,11431,11745,12233,12666
H2@100 bar/243K (wt%),16.8,6.9,6.6,14.2,10.8,5.5,10.7,11.8,6.1,10.4,...,,,,,,,,,,
H2@100 bar/243K (wt%),,,,,,,,,,,...,,,,,,,,,,
H2@100 bar/243K (wt%),,,,,,,,,,,...,,,,,,,,,,
H2@100 bar/243K (wt%),,,,,,,,,,,...,,,,,,,,,,
H2@100 bar/243K (wt%),,,,,,,,,,,...,7.6,7.4,9.3,6.1,5.1,4.8,6.4,4.7,8.0,4.1
