In [1]:
%load_ext autoreload
%autoreload 2

In [9]:
from gpt3forchem.data import get_photoswitch_data
from gpt3forchem.input import create_single_property_forward_prompts
from sklearn.model_selection import train_test_split
from gpt3forchem.api_wrappers import fine_tune, query_gpt3, extract_prediction, ensemble_fine_tune, multiple_query_gpt3
import time
from pycm import ConfusionMatrix
from gpt3forchem.baselines import GPRBaseline, compute_fragprints
import pandas as pd
import numpy as np


import matplotlib.pyplot as plt

plt.style.use(["science", "nature"])


In [3]:
data = get_photoswitch_data()


In [4]:
prompts = create_single_property_forward_prompts(
    data,
    "wavelength_cat",
    {"wavelength_cat": "transition wavelength"},
    representation_col="SMILES",
)


In [5]:
prompts = prompts.sample(frac=0.5)

In [6]:
train_prompts, test_prompts = train_test_split(
    prompts, test_size=0.2, random_state=None
)


In [7]:
models = ensemble_fine_tune(train_prompts, test_prompts, num_models=5)

Fine tuning on 5 train files and 5 valid files
Fine-tune ft-0sMbEhCOyzwtyDu4KXcCcWBQ has the status "pending" and will not be logged
🎉 wandb sync completed successfully
Fine-tune ft-0sMbEhCOyzwtyDu4KXcCcWBQ has the status "pending" and will not be logged
🎉 wandb sync completed successfully
Fine-tune ft-0sMbEhCOyzwtyDu4KXcCcWBQ has the status "pending" and will not be logged
🎉 wandb sync completed successfully
Uploaded file from run_files/2022-09-01-13-33-35_train__ensemble_2_125.jsonl: file-hq8u626g7lvBn2brPfG5oqr1
Uploaded file from run_files/2022-09-01-13-33-35_valid__ensemble_2_40.jsonl: file-TSWNDsqoAFdORx2gdgrvfpPy
Created fine-tune: ft-93BW8r44toCI0uGx7RGezjqu
Streaming events until fine-tuning is complete...

(Ctrl-C will interrupt the stream, but not cancel the fine-tune)
[2022-09-01 13:33:44] Created fine-tune: ft-93BW8r44toCI0uGx7RGezjqu

Stream interrupted (client disconnected).
To resume the stream, run:

  openai api fine_tunes.follow -i ft-93BW8r44toCI0uGx7RGezjqu

 
Uplo

In [8]:
models

(#5) ['ada:ft-lsmoepfl-2022-09-01-11-35-43','ada:ft-lsmoepfl-2022-09-01-11-37-42',None,None,'ada:ft-lsmoepfl-2022-09-01-11-39-37']

In [17]:
completions =   multiple_query_gpt3(models, test_prompts)

In [18]:
completions

[{'choices': [<OpenAIObject at 0x2967e40e0> JSON: {
     "finish_reason": "length",
     "index": 0,
     "logprobs": null,
     "text": " 2@@@@@@ 1@@@@@@"
   },
   <OpenAIObject at 0x2967d4e00> JSON: {
     "finish_reason": "length",
     "index": 1,
     "logprobs": null,
     "text": " 0@@@@@@@@@@@@@@"
   },
   <OpenAIObject at 0x2967d4ea0> JSON: {
     "finish_reason": "length",
     "index": 2,
     "logprobs": null,
     "text": " 2@@@@@@@@@@@@@@"
   },
   <OpenAIObject at 0x2967d4130> JSON: {
     "finish_reason": "length",
     "index": 3,
     "logprobs": null,
     "text": " 0@@@@@@@@@@@@@@"
   },
   <OpenAIObject at 0x2967d70e0> JSON: {
     "finish_reason": "length",
     "index": 4,
     "logprobs": null,
     "text": " 2@@@@@@@=@@@@"
   },
   <OpenAIObject at 0x2967d7090> JSON: {
     "finish_reason": "length",
     "index": 5,
     "logprobs": null,
     "text": " 0@@@@@@@@@@@@@@"
   },
   <OpenAIObject at 0x2967d72c0> JSON: {
     "finish_reason": "length",
     "index"

In [19]:
len(completions)

3

In [23]:
predictions = [
    [
        int(extract_prediction(completions[j], i))
        for i in range(len(completions[j]["choices"]))
    ]
    for j in range(len(completions))
]


In [33]:
pred = np.array(predictions).T

In [46]:
true = [
    int(test_prompts.iloc[i]["completion"].split("@")[0])
    for i in range(len(predictions[0]))
]

predictions are mostly the same but not always. Let's get the performance per prediction and for the majority vote.

In [34]:
maj_vote = np.apply_along_axis(lambda x: np.argmax(np.bincount(x)), axis=1, arr=pred)


In [35]:
maj_vote

array([2, 0, 2, 0, 2, 0, 2, 2, 2, 0, 0, 2, 2, 0, 2, 0, 2, 2, 2, 0, 0, 2,
       1, 2, 2, 0, 1, 2, 2, 0, 2, 0, 2, 0, 1, 0, 0, 2, 0, 0])

In [41]:
len(true)

3

In [49]:
cm_maj = ConfusionMatrix(maj_vote, true)

In [50]:
cm_0 = ConfusionMatrix(predictions[0], true)
cm_1 = ConfusionMatrix(predictions[1], true)
cm_2 = ConfusionMatrix(predictions[2], true)

In [52]:
print(cm_0.F1_Macro, cm_1.F1_Macro, cm_2.F1_Macro, cm_maj.F1_Macro)

0.35218390804597705 0.3517241379310345 0.5881481481481481 0.33766749379652605


In [53]:
print(cm_0.F1_Micro, cm_1.F1_Micro, cm_2.F1_Micro, cm_maj.F1_Micro)

0.575 0.575 0.675 0.575


In [54]:
print(cm_0.ACC_Macro, cm_1.ACC_Macro, cm_2.ACC_Macro, cm_maj.ACC_Macro)

0.8299999999999998 0.8299999999999998 0.8699999999999999 0.8299999999999998
