In [2]:
from pathlib import Path
import pandas as pd
import json
import re
from sklearn.model_selection import train_test_split
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm

In [3]:
load_dotenv()

True

# Settings

In [18]:
training_path = Path('../data/training.csv')
chatgpt_results_path = Path('../data/chatgpt/results')
output_path = Path('../data/training_with_chatgpt.csv')
output_small_path = Path('../data/training_with_chatgpt_small.csv')
output_small_train_path = Path('../data/train_small.csv')
output_small_test_path = Path('../data/test_small.csv')
output_train_path = Path('../data/train.csv')
output_test_path = Path('../data/test.csv')

In [5]:
df = pd.read_csv(training_path)
df['chatgpt_raw_response'] = ''
df['chatgpt_extracted_location'] = ''
df['chatgpt_extracted_situation'] = ''
df['chatgpt_extracted_risk'] = ''
df['chatgpt_extracted_message'] = ''
df['objs_ctrls_embedding'] = None
df['message_embedding'] = None
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 124248 entries, 0 to 124247
Data columns (total 13 columns):
 #   Column                       Non-Null Count   Dtype 
---  ------                       --------------   ----- 
 0   town                         124248 non-null  object
 1   section                      124248 non-null  int64 
 2   scenario                     124248 non-null  object
 3   part                         124248 non-null  int64 
 4   objects                      124248 non-null  object
 5   controls                     124248 non-null  object
 6   chatgpt_raw_response         124248 non-null  object
 7   chatgpt_extracted_location   124248 non-null  object
 8   chatgpt_extracted_situation  124248 non-null  object
 9   chatgpt_extracted_risk       124248 non-null  object
 10  chatgpt_extracted_message    124248 non-null  object
 11  objs_ctrls_embedding         0 non-null       object
 12  message_embedding            0 non-null       object
dtypes: int64(2), o

# Combine results

In [6]:
chatgpt_results_files = list(chatgpt_results_path.glob('*.txt'))
chatgpt_results_files.sort(key=lambda x: int(x.stem))

In [7]:
start_json_pattern = r'```(json)*(\n)*'

In [8]:
for index, row in tqdm(df.iterrows(), total=len(df)):
    with open(chatgpt_results_files[index], 'r') as f:
        chatgpt_raw_response = f.read()
        df.at[index, 'chatgpt_raw_response'] = chatgpt_raw_response
        try:
            chatgpt_raw_response = re.sub(start_json_pattern, '', chatgpt_raw_response).strip()
            chatgpt_response = json.loads(chatgpt_raw_response)
        except json.JSONDecodeError:
            print(chatgpt_raw_response)
        df.at[index, 'chatgpt_extracted_location'] = chatgpt_response['location']
        df.at[index, 'chatgpt_extracted_situation'] = chatgpt_response['situation']
        df.at[index, 'chatgpt_extracted_risk'] = chatgpt_response['risk']
        df.at[index, 'chatgpt_extracted_message'] = chatgpt_response['message']

100%|██████████| 124248/124248 [00:10<00:00, 12098.62it/s]


In [9]:
df.head(5)

Unnamed: 0,town,section,scenario,part,objects,controls,chatgpt_raw_response,chatgpt_extracted_location,chatgpt_extracted_situation,chatgpt_extracted_risk,chatgpt_extracted_message,objs_ctrls_embedding,message_embedding
0,Town12,81,StaticCutIn,1,"[(30, 'vehicle.ford.mustang'), (15, 'vehicle.m...","[{'throttle': 0.75, 'steer': 2.133598718501161...","{\n ""location"": ""City street"",\n ""situation""...",City street,Approaching multiple vehicles,medium,"Stay focused, be aware of surrounding vehicles...",,
1,Town12,81,StaticCutIn,1,"[(29, 'vehicle.ford.mustang'), (14, 'vehicle.m...","[{'throttle': 0.75, 'steer': 0.000891609524842...","{\n ""location"": ""City street"",\n ""situation""...",City street,Approaching multiple vehicles and a traffic light,medium,"Stay alert, approaching vehicles and traffic l...",,
2,Town12,81,StaticCutIn,1,"[(28, 'vehicle.ford.mustang'), (13, 'vehicle.m...","[{'throttle': 0.75, 'steer': -0.00017667813517...","{\n ""location"": ""City street"",\n ""situation""...",City street,Approaching multiple vehicles and a traffic light,medium,"Stay focused, be prepared to slow down for the...",,
3,Town12,81,StaticCutIn,1,"[(30, 'vehicle.nissan.patrol_2021'), (30, 'veh...","[{'throttle': 0.0, 'steer': 0.0002305144735146...","{\n ""location"": ""Urban area"",\n ""situation"":...",Urban area,Approaching multiple vehicles and a stop sign,medium,"Stay alert, slow down, and prepare to stop at ...",,
4,Town12,81,StaticCutIn,1,"[(48, 'vehicle.mini.cooper_s_2021'), (7, 'vehi...","[{'throttle': 0.0, 'steer': -0.000179785565705...","{\n ""location"": ""City street"",\n ""situation""...",City street,Approaching a stop sign with vehicles nearby,medium,"Approach the stop sign cautiously, there are v...",,


In [10]:
train, test = train_test_split(df, test_size=24248, random_state=42)
train.to_csv(output_train_path, index=False)
test.to_csv(output_test_path, index=False)

## Downsample

In [11]:
small_df = df.sample(frac=0.1, random_state=42)
small_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 12425 entries, 92953 to 120974
Data columns (total 13 columns):
 #   Column                       Non-Null Count  Dtype 
---  ------                       --------------  ----- 
 0   town                         12425 non-null  object
 1   section                      12425 non-null  int64 
 2   scenario                     12425 non-null  object
 3   part                         12425 non-null  int64 
 4   objects                      12425 non-null  object
 5   controls                     12425 non-null  object
 6   chatgpt_raw_response         12425 non-null  object
 7   chatgpt_extracted_location   12425 non-null  object
 8   chatgpt_extracted_situation  12425 non-null  object
 9   chatgpt_extracted_risk       12425 non-null  object
 10  chatgpt_extracted_message    12425 non-null  object
 11  objs_ctrls_embedding         0 non-null      object
 12  message_embedding            0 non-null      object
