# Forward predictions of polymer adsorption energies


In [2]:
import time

from sklearn.model_selection import train_test_split

from gpt3forchem.data import get_polymer_data
from gpt3forchem.input import create_single_property_forward_prompts
from gpt3forchem.api_wrappers import fine_tune, query_gpt3, extract_prediction

from pycm import ConfusionMatrix

import pandas as pd

Let's run one fine-tuning and inference for sanity check and then do it a coule of times for statistics.


## Sanity check


In [2]:
df = get_polymer_data()


In [3]:
df_train, df_test = train_test_split(df, train_size=200, random_state=42)


In [4]:
train_prompts = create_single_property_forward_prompts(
    df_train, "deltaGmin_cat", {"deltaGmin_cat": "adsorption energy"}
)

test_prompts = create_single_property_forward_prompts(
    df_test, "deltaGmin_cat", {"deltaGmin_cat": "adsorption energy"}
)


In [5]:
train_size  = len(train_prompts)
test_size = len(test_prompts)

filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
train_filename = f"run_files/{filename_base}_train_prompts_polymers_{train_size}.jsonl"
valid_filename = f"run_files/{filename_base}_valid_prompts_polymers_{test_size}.jsonl"

train_prompts.to_json(train_filename, orient="records", lines=True)
test_prompts.to_json(valid_filename, orient="records", lines=True)


In [6]:
modelname = fine_tune(train_filename, valid_filename)

wandb: Currently logged in as: kjappelbaum. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.13.1
wandb: Run data is saved locally in /Users/kevinmaikjablonka/git/kjappelbaum/gpt3forchem/experiments/wandb/run-20220818_072642-ft-EGjyqUbr0EbripHqk9KMw6xI
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run ft-EGjyqUbr0EbripHqk9KMw6xI
wandb: ⭐️ View project at https://wandb.ai/kjappelbaum/GPT-3
wandb: 🚀 View run at https://wandb.ai/kjappelbaum/GPT-3/runs/ft-EGjyqUbr0EbripHqk9KMw6xI
wandb: Waiting for W&B process to finish... (success).
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:             elapsed_examples ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:               elapsed_tokens ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:                training_loss █▂▃▂▄▂▂▂▃▃▂▂▂▂▂▂▃▂▂▂▂▂▂▃▂▁▁▅▁▂▃▂▁▁▁▁▁▁▁▁
wandb:   training_sequence_accuracy ▁▁▁▁▁▁▁▁▁▁▁█▁▁▁█▁▁█▁▁▁▁▁▁██▁

🎉 wandb sync completed successfully


In [7]:
test_prompt_subset = test_prompts
completions = query_gpt3(modelname, test_prompt_subset)

Error communicating with OpenAI
Error on row 2466
Error communicating with OpenAI
Error on row 2467
Error communicating with OpenAI
Error on row 2468
Error communicating with OpenAI
Error on row 2469
Error communicating with OpenAI
Error on row 2470
Error communicating with OpenAI
Error on row 2471
Error communicating with OpenAI
Error on row 2472
Error communicating with OpenAI
Error on row 2473
Error communicating with OpenAI
Error on row 2474
Error communicating with OpenAI
Error on row 2475
Error communicating with OpenAI
Error on row 2476
Error communicating with OpenAI
Error on row 2477
Error communicating with OpenAI
Error on row 2478
Error communicating with OpenAI
Error on row 2479
Error communicating with OpenAI
Error on row 2480
Error communicating with OpenAI
Error on row 2481
Error communicating with OpenAI
Error on row 2482
Error communicating with OpenAI
Error on row 2483
Error communicating with OpenAI
Error on row 2484
Error communicating with OpenAI
Error on row 2485


In [11]:
ok_completions = [(i, c) for i, c in enumerate(completions) if c is not None]

In [14]:
indices = [i for i, _ in ok_completions]

In [21]:
test_prompt_subset.iloc[0]['completion'].split('@')[0]

' 3'

In [25]:
predictions = [extract_prediction(completion) for _,completion in ok_completions]
true = [int(test_prompt_subset.iloc[i]['completion'].split('@')[0]) for i,_ in ok_completions]

In [26]:
cm = ConfusionMatrix(true, predictions)

In [27]:
print(cm)

Predict   0         1         2         3         4         
Actual
0         420       76        0         0         0         

1         53        379       65        0         0         

2         0         83        348       66        0         

3         0         0         69        334       79        

