In [70]:
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


We can add context in two ways to help improve the model performance:

1. Train on multiple small datasets at the same time (e.g. Xe, Kr, CH4, N2, H2O Henry coefficients) instead of trainin only on one dataset.

2. Train with additional information about the gases and test if this can be used by the model to "extrapolate" to unseen guests

All available outputs are

- "outputs.pbe.bandgap",
- "outputs.Xe-henry_coefficient-mol--kg--Pa",
- "outputs.Kr-henry_coefficient-mol--kg--Pa",
- "outputs.H2O-henry_coefficient-mol--kg--Pa",
- "outputs.H2S-henry_coefficient-mol--kg--Pa",
- "outputs.CO2-henry_coefficient-mol--kg--Pa",
- "outputs.CH4-henry_coefficient-mol--kg--Pa",
- "outputs.O2-henry_coefficient-mol--kg--Pa",


In [71]:
from gpt3forchem.data import get_mof_data, discretize
from sklearn.model_selection import train_test_split
from gpt3forchem.input import (
    create_single_property_forward_prompts,
    create_single_property_forward_prompts_multiple_targets,
)
from gpt3forchem.api_wrappers import extract_prediction, fine_tune, query_gpt3
from collections import Counter
import numpy as np 
from pycm import ConfusionMatrix
import time

In [3]:
df = get_mof_data()


  return HashableDataFrame(pd.read_csv(os.path.join(datadir, "mof.csv")))


In [25]:
df = df.dropna(subset=["outputs.H2O-henry_coefficient-mol--kg--Pa"])


Let's get the logs of the Henry coefficients.


In [30]:
features = [
    "outputs.Xe-henry_coefficient-mol--kg--Pa",
    "outputs.Kr-henry_coefficient-mol--kg--Pa",
    "outputs.H2O-henry_coefficient-mol--kg--Pa",
    "outputs.H2S-henry_coefficient-mol--kg--Pa",
    "outputs.CO2-henry_coefficient-mol--kg--Pa",
    "outputs.CH4-henry_coefficient-mol--kg--Pa",
    "outputs.O2-henry_coefficient-mol--kg--Pa",
]


In [36]:
for feature in features:
    df[feature + '_log'] = np.log10(df[feature] + 1e-40)

Since the datasets are relatively small, we will only work with three bins.

In [69]:
for feature in features:

    discretize(
        df, f"{feature}_log", n_bins=3, labels=["low", "medium", "high"]
    )


In [42]:
Counter(df["outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat"])


Counter({'medium': 47, 'low': 99, 'high': 7})

In [73]:
train_df, test_df = train_test_split(
    df, train_size=0.8, stratify=df["outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat"]
)


# 1. Train on multiple small datasets at the same time


Let's take H2O Henry coefficient as our target. We first train a model just on it and then train a model on it and a bunch of Henry coefficients of other molecules (but we still only test on H2O).


In [74]:
train_prompts = create_single_property_forward_prompts(train_df, "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat", {'outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat': 'H2O Henry coefficient'}, representation_col="info.mofid.mofid_clean", encode_value=False)

In [75]:
test_prompts = create_single_property_forward_prompts(
    test_df,
    "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat",
    {"outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat": "H2O Henry coefficient"},
    representation_col="info.mofid.mofid_clean",
    encode_value=False,
)


In [76]:
filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
train_filename = f"run_files/{filename_base}_train_prompts_mof_h2o.jsonl"
valid_filename = f"run_files/{filename_base}_valid_prompts_mof_h2o.jsonl"

train_prompts.to_json(train_filename, orient="records", lines=True)
test_prompts.to_json(valid_filename, orient="records", lines=True)

In [78]:
fine_tune(train_filename, valid_filename)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

'ada:ft-lsmoepfl-2022-09-13-20-55-28'

In [79]:
completions = query_gpt3('ada:ft-lsmoepfl-2022-09-13-20-55-28', test_prompts)

In [80]:
predictions = [
    extract_prediction(completions, i) for i in range(len(completions['choices']))
]

