### Figure 4

Import packages and figure setup

In [None]:
from pandas.api.types import CategoricalDtype
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import datetime
import fed3bandit as f3b
import copy
from scipy.stats import ttest_rel, ttest_ind

plt.rcParams.update({'font.size': 28, 'figure.autolayout': True})

Download data

In [None]:
#Here is the data download

print("Data downloaded succesfully")

Data pre-processing

In [None]:
"""Here we get the slice of the data that goes from 0 to 8 hours after the intra-striatal infusion."""
post_hours = 8

# The saline groups
saline_slices = {}
for session in p_saline_data:
    c_session = p_saline_data[session]
    c_injection = saline_injectiontimes[session]
    s_session = {mouse: c_session[mouse][np.logical_and(c_session[mouse].iloc[:,0] < c_injection + datetime.timedelta(hours=post_hours),
                                                        c_session[mouse].iloc[:,0] > c_injection)] for mouse in c_session}
    saline_slices[session] = s_session
       
# The MK801 groups                               
mk801_slices = {}
for session in p_mk801_data:
    c_session = p_mk801_data[session]
    c_injection = mk801_injectiontimes[session]
    s_session = {mouse: c_session[mouse][np.logical_and(c_session[mouse].iloc[:,0] < c_injection + datetime.timedelta(hours=post_hours),
                                                        c_session[mouse].iloc[:,0] > c_injection)] for mouse in c_session}
    mk801_slices[session] = s_session


### Data Analysis

Sample behavior after saline or MK801 infucsion. Panel 4B

In [None]:
#Sample saline MK801
sample_saline = saline_slices["saline_3"]["C50F5"]
sample_mk801 = mk801_slices["mk801_1"]["C50F5"]

fig, ax = plt.subplots(figsize=(6,3))
c_trueleft = f3b.true_probs(sample_saline)[0].to_list()
c_bactions = f3b.binned_paction(sample_saline)
print(len(c_trueleft), len(c_bactions))
ax.plot(c_trueleft, c="red", linewidth=3)
ax.plot(c_bactions, c="darkcyan", linewidth=3)
plt.axis("off")


fig, ax = plt.subplots(figsize=(6,3))
c_trueleft = f3b.true_probs(sample_mk801)[0].to_list()
c_bactions = f3b.binned_paction(sample_mk801)
print(len(c_trueleft), len(c_bactions))
ax.plot(c_trueleft, c="red", linewidth=3)
ax.plot(c_bactions, c="olive", linewidth=3)
plt.axis("off")

Calculation of pokes, pellets, pokes per pellet, win-stay, and lose-shift metrics.
For documentation and source code of FED3Bandit package, which is used to calculate metrics, please see:

FED3Bandit package documentation: https://fed3bandit.readthedocs.io/en/latest/analysis/fed3live_api.html

FED3Bandit package source code:  https://github.com/AlexLM96/fed3bandit/blob/main/fed3bandit/fed3bandit/fed3bandit.py



In [None]:
# Metrics of the saline group
saline_pellets = {}
saline_pokes = {}
saline_ws = {}
saline_ls = {}
for session in saline_slices:
    #Load data
    c_session = saline_slices[session]
    
    #Calculate metrics
    c_session_pellets = {mouse: [f3b.count_pellets(c_session[mouse])] for mouse in c_session}
    c_session_pokes = {mouse: [f3b.count_pokes(c_session[mouse])] for mouse in c_session}
    c_session_ws = {mouse: [f3b.win_stay(c_session[mouse])] for mouse in c_session}
    c_session_ls = {mouse: [f3b.lose_shift(c_session[mouse])] for mouse in c_session}

    saline_pellets = saline_pellets | c_session_pellets
    saline_pokes = saline_pokes | c_session_pokes
    saline_ws = saline_ws | c_session_ws
    saline_ls = saline_ls | c_session_ls
    
saline_pellets = pd.DataFrame(saline_pellets).T
saline_pokes = pd.DataFrame(saline_pokes).T
saline_ws = pd.DataFrame(saline_ws).T
saline_ls = pd.DataFrame(saline_ls).T
    
