In [38]:
from google_api import generate
from llama_api import generate_llama

In [39]:
import random

conditions = {
    "high_time_pressure_condition": ' This is a time critical task. Respond after seen at least dots as possible.',
    "both_condition": ' Please respond after seen at least dots as possible but also make sure that you are as accurate as possible.',
    "high_accuracy_pressure_condition": " It is very important to be as accurate as possible. Don't rush your decision.",
    "low_reward_condition": ' This task is not rewarded.',
    "medium_reward_condition": ' This task is rewarded with $0.1 for each correct response.',
    "high_reward_condition": ' This task is rewarded with $1 for each correct response.'
}


def create_prompt(coherence,
                  number_of_dots,
                  direction,
                  preamble=None,
                  appendix=None,
                  condition='',
                  total_num_of_dots=1000):
    prompt = ''
    if preamble is None:
        preamble = f'In this task you see dots moving. Please indicate the main direction of the dots. In total there are {total_num_of_dots} dots. Some of them might move left and other might move right but you can only look at one point at a time. You can decide to either look at more dots then respond with "continue". But you can also decide that you looked at enough dots. Then respond with "left" or "right".{condition} Remember: Your decision is to either look at more dots (continue) or to decide on the main direction of the dots (left or right).'
    if appendix is None:
        appendix = 'Please indicate if you want to look at more dots (continue), then you will see another dot and can decide again. If you decide that you looked at enough dots, decide on the main directory of th dots. The only acceptable responses are "left", "right" or "continue". Make sure to not respond with anything else then "left", "right" or "continue".'
    prompt += preamble
    if abs(coherence * number_of_dots - round(coherence * number_of_dots)) > 0.004:
        raise ValueError('coherence * number_of_dots must be an integer')
    if direction == 'right':
        trials = ['This dot is moving right.'] * int(number_of_dots * coherence)
        trials += ['This dot is moving left.'] * int(number_of_dots * (1 - coherence))
    else:
        trials = ['This dot is moving left.'] * int(number_of_dots * coherence)
        trials += ['This dot is moving right.'] * int(number_of_dots * (1 - coherence))
    random.shuffle(trials)
    for trial in trials:
        prompt += ' ' + trial
    prompt += ' ' + appendix
    return prompt

In [40]:
def get_accuracy(response, direction):
    if response.lower() in 'continue' or 'continue' in response.lower():
        return 'continue'
    return response.lower() in direction or direction in response.lower()
coherences = [.5, .6, .7, .8, .9, 1.]
times = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
directions = ['left', 'right']

In [41]:
import json
data = {}
data['coherence'] = []
data['direction'] = []
data['num_of_dots'] = []
data['prompt'] = []
data['response'] = []
data['accuracy'] = []
data['model'] = []
data['condition'] = []
data['total_num_of_dots'] = 1000
with open('stabFlex.json', 'w') as f:
    json.dump(data, f)

In [50]:
import time

with open('stabFlex.json', 'r') as f:
    data = json.load(f)
    
is_sleep = False
# model = 'llama3.1-70b'
# model = 'mixtral-8x22b'
# model = 'llama2-7b'
models = ['llama3.1-70b', 'gemini-1.5-flash-002', 'gemini-1.0-pro-001']
models = ['llama3.1-70b', 'gemini-1.5-flash-002', 'gemini-1.0-pro-001']
models = ['llama3.1-70b']

total_num = 1000
# model = 'llama3-8b'
i = 0
for c in conditions:
    for timestep in times:
        for direction in directions:   
            for coherence in coherences:
                prompt= create_prompt(coherence, timestep, direction, condition=conditions[c], total_num_of_dots=total_num)
                model = random.choice(models)
                if 'gemini' in model:
                    response = generate(prompt)
                else:
                    response = generate_llama(prompt, model)
                accuracy = get_accuracy(response, direction)
                if is_sleep:
                    time.sleep(4)
                data['coherence'].append(coherence)
                data['direction'].append(direction)
                data['num_of_dots'].append(timestep)
                data['prompt'].append(prompt)
                data['response'].append(response)
                data['accuracy'].append(accuracy)
                data['model'].append(model)
                data['condition'].append(c)
                data['total_num_of_dots'] = total_num
                print('.', end='')
                i+=1
                if not i % len(coherences):
                    print('')
                with open('stabFlex.json', 'w') as f:
                    json.dump(data, f)

......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......
......


In [49]:
import pandas as pd
import matplotlib.pyplot as plt
with open('stabFlex.json', 'r') as f:
    data = json.load(f)
    
df = pd.DataFrame(data)

result = df.groupby(['num_of_dots', 'coherence'])['accuracy'].value_counts().unstack()
print(result)
# # for each num dots. Get the number or 'correct', 'incorrect' and none values
# 
#     
# 
# print(df['coherence'].unique())
# # remove all rows where coherenc is above 1 or below 0
# df = df[df['coherence'] <= 1]
# df = df[df['model'] != 'llama3-8']
# 
# # use get_accuracy to get the accuracy in each row 
# df['accuracy_new'] = df.apply(lambda x: get_accuracy(x['response'], x['direction']), axis=1)
# df.to_csv('psychometric.csv')
# 
# df = df[df['num_of_dots'] == 100]
# 
# 
# # plot the data mean coherence vs mean accuracy with model as hue
# plt.figure(figsize=(10, 6))
# 
# 
# for m in df['model'].unique():
#     plt.plot(df[df['model'] == m].groupby('coherence')['coherence'].mean(), df[df['model'] == m].groupby('coherence')['accuracy_new'].mean(), 'o')
# plt.errorbar(df.groupby('coherence')['coherence'].mean(), df.groupby('coherence')['accuracy_new'].mean())
# 
# 
# # plot the full average
# plt.xlabel('Coherence')
# plt.ylabel('Accuracy')
# plt.title('Accuracy vs Coherence')
# plt.legend(df['model'].unique())
# plt.show()
# 
# 
# 
# # plot the data points
# 
# # plt.plot(df.groupby('coherence')['accuracy'].mean())
# 
# plt.show()

accuracy               False  True  continue
num_of_dots coherence                       
10          0.5          1.0   2.0       7.0
            0.6          NaN   2.0       7.0
            0.7          NaN   3.0       6.0
            0.8          NaN   5.0       4.0
            0.9          NaN   8.0       1.0
            1.0          NaN   8.0       1.0
20          0.5          4.0   NaN       4.0
            0.6          1.0   2.0       5.0
            0.7          NaN   7.0       1.0
            0.8          NaN   7.0       1.0
            0.9          NaN   8.0       NaN
            1.0          NaN   8.0       NaN
30          0.5          1.0   1.0       5.0
            0.6          NaN   4.0       2.0
            0.7          1.0   2.0       3.0
            0.8          NaN   6.0       NaN
            0.9          1.0   5.0       NaN
            1.0          NaN   6.0       NaN
40          0.5          3.0   1.0       2.0
            0.6          NaN   3.0       3.0
          