In [81]:
true = test_prompts['completion'].apply(lambda x: x.split('@')[0].strip())

In [82]:
cm = ConfusionMatrix(actual_vector=true.to_list(), predict_vector=predictions)

In [84]:
print(cm)

Predict      high         low          medium       
Actual
high         0            0            1            

low          0            19           1            

medium       0            9            1            





Overall Statistics : 

95% CI                                                            (0.47673,0.81359)
ACC Macro                                                         0.76344
ARI                                                               0.12517
AUNP                                                              0.54329
AUNU                                                              0.52276
Bangdiwala B                                                      0.61356
Bennett S                                                         0.46774
CBA                                                               0.25952
CSI                                                               None
Chi-Squared                                                       None
Chi-Squ

Now, let's use the same train/test split - but add additional outputs.

In [85]:
train_prompts_aug = create_single_property_forward_prompts_multiple_targets(
    train_df,
    [
        "outputs.Xe-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.Kr-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.H2S-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.CO2-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.CH4-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.O2-henry_coefficient-mol--kg--Pa_log_cat",
    ],
    {
        "outputs.Xe-henry_coefficient-mol--kg--Pa_log_cat": "Xe Henry coefficient",
        "outputs.Kr-henry_coefficient-mol--kg--Pa_log_cat": "Kr Henry coefficient",
        "outputs.H2S-henry_coefficient-mol--kg--Pa_log_cat": "H2S Henry coefficient",
        "outputs.CO2-henry_coefficient-mol--kg--Pa_log_cat": "CO2 Henry coefficient",
        "outputs.CH4-henry_coefficient-mol--kg--Pa_log_cat": "CH4 Henry coefficient",
        "outputs.O2-henry_coefficient-mol--kg--Pa_log_cat": "O2 Henry coefficient",
    },
    representation_col="info.mofid.mofid_clean",
    encode_value=False,
)


In [97]:
import pandas as pd

In [98]:
train_prompts_aug_ = []

for i, row in train_prompts_aug.iterrows():
    if not 'nan' in  row['completion']:
        train_prompts_aug_.append(row)

train_prompts_aug = pd.DataFrame(train_prompts_aug_)

In [99]:
train_augment_filename = f"run_files/{filename_base}_train_prompts_mof_aug.jsonl"

train_prompts_aug.to_json(train_augment_filename, orient="records", lines=True)

In [100]:
fine_tune(train_augment_filename, valid_filename)

Fine-tune ft-lOtEpqD3BeNqeZo71xdB83n6 has the status "pending" and will not be logged
🎉 wandb sync completed successfully


'ada:ft-lsmoepfl-2022-09-13-21-17-29'

In [101]:
completions_aug = query_gpt3('ada:ft-lsmoepfl-2022-09-13-21-17-29', test_prompts)

In [102]:
predictions_aug = [
    extract_prediction(completions_aug, i) for i in range(len(completions_aug['choices']))
]

In [103]:
true = test_prompts['completion'].apply(lambda x: x.split('@')[0].strip())

In [104]:
cm_aug = ConfusionMatrix(actual_vector=true.to_list(), predict_vector=predictions_aug)

In [106]:
print(cm_aug)

Predict      high         low          medium       
Actual
high         0            0            1            

low          3            2            15           

medium       0            0            10           





Overall Statistics : 

95% CI                                                            (0.21563,0.55856)
ACC Macro                                                         0.5914
ARI                                                               -0.10601
AUNP                                                              0.56905
AUNU                                                              0.53968
Bangdiwala B                                                      0.34323
Bennett S                                                         0.08065
CBA                                                               0.16154
CSI                                                               -0.17179
Chi-Squared                                                       3.27885


Actually, I was stupid - I didn't even include the water data in the training set.

In [107]:
train_prompts_aug_full = pd.concat([train_prompts_aug, train_prompts])

In [108]:
train_augment_full_filename = f"run_files/{filename_base}_train_prompts_mof_aug_full.jsonl"