# Metrics of the MK801 group
mk801_pellets = {}
mk801_pokes = {}
mk801_ws = {}
mk801_ls = {}
for session in mk801_slices:
    #Load data
    c_session = mk801_slices[session]
    
    #Calculate metrics
    c_session_pellets = {mouse: [f3b.count_pellets(c_session[mouse])] for mouse in c_session}
    c_session_pokes = {mouse: [f3b.count_pokes(c_session[mouse])] for mouse in c_session}
    c_session_ws = {mouse: [f3b.win_stay(c_session[mouse])] for mouse in c_session}
    c_session_ls = {mouse: [f3b.lose_shift(c_session[mouse])] for mouse in c_session}

    mk801_pellets = mk801_pellets | c_session_pellets
    mk801_pokes = mk801_pokes | c_session_pokes
    mk801_ws = mk801_ws | c_session_ws
    mk801_ls = mk801_ls | c_session_ls
    
mk801_pellets = pd.DataFrame(mk801_pellets).T
mk801_pokes = pd.DataFrame(mk801_pokes).T
mk801_ws = pd.DataFrame(mk801_ws).T
mk801_ls = pd.DataFrame(mk801_ls).T

Concatenating saline and mk801 groups and preparing for plotting.

In [None]:
# Pellets
all_pellets = pd.concat([saline_pellets, mk801_pellets], axis=1).reset_index()
all_pellets.columns = ["Mouse", "Sal", "MK801"]
m_all_pellets = pd.melt(all_pellets, id_vars="Mouse")
m_all_pellets["variable"] = m_all_pellets["variable"].astype(cat_size_order)
m_all_pellets = m_all_pellets.sort_values(by="variable")

# Pokes
all_pokes = pd.concat([saline_pokes, mk801_pokes], axis=1).reset_index()
all_pokes.columns = ["Mouse", "Sal", "MK801"]
m_all_pokes = pd.melt(all_pokes, id_vars="Mouse")
m_all_pokes["variable"] = m_all_pokes["variable"].astype(cat_size_order)
m_all_pokes = m_all_pokes.sort_values(by="variable")

#Win-stay
all_ws = pd.concat([saline_ws, mk801_ws], axis=1).reset_index()
all_ws.columns = ["Mouse", "Sal", "MK801"]
m_all_ws = pd.melt(all_ws, id_vars="Mouse")
m_all_ws["variable"] = m_all_ws["variable"].astype(cat_size_order)
m_all_ws = m_all_ws.sort_values(by="variable")

#Lose-shift
all_ls = pd.concat([saline_ls, mk801_ls], axis=1).reset_index()
all_ls.columns = ["Mouse", "Sal", "MK801"]
m_all_ls = pd.melt(all_ls, id_vars="Mouse")
m_all_ls["variable"] = m_all_ls["variable"].astype(cat_size_order)
m_all_ls = m_all_ls.sort_values(by="variable")

Pellets plot and statistics (Panel 4C)

In [None]:
fig, ax = plt.subplots(figsize=(5, 8))
sns.boxplot(x="variable", y="value", data=m_all_pellets, palette=[
            "darkcyan", "olive", "salmon"], boxprops={"linewidth": 2.5}, whiskerprops={"linewidth": 2.5})
sns.swarmplot(x="variable", y="value", data=m_all_pellets, palette=["silver", "silver", "silver"], s=10)
ax.set_ylabel("Pellets")
ax.set_xlabel("")
sns.despine()
ax.spines["bottom"].set_linewidth(2)
ax.spines["left"].set_linewidth(2)
ax.set_ylim(0,200)
ax.set_yticks(np.arange(0,180,40))

# Run independent ttest
mk801_pellets_ttest = ttest_ind(m_all_pellets["value"][m_all_pellets["variable"]== "Sal"], 
                                m_all_pellets["value"][m_all_pellets["variable"] == "MK801"], nan_policy="omit")

print(mk801_pellets_ttest)

Pokes plot and statistics (Panel 4D)

In [None]:
fig, ax = plt.subplots(figsize=(5, 8))
sns.boxplot(x="variable", y="value", data=m_all_pokes, palette=[
            "darkcyan", "olive"], boxprops={"linewidth": 2.5}, whiskerprops={"linewidth": 2.5})
sns.swarmplot(x="variable", y="value", data=m_all_pokes,
              palette=["silver", "silver"], s=10)
ax.set_ylabel("Pokes")
ax.set_xlabel("")
ax.spines["bottom"].set_linewidth(2)
ax.spines["left"].set_linewidth(2)
ax.set_ylim(0,340)
ax.set_yticks(np.arange(0,340,80))
sns.despine()

# Run independent ttest
mk801_pokes_ttest = ttest_ind(m_all_pokes["value"][m_all_pokes["variable"]== "Sal"], 
                                m_all_pokes["value"][m_all_pokes["variable"] == "MK801"], nan_policy="omit")