dtypes: int64(2), object(11)
memory 

In [12]:
small_df.to_csv(output_small_path, index=False)

# Add embedding for small dataset

In [13]:
client = OpenAI()

In [14]:
def get_embedding(text, model="text-embedding-ada-002"):
    txt = text.replace("\n", " ")
    return client.embeddings.create(input=[txt], model=model).data[0].embedding

In [19]:
small_df = pd.read_csv(output_small_path)
small_df.head(5)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 12425 entries, 0 to 12424
Data columns (total 13 columns):
 #   Column                       Non-Null Count  Dtype  
---  ------                       --------------  -----  
 0   town                         12425 non-null  object 
 1   section                      12425 non-null  int64  
 2   scenario                     12425 non-null  object 
 3   part                         12425 non-null  int64  
 4   objects                      12425 non-null  object 
 5   controls                     12425 non-null  object 
 6   chatgpt_raw_response         12425 non-null  object 
 7   chatgpt_extracted_location   12425 non-null  object 
 8   chatgpt_extracted_situation  12425 non-null  object 
 9   chatgpt_extracted_risk       12425 non-null  object 
 10  chatgpt_extracted_message    12425 non-null  object 
 11  objs_ctrls_embedding         0 non-null      float64
 12  message_embedding            0 non-null      float64
dtypes: float64(2), i

In [30]:
for index, row in tqdm(small_df.iterrows(), total=len(small_df)):
    if not pd.isnull(small_df.at[index, 'objs_ctrls_embedding']) and not pd.isnull(small_df.at[index, 'message_embedding']):
        continue

    if pd.isnull(small_df.at[index, 'objs_ctrls_embedding']):
        objs_ctrls = f"objects: {row['objects']}, controls: {row['controls']}"
        objs_ctrls_embedding = get_embedding(objs_ctrls)
        small_df.at[index, 'objs_ctrls_embedding'] = str(objs_ctrls_embedding)
        small_df.to_csv(output_small_path, index=False)
        
    if pd.isnull(small_df.at[index, 'message_embedding']):
        msg_embedding = get_embedding(small_df.at[index, 'chatgpt_extracted_message'])
        small_df.at[index, 'message_embedding'] = str(msg_embedding)
        small_df.to_csv(output_small_path, index=False)

  0%|          | 0/12425 [00:00<?, ?it/s]