train_prompts_aug_full.to_json(train_augment_full_filename, orient="records", lines=True)

In [109]:
fine_tune(train_augment_full_filename, valid_filename)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

'ada:ft-lsmoepfl-2022-09-14-06-15-26'

In [110]:
completions = query_gpt3('ada:ft-lsmoepfl-2022-09-14-06-15-26', test_prompts)

In [111]:
predictions = [extract_prediction(completions, i) for i in range(len(completions['choices']))]

In [112]:
cm = ConfusionMatrix(actual_vector=true.to_list(), predict_vector=predictions)

In [113]:
print(cm)

Predict      high         low          medium       
Actual
high         0            1            0            

low          0            19           1            

medium       0            8            2            





Overall Statistics : 

95% CI                                                            (0.51286,0.84198)
ACC Macro                                                         0.78495
ARI                                                               0.06454
AUNP                                                              0.5671
AUNU                                                              0.54737
Bangdiwala B                                                      0.61864
Bennett S                                                         0.51613
CBA                                                               0.29286
CSI                                                               None
Chi-Squared                                                       None
Chi-Squa

Ok, at least this gave a small performance boost.

# 2. Train w/ additional information about the gases


### Using only the chemical name of the molecules as context


We train one a range of gases and then test on one it has not seen before. 

Not sure what the right baseline for this is. The easiest is perhaps to use a dummy model. 

In [115]:
from sklearn.dummy import DummyClassifier

In [116]:
dummy_stratified = DummyClassifier(strategy="stratified")

In [117]:
train_true = train_prompts['completion'].apply(lambda x: x.split('@')[0].strip())

In [118]:
dummy_stratified.fit(train_true, train_true)

In [119]:
dummy_predictions = dummy_stratified.predict(true)

In [120]:
cm_dummy = ConfusionMatrix(actual_vector=true.to_list(), predict_vector=dummy_predictions)

In [121]:
print(cm_dummy)

Predict      high         low          medium       
Actual
high         0            1            0            

low          0            11           9            

medium       0            6            4            





Overall Statistics : 

95% CI                                                            (0.30795,0.65979)
ACC Macro                                                         0.65591
ARI                                                               -0.04079
AUNP                                                              0.46753
AUNU                                                              0.48084
Bangdiwala B                                                      0.27959
Bennett S                                                         0.22581
CBA                                                               0.2859
CSI                                                               None
Chi-Squared                                                       None
Chi-Squ

