In [29]:
%load_ext autoreload 
%autoreload 2

In [20]:
from gpt3forchem.data import get_photoswitch_data
from gpt3forchem.input import create_single_property_forward_prompts
from gpt3forchem.api_wrappers import fine_tune, query_gpt3, extract_prediction

from sklearn.model_selection import train_test_split
import time

from pycm import ConfusionMatrix

In [3]:
data = get_photoswitch_data()

In [6]:
train_data, test_data = train_test_split(data, train_size=0.8, random_state=None, stratify=data["wavelength_cat"])

In [7]:
train_prompts = create_single_property_forward_prompts(
    train_data,
    "wavelength_cat",
    {"wavelength_cat": "transition wavelength"},
    representation_col="SMILES",
    smiles_augmentation=True,
)


In [8]:
test_prompts = create_single_property_forward_prompts(
    test_data,
    "wavelength_cat",
    {"wavelength_cat": "transition wavelength"},
    representation_col="SMILES",
    smiles_augmentation=True,
)

In [23]:
test_prompts_unaugmented = create_single_property_forward_prompts(
    test_data,
    "wavelength_cat",
    {"wavelength_cat": "transition wavelength"},
    representation_col="SMILES",
    smiles_augmentation=False,
)

In [10]:
train_size = len(train_prompts)
test_size = len(test_prompts)

filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
train_filename = (
    f"run_files/{filename_base}_train_prompts_photoswitch_augmented_{train_size}.jsonl"
)
valid_filename = (
    f"run_files/{filename_base}_valid_prompts_photoswitch_augmented_{test_size}.jsonl"
)

train_prompts.to_json(train_filename, orient="records", lines=True)
test_prompts.to_json(valid_filename, orient="records", lines=True)

In [12]:
fine_tune(train_filename, valid_filename, "ada")

Fine-tune ft-3oU8P4lxLqhlKnFnU3HpvPah has the status "running" and will not be logged
🎉 wandb sync completed successfully


'ada:ft-lsmoepfl-2022-09-05-15-28-54'

In [13]:
completions = query_gpt3('ada:ft-lsmoepfl-2022-09-05-15-28-54', test_prompts)

In [24]:
completions_unaugmented = query_gpt3('ada:ft-lsmoepfl-2022-09-05-15-28-54', test_prompts_unaugmented)

In [17]:
predictions_augmented = [
    int(extract_prediction(completions, i)) for i in range(len(completions["choices"]))
]


In [18]:
true_augmented = [int(e.split("@")[0]) for e in test_prompts["completion"]]


In [25]:
predictions_unaugmented = [
    int(extract_prediction(completions, i)) for i in range(len(completions_unaugmented["choices"]))
]
true_unaugmented = [int(e.split("@")[0]) for e in test_prompts_unaugmented["completion"]]


In [26]:
cm_unaugmented = ConfusionMatrix(actual_vector=true_unaugmented, predict_vector=predictions_unaugmented)

In [27]:
print(cm_unaugmented)

Predict  0        1        2        3        4        
Actual
0        11       5        10       1        0        

1        11       2        8        2        0        

2        5        6        8        1        0        

3        1        1        2        1        1        

4        1        1        0        0        0        





Overall Statistics : 

95% CI                                                            (0.18218,0.38192)
ACC Macro                                                         0.71282
ARI                                                               0.01064
AUNP                                                              0.49858
AUNU                                                              0.50582
Bangdiwala B                                                      0.11047
Bennett S                                                         0.10256
CBA                                                               0.18373
CSI                          

In [21]:
cm = ConfusionMatrix(true_augmented, predictions_augmented)


In [22]:
print(cm)

Predict   0         1         2         3         4         
Actual
0         237       24        9         0         0         

1         27        148       51        4         0         

2         0         11        176       13        0         

3         0         0         18        42        0         

4         0         0         0         1         19        





Overall Statistics : 

95% CI                                                            (0.76923,0.82564)
ACC Macro                                                         0.91897
ARI                                                               0.55122
AUNP                                                              0.86197
AUNU                                                              0.87752
Bangdiwala B                                                      0.66116
Bennett S                                                         0.74679
CBA                                                               0