## install & load packages

In [4]:
!pip -q install git+https://github.com/mwshinn/PyDDM
import pyddm
import pyddm.plot
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pyddm import Sample

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for pyddm (pyproject.toml) ... [?25l[?25hdone


## prep data

In [5]:
# Load data - need to manually upload inference_tidy.csv into the files tab on each new runtime
df = pd.read_csv('inference_tidy.csv')
df = df.dropna(subset=['RT'])

# 8 drifts

In [6]:
def eightDrifts(t, trueCongruence, signal1_onset, noise2_onset, signal2_onset,
                      noise1Drift_80, noise1Drift_50, signal1Drift_80, signal1Drift_50,
                      noise2Drift_80, noise2Drift_50, signal2Drift_80, signal2Drift_50):
  # drift rate during first noise period
  if t < signal1_onset:
    if trueCongruence == 'congruent':
      return noise1Drift_80
    elif trueCongruence == 'incongruent':
      return -noise1Drift_80
    else:
      return noise1Drift_50

  # drift rates during first signal period
  if t >= signal1_onset and t < noise2_onset:
    if trueCongruence == 'congruent':
      return signal1Drift_80
    elif trueCongruence == 'incongruent':
      return -signal1Drift_80
    else:
      return signal1Drift_50

  # drift rates during the second noise period
  if t >= noise2_onset and t < signal2_onset:
    if trueCongruence == 'congruent':
      return noise2Drift_80
    elif trueCongruence == 'incongruent':
      return -noise2Drift_80
    else:
      return noise2Drift_50

  # drift rates during the second signal period
  if t >= signal2_onset:
    if trueCongruence == 'congruent':
      return signal2Drift_80
    elif trueCongruence == 'incongruent':
      return -signal2Drift_80
    else:
      return signal2Drift_50

# specify model
model = pyddm.gddm(
    drift = eightDrifts,
    starting_position = 0,
    bound="B",
    T_dur = 4.1,
    nondecision='ndt',
    parameters={'B': (0.01, 10), 'ndt': (0.01, 0.5),
                'noise1Drift_80': (0, 3), 'noise1Drift_50': (0,3),
                'signal1Drift_80': (0, 3), 'signal1Drift_50': (0, 3),
                'noise2Drift_80': (0, 3), 'noise2Drift_50': (0,3),
                'signal2Drift_80': (0, 3), 'signal2Drift_50': (0, 3)},
    conditions = ['trueCue', 'trueCongruence', 'coherence', 'signal1_onset', 'noise2_onset', 'signal2_onset']
)


pyddm.plot.model_gui_jupyter(model, conditions={'trueCue': [0.5, 0.8],
                                                'trueCongruence': ['congruent', 'incongruent', 'neutral'],
                                                'signal1_onset': [0.8],
                                                'noise2_onset': [1.2],
                                                'signal2_onset': [2.2]})

HBox(children=(VBox(children=(FloatSlider(value=2.45412500489799, continuous_update=False, description='noise1…

Output()

In [None]:
pyddm.plot.model_gui_jupyter(model, conditions={'trueCue': [0.5, 0.8],
                                                'trueCongruence': ['congruent', 'incongruent', 'neutral'],
                                                'signal1_onset': [0.8],
                                                'noise2_onset': [1.2],
                                                'signal2_onset': [2.2]})



# 12 drifts - allow for different rates on congruent & incongruent trials

In [5]:
def twelveDrifts(t, trueCongruence, signal1_onset, noise2_onset, signal2_onset,
                      n1_80_cong, n1_80_incong, n1_50, s1_80_cong, s1_80_incong, s1_50,
                      n2_80_cong, n2_80_incong, n2_50, s2_80_cong, s2_80_incong, s2_50):
  # drift rate during first noise period
  if t < signal1_onset:
    if trueCongruence == 'congruent':
      return n1_80_cong
    elif trueCongruence == 'incongruent':
      return n1_80_incong
    else:
      return n1_50

  # drift rates during first signal period
  if t >= signal1_onset and t < noise2_onset:
    if trueCongruence == 'congruent':
      return s1_80_cong
    elif trueCongruence == 'incongruent':
      return s1_80_incong
    else:
      return s1_50

  # drift rates during the second noise period
  if t >= noise2_onset and t < signal2_onset:
    if trueCongruence == 'congruent':
      return n2_80_cong
    elif trueCongruence == 'incongruent':
      return n2_80_incong
    else:
      return n2_50

  # drift rates during the second signal period
  if t >= signal2_onset:
    if trueCongruence == 'congruent':
      return s2_80_cong
    elif trueCongruence == 'incongruent':
      return s2_80_incong
    else:
      return s2_50

# specify model
model = pyddm.gddm(
    drift = twelveDrifts,
    starting_position = 0,
    bound="B",
    T_dur = 4.1,
    nondecision='ndt',
    parameters={'B': (0.01, 10), 'ndt': (0.01, 0.5),
                'n1_80_cong': (0, 3), 'n1_80_incong': (-3,0), 'n1_50': (0,3),
                's1_80_cong': (0, 3), 's1_80_incong': (-3, 0), 's1_50': (0, 3),
                'n2_80_cong': (0, 3), 'n2_80_incong': (-3,0), 'n2_50': (0,3),
                's2_80_cong': (0, 3), 's2_80_incong': (-3, 0), 's2_50': (0,3)},
    conditions = ['trueCue', 'trueCongruence', 'coherence', 'signal1_onset', 'noise2_onset', 'signal2_onset']
)


pyddm.plot.model_gui_jupyter(model, conditions={'trueCue': [0.5, 0.8],
                                                'trueCongruence': ['congruent', 'incongruent', 'neutral'],
                                                'signal1_onset': [0.8],
                                                'noise2_onset': [1.2],
                                                'signal2_onset': [2.2]})

HBox(children=(VBox(children=(FloatSlider(value=0.21297746064548867, continuous_update=False, description='n1_…

Output()

In [7]:
import pickle

In [11]:
test_model = pyddm.gddm()

test_model

pickle.dump(test_model, open('test_model.p', 'wb'))

In [13]:
fitted_model = pickle.load(open('test_model.p', 'rb'))

fitted_model

Model(name='', drift=DriftConstant(drift=0), noise=NoiseConstant(noise=1), bound=BoundConstant(B=1), IC=ICPointRatio(x0=0), overlay=OverlayChain(overlays=[OverlayNonDecision(nondectime=0), OverlayUniformMixture(umixturecoef=0.02)]), dx=0.005, dt=0.005, T_dur=2.0)