4         0         0         0         61        433       





Overall Statistics : 

95% CI                                                            (0.7597,0.79261)
ACC Macro                                                         0.91046
ARI                                                               0.54833
AUNP                                                              0.86011
AUNU                                                              0.85992
Bangdiwala B                                                      0.60837
Bennett S                                                         0.72019
CBA                                                               0.

In [30]:
import re
import subprocess

In [32]:
modelname = fine_tune(train_filename, valid_filename, 'davinci')

Found potentially duplicated files with name '2022-08-18-07-23-57_train_prompts_polymers_200.jsonl', purpose 'fine-tune' and size 26684 bytes
file-2q7AkuiOuT3jGJD1IpENQmlB
file-gXdwrlEVsMvM8F8FVi4cktI3
Enter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: Uploaded file from run_files/2022-08-18-07-23-57_train_prompts_polymers_200.jsonl: file-n9hBhkfGaJ0LrnHNq2x37qPJ
Found potentially duplicated files with name '2022-08-18-07-23-57_valid_prompts_polymers_2925.jsonl', purpose 'fine-tune' and size 392066 bytes
file-LXtdj0RWxCYIs96cGx4jlFGj
file-MNtEFMqzVyHyX5fa5FbTRN2w
Enter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: Uploaded file from run_files/2022-08-18-07-23-57_valid_prompts_polymers_2925.jsonl: file-JZzkBIg5AYoXVokAPQ0Ddtdb
Created fine-tune: ft-NTqOw8HbPaeOibDQKdzqACYR
Streaming events until fine-tuning is complete...

