/
figure2d_training_probability.py
136 lines (117 loc) · 6.3 KB
/
figure2d_training_probability.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantify the variability of the time to trained over labs.
@author: Guido Meijer, Miles Wells
16 Jan 2020
"""
from os.path import join
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import seaborn as sns
from ibl_pipeline import subject
from ibl_pipeline.analyses import behavior as behavior_analysis
from paper_behavior_functions import (seaborn_style, institution_map, query_subjects,
group_colors, figpath, load_csv, CUTOFF_DATE,
FIGURE_HEIGHT, FIGURE_WIDTH, QUERY)
from lifelines import KaplanMeierFitter
# Settings
fig_path = figpath()
seaborn_style()
if QUERY is True:
mice_started_training = query_subjects(criterion=None)
still_training = (mice_started_training.aggr(behavior_analysis.SessionTrainingStatus,
session_start_time='max(session_start_time)')
* behavior_analysis.SessionTrainingStatus - subject.Death
& 'training_status = "in_training"'
& 'session_start_time > "%s"' % CUTOFF_DATE)
use_subjects = mice_started_training - still_training
# Get training status and training time in number of sessions and trials
ses = ((use_subjects * behavior_analysis.SessionTrainingStatus * behavior_analysis.PsychResults)
.proj('subject_nickname', 'training_status', 'n_trials_stim', 'institution_short')
.fetch(format='frame').reset_index())
ses['n_trials'] = [sum(i) for i in ses['n_trials_stim']]
ses = ses.drop('n_trials_stim', axis=1).dropna()
ses = ses.sort_values(['subject_nickname','session_start_time'])
else:
# Load in sessions from csv file
ses = load_csv('Fig2d.csv').dropna()
# Select mice that started training before cut off date
ses = ses.groupby('subject_uuid').filter(
lambda s : s['session_start_time'].min() < CUTOFF_DATE)
# Construct dataframe from query
training_time = pd.DataFrame()
for i, nickname in enumerate(ses['subject_nickname'].unique()):
training_time.loc[i, 'nickname'] = nickname
training_time.loc[i, 'lab'] = ses.loc[ses['subject_nickname'] == nickname,
'institution_short'].values[0]
training_time.loc[i, 'sessions'] = sum((ses['subject_nickname'] == nickname)
& ((ses['training_status'] == 'in_training')
| (ses['training_status'] == 'untrainable')))
training_time.loc[i, 'trials'] = ses.loc[((ses['subject_nickname'] == nickname)
& (ses['training_status'] == 'in_training')),
'n_trials'].sum()
training_time.loc[i, 'status'] = ses.loc[ses['subject_nickname'] == nickname,
'training_status'].values[-1]
training_time.loc[i, 'date'] = ses.loc[ses['subject_nickname'] == nickname,
'session_start_time'].values[-1]
# Transform training status into boolean
training_time['trained'] = np.nan
training_time.loc[((training_time['status'] == 'untrainable')
| (training_time['status'] == 'in_training')), 'trained'] = 0
training_time.loc[((training_time['status'] != 'untrainable')
& (training_time['status'] != 'in_training')), 'trained'] = 1
# Add lab number
training_time['lab_number'] = training_time.lab.map(institution_map()[0])
training_time = training_time.sort_values('lab_number')
# %% PLOT
# Set figure style and color palette
use_palette = [[0.6, 0.6, 0.6]] * len(np.unique(training_time['lab']))
use_palette = use_palette + [[1, 1, 0.2]]
lab_colors = group_colors()
ylim = [-0.02, 1.02]
# Plot hazard rate survival analysis
f, (ax1) = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT))
kmf = KaplanMeierFitter()
for i, lab in enumerate(np.unique(training_time['lab_number'])):
kmf.fit(training_time.loc[training_time['lab_number'] == lab, 'sessions'].values,
event_observed=training_time.loc[training_time['lab_number'] == lab, 'trained'])
ax1.step(kmf.cumulative_density_.index.values, kmf.cumulative_density_.values,
color=lab_colors[i])
kmf.fit(training_time['sessions'].values, event_observed=training_time['trained'])
ax1.step(kmf.cumulative_density_.index.values, kmf.cumulative_density_.values, color='black')
ax1.set(ylabel='Reached proficiency', xlabel='Training day',
xlim=[0, 60], ylim=ylim)
ax1.set_title('All labs: %d mice' % training_time['nickname'].nunique())
# kmf.fit(training_time['sessions'].values, event_observed=training_time['trained'])
# kmf.plot_cumulative_density(ax=ax2)
# ax2.set(ylabel='Cumulative probability of\nreaching trained criterion', xlabel='Training day',
# title='All labs', xlim=[0, 60], ylim=[0, 1.02])
# ax2.get_legend().set_visible(False)
sns.despine(trim=True, offset=5)
plt.tight_layout()
seaborn_style()
plt.savefig(join(fig_path, 'figure2d_probability_trained.pdf'))
plt.savefig(join(fig_path, 'figure2d_probability_trained.png'), dpi=300)
# Plot the same figure as a function of trial number
f, (ax1) = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/3, FIGURE_HEIGHT))
kmf = KaplanMeierFitter()
for i, lab in enumerate(np.unique(training_time['lab_number'])):
kmf.fit(training_time.loc[training_time['lab_number'] == lab, 'trials'].values,
event_observed=training_time.loc[training_time['lab_number'] == lab, 'trained'])
ax1.step(kmf.cumulative_density_.index.values, kmf.cumulative_density_.values,
color=lab_colors[i])
kmf.fit(training_time['trials'].values, event_observed=training_time['trained'])
ax1.step(kmf.cumulative_density_.index.values, kmf.cumulative_density_.values, color='black')
ax1.set(ylabel='Reached proficiency', xlabel='Trial',
xlim=[0, 40e3], ylim=ylim)
format_fcn = ticker.FuncFormatter(lambda x, pos: '{:,.0f}'.format(x / 1e3) + 'K')
ax1.xaxis.set_major_formatter(format_fcn)
ax1.set_title('All labs: %d mice' % training_time['nickname'].nunique())
sns.despine(trim=True, offset=5)
plt.tight_layout()
seaborn_style()
plt.savefig(join(fig_path, 'figure2d_probability_trained_trials.pdf'))
plt.savefig(join(fig_path, 'figure2d_probability_trained_trials.png'), dpi=300)