In [94]:
from pathlib import Path
import pandas as pd
import json
import re
from sklearn.model_selection import train_test_split
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm

In [None]:
load_dotenv()

# Settings

In [106]:
training_path = Path('../data/training.csv')
chatgpt_results_path = Path('../data/chatgpt/results')
output_path = Path('../data/training_with_chatgpt.csv')
output_train_path = Path('../data/train.csv')
output_test_path = Path('../data/test.csv')

In [97]:
df = pd.read_csv(training_path)
df['chatgpt_raw_response'] = ''
df['chatgpt_extracted_location'] = ''
df['chatgpt_extracted_situation'] = ''
df['chatgpt_extracted_risk'] = ''
df['chatgpt_extracted_message'] = ''
df['objs_ctrls_embedding'] = None
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 124248 entries, 0 to 124247
Data columns (total 12 columns):
 #   Column                       Non-Null Count   Dtype 
---  ------                       --------------   ----- 
 0   town                         124248 non-null  object
 1   section                      124248 non-null  int64 
 2   scenario                     124248 non-null  object
 3   part                         124248 non-null  int64 
 4   objects                      124248 non-null  object
 5   controls                     124248 non-null  object
 6   chatgpt_raw_response         124248 non-null  object
 7   chatgpt_extracted_location   124248 non-null  object
 8   chatgpt_extracted_situation  124248 non-null  object
 9   chatgpt_extracted_risk       124248 non-null  object
 10  chatgpt_extracted_message    124248 non-null  object
 11  objs_ctrls_embedding         0 non-null       object
dtypes: int64(2), object(10)
memory usage: 11.4+ MB


# Combine results

In [98]:
chatgpt_results_files = list(chatgpt_results_path.glob('*.txt'))
chatgpt_results_files.sort(key=lambda x: int(x.stem))

In [99]:
start_json_pattern = r'```(json)*(\n)*'

In [101]:
for index, row in tqdm(df.iterrows(), total=len(df)):
    with open(chatgpt_results_files[index], 'r') as f:
        chatgpt_raw_response = f.read()
        df.at[index, 'chatgpt_raw_response'] = chatgpt_raw_response
        try:
            chatgpt_raw_response = re.sub(start_json_pattern, '', chatgpt_raw_response).strip()
            chatgpt_response = json.loads(chatgpt_raw_response)
        except json.JSONDecodeError:
            print(chatgpt_raw_response)
        df.at[index, 'chatgpt_extracted_location'] = chatgpt_response['location']
        df.at[index, 'chatgpt_extracted_situation'] = chatgpt_response['situation']
        df.at[index, 'chatgpt_extracted_risk'] = chatgpt_response['risk']
        df.at[index, 'chatgpt_extracted_message'] = chatgpt_response['message']

100%|██████████| 124248/124248 [00:05<00:00, 20782.58it/s]


In [102]:
df.head(5)

Unnamed: 0,town,section,scenario,part,objects,controls,chatgpt_raw_response,chatgpt_extracted_location,chatgpt_extracted_situation,chatgpt_extracted_risk,chatgpt_extracted_message,objs_ctrls_embedding
0,Town12,81,StaticCutIn,1,"[(30, 'vehicle.ford.mustang'), (15, 'vehicle.m...","[{'throttle': 0.75, 'steer': 2.133598718501161...","{\n ""location"": ""City street"",\n ""situation""...",City street,Approaching multiple vehicles,medium,"Stay focused, be aware of surrounding vehicles...",
1,Town12,81,StaticCutIn,1,"[(29, 'vehicle.ford.mustang'), (14, 'vehicle.m...","[{'throttle': 0.75, 'steer': 0.000891609524842...","{\n ""location"": ""City street"",\n ""situation""...",City street,Approaching multiple vehicles and a traffic light,medium,"Stay alert, approaching vehicles and traffic l...",
2,Town12,81,StaticCutIn,1,"[(28, 'vehicle.ford.mustang'), (13, 'vehicle.m...","[{'throttle': 0.75, 'steer': -0.00017667813517...","{\n ""location"": ""City street"",\n ""situation""...",City street,Approaching multiple vehicles and a traffic light,medium,"Stay focused, be prepared to slow down for the...",
3,Town12,81,StaticCutIn,1,"[(30, 'vehicle.nissan.patrol_2021'), (30, 'veh...","[{'throttle': 0.0, 'steer': 0.0002305144735146...","{\n ""location"": ""Urban area"",\n ""situation"":...",Urban area,Approaching multiple vehicles and a stop sign,medium,"Stay alert, slow down, and prepare to stop at ...",
4,Town12,81,StaticCutIn,1,"[(48, 'vehicle.mini.cooper_s_2021'), (7, 'vehi...","[{'throttle': 0.0, 'steer': -0.000179785565705...","{\n ""location"": ""City street"",\n ""situation""...",City street,Approaching a stop sign with vehicles nearby,medium,"Approach the stop sign cautiously, there are v...",


# Add embedding

In [103]:
client = OpenAI()

In [105]:
def get_embedding(text, model="text-embedding-ada-002"):
    txt = text.replace("\n", " ")
    return client.embeddings.create(input = [txt], model=model).data[0].embedding

In [107]:
df = pd.read_csv(output_path)
for index, row in tqdm(df.iterrows(), total=len(df)):
    if not pd.isnull(df.at[index, 'objs_ctrls_embedding']):
        continue
    text = f"objects: {row['objects']}, controls: {row['controls']}"
    embedding = get_embedding(text)
    df.at[index, 'objs_ctrls_embedding'] = embedding
    df.to_csv(output_path, index=False)

  0%|          | 240/124248 [01:27<12:34:17,  2.74it/s] 


KeyboardInterrupt: 

In [109]:
train, test = train_test_split(df, test_size=24248, random_state=42)

In [110]:
train.to_csv(output_train_path, index=False)
test.to_csv(output_test_path, index=False)