(Ctrl-C will interrupt the stream, but not cancel the fine-tune)
[2022-08-18 

wandb: Currently logged in as: kjappelbaum. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.13.1
wandb: Run data is saved locally in /Users/kevinmaikjablonka/git/kjappelbaum/gpt3forchem/experiments/wandb/run-20220818_221159-ft-NTqOw8HbPaeOibDQKdzqACYR
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run ft-NTqOw8HbPaeOibDQKdzqACYR
wandb: ⭐️ View project at https://wandb.ai/kjappelbaum/GPT-3
wandb: 🚀 View run at https://wandb.ai/kjappelbaum/GPT-3/runs/ft-NTqOw8HbPaeOibDQKdzqACYR
wandb: Waiting for W&B process to finish... (success).
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:             elapsed_examples ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:               elapsed_tokens ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:                training_loss █▄▃▃▃▃▃▃▄▄▄▃▃▃▂▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▂▂▂▂▁
wandb:   training_sequence_accuracy ▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁█▁██▁██▁██▁

🎉 wandb sync completed successfully


In [3]:
test_prompts = pd.read_json("run_files/2022-08-18-07-23-57_train_prompts_polymers_200.jsonl", orient="records", lines=True)

In [7]:
completions = query_gpt3("davinci:ft-lsmoepfl-2022-08-19-02-11-53", test_prompts)

In [10]:
ok_completions = [(i, c) for i, c in enumerate(completions) if c is not None]
predictions = [extract_prediction(completion) for _,completion in ok_completions]
true = [int(test_prompts.iloc[i]['completion'].split('@')[0]) for i,_ in ok_completions]

In [11]:
cm = ConfusionMatrix(true, predictions)

In [12]:
print(cm)

Predict  0        1        2        3        4        
Actual
0        33       3        0        0        0        

1        0        44       1        0        0        

2        0        3        34       3        0        

3        0        0        5        32       5        

4        0        0        0        4        33       





Overall Statistics : 

95% CI                                                            (0.83496,0.92504)
ACC Macro                                                         0.952
ARI                                                               0.7291
AUNP                                                              0.92441
AUNU                                                              0.92471
Bangdiwala B                                                      0.77877
Bennett S                                                         0.85
CBA                                                               0.8554
CSI                                 

In [14]:
import openai

In [18]:
# |export
def query_gpt3(model, df, temperature=0, max_tokens=10, sleep=5, one_by_one=False):
    if one_by_one:
        completions = []
        for i, row in df.iterrows():
            try:
                completion = openai.Completion.create(
                    model=model,
                    prompt=row["prompt"],
                    temperature=temperature,
                    max_tokens=max_tokens,
                )
                completions.append(completion)
                time.sleep(sleep)
            except Exception as e:
                print(e)
                print(f"Error on row {i}")
                completions.append(None)
    else: 
        completions = openai.Completion.create(
                    model=model,
                    prompt=df["prompt"].to_list(),
                    temperature=temperature,
                    max_tokens=max_tokens,
                )
    return completions

In [24]:
res = query_gpt3("davinci:ft-lsmoepfl-2022-08-19-02-11-53", test_prompts.iloc[:5])

In [28]:
len(res['choices'])

5

In [29]:
def extract_prediction(completion, i=0):
    return completion["choices"][i]["text"].split("@")[0].strip()


In [30]:
extract_prediction(res)

'2'

In [44]:
predictions = [
    extract_prediction(res, i)
    for i, completion in enumerate(res["choices"][0])
]

In [43]:
res['choices'][0]

<OpenAIObject at 0x170b428b0> JSON: {
  "finish_reason": "length",
  "index": 0,
  "logprobs": null,
  "text": " 2@@@### 3@@@### 2@@"
}

In [47]:
cm.ACC_Macro

0.952

In [48]:
from fastcore.basics import chunked

In [66]:
for chunk in chunked(test_prompts['prompt'], 20):
    print(chunk)

['what is the adsorption energy of B-R-B-B-R-R-R-R-R-R-A-B-W-B-W-A-B-A-W-W-A-R-B-W-B-W-B-W-B-B-R-W-R-R-W-B-W-R###', 'what is the adsorption energy of R-A-R-B-R-W-W-B-W-R-B-R-B-A-B-B-W-B-W-A-R-R-B-B-A-W-B-R###', 'what is the adsorption energy of W-R-W-R-B-A-W-W-R-A-W-A-A-W-R-A-B-W-A-B-R-B-R-A-W-W-W-R-R-R-W-R-A-A-A-A-W-A###', 'what is the adsorption energy of W-R-W-B-W-A-A-A-R-W-B-R-W-A-W-W-R-B-W-B-R-R-A-A###', 'what is the adsorption energy of A-A-W-R-R-B-R-W-B-A-R-A-W-B-A-R-A-W-R-B-W-A-W-A###', 'what is the adsorption energy of R-W-W-A-B-B-W-B-W-B-A-R-A-R-R-W-R-R-A-W-R-R###', 'what is the adsorption energy of R-A-B-W-A-B-A-W-B-A-R-W-B-W-W-B-R-A-W-B-A-W-A-B-R-W-B-W-W-A###', 'what is the adsorption energy of W-R-B-A-R-W-W-R-A-B-R-W-A-R-W-W-W-W-W-B-R-W-B-W-A-W###', 'what is the adsorption energy of W-R-W-B-A-A-A-B-B-W-B-R-R-W-R-A-W-R-R-R-A-W-A-A-B-A-B-R###', 'what is the adsorption energy of R-W-A-B-R-R-A-A-R-B-A-A-A-B-R-A-R-B-A-R-B-W-W-A-W-R-A-A-B-A###', 'what is the adsorption energy of

In [63]:
pd.DataFrame.from_records(chunk)

Unnamed: 0,0,1
0,180,prompt what is the adsorption energy of...
1,181,prompt what is the adsorption energy of...
2,182,prompt what is the adsorption energy of...
3,183,prompt what is the adsorption energy of...
4,184,prompt what is the adsorption energy of...
5,185,prompt what is the adsorption energy of...
6,186,prompt what is the adsorption energy of...
7,187,prompt what is the adsorption energy of...
8,188,prompt what is the adsorption energy of...
9,189,prompt what is the adsorption energy of...


In [57]:
a = [res["choices"][:2], res["choices"][2:4]]

In [59]:
sum(a, [])

[<OpenAIObject at 0x170b428b0> JSON: {
   "finish_reason": "length",
   "index": 0,
   "logprobs": null,
   "text": " 2@@@### 3@@@### 2@@"
 },
 <OpenAIObject at 0x170b42a90> JSON: {
   "finish_reason": "length",
   "index": 1,
   "logprobs": null,
   "text": " 3@@@ 3@@@ 3@@@ 3"
 },
 <OpenAIObject at 0x170b42ef0> JSON: {
   "finish_reason": "length",
   "index": 2,
   "logprobs": null,
   "text": " 1@@@### 2@@@### 3@@"
 },
 <OpenAIObject at 0x170b3c130> JSON: {
   "finish_reason": "length",
   "index": 3,
   "logprobs": null,
   "text": " 4@@@ 4@@@ 3@@@ 3"
 }]