In [1]:
%%capture
! pip install arviz==0.11.00
! pip install pymc3==3.10.0

In [4]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import seaborn as sns
from pymc3 import math

from theano import tensor as tt
from scipy import stats

# Chapter 10 - Memory retention
  
This Chapter is about estimating the relationship between memory retention and time.
The model being considered is a simplified version of the exponential decay model. The model assumes that the probability that an item will be remembered after a period of time $t$ has elapsed is $\theta_{t} = \text{exp}(−\alpha t)+\beta$, with the restriction $0 < \theta_{t} < 1$. The $\alpha$ parameter corresponds to the rate of decay of information. The $\beta$ parameter corresponds to a baseline level of remembering that is assumed to remain even after very long time periods.
  
## 10.1 No individual differences


$$ \alpha \sim \text{Beta}(1,1)$$
$$ \beta \sim \text{Beta}(1,1)$$
$$ \theta_{j} = \text{min}(1,\text{exp}(−\alpha t_{j})+\beta)$$
$$ k_{ij} \sim \text{Binomial}(\theta_{j},n)$$

The above model is very sensitive to the starting value. We can specify a starting value for each parameter by assigning a `testval` when the RV is created:
```python
alpha = pm.Beta('alpha', alpha=1, beta=1, testval=.30)
```

In fact, with a bad starting value, NUTS really has a hard time sampling, and we get a `Bad initial energy` error. The reason is that bounding the theta gives 0 gradient, which is a problem as NUTS needs th gradient to work.

In [7]:
t = np.array([1, 2, 4, 7, 12, 21, 35, 59, 99, 200])
nt = len(t)
# slist = [0,1,2,3]
ns = 4
tmat = np.repeat(t, ns).reshape(nt, -1).T
k1 = np.ma.masked_values([18, 18, 16, 13, 9, 6, 4, 4, 4, -999,
                          17, 13,  9,  6, 4, 4, 4, 4, 4, -999,
                          14, 10,  6,  4, 4, 4, 4, 4, 4, -999,
                          -999, -999, -999, -999, -999, -999, -999, -999, -999, -999], 
                          value=-999).reshape(ns,-1)
n = 18