In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import scipy

from torch import distributions as dist
from scipy.special import logsumexp, expit
from scipy.stats import entropy

from tqdm.notebook import tqdm

from copy import deepcopy

In [3]:
a = torch.rand((3, 3))
b = torch.rand((3, 3))

In [8]:
var = 0
for i in range(3):
    var -= (a[i] * b[i]).sum()
print(var)

tensor(-1.8267)


In [7]:
torch.mul(a, b).sum()

tensor(1.8267)

In [2]:
EPS = np.finfo(np.float64).eps

In [3]:
def cartesian_product(*arrays):

    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        arr[..., i] = a
    return arr.reshape(-1, la)

def cp_grid_param(grid_size, bounds, methods):
    
    bds = np.asarray(bounds)
    mths = np.asarray(methods)

    diff = bds[:, 1] - bds[:, 0] > 0
    not_diff = np.invert(diff)

    values = np.atleast_2d(
        [m(*b, num=grid_size) for (b, m) in
         zip(bds[diff], mths[diff])])

    var = cartesian_product(*values)
    grid = np.zeros((max(1, len(var)), len(bds)))
    if np.sum(diff):
        grid[:, diff] = var
    if np.sum(not_diff):
        grid[:, not_diff] = bds[not_diff, 0]

    return grid


bounds = [1., 1.], [0.0, 0.5]
grid = cp_grid_param(grid_size=100, bounds=bounds, methods=(np.geomspace, np.linspace))

In [4]:
grid.shape

(100, 2)

In [5]:
n_pres = np.array([2, 0])
delta = np.array([1., 0.])

n_item = len(n_pres)
n_param = len(bounds)

prior = np.ones(len(grid))
prior -= scipy.special.logsumexp(prior)

item = 0

init_fr, rep_effect = grid.T
logp_success = -init_fr*(1-rep_effect)**(n_pres[item]-1)*delta[item]

post_success = prior + logp_success
post_success -= logsumexp(post_success)
ig_success = entropy(np.exp(post_success), np.exp(prior))

post_failure = prior + np.log(1 - np.exp(logp_success))
post_failure -= logsumexp(post_failure)
ig_failure = entropy(np.exp(post_failure), np.exp(prior))

marg_p_success = np.sum(np.exp(prior + logp_success))

print(ig_success, ig_failure)
print(marg_p_success*ig_success + (1 - marg_p_success)*ig_failure)

0.010559783919909747 0.008949329037300823
0.009718164172281003


In [6]:
n_epoch = 100
n_sample = 10
lr = 0.5

n_pres = np.arange(0, 5)
delta = np.ones(len(n_pres))

n_item = len(n_pres)
n_param = len(bounds)

prior_action = dist.Categorical(logits=torch.ones(n_item))

logits_action = torch.nn.Parameter(torch.ones(n_item))

opt = torch.optim.Adam([logits_action, ], lr=lr)

prior = np.ones(len(grid))
prior -= logsumexp(prior)

with tqdm(total=n_epoch) as pbar:
    for _ in range(n_epoch):

        opt.zero_grad()
        
        q = dist.Categorical(logits=logits_action)
        item_samples = q.sample((n_sample, ))
        
        sum_terms = 0
        
        for item in item_samples:
            if n_pres[item] == 0:
                expected_ig = 0.0
            else:
                init_fr, rep_effect = grid.T
                logp_success = -init_fr*(1-rep_effect)**(n_pres[item]-1)*delta[item]

                post_success = prior + logp_success
                post_success -= logsumexp(post_success)
                ig_success = entropy(np.exp(post_success), np.exp(prior))

                post_failure = prior + np.log(1 - np.exp(logp_success))
                post_failure -= logsumexp(post_failure)
                ig_failure = entropy(np.exp(post_failure), np.exp(prior))

                marg_p_success = np.sum(np.exp(prior + logp_success))

                expected_ig = marg_p_success*ig_success + (1 - marg_p_success)*ig_failure

            sum_terms += q.log_prob(item).exp() * expected_ig
        
        loss = - torch.log(sum_terms) + torch.distributions.kl_divergence(q, prior_action)

        loss.backward()
        opt.step()
        
        pbar.set_postfix({"loss": loss.item()})
        pbar.update()

dist.Categorical(logits=logits_action).probs.data

  0%|          | 0/100 [00:00<?, ?it/s]

tensor([0.1070, 0.1070, 0.0975, 0.1394, 0.5491])