Ok, adding no water training data leads to a worse performance than the dummy case ... :(

But, perhaps, GPT-3 can easier use this information about the molecule if we also provide it with the name

In [122]:
train_prompts_aug_w_name = create_single_property_forward_prompts_multiple_targets(
    train_df,
    [
        "outputs.Xe-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.Kr-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.H2S-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.CO2-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.CH4-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.O2-henry_coefficient-mol--kg--Pa_log_cat",
    ],
    {
        "outputs.Xe-henry_coefficient-mol--kg--Pa_log_cat": "Xenon (Xe) Henry coefficient",
        "outputs.Kr-henry_coefficient-mol--kg--Pa_log_cat": "Krypton (Kr) Henry coefficient",
        "outputs.H2S-henry_coefficient-mol--kg--Pa_log_cat": "hydrogen disulfide (H2S) Henry coefficient",
        "outputs.CO2-henry_coefficient-mol--kg--Pa_log_cat": "carbon dioxide (CO2) Henry coefficient",
        "outputs.CH4-henry_coefficient-mol--kg--Pa_log_cat": "methane (CH4) Henry coefficient",
        "outputs.O2-henry_coefficient-mol--kg--Pa_log_cat": "oxygen (O2) Henry coefficient",
    },
    representation_col="info.mofid.mofid_clean",
    encode_value=False,
)

train_prompts_aug_w_name_h2o = create_single_property_forward_prompts_multiple_targets(
    train_df,
    [
        "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat",
    ],
    {
        "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat": "water (H2O) Henry coefficient",
    },
    representation_col="info.mofid.mofid_clean",
    encode_value=False,
)

test_prompts_aug_w_name = create_single_property_forward_prompts_multiple_targets(
    test_df,
    [
        "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat",
    ],
    {
        "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat": "water (H2O) Henry coefficient",
    },
    representation_col="info.mofid.mofid_clean",
    encode_value=False,
)

In [123]:
filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
train_prompts_aug_w_name_filename = f"run_files/{filename_base}_train_prompts_mof_no_h2o.jsonl"
valid_filename = f"run_files/{filename_base}_valid_prompts_mof_h2o.jsonl"

train_prompts_aug_w_name.to_json(train_prompts_aug_w_name_filename, orient="records", lines=True)
test_prompts_aug_w_name.to_json(valid_filename, orient="records", lines=True)

In [124]:
fine_tune(train_prompts_aug_w_name_filename, valid_filename)

Fine-tune ft-zyt0bJQNVG30gv5q43Yx4gAE has the status "pending" and will not be logged
🎉 wandb sync completed successfully


'ada:ft-lsmoepfl-2022-09-14-07-20-41'

In [125]:
predictions_w_name_no_h2o = query_gpt3('ada:ft-lsmoepfl-2022-09-14-07-20-41', test_prompts_aug_w_name)

In [127]:
predictions_w_name_no_h2o_ = [extract_prediction(predictions_w_name_no_h2o, i) for i in range(len(predictions_w_name_no_h2o['choices']))]

In [129]:
print(ConfusionMatrix(actual_vector=true.to_list(), predict_vector=predictions_w_name_no_h2o_))

Predict      high         low          medium       
Actual
high         0            1            0            

low          0            16           4            

medium       0            8            2            





Overall Statistics : 

95% CI                                                            (0.40694,0.75435)
ACC Macro                                                         0.72043
ARI                                                               -0.03621
AUNP                                                              0.49567
AUNU                                                              0.49856
Bangdiwala B                                                      0.46429
Bennett S                                                         0.37097
CBA                                                               0.28
CSI                                                               None
Chi-Squared                                                       None
Chi-Squar

Now, let's try adding even more context to the prompts (data taken from our simulation inputs, https://github.com/lsmo-epfl/aiida-lsmo/blob/develop/aiida_lsmo/calcfunctions/ff_data.yaml)

In [130]:
train_prompts_aug_w_name_and_molinfo = create_single_property_forward_prompts_multiple_targets(
    train_df,
    [
        "outputs.Xe-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.Kr-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.H2S-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.CO2-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.CH4-henry_coefficient-mol--kg--Pa_log_cat",
        "outputs.O2-henry_coefficient-mol--kg--Pa_log_cat",
    ],
    {
        "outputs.Xe-henry_coefficient-mol--kg--Pa_log_cat": "Xenon (Xe, critical temperature 289.74 K, critial pressure 5840000 Pa, radius 1.985 A) Henry coefficient",
        "outputs.Kr-henry_coefficient-mol--kg--Pa_log_cat": "Krypton (Kr, critical temperature 209.35 K, critial pressure 5502000 Pa, radius 1.83 A) Henry coefficient",
        "outputs.H2S-henry_coefficient-mol--kg--Pa_log_cat": "hydrogen disulfide (H2S, critical temperature 373.53 K, critial pressure 8963000 Pa, radius 1.74 A) Henry coefficient",
        "outputs.CO2-henry_coefficient-mol--kg--Pa_log_cat": "carbon dioxide (CO2, critical temperature 304.19 K, critial pressure 7382000 Pa, radius 1.525 A) Henry coefficient",
        "outputs.CH4-henry_coefficient-mol--kg--Pa_log_cat": "methane (CH4, critical temperature 190.56 K, critial pressure 4599000 Pa, radius 1.865 A) Henry coefficient",
        "outputs.O2-henry_coefficient-mol--kg--Pa_log_cat": "oxygen (O2, critical temperature 154.58 K, critial pressure 5043000 Pa, radius 1.51 A) Henry coefficient",
    },
    representation_col="info.mofid.mofid_clean",
    encode_value=False,
)

train_prompts_aug_w_name_h2o_and_molinfo = create_single_property_forward_prompts_multiple_targets(
    train_df,
    [
        "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat",
    ],
    {
        "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat": "water (H2O, critical temperature 647.16 K, critial pressure 22055000 Pa, radius 1.58 A) Henry coefficient",
    },
    representation_col="info.mofid.mofid_clean",
    encode_value=False,
)

test_prompts_aug_w_name_and_molinfo = create_single_property_forward_prompts_multiple_targets(
    test_df,
    [
        "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat",
    ],
    {
        "outputs.H2O-henry_coefficient-mol--kg--Pa_log_cat": "water (H2O, critical temperature 647.16 K, critial pressure 22055000 Pa, radius 1.58 A) Henry coefficient",
    },
    representation_col="info.mofid.mofid_clean",
    encode_value=False,
)

In [131]:
filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
train_prompts_aug_w_name_w_molinfo_filename = f"run_files/{filename_base}_train_prompts_mof_no_h2o_w_molinfo.jsonl"
valid_filename_w_molinfo = f"run_files/{filename_base}_valid_prompts_mof_h2o_w_molinfo.jsonl"

train_prompts_aug_w_name_and_molinfo.to_json(train_prompts_aug_w_name_w_molinfo_filename, orient="records", lines=True)
test_prompts_aug_w_name_and_molinfo.to_json(valid_filename_w_molinfo, orient="records", lines=True)

In [139]:
train_prompts_aug_w_name_and_molinfo_ = []
for i, row in train_prompts_aug_w_name_and_molinfo.iterrows(): 
    if not 'nan' in row['completion']:
        train_prompts_aug_w_name_and_molinfo_.append(row)

train_prompts_aug_w_name_and_molinfo = pd.DataFrame(train_prompts_aug_w_name_and_molinfo_)

test_prompts_aug_w_name_and_molinfo_ = []
for i, row in test_prompts_aug_w_name_and_molinfo.iterrows(): 
    if not 'nan' in row['completion']:
        test_prompts_aug_w_name_and_molinfo_.append(row)

test_prompts_aug_w_name_and_molinfo = pd.DataFrame(test_prompts_aug_w_name_and_molinfo_)

train_prompts_aug_w_name_and_molinfo.to_json(train_prompts_aug_w_name_w_molinfo_filename, orient="records", lines=True)
test_prompts_aug_w_name_and_molinfo.to_json(valid_filename_w_molinfo, orient="records", lines=True)

In [140]:
fine_tune(train_prompts_aug_w_name_w_molinfo_filename, valid_filename_w_molinfo)

Fine-tune ft-kZLGUysgsbz5j94f9PFzCBDQ has the status "pending" and will not be logged
🎉 wandb sync completed successfully


'ada:ft-lsmoepfl-2022-09-14-09-07-33'

In [141]:
completions_w_molinfo = query_gpt3('ada:ft-lsmoepfl-2022-09-14-09-07-33', test_prompts_aug_w_name_and_molinfo)

In [142]:
predictions_w_molinfo = [extract_prediction(completions_w_molinfo, i) for i in range(len(completions_w_molinfo['choices']))]

In [143]:
print(ConfusionMatrix(actual_vector=true.to_list(), predict_vector=predictions_w_molinfo))

Predict      high         low          medium       
Actual
high         0            0            1            

low          1            2            17           

medium       2            1            7            





Overall Statistics : 

95% CI                                                            (0.13053,0.45011)
ACC Macro                                                         0.52688
ARI                                                               0.03759
AUNP                                                              0.47597
AUNU                                                              0.45866
Bangdiwala B                                                      0.16933
Bennett S                                                         -0.06452
CBA                                                               0.12667
CSI                                                               -0.41778
Chi-Squared                                                       1.984
C

Interestingly, adding more info makes things worse...