# General Description:

For CRDM

Plotting choice as a function of the lottery amount, subdivded by lottery probabilities and ambuguity levels.

Importing libraries and mounting Google Drive

In [None]:
%matplotlib widget
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from glob import glob
import os,sys
pd.options.display.max_rows = 999
pd.options.display.max_columns = 999

In [None]:
# split dataframe by gains/losses
def get_by_domain(df,domain='gain',task='crdm',verbose='False'):
    if verbose:
        print('Working on this domain: {}'.format(domain))
    # select by domain: gain/loss
    safe_col = '{}_sure_amt'.format(task)
    if domain=='gain':
        df = df.loc[df[safe_col]>0]
    elif domain=='loss':
        df = df.loc[df[safe_col]<0]
    return df

In [None]:
def count_tuples(listA):
  unique_items = list(set(listA))
  item_count = [listA.count(item) for item in unique_items]
  return unique_items,item_count

In [None]:
def tabulate_col(fn,df,col='crdm_sure_amt'):
    count_df = df[col].value_counts().sort_index()
    count_df = count_df.reset_index()
    print(count_df)
    #create .csv file with this info
    print("Saving to: {}".format(fn))
    if not os.path.exists(os.path.dirname(fn)):
        os.makedirs(os.path.dirname(fn))
    count_df.to_csv(fn)

In [None]:
def count_lott_p_sure_amt_amb(fn='',df=[]):
    if df.empty:
        df = pd.read_csv(fn)
        df['crdm_lott_amt'] = df['crdm_lott_top'] + df['crdm_lott_bot']
        #0's are now choosing immediate, 1 is choosing delay
        # df['crdm_trial_resp.corr'] = 1.0 - df['crdm_trial_resp.corr']
        # get unique amounts that are "task" trials
        df = get_by_domain(df,domain='gain',task='crdm',verbose=True)
    for col in ['crdm_sure_amt','crdm_amb_lev','crdm_lott_p','crdm_lott_amt']:
        fn = os.path.join('csv','{}.csv'.format(col))
        tabulate_col(fn,df,col=col)

In [None]:
def drop_blank(df):
    df_len = df.shape[0]
    df['responded'] = df['crdm_choice'].notna()
    if not df['responded'].all():
        non_responses_nb = df['responded'].value_counts()[False]
        print('\n**WARNING** We dropped {0} of {1} non responses that were left blank'.format(non_responses_nb,df_len))
        df = df.loc[df['responded'],:].reset_index(drop=True)    
    return df
        

In [None]:
def get_subject(fn):
  subj = os.path.basename(fn).replace('_crdm.csv','')
  return subj

In [None]:
def plot_3D_choice(utility_dir,df,subj,trials='risky'):
    # cols = ['crdm_amb_lev','crdm_sure_amt','crdm_lott_p','crdm_lott_amt']
    xcol = 'crdm_lott_amt'
    ycol = 'crdm_sure_amt'
    if trials=='risky':
        # select no ambiguous trials
        df = df.loc[df['crdm_amb_lev']==0]
        zcol = 'crdm_lott_p'
        zlabel = 'probability lottery'
    elif trials=='ambiguity':
        df = df.loc[df['crdm_amb_lev']>0]
        zcol = 'crdm_amb_lev'
        zlabel = 'ambiguity level'
    count_lott_p_sure_amt_amb(df=df)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    xrange = (df[xcol].max() - df[xcol].min())
    x = np.arange(df[xcol].min(),df[xcol].max(),xrange/20.0)
    if df[ycol].unique().shape[0]==1:
        # only one choice for sure amount = 0.50
        yrange = 1.0
        y = np.arange(0,1,yrange/20.0)
    else:
        yrange = (df[ycol].max() - df[ycol].min())
        y = np.arange(df[ycol].min(),df[ycol].max(),yrange/20.0)
    xx, yy = np.meshgrid(x, y)
    for z in df[zcol].unique():
        zz = np.full(xx.shape,z)
        ax.plot_surface(xx,yy,zz,alpha=0.1,color='k')
    colors = ['b','r']
    labels = ['safe bet', 'lottery']
    for c in df['crdm_choice'].unique():
        ax.scatter(df.loc[(df['crdm_choice']==c),xcol],df.loc[(df['crdm_choice']==c),ycol],
                df.loc[(df['crdm_choice']==c),zcol],label=labels[int(c)],color=colors[int(c)])
        # zlabel = 'probability lottery'
    
    plt.xticks(np.arange(df[xcol].min(), df[xcol].max()+0.1*xrange, xrange/2.0))
    ax.set_zticks(df[zcol].unique().tolist())
    if df[ycol].unique().shape[0]==1:
        plt.yticks(np.arange(0, 1.1, yrange/2.0))
    else:
        plt.yticks(np.arange(df[ycol].min(), df[ycol].max()+0.1*yrange, yrange/2.0))
    ax.set_xlabel('lottery amount ($)')
    ax.view_init(elev=10,azim=-35)
    ax.set_ylabel('safe amount ($)')
    ax.set_zlabel(zlabel)
    ax.legend()
    plt.title('{} {}'.format(subj,trials))
    plt.gca().invert_xaxis()
    plt.tight_layout()
    fig_fn = os.path.join(utility_dir,subj,'crdm/{}_crdm_{}_choices.png'.format(subj,trials))
    if not os.path.exists(os.path.dirname(fig_fn)):
        os.makedirs(os.path.dirname(fig_fn))
    print("Saving to: {}".format(fig_fn))
    plt.savefig(fig_fn)
    plt.show()
    