### Trajectories

In [7]:
current_n_pres = np.arange(1, 3)
current_delta = np.ones(len(current_n_pres))

n_item = len(current_n_pres)
n_param = len(bounds)

prior = np.ones(len(grid))
prior -= scipy.special.logsumexp(prior)

n_step = 2
inter_trial_interval = 2.

all_traj = cartesian_product(*[np.arange(n_item) for _ in range(n_step)])
results = []

for traj in all_traj:
    expected_ig = 0. 
    
    n_pres = deepcopy(current_n_pres)
    delta = deepcopy(current_delta)
    
    for item in traj:

        if n_pres[item] == 0:
            expected_ig += 0.0
        else:
            init_fr, rep_effect = grid.T
            logp_success = -init_fr*(1-rep_effect)**(n_pres[item]-1)*delta[item]

            post_success = prior + logp_success
            post_success -= logsumexp(post_success)
            ig_success = entropy(np.exp(post_success), np.exp(prior))

            post_failure = prior + np.log(1 - np.exp(logp_success))
            post_failure -= logsumexp(post_failure)
            ig_failure = entropy(np.exp(post_failure), np.exp(prior))

            marg_p_success = np.sum(np.exp(prior + logp_success))

            expected_ig += marg_p_success*ig_success + (1 - marg_p_success)*ig_failure
        
        n_pres[item] += 1
        delta[item] = inter_trial_interval
    
    print(traj, expected_ig)
    
    results.append(expected_ig)
# print(results)

[0 0] 0.012636287668632572
[0 1] 0.009718164172281003
[1 0] 0.009718164172281003
[1 1] 0.05421082840397357


In [8]:
n_epoch = 1000
n_sample = 10
lr = 0.2

current_n_pres = np.arange(1, 3)
current_delta = np.ones(len(current_n_pres))

n_item = len(current_n_pres)
n_param = len(bounds)

n_step = 2
inter_trial_interval = 3.

prior_action = dist.Categorical(logits=torch.ones(n_item))

logits_action = torch.nn.Parameter(torch.ones((n_step, n_item)))

opt = torch.optim.Adam([logits_action, ], lr=lr)

prior = np.ones(len(grid))
prior -= logsumexp(prior)

with tqdm(total=n_epoch) as pbar:
    for _ in range(n_epoch):

        opt.zero_grad()
        
        for smp in range(n_sample):
            
            n_pres = deepcopy(current_n_pres)
            delta = deepcopy(current_delta)
            
            sum_terms = 0
            sum_kl_div_prior = 0
            
            for t in range(n_step):
                        
                q = dist.Categorical(logits=logits_action[t])
                item = q.sample()

                if n_pres[item] == 0:
                    expected_ig = 0.0

                else:
                    init_fr, rep_effect = grid.T
                    logp_success = -init_fr*(1-rep_effect)**(n_pres[item]-1)*delta[item]

                    post_success = prior + logp_success
                    post_success -= logsumexp(post_success)
                    ig_success = entropy(np.exp(post_success), np.exp(prior))

                    post_failure = prior + np.log(1 - np.exp(logp_success))
                    post_failure -= logsumexp(post_failure)
                    ig_failure = entropy(np.exp(post_failure), np.exp(prior))

                    marg_p_success = np.sum(np.exp(prior + logp_success))

                    expected_ig = marg_p_success*ig_success + (1 - marg_p_success)*ig_failure

                sum_terms += q.log_prob(item).exp() * expected_ig      
                        
                n_pres[item] += 1
                delta[item] = inter_trial_interval
                
                sum_kl_div_prior += torch.distributions.kl_divergence(q, prior_action)

            loss = - torch.log(sum_terms)

        loss.backward()
        opt.step()
        
        pbar.set_postfix({"loss": loss.item()})
        pbar.update()

dist.Categorical(logits=logits_action).probs

  0%|          | 0/1000 [00:00<?, ?it/s]

In [9]:
n_item = 4

current_n_pres = np.ones(n_item)
current_delta = np.ones(n_item)

n_param = len(bounds)

prior = np.ones(len(grid))
prior -= scipy.special.logsumexp(prior)

n_step = 5
inter_trial_interval = 0.1

all_traj = cartesian_product(*[np.arange(n_item) for _ in range(n_step)])
results = []