[-0.028873074799776077, 0.0040616123005747795, 0.010773960500955582, -0.016445966437458992, -0.029443126171827316, 0.008194481022655964, -0.006854861509054899, 0.010823840275406837, -0.028459789231419563, -0.04392241686582565, -0.017643073573708534, 0.045233532786369324, -0.004040235187858343, 0.004581783432513475, 0.00663040392100811, 0.017928099259734154, 0.006933243479579687, -0.04195574298501015, 0.013253682292997837, -0.015491131693124771, 0.0019827079959213734, 0.007310902234166861, -0.00930607970803976, -0.0030533347744494677, -0.03685378655791283, 0.0034220863599330187, 0.005729010794311762, -0.01922496408224106, 0.0007731314399279654, -0.037851374596357346, 0.01691625826060772, -0.03440256789326668, -0.022317489609122276, -0.04044510796666145, -0.01441516075283289, -0.003883471479639411, 0.0015213232254609466, -0.03169482573866844, 0.029614141210913658, 0.002262389287352562, -0.0005085922311991453, 0.01173592172563076, -0.005729010794311762, -0.015576639212667942, -0.000455817




ValueError: Must have equal len keys and value when setting with an iterable

In [31]:
small_df.at[0, 'objs_ctrls_embedding'] = [-0.028873074799776077, 0.0040616123005747795, 0.010773960500955582, -0.016445966437458992, -0.029443126171827316, 0.008194481022655964, -0.006854861509054899, 0.010823840275406837, -0.028459789231419563, -0.04392241686582565, -0.017643073573708534, 0.045233532786369324, -0.004040235187858343, 0.004581783432513475, 0.00663040392100811, 0.017928099259734154, 0.006933243479579687, -0.04195574298501015, 0.013253682292997837, -0.015491131693124771, 0.0019827079959213734, 0.007310902234166861, -0.00930607970803976, -0.0030533347744494677, -0.03685378655791283, 0.0034220863599330187, 0.005729010794311762, -0.01922496408224106, 0.0007731314399279654, -0.037851374596357346, 0.01691625826060772, -0.03440256789326668, -0.022317489609122276, -0.04044510796666145, -0.01441516075283289, -0.003883471479639411, 0.0015213232254609466, -0.03169482573866844, 0.029614141210913658, 0.002262389287352562, -0.0005085922311991453, 0.01173592172563076, -0.005729010794311762, -0.015576639212667942, -0.000455817993497476, 0.015947172418236732, 0.010004391893744469, 0.003069367492571473, -0.03739533573389053, 0.02170468494296074, 0.00325819686986506, 0.019281970337033272, -0.027048911899328232, -0.02123439311981201, 0.029029838740825653, 0.006359629798680544, -0.003496905555948615, 0.020023036748170853, -0.029614141210913658, -0.021220142021775246, 0.016588479280471802, -0.005123332142829895, -0.025794800370931625, -0.002365710912272334, -0.016616981476545334, 0.006651780568063259, 0.012762012891471386, 0.0014055316569283605, 0.01846964657306671, 0.024740206077694893, 0.016959013417363167, 0.005034261383116245, 0.013780979439616203, -0.0028128447011113167, -0.010674201883375645, -0.03149531036615372, -0.012740636244416237, 0.006804981734603643, 0.002611545380204916, -0.008679023943841457, 0.021619178354740143, -0.027105918154120445, -0.037281326949596405, 0.04044510796666145, 0.014265522360801697, -0.005329975392669439, -0.007375032640993595, 0.009990140795707703, 0.009313205257058144, -0.019951779395341873, -0.0027113042306154966, 0.01865491457283497, 0.013196676969528198, 0.0070223137736320496, -0.005843020975589752, 0.021647680550813675, -0.026934903115034103, 0.019495738670229912, -0.0027736537158489227, -0.011472273617982864, 0.0032866992987692356, -0.004888185765594244, -0.02642185613512993, -0.01150790136307478, -0.011764423921704292, -0.029101096093654633, 0.013289310038089752, 0.007795445155352354, 0.016474468633532524, -0.0059962221421301365, -0.007809696719050407, 0.01996603049337864, -0.00026253514806739986, -0.024526437744498253, -0.020080041140317917, 0.019438734278082848, 0.00018459849525243044, 0.0045354668982326984, -0.014864075928926468, -0.03024119697511196, 0.0002335872413823381, 0.0428963266313076, 0.0018348511075600982, -0.03619822859764099, 0.021619178354740143, -0.0056541915982961655, 0.010944976471364498, -0.02521049790084362, -0.006131609436124563, -0.016673987731337547, 0.018925687298178673, 0.019524240866303444, 0.014286899007856846, -0.0018972004763782024, -0.0009690864244475961, 0.019552743062376976, -0.008415375836193562, 0.013346315361559391, -0.01968100480735302, -0.011657539755105972, -0.00448558758944273, 0.027276933193206787, -0.015790408477187157, 0.016959013417363167, -0.002554540289565921, 0.005679131485521793, 0.006185051519423723, 0.011322635225951672, -0.011279881000518799, -0.015533885918557644, 0.0003333461354486644, 0.0073964097537100315, 0.003354392945766449, 0.009227697737514973, -0.011607659980654716, 0.002292673336341977, 0.024198658764362335, -0.005354915279895067, 0.02056458406150341, 0.005269407294690609, 0.0126408776268363, 0.029728151857852936, 0.0011650414671748877, -0.02918660268187523, 0.009747869335114956, 0.0159614235162735, 0.01705877110362053, 0.010047146119177341, -0.023728366941213608, -0.016004176810383797, -0.021961208432912827, 0.03252140060067177, -0.007375032640993595, -0.019595498219132423, -0.013809481635689735, 0.03200835362076759, 0.029699649661779404, 0.014621804468333721, -0.04660165682435036, 0.00042620208114385605, 0.011500775814056396, 0.017571818083524704, 0.03283492848277092, 0.041186172515153885, 0.01313967164605856, -0.032891932874917984, -0.007866702042520046, -0.04158521071076393, -0.004214813467115164, -0.0035913202445954084, 0.04255429655313492, 0.027747225016355515, -0.0036732652224600315, -0.005191025324165821, -0.6206713318824768, -0.0183271337300539, 0.017714329063892365, -0.02273077704012394, -0.0019559869542717934, 0.04577508196234703, 0.010389176197350025, -0.0059356545098125935, -0.025524025782942772, 0.030868252739310265, -0.009227697737514973, 0.03705330565571785, 0.003733833087608218, -0.010795338079333305, -0.017985103651881218, -0.01134401187300682, 0.024868467822670937, -0.04266830533742905, 0.004556844010949135, 0.022930294275283813, -0.013431822881102562, 0.010203910060226917, 0.009940261952579021, -0.006769353523850441, 0.0033686442766338587, 0.004873934667557478, -0.020593086257576942, 0.007531796582043171, 0.02099212259054184, 0.0006796075031161308, -0.012954405508935452, -0.013132546097040176, 0.02407039701938629, 0.0032528527081012726, 0.029671145603060722, -0.0070472536608576775, -0.034345563501119614, 0.026820892468094826, 0.019766513258218765, 0.0360272116959095, -0.012099329382181168, -0.01691625826060772, 0.02223198302090168, 0.002663206309080124, -0.0011846369598060846, 0.006844172719866037, 0.014963834546506405, 0.002371055306866765, 0.006861987058073282, -0.035229142755270004, 0.024797212332487106, 0.01641746424138546, -0.00014774559531360865, -0.01751481182873249, 0.04472048953175545, -0.0401315800845623, 0.030127186328172684, 0.009263326413929462, 0.01006852276623249, -0.0014776786556467414, 0.007845324464142323, 0.011529278010129929, -0.026336349546909332, -0.018284380435943604, -0.01435815542936325, 0.02623658999800682, 0.0027611837722361088, 0.013068415224552155, -0.03978955000638962, -0.04098665714263916, -0.010268040932714939, 0.018341386690735817, -0.007246771361678839, -0.0055544329807162285, 0.02397063747048378, -0.014123009517788887, 0.04674416780471802, -0.009006802923977375, -0.008871416561305523, 0.02489697001874447, -0.007246771361678839, -0.0009089639061130583, -0.013837983831763268, -0.02623658999800682, 0.02308705821633339, 0.004300320986658335, -0.014201391488313675, -0.01882592961192131, 0.027818480506539345, -0.002909040777012706, 0.01846964657306671, 0.01925346627831459, 0.011764423921704292, -0.01673099212348461, -0.0062456196174025536, 0.008379747159779072, -0.0214909166097641, 0.004895311780273914, 0.009177818894386292, -0.023742618039250374, -0.023600105196237564, -0.0064878910779953, 0.008244359865784645, 0.005087703932076693, 0.029870664700865746, 0.015519633889198303, -0.01751481182873249, -0.002896570833399892, 0.004378702957183123, -0.026265092194080353, -0.010025769472122192, 0.010752583853900433, -0.0041542453691363335, 0.018597908318042755, -0.010902222245931625, -0.03457358479499817, -0.016331955790519714, -0.007204017601907253, 0.013560084626078606, -0.020336564630270004, 0.028915828093886375, 0.017258288338780403, 0.011336886323988438, -0.008094722405076027, 0.004129305947571993, -0.010531689040362835, 0.02244575135409832, -0.00730377621948719, -0.022645270451903343, -0.02027955837547779, 0.05096254497766495, 0.03765185922384262, 0.009576854296028614, -0.01982351765036583, 0.022559762001037598, -0.009170693345367908, 0.0022018214222043753, -0.01705877110362053, 0.023785371333360672, -0.009804874658584595, -0.0032599782571196556, 0.007239645812660456, 0.012968656606972218, -0.004214813467115164, 0.016688238829374313, -0.010168282315135002, -0.003308076411485672, 0.02935761772096157, -0.012427108362317085, -0.03197985142469406, 0.0002990540233440697, -0.027248430997133255, -0.02627934329211712, 0.020821107551455498, -0.0012719259830191731, -0.011800052598118782, -0.027063162997364998, -0.011750172823667526, 0.017073022201657295, -0.006266996264457703, -0.009747869335114956, 0.004613848868757486, 0.003334797453135252, -0.01744355633854866, 0.00417562248185277, -0.009526974521577358, -0.02361435629427433, 0.011607659980654716, -0.04474899172782898, -0.028003748506307602, -0.0028146260883659124, -0.003420304972678423, 0.012149208225309849, 0.009370210580527782, -0.022032465785741806, 0.02694915421307087, -0.03534315153956413, -0.03237888962030411, -0.0035770691465586424, -0.001460755243897438, -0.006398820783942938, 0.01882592961192131, -0.005985533818602562, 0.01610393635928631, 0.039419014006853104, -0.01122287567704916, 0.01567639783024788, 0.01829863153398037, -0.031153278425335884, 0.04674416780471802, -0.008515134453773499, 0.017571818083524704, 0.0067123486660420895, 0.00872177816927433, -0.0176573246717453, 0.011700293980538845, 0.010952102020382881, -0.00249575381167233, 0.008921295404434204, 0.02134840376675129, 0.04024558886885643, 0.026122579351067543, 0.011301257647573948, -0.02198971062898636, 0.000678716809488833, -0.009690864011645317, -0.001384154660627246, -0.006224242504686117, 0.010488935746252537, 0.0021555046550929546, 0.030839750543236732, -0.001964893890544772, -0.011265629902482033, 0.008679023943841457, 0.014864075928926468, 0.01627495139837265, -0.017714329063892365, -0.0025028795935213566, -0.011108865961432457, -0.020977871492505074, 0.0009744306444190443, 0.0024619069881737232, 0.0533282570540905, -0.008322741836309433, -0.02351459674537182, -0.0016442403430119157, 0.015377121046185493, 0.03779437020421028, 0.010959227569401264, -0.013681219890713692, -0.016160940751433372, -0.005240905098617077, 0.013588586822152138, 0.005297909956425428, 0.004193436354398727, -0.010624323040246964, 0.026336349546909332, -0.014586175791919231, 0.02826027013361454, -0.013617089949548244, 0.029272111132740974, 0.024084648117423058, 0.036711275577545166, -0.009142190217971802, 0.004670854192227125, 0.0005063654971309006, 0.014158638194203377, 0.004417893942445517, -0.021120384335517883, 0.02840278297662735, 0.02039356902241707, -0.005251593422144651, -0.015533885918557644, 0.003003455465659499, 0.0054689254611730576, -0.010795338079333305, 0.011351137422025204, 0.013111169449985027, 0.025908811017870903, 0.005354915279895067, 0.029443126171827316, 0.002554540289565921, 0.00169322919100523, -0.00310321431607008, 0.03488711267709732, -0.0022000400349497795, -0.04620262235403061, -0.006452262867242098, 0.014443662948906422, -0.014529170468449593, -0.02719142474234104, -0.026721132919192314, -0.01227034442126751, -0.023072807118296623, 0.01846964657306671, 0.0020646529737859964, 0.017400801181793213, 0.0007147903088480234, 0.019652502611279488, 0.0290725938975811, 0.014058878645300865, -0.021690433844923973, 0.02137690596282482, 0.012455610558390617, -0.007688560523092747, -0.00925620086491108, -0.027747225016355515, 0.025053733959794044, -0.0006079057930037379, 0.0019666755106300116, 0.0034470262471586466, 0.02935761772096157, 0.006138734985142946, 0.0005994440871290863, 0.00040504784556105733, 0.029842162504792213, 0.00971224159002304, -0.0048846229910850525, 0.01755756512284279, 0.013638466596603394, -0.007318027783185244, -0.0018651351565495133, 0.006256307940930128, 0.0043965172953903675, 0.029058340936899185, 0.010203910060226917, 0.030697237700223923, -0.03759485483169556, -0.00474211061373353, -0.03212236613035202, 0.013916365802288055, -0.00860064197331667, -0.018711918964982033, 0.0011552436044439673, 0.0044499593786895275, -0.014764316380023956, 0.018270129337906837, 0.0038692201487720013, 0.01896844245493412, 0.010716956108808517, -0.024839965626597404, -0.029029838740825653, -0.02489697001874447, 0.005529493093490601, -0.011358262971043587, 0.023414839059114456, 0.013802356086671352, -0.0007455196464434266, -0.017671575769782066, -0.01789959706366062, -0.024141652509570122, -0.013916365802288055, 0.002262389287352562, 0.010089900344610214, -0.020222553983330727, 0.01145089603960514, 0.0143795320764184, -0.017571818083524704, 0.009997266344726086, 0.0013788104988634586, -0.017742833122611046, -0.018498150631785393, 0.004204125143587589, -0.03805089369416237, -0.006861987058073282, -0.00992600992321968, 0.012277469970285892, 0.02198971062898636, 0.02442667819559574, -0.013588586822152138, -0.005112643353641033, 0.01815611869096756, 0.008294239640235901, -0.020721348002552986, -0.01274776179343462, -0.0033258902840316296, 0.017358047887682915, 0.0021163136698305607, -0.01858365722000599, -0.017073022201657295, 0.03007018193602562, 0.0074035353027284145, 0.025823302567005157, 0.021761691197752953, 0.00981912575662136, -0.008486632257699966, 0.0026151081547141075, -0.009049557149410248, 0.011721670627593994, 0.005483176559209824, 0.02783273160457611, 0.019096702337265015, 0.013438948430120945, -0.021462414413690567, -0.0007722407463006675, -0.009491346776485443, -0.010766834951937199, 0.006438011769205332, 0.0015863445587456226, -0.017814088612794876, -0.0077384402975440025, -0.018426893278956413, -0.012762012891471386, -0.01292590331286192, -0.005401231814175844, -0.010025769472122192, 0.029614141210913658, -0.013481702655553818, -0.0055544329807162285, 0.017101524397730827, -0.025794800370931625, -0.0023835250176489353, -0.023913633078336716, 0.008237234316766262, -0.0018847306491807103, 0.006822796072810888, -0.023058556020259857, -0.002331864321604371, 0.013709723018109798, 0.00694036902859807, 0.00601047370582819, -0.012911651283502579, 0.008572139777243137, 0.012142082676291466, -0.010374925099313259, -0.0287590641528368, 0.01187843456864357, 0.0001702358858892694, -0.00948422122746706, 0.0044036428444087505, 0.02361435629427433, 0.005924965720623732, 0.0014696622965857387, 0.02673538401722908, -0.004574657883495092, -0.021718937903642654, 0.015248860232532024, -0.022787783294916153, 0.0074462890625, 0.017543314024806023, -0.005408357363194227, -0.004891749005764723, -0.003179814899340272, 0.0018170371185988188, 0.021975459530949593, -0.03215086832642555, -0.011750172823667526, -0.0002079794940073043, -0.0015480442671105266, -0.012662254273891449, -0.0007606616127304733, -0.012598123401403427, -0.004553281236439943, 0.018555155023932457, -0.00730377621948719, -0.009683738462626934, -0.023144064471125603, -0.03303444758057594, -0.022616766393184662, -0.016602730378508568, -0.012177711352705956, 0.005729010794311762, -0.030697237700223923, -0.006787167862057686, -0.009021054953336716, -0.01097347866743803, 0.01363134104758501, 0.009683738462626934, 0.012733510695397854, 0.008814411237835884, -0.027077414095401764, 0.001046577701345086, -0.017457807436585426, 0.008244359865784645, 0.014465040527284145, 0.0016950105782598257, -0.01001864392310381, -0.01286889798939228, -0.009405839256942272, -0.005900026299059391, 0.006060353014618158, -0.00454615568742156, 0.009149315766990185, -0.02893008105456829, 0.0035450037103146315, 0.0012656910112127662, 0.00996163859963417, -0.023771120235323906, 0.02035081572830677, -0.023272326216101646, 0.010132653638720512, 0.030583227053284645, 0.02254551090300083, 0.04138569161295891, 0.015761906281113625, 0.017742833122611046, -0.025324508547782898, -0.01822737604379654, -0.033490486443042755, -0.03765185922384262, -0.010538814589381218, 0.00893554650247097, 0.02740519493818283, 0.01982351765036583, 0.03491561487317085, -0.028431285172700882, 0.007567424792796373, -0.0055686840787529945, 0.011949690990149975, 0.03588470071554184, -0.001647803233936429, 0.01503509096801281, -0.029158100485801697, 0.00448558758944273, 0.0004237526445649564, -0.015291613526642323, 0.00756029924377799, 0.009933135472238064, -0.02435542270541191, 0.008465254679322243, 0.004873934667557478, 0.02166193164885044, -0.008679023943841457, -0.02198971062898636, -0.01701601780951023, -0.001170385628938675, -3.710340752149932e-05, -1.6297108231810853e-05, -0.04463497921824455, -0.022032465785741806, 0.01078108698129654, 0.015733404085040092, -0.014322527684271336, -0.0027629651594907045, 0.01503509096801281, -0.026578620076179504, 0.0035895388573408127, -0.0014874764019623399, 0.011372514069080353, -0.010503186844289303, -0.012434233911335468, -0.026008570566773415, -0.02025105617940426, -0.02351459674537182, -0.009455718100070953, 0.02726268209517002, 0.001726185204461217, -0.005839458201080561, -0.018954191356897354, 0.0186834167689085, -0.015662146732211113, -0.0023443340323865414, -0.014101632870733738, -0.004713607951998711, -0.0021875700913369656, -0.009192069992423058, -0.03768036141991615, -0.016887756064534187, 0.0020699971355497837, 0.024198658764362335, -0.0034559331834316254, 6.836852662672754e-06, -0.035542670637369156, -0.004286069888621569, 0.028516793623566628, 0.011308383196592331, 0.05361328274011612, 0.0018001137068495154, 0.032464396208524704, 0.02992766909301281, -0.009747869335114956, 0.0023425526451319456, -0.01221333909779787, 0.0005900916876271367, 0.010374925099313259, 0.01801360584795475, 0.01001864392310381, 0.004097240511327982, -0.03243589401245117, 0.0010617197258397937, -0.004350200295448303, -0.003056897548958659, -0.03138129785656929, 0.027861235663294792, 0.027775727212429047, 0.005992659367620945, -0.01730104349553585, -0.017628822475671768, -0.03343348205089569, -0.010097025893628597, -0.00328491791151464, 0.004571095108985901, 0.02407039701938629, -0.014450788497924805, -0.03360449895262718, 0.02868780866265297, -0.023585854098200798, 0.0033009506296366453, -0.0025189120788127184, -0.014571924693882465, 0.01617519184947014, 0.003049771999940276, -0.010403428226709366, 0.005376291926950216, 0.007845324464142323, 0.02304430492222309, -0.012070826254785061, -0.006769353523850441, 0.024440929293632507, 0.005960593931376934, 0.00015097440336830914, -0.013089792802929878, 0.01889718510210514, 0.03360449895262718, -0.0006911866366863251, -0.011814303696155548, 0.015106347389519215, 0.013752476312220097, -0.009284703060984612, 0.0018197091994807124, -0.00042108053457923234, -0.011821429245173931, -0.018384139984846115, 0.003819340607151389, 0.017529062926769257, 0.015206106007099152, -0.018925687298178673, 0.0028021561447530985, -0.01577615737915039, -0.023785371333360672, -0.018056361004710197, -0.003929787781089544, 0.009420090354979038, -0.0005126004107296467, -0.023699862882494926, -0.023129813373088837, -0.0041008032858371735, 0.001402859459631145, 0.026991907507181168, -0.01829863153398037, 0.0401315800845623, 0.04081564024090767, -0.030611731112003326, 0.00612448388710618, -0.05871523544192314, 0.01886868290603161, -0.043665893375873566, 0.027861235663294792, 0.001460755243897438, -0.0059570311568677425, -0.019096702337265015, 0.011664665304124355, -0.03266391158103943, -0.014450788497924805, 0.009648110717535019, -0.008629144169390202, 0.005027135834097862, 0.00474211061373353, 0.010382050648331642, 0.021049126982688904, 0.0287590641528368, -0.027789978310465813, -0.0074462890625, 0.015476880595088005, -0.013766727410256863, 0.04551855847239494, 0.019452985376119614, -0.009812000207602978, 0.010895096696913242, -0.0005758404149673879, 0.016004176810383797, 0.03200835362076759, -0.025324508547782898, -0.009762120433151722, -0.009405839256942272, 0.015491131693124771, -0.01648871973156929, 0.009177818894386292, -0.0004239753179717809, -0.029029838740825653, -0.010325046256184578, 0.008002088405191898, -0.0004827617958653718, -0.028017999604344368, 0.02833152748644352, 0.018925687298178673, 0.003056897548958659, 0.022174978628754616, -0.004400080069899559, -0.0032332572154700756, -0.019795015454292297, -0.016303453594446182, -0.016816500574350357, 0.00325819686986506, 0.0051660859026014805, 0.03445957228541374, 0.01908245123922825, -0.005778890568763018, -0.020008783787488937, -0.00011890903988387436, -0.025538276880979538, 0.0011418830836191773, -0.010930724442005157, 0.030440714210271835, 0.014621804468333721, -0.010389176197350025, -0.03369000554084778, 0.004852558020502329, 0.005073452368378639, 0.019068200141191483, -0.022060967981815338, -0.006819233298301697, 0.010631448589265347, -0.010588694363832474, -0.0011160526191815734, 0.019096702337265015, -0.04415043815970421, -0.019524240866303444, 0.011130242608487606, 0.019581247121095657, 0.04169921949505806, 0.03773736581206322, 0.004353763535618782, 0.004749236162751913, 0.0007143449620343745, -0.017187032848596573, 0.006042539142072201, -0.014180014841258526, -0.002424497390165925, -0.029243608936667442, -0.031010765582323074, -0.012484113685786724, 0.0040544867515563965, 0.020550332963466644, -0.015134849585592747, 0.001886512036435306, 0.013054164126515388, -0.0006239384529180825, -0.040160082280635834, -0.0070935701951384544, -0.023742618039250374, -0.006861987058073282, 0.008842913433909416, -0.00708288187161088, 0.00201655481941998, 0.0380793958902359, 0.02709166705608368, 0.033918026834726334, -0.010089900344610214, 0.013859361410140991, 0.0016647266456857324, 0.002533163409680128, -0.006266996264457703, -0.015149100683629513, 0.014493542723357677, -0.018669165670871735, 0.029243608936667442, -0.023386335000395775, 0.015020839869976044, -0.005080577917397022, -0.007809696719050407, -0.02180444449186325, 0.02754770778119564, -0.005639940500259399, -0.015334367752075195, 0.01734379678964615, 0.009548351168632507, -0.020821107551455498, -0.004282507114112377, -0.020977871492505074, -0.0040081702172756195, -0.020664343610405922, 0.0003912419197149575, 0.0063382526859641075, 0.0032546340953558683, -0.02297304943203926, 0.01068845298141241, 0.016460217535495758, -0.03750934451818466, -0.009177818894386292, 0.17660175263881683, -0.012313098646700382, 0.031153278425335884, 0.04751373827457428, -0.01804210990667343, 0.01272638514637947, 0.0008230109233409166, -0.01748630963265896, -0.026051323860883713, 0.01748630963265896, -0.002253482351079583, 0.0055081164464354515, -0.012648003175854683, -0.003972542006522417, 0.0037979637272655964, 0.027718722820281982, -0.02208947017788887, -0.026193836703896523, 0.0031709077302366495, -0.0023425526451319456, -0.010139779187738895, 9.25775893847458e-05, 0.0006836156826466322, -0.028915828093886375, 0.018783174455165863, -0.0134888282045722, -0.002139472169801593, 0.024882718920707703, 0.023728366941213608, 0.02723417989909649, -0.0036982048768550158, -0.017172781750559807, -0.0038692201487720013, -0.0028520356863737106, 0.011301257647573948, -0.018455395475029945, -0.02120589092373848, 0.01610393635928631, 0.010367799550294876, 0.023272326216101646, 0.012113580480217934, -0.0428963266313076, 0.014044627547264099, -0.017457807436585426, -0.036540258675813675, -0.006609026808291674, -0.02432692050933838, -0.009954513050615788, 0.0013476358726620674, 0.002921510487794876, -0.029699649661779404, -0.019795015454292297, 0.03474459797143936, 0.048881858587265015, -0.001532011665403843, -0.009897507727146149, 0.016146689653396606, -0.0032866992987692356, 0.017144279554486275, -0.00324394553899765, -0.008508008904755116, 0.03924800083041191, -0.004802678246051073, 0.02716292254626751, 0.0077313147485256195, 0.006662469357252121, -0.006078166887164116, -0.010089900344610214, 0.01665973663330078, -0.003199410391971469, -0.004567532334476709, 0.007617304567247629, 0.006455825641751289, 0.00356816197745502, -0.013396195136010647, -0.03386101871728897, 0.026507364585995674, -0.0040758633986115456, 0.02730543538928032, 0.014137260615825653, -0.021291399374604225, -0.009313205257058144, 0.016089685261249542, -0.004086552187800407, -0.007795445155352354, -0.035086628049612045, 0.02535301074385643, 0.012648003175854683, -0.018284380435943604, 0.018027858808636665, -0.0032884806860238314, -0.018555155023932457, -0.00860064197331667, 0.005369166377931833, -0.004571095108985901, 0.007823947817087173, -0.0018686979310587049, 0.027134420350193977, -0.007310902234166861, 0.006220679730176926, -0.03571368381381035, -0.001545372186228633, 0.01670248992741108, -0.0038050892762839794, 0.007660058327019215, -0.0067265997640788555, 0.023813873529434204, -0.005971282720565796, 0.014479291625320911, -0.041642215102910995, -0.004186310805380344, -0.03303444758057594, -0.0015712026506662369, -0.003464840352535248, 0.01049606129527092, 0.017571818083524704, -0.0015765468124300241, -0.018483897671103477, 0.011978193186223507, -0.009769245982170105, 0.005707634147256613, -0.041471198201179504, -0.022673772647976875, -0.0063240015879273415, -0.008358370512723923, -0.0007521999068558216, -0.05024998262524605, 0.016759494319558144, 0.01720128394663334, 0.004453522153198719, -0.00981912575662136, -0.042639803141355515, -0.00530147273093462, 0.012705008499324322, 0.002709522843360901, 0.008358370512723923, -0.0012211557477712631, -0.02198971062898636, 0.006341815460473299, 0.00033445953158661723, 0.00868614949285984, 0.014372406527400017, 0.03545716404914856, -0.02137690596282482, -0.004553281236439943, -0.01272638514637947, 0.005240905098617077, -0.0074462890625, -0.03300594538450241, -0.019880523905158043, -0.011001980863511562, 0.002973171416670084, -0.001324477489106357, -0.023272326216101646, 0.00922057218849659, -0.016460217535495758, 0.00035182826104573905, -0.03691079095005989, 0.0041471198201179504, 0.025253253057599068, -0.018711918964982033, -0.004556844010949135, -0.00849375780671835, -0.0017413272289559245, -0.009334582835435867, -0.013723974116146564, -0.18355637788772583, 0.03340497985482216, 0.018127616494894028, -0.02549552358686924, 0.004891749005764723, -0.0033953653182834387, 0.0242129098623991, -0.003012362401932478, -0.01911095529794693, -0.025680789723992348, 0.03488711267709732, -0.036426249891519547, -0.0332624651491642, -0.010367799550294876, 0.007164826616644859, -0.005422608461230993, -0.006128046661615372, 0.01554813701659441, 0.007852450013160706, 0.01865491457283497, 0.0346875935792923, -0.02754770778119564, 0.003518282435834408, 0.011351137422025204, -0.02156217396259308, 0.011800052598118782, 0.0021946956403553486, 0.021177388727664948, -0.015419875271618366, -0.0143795320764184, 0.01779983751475811, 0.007510419934988022, 0.033918026834726334, 0.024911221116781235, 0.008215857669711113, 0.021818695589900017, -0.015875915065407753, -0.026393353939056396, -0.007595927454531193, 0.026849394664168358, -0.003548566484823823, 0.00448558758944273, -0.015462629497051239, 0.011572032235562801, 0.013324938714504242, 0.01047468464821577, 0.009412964805960655, -0.012327349744737148, 0.004906000103801489, -0.008750280365347862, 0.026550117880105972, -0.039903558790683746, -0.004827618133276701, 0.01068845298141241, 0.02726268209517002, 0.02904408983886242, 0.002741588279604912, 0.0015391373308375478, -0.004278944339603186, -0.02191845513880253, -0.01722978614270687, -0.014579050242900848, 0.005312161520123482, -0.004471336491405964, -0.03309145197272301, -0.00687980093061924, -0.017614571377635002, 0.01289027463644743, -0.034374065697193146, 0.01655997708439827, -0.009918884374201298, 0.013859361410140991, 0.009177818894386292, 0.03505812585353851, 0.0040616123005747795, -0.012469862587749958, 0.001027872902341187, 0.03340497985482216, 0.013025661930441856, -8.055307989707217e-05, -0.01963825151324272, 0.051076553761959076, -0.02166193164885044, -0.0038692201487720013, 0.028887325897812843, -0.003521845443174243, 0.022531259804964066, 0.007845324464142323, -0.00024271696747746319, -0.01012552808970213, 0.01024666428565979, -0.025908811017870903, 0.0074106608517467976, -0.012120706029236317, -0.0159614235162735, 0.01503509096801281, -0.014229894615709782, 0.005693382583558559, 0.012427108362317085, -0.017671575769782066, -0.0116147855296731, -0.0070151882246136665, -0.00992600992321968, -0.014087381772696972, -0.00907805934548378, 0.028915828093886375, -0.020863860845565796, 0.031010765582323074, 0.03463058918714523, 0.01631770469248295, 0.004767050035297871, 0.005486739333719015, -0.0052266535349190235, 0.010845216922461987, 0.0014981648419052362, -0.00037899473682045937, 0.003381113987416029, -0.017400801181793213, 0.019709507003426552, 0.01351020485162735, 0.005971282720565796, -0.018483897671103477, 0.002946450375020504, 0.017315294593572617, 0.010218161158263683, 0.005012884736061096, -0.0982767641544342, -0.04144269600510597, 0.03656876087188721, 0.007424912415444851, 3.5071486763627036e-06, 0.027960993349552155, -0.008985426276922226, 0.03400353342294693, -0.03121028281748295, 0.0019346100743860006, -0.0016041586641222239, -0.011272755451500416, 0.003976104781031609, -0.010139779187738895, 0.01517760381102562, 0.0033632998820394278, 0.0009877912234514952, 0.00530147273093462, -0.01996603049337864, 0.019695255905389786, -0.025181995704770088, -0.0017493434716016054, 0.0018170371185988188, 0.008422501385211945, -6.296166975516826e-05, 0.028132008388638496, -0.020479075610637665, 0.031922847032547, -0.002941105980426073, 0.009334582835435867, 0.01911095529794693, -0.00245478143915534, 0.008308490738272667, -0.01851240172982216, -0.01854090392589569, 0.0023300829343497753, -0.020265307277441025, -0.018198873847723007, 0.03722431883215904, -0.01134401187300682, 0.005875086411833763, -0.00262401532381773, -0.008365496061742306, -0.006103106774389744, -0.01734379678964615, -0.024768708273768425, -0.03762335702776909, 0.028317276388406754, 0.0005985533935017884, 0.006441574543714523, -0.01634620875120163, 0.01148652471601963, -0.0004257567343302071, -0.00040126233943738043, -0.010802463628351688, 0.016360459849238396, 0.039447519928216934, 0.0030070182401686907, -0.028559546917676926, 0.03021269477903843, -0.013780979439616203, 0.0040687378495931625, -0.013624215498566628, 0.013054164126515388, -0.00567200593650341, -0.001279942225664854, -0.028559546917676926, -0.02027955837547779, 0.038421425968408585, -0.016089685261249542, -0.014949583448469639, 0.0014598645502701402, -0.0062776850536465645, 0.01825587823987007, -0.0291153471916914, 0.0059819710440933704, 0.00417562248185277, -0.014051753096282482, 0.008301365189254284, -0.0214481633156538, -0.017258288338780403, 0.011707419529557228, -0.01822737604379654, -0.016018427908420563, 0.03924800083041191, 0.009904633276164532, 0.00089872075477615, 0.0033704256638884544, 0.005262281745672226, -0.015932921320199966, 0.01517760381102562, 0.0020593085791915655, 0.03705330565571785, -0.033661503344774246, 0.001427799230441451, 0.014921081252396107, 0.003773024072870612, 0.029272111132740974, -0.020764101296663284, 0.008807285688817501, 0.009106562472879887, -0.0027772164903581142, -0.06253457814455032, 0.0066874087788164616, -0.015576639212667942, -0.00321722449734807, -0.010146904736757278, -0.016303453594446182, 0.005433297250419855, -0.02935761772096157, 0.01577615737915039, 0.015234609134495258, -0.01283326931297779, 0.012469862587749958, -0.010987729765474796, 0.004606723319739103, -0.020236805081367493, -0.01641746424138546, 0.015875915065407753, -0.0004293195379432291, 0.013018536381423473, 0.002134127775207162, 0.010937850922346115, -0.018341386690735817, 0.0126408776268363, 0.019581247121095657, -0.0025937312748283148, 0.027148671448230743, 0.006584087386727333, 0.04087264463305473, 0.002768309321254492, -0.025025231763720512, 0.020878111943602562, -0.021034875884652138, 0.006017599254846573, 0.017258288338780403, 0.012298846617341042, -0.004214813467115164, -0.0033757698256522417, 0.020678594708442688, 0.024668950587511063, -0.020792605355381966, -0.01801360584795475, -0.03876345604658127, -0.0040152957662940025, -0.0048774974420666695, -0.03138129785656929, -0.007610178552567959, 0.0023051430471241474, 0.01097347866743803, -0.005939217284321785, -0.0038157778326421976, 0.0008457238436676562, 0.01301141083240509, -0.01932472363114357, -0.0056007495149970055, -0.007346530444920063, -0.006573398597538471, -0.0009975889697670937, -0.005183899775147438, -0.001525776693597436, -0.012455610558390617, 0.03340497985482216, -0.0012772701447829604, 0.022217731922864914, -0.031751830130815506, -0.007417786400765181, -0.008736029267311096, -0.01097347866743803, -0.013674094341695309, 0.009754994884133339, -0.00555799575522542, -0.013125420548021793, -0.03414604440331459, 0.01808486320078373, 0.021889952942728996, -0.0036429811734706163, -0.009284703060984612, -0.001442050444893539, 0.018355637788772583, -0.004139994271099567, 0.018141867592930794, 0.05067751929163933, -0.013731099665164948, -0.03323396295309067, -0.0015533885452896357, 0.04765624925494194, -0.01631770469248295, -0.0074819172732532024, 0.012961531057953835, 0.015576639212667942, -0.02559528313577175, -0.038307417184114456, 0.017757084220647812, 0.010481810197234154, -0.00443214550614357, 0.025552529841661453, -0.00010187432053498924, -0.01886868290603161, 0.012648003175854683, -0.012512615881860256, 0.037423837929964066, -0.009192069992423058, 0.002406683284789324, -0.01989477500319481, -0.017386550083756447, -0.013659843243658543, -0.013723974116146564, 0.001402859459631145, -0.009648110717535019, -0.02773297391831875, 0.016146689653396606, 0.018854431807994843, -0.0011151619255542755, -0.0166312325745821, 0.010168282315135002, -0.014892578125, 0.01278339046984911, -0.009135064668953419, 0.005739699583500624, -0.02900133654475212, 0.020236805081367493, 0.024768708273768425, 0.016602730378508568, -0.000755762739572674, 0.004898874554783106, 0.03844992816448212, 0.047884270548820496, 0.00677291676402092, -0.016232198104262352, -0.014514919370412827, 0.0159614235162735, 0.0032475083135068417, -0.012505490332841873, -0.017970852553844452, -0.014144386164844036, -0.025338759645819664, -0.019909026101231575, -0.019167959690093994, -0.013702597469091415, -0.004043797962367535, 0.07712788134813309, 0.0007619976531714201, -0.021248646080493927, 0.0019274844089522958, 0.010396302677690983, 0.015149100683629513, 0.02113463543355465, 0.0070757558569312096, -0.030839750543236732, -0.004018858540803194, 0.01517760381102562, -0.012861772440373898, 0.008650521747767925, -0.04634513333439827, 0.01216346025466919, 0.004058049526065588, -0.01526311133056879, -0.003619822906330228, -0.009534100070595741, -0.019766513258218765, 0.003297387855127454, 0.014514919370412827, 0.002087811240926385, 0.00934170838445425, 0.010923598892986774, 0.0023176129907369614, -0.007574550341814756, -0.007271711248904467, -0.034801602363586426, -0.02446943148970604, -0.009869005531072617, 0.01911095529794693, -0.026322098448872566, -0.022402998059988022, -0.03212236613035202, 0.008443878032267094, -0.013574335724115372, -0.0035539106465876102, 0.03463058918714523, -0.009313205257058144, -0.008422501385211945, 0.07148437201976776, -0.018455395475029945, -0.014229894615709782, 0.021533669903874397, -0.009192069992423058, 0.01049606129527092, -0.004945191089063883, -0.04169921949505806]

ValueError: Must have equal len keys and value when setting with an iterable

In [109]:
small_train, small_test = train_test_split(small_df, test_size=2425, random_state=42)
small_train.to_csv(output_small_train_path, index=False)
small_test.to_csv(output_small_test_path, index=False)