In [None]:
split_dir = '/Volumes/UCDN/datasets/IDM_ado/split/'
utility_dir = '/Volumes/UCDN/datasets/IDM_ado/utility/'

cols = ['crdm_amb_lev','crdm_sure_amt','crdm_lott_p','crdm_lott_amt']

#get set of all good data files for analysis
good_files = sorted(glob(os.path.join(split_dir, '*/crdm/*.csv')))
if (not good_files):
    print("No good files available. Check file path.")
    sys.exit()

fn_list = sorted(glob(os.path.join(split_dir,'*/crdm/*.csv')))

for idx, subj_fn in enumerate(fn_list):
    subj=os.path.basename(subj_fn).replace('_crdm.csv','')
    # count_lott_p_sure_amt_amb(subj_fn)
    domain = 'gain'
    print(subj_fn)
    df = pd.read_csv(subj_fn)
    df = get_by_domain(df,domain=domain,task='crdm',verbose=True)
    df = drop_blank(df)
    df['crdm_lott_amt'] = df['crdm_lott_top'] + df['crdm_lott_bot']

    plot_3D_choice(utility_dir,df,subj,trials='risky')
    plot_3D_choice(utility_dir,df,subj,trials='ambiguity')



In [None]:
# def plot_2d_choice(idx,df,subj_fn,utility_dir,domain):
#     cols = ['crdm_amb_lev','crdm_sure_amt','crdm_lott_p','crdm_lott_amt']

#     # crdm_lott_p
#     xcol=cols[2]
#     # ambig lev
#     ycol=cols[0]

#     # ylabels = ['Immediate $2','Immediate $5', 'Immediate $15']
#     # titles = ['Delay Wait: Now', 'Delay Wait: 1 month','Delay Wait: 3 months','Delay Wait: 12 months','Delay Wait: 60 months']

#     xax = cols[3]
#     yax = 'crdm_choice'


#     ylabels = ['Amb 0.0','Amb 24.0', 'Amb 50.0', 'Amb 74.0']
#     xtitle = ['Lott 13.0p', 'Lott 25.0p','Lott 38.0p','Lott 50.0p','Lott 75.0p']
#     # ylabels = get_labels(df,col=ycol)
#     # xtitle = get_labels(df,col=xcol)

#     plt.figure(idx,figsize=(20,10))
#     xcol_vals = sorted(df[xcol].unique())
#     ycol_vals = sorted(df[ycol].unique())
#     index=0
#     #loop through each delay-wait subgroup for each smaller sooner amount 
#     for iy, yv in enumerate(ycol_vals):
#         for ix, xv in enumerate(xcol_vals):
#             # index = 1+iy+(len(xcol_vals)*ix)
#             # print('({},{}) with ({},{}) and index: {}'.format(ix,iy,xv,yv,index))
#             # print('title:{} ylabel: {}'.format(xtitle[ix],ylabels[iy]))
#             index += 1
#             plt.subplot(len(ycol_vals),len(xcol_vals),index)
#             plt.ylim([-0.1,1.1])
#             plt.xlim([df[xax].min()-1,df[xax].max()+1])
#             if domain=='gain':
#                 plt.xlim([0,df[xax].max()+1])
#             elif domain=='loss':
#                 plt.xlim([df[xax].min()-1,0])

#             if (ix == 0):
#                 plt.ylabel(ylabels[iy],fontsize=12)
#                 plt.yticks([0,1],['0-safe','1-lottery'])
#             else:
#                 plt.yticks([0,1])
#             if (iy == 0):
#                 plt.title(xtitle[ix],fontsize=12)
#             if (iy == len(ycol_vals)-1):
#                 plt.xlabel(xax,fontsize=15)
      
#             x = []
#             y = []
#             #get dataframe with the appropriate smaller sooner and delay-wait time values
#             idf = df.loc[(df[xcol] == xv) & (df[ycol] == yv)]
#             if not len(idf):
#                 continue
#             x = x + idf[xax].tolist()
#             y = y + idf[yax].tolist()
      
#             #sort by x and y in order to connect lines properly on final plots
#             x,y = zip(*sorted(zip(x,y)))
#             plt.plot(x,y,'*-')


#     subj = get_subject(subj_fn)

#     subj_crdm_dir = os.path.join(utility_dir,subj,'crdm')
#     if not os.path.exists(subj_crdm_dir):
#         print('Making subjects crdm directory : {}'.format(subj_crdm_dir))
#         os.makedirs(subj_crdm_dir)

#     plt.suptitle('{} {}'.format(subj,domain), fontsize=25)
#     fig_fn = os.path.join(utility_dir,subj,'crdm/{}_crdm_plot_lottery_amt_choice_{}.png'.format(subj,domain))
#     print("Saving to: {}".format(fig_fn))
#     plt.savefig(fig_fn)
#     plt.show()