for traj in tqdm(all_traj):
    expected_ig = 0. 
    
    n_pres = deepcopy(current_n_pres)
    delta = deepcopy(current_delta)
    
    for item in traj:

        if n_pres[item] == 0:
            expected_ig += 0.0
        else:
            init_fr, rep_effect = grid.T
            logp_success = -init_fr*(1-rep_effect)**(n_pres[item]-1)*delta[item]

            post_success = prior + logp_success
            post_success -= logsumexp(post_success)
            ig_success = entropy(np.exp(post_success), np.exp(prior))

            post_failure = prior + np.log(1 - np.exp(logp_success))
            post_failure -= logsumexp(post_failure)
            ig_failure = entropy(np.exp(post_failure), np.exp(prior))

            marg_p_success = np.sum(np.exp(prior + logp_success))

            expected_ig += marg_p_success*ig_success + (1 - marg_p_success)*ig_failure
        
        n_pres[item] += 1
        delta[item] = inter_trial_interval
    
    # print(traj, expected_ig)
    
    results.append(expected_ig)
# print(results)

In [10]:
all_traj[results == np.max(results)]

### Debug `deltas` computation 

In [11]:
current_iter = 0
current_ss = 0

n_item = 10
n_iter_per_session = 2
time_per_iter = 3
break_length = 10
n_session = 3


delta = np.zeros(n_item)
n_pres = np.zeros(n_item)

trajectory = []
delays = []

done = False
while not done:
    
    item = np.random.choice(np.arange(n_item))
    trajectory.append(item)
    
    current_iter += 1
    if current_iter >= n_iter_per_session:
        current_iter = 0
        current_ss += 1
        time_elapsed = break_length
    else:
        time_elapsed = time_per_iter

    done = False
    if current_ss >= n_session:
        done = True

    # increase delta
    delta += time_elapsed
    # ...specific for item shown
    delta[item] = time_elapsed
    # increment number of presentation
    n_pres[item] += 1
    
    delays.append(time_elapsed)
    

print("delays", delays)
print("n_pres", n_pres)
print("delta", delta)
    

traj = np.asarray(trajectory)

delta = np.zeros(n_item)
n_pres = np.zeros(n_item)

delays = np.tile(
    [time_per_iter for _ in range(n_iter_per_session - 1)] + [break_length, ], 
    n_session)

n_step = len(trajectory)

for item in range(n_item):
    item_pres = traj == item
    n_pres_traj = np.sum(item_pres)
    n_pres[item] += n_pres_traj
    if n_pres_traj == 0:
        delta[item] = np.sum(delays)
    else:
        delta[item] = np.sum(delays[np.arange(n_step)[item_pres][-1]:])
        
print("delays", delays)
print("n_pres", n_pres)
print("delta", delta)

In [12]:
n_item = 3

inv_temp = 10.

threshold = 0.9

n_iter_per_session = 2 
time_per_iter = 2
break_length = 10
n_session = 2

t = 0

time_per_iter = 2
n_iter_per_session = 2 
break_length = 10
n_session = 2

t_max = n_session*n_iter_per_session

t_remaining = t_max - t

initial_forget_rates = np.ones(n_item) * 0.01
repetition_rates = np.ones(n_item) * 0.2

delays = np.tile(
    [time_per_iter 
     for _ in range(n_iter_per_session - 1)] 
    + [break_length, ], 
    n_session)[t:]

current_n_pres = np.zeros(n_item)
current_delta = np.ones(n_item)

print("N possible trajectories", n_item**t_remaining)

all_traj = cartesian_product(*[np.arange(n_item) 
                               for _ in range(t_remaining)])

results = []

for traj in all_traj:

    n_pres = deepcopy(current_n_pres)
    delta = deepcopy(current_delta)

    for item in range(n_item):

        item_pres = traj == item
        n_pres_traj = np.sum(item_pres)
        n_pres[item] += n_pres_traj
        if n_pres_traj == 0:
            delta[item] += np.sum(delays)
        else:
            idx_last_pres = np.arange(t_remaining)[item_pres][-1]
            delta[item] = np.sum(delays[idx_last_pres:])

    p = np.zeros(n_item)

    view = n_pres > 0
    rep = n_pres[view] - 1.
    delta = delta[view]

    init_fr = initial_forget_rates[view]
    rep_eff = repetition_rates[view]

    forget_rate = init_fr * (1 - rep_eff) ** rep
    logp_recall = - forget_rate * delta

    p[n_pres > 0] = np.exp(logp_recall)

    learning_reward = np.mean(expit(inv_temp * (p - threshold)))
    results.append(learning_reward)

max_r = np.max(results)
best = results == max_r
print("Best possible reward", max_r)
print("N best", np.sum(best))
print("Best are:0\n", all_traj[best])

In [68]:
n_item = 3

inv_temp = 100.

threshold = 0.9

n_iter_per_session = 2 
time_per_iter = 2
break_length = 10
n_session = 2

t = 0

time_per_iter = 2
n_iter_per_session = 2 
break_length = 10
n_session = 2

n_epochs = 1000
lr = 0.5

t_max = n_session*n_iter_per_session

t_remaining = t_max - t

logits_action = torch.nn.Parameter(torch.ones((t_remaining, n_item)))

opt = torch.optim.Adam([logits_action, ], lr=lr)

prior_action = dist.Categorical(logits=torch.ones(n_item))

delays = np.tile(
    [time_per_iter 
     for _ in range(n_iter_per_session - 1)] 
    + [break_length, ], 
    n_session)[t:]

n_step = t_remaining

with tqdm(total=n_epochs, position=1, leave=False) as pbar:
    
    for epoch in range(n_epochs):

        opt.zero_grad()

        q = dist.Categorical(logits_action)
        trajectories = q.sample((n_sample, ))

        loss = 0

        for trajectory in trajectories:

            n_pres = deepcopy(current_n_pres)
            delta = deepcopy(current_delta)

            for item in range(n_item):

                item_pres = traj == item
                n_pres_traj = np.sum(item_pres)
                n_pres[item] += n_pres_traj
                if n_pres_traj == 0:
                    delta[item] += np.sum(delays)
                else:
                    idx_last_pres = np.arange(t_remaining)[item_pres][-1]
                    delta[item] = np.sum(delays[idx_last_pres:])

            p = np.zeros(n_item)

            view = n_pres > 0
            rep = n_pres[view] - 1.
            delta = delta[view]

            init_fr = initial_forget_rates[view]
            rep_eff = repetition_rates[view]

            forget_rate = init_fr * (1 - rep_eff) ** rep
            logp_recall = - forget_rate * delta

            p[n_pres > 0] = np.exp(logp_recall)

            learning_reward = np.mean(expit(inv_temp * (p - threshold)))

            loss -= q.log_prob(trajectory).sum().exp() * learning_reward

            for t in range(n_step):
                q = dist.Categorical(logits_action[t])
                loss += torch.distributions.kl_divergence(q, prior_action)

        loss.backward()
        opt.step()

        pbar.set_postfix({"loss": loss.item()})
        pbar.update()


traj = np.argmax(logits_action.detach().numpy(), axis=1)

n_pres = deepcopy(current_n_pres)
delta = deepcopy(current_delta)

for item in range(n_item):

    item_pres = traj == item
    n_pres_traj = np.sum(item_pres)
    n_pres[item] += n_pres_traj
    if n_pres_traj == 0:
        delta[item] += np.sum(delays)
    else:
        idx_last_pres = np.arange(t_remaining)[item_pres][-1]
        delta[item] = np.sum(delays[idx_last_pres:])

p = np.zeros(n_item)

view = n_pres > 0
rep = n_pres[view] - 1.
delta = delta[view]

init_fr = initial_forget_rates[view]
rep_eff = repetition_rates[view]

forget_rate = init_fr * (1 - rep_eff) ** rep
logp_recall = - forget_rate * delta

p[n_pres > 0] = np.exp(logp_recall)

learning_reward = np.mean(expit(inv_temp * (p - threshold)))

print(traj)
print(learning_reward)

  0%|          | 0/1000 [00:00<?, ?it/s]

[1 0 1 0]
0.5365492649422604


In [69]:
all_traj = cartesian_product(*[np.arange(2) 
                               for _ in range(2)])
all_traj

array([[0, 0],
       [0, 1],
       [1, 0],
       [1, 1]])

In [78]:
n_step, n_item = 5, 4
logits_action = torch.rand((n_step, n_item))

loss = 0
p = dist.Categorical(logits=torch.ones(n_item))
for t in range(n_step):
    q = dist.Categorical(logits=logits_action[t])
    loss += torch.distributions.kl_divergence(q, p)
print(loss)

loss = 0
p = dist.Categorical(logits=torch.ones((n_step, n_item)))
q = dist.Categorical(logits=logits_action)
loss = torch.distributions.kl_divergence(q, p).sum()
print(loss)

tensor(0.2134)
tensor(0.2134)
