In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import seaborn as sns
import yaml
from daart.data import DataGenerator, compute_sequence_pad
from daart.eval import get_precision_recall, run_lengths
from daart.io import get_expt_dir
from daart.transforms import ZScore

from daart_utils.data import DataHandler
from daart_utils.models import compute_model_predictions, get_default_hparams
#from daart_utils.paths import data_path, results_path
from daart_utils.plotting import plot_heatmaps
import ssm
from ssm.util import random_rotation, find_permutation

In [3]:
import torch
from daart.io import get_expt_dir, find_experiment
from daart.models import Segmenter, RSLDS

#from daart_utils.session_ids.fly import SESS_IDS_TRAIN_5, SESS_IDS_TEST
#from daart_utils.session_ids.fly import label_names

#ssm labels
label_names = ['right', 'left', 'top', 'bottom']

# save predicted states from models
save_states = True
# overwrite predicted states from models
overwrite_states = False
# compute state statistics like median bout duration and behavior ratios
compute_state_stats = False

dataset = 'ssm'
input_type = 'markers'
# input_type = 'features-simba'
sequence_length = 500
batch_size = 8
ignore_background = True
anneal_start = 25
anneal_end = 75


sess_ids_test = ['ssm_v1']
# load model

#expt_dir = get_expt_dir(os.path.join(results_path, dataset), sess_ids)
#print(expt_dir)

# fill out hparams
model_type = 'dtcn'
tt_expt_dir = 'test'
lambda_weak = 0
lambda_strong = 1
lambda_pred = 0

ss_algo = 'weak'  # 'weak' | 'pseudo_labels' | 'task'

# fill out hparams
backbone = 'dtcn'
model_class = 'rslds'
# tt_expt_dir = 'ssl_test2'  # fly, ibl
hparams = get_default_hparams(
    model_class=model_class, device='cuda', sequence_length=sequence_length, n_lags=4,
    input_type=input_type, backbone=backbone, batch_size=batch_size,
    anneal_start=25, anneal_end=75, prob_threshold=0.9,  # pseudo_labels params
)
hparams['sequence_pad'] = compute_sequence_pad(hparams)

#hparams['device'] = 'cuda'
hparams['device'] = 'cpu'

version_dir = "/Users/blau/Projects/daart/results_gmdgm/ssm_v1/dtcn/rgt_v1/version_0" #os.path.join(hparams['tt_expt_dir'], version_str)
hdir = "/Users/blau/Projects/daart/results_gmdgm/ssm_v1/dtcn/rgt_v1/version_0"
model_file = os.path.join(version_dir, 'best_val_model.pt')
arch_file = os.path.join(hdir, 'hparams.yaml')
#arch_file = os.path.join(version_bdir, 'hparams.yaml')
print('Loading model defined in %s' % arch_file)
with open(arch_file, 'rb') as f:
    hparams_new = yaml.safe_load(f)
hparams_new['device'] = hparams.get('device', 'cpu')
model_0 = RSLDS(hparams_new)
model_0.load_state_dict(torch.load(
    model_file, map_location=lambda storage, loc: storage))
model_0.to(hparams_new['device'])
model_0.eval()

Loading model defined in /Users/blau/Projects/daart/results_gmdgm/ssm_v1/dtcn/rgt_v1/version_0/hparams.yaml


FileNotFoundError: [Errno 2] No such file or directory: '/Users/blau/Projects/daart/results_gmdgm/ssm_v1/dtcn/rgt_v1/version_0/hparams.yaml'

In [None]:
def confusion_matrix(true_states, inf_states, num_states):
    confusion = np.zeros((num_states, num_states))
    ztotal = np.zeros((num_states, 1))
    for i in range(num_states):
        for ztrue, zinf in zip(true_states, inf_states):
            for j in range(num_states):
                confusion[i, j] += np.sum((ztrue == i) & (zinf == j))
            ztotal[i] += np.sum(ztrue==i)
    return confusion / ztotal

In [None]:
states_hand = {}
state_overlaps_0 = {}
# state_overlaps_1 = {}
state_probs_0 = {}
# state_probs_1 = {}
states_0 = {}
# states_1 = {}
state_overlaps = {}
print((hparams['trial_splits']))

inferred_latents = []
for expt_id in sess_ids_test:
        
    print(expt_id)
    
    # initialize data handler; point to correct base path
    handler = DataHandler(expt_id, base_path=os.path.join(data_path, dataset))
    if input_type == 'markers':
        markers_file = handler.get_marker_filepath()
    else:
        markers_file = handler.get_feature_filepath(dirname=input_type)

    hand_labels_file = os.path.join(
                "/Users/blau/Projects/daart/data/", 'labels-hand', expt_id + '_labels.csv')

    # define data generator signals
    signals = ['markers', 'labels_strong']
    transforms = [ZScore(), None]
    paths = [markers_file, hand_labels_file]

    # build data generator
    data_gen_test = DataGenerator(
        [expt_id], [signals], [transforms], [paths], device=hparams['device'], 
        batch_size=hparams['batch_size'], trial_splits='1;1;0;0', 
        sequence_pad=hparams['sequence_pad'], sequence_length=hparams['sequence_length'],
        input_type=hparams['input_type'])
    print('----------------------------')
    print(data_gen_test)
    print('----------------------------')
    print('\n')

    # load hand labels
    handler.load_hand_labels()
    states = np.argmax(handler.hand_labels.vals, axis=1)
    cutoff = int(np.floor(states.shape[0] / hparams['batch_size'])) * hparams['batch_size']
    #states = states[:cutoff]
    states_hand[expt_id] = states
    
    # compute predictions
    print('computing predictions for model 0...', end='')
    tmp = model_0.predict_labels(data_gen_test, return_scores=True)
    labels_pred = np.vstack(tmp['qy_x_probs'][0])
    
    inferred_latents = np.vstack(tmp['qz_xy_mean'][0])
    yhat = np.vstack(tmp['qy_x_probs'][0])
    
    weights = tmp['pz_mean']
    labels_model_ = np.argmax(labels_pred, axis=1)
    probs_max = np.max(labels_pred, axis=1)
    labels_model = np.copy(labels_model_)
    # labels_model[probs_max < 0.75] = 0
    state_probs_0[expt_id] = labels_pred
    states_0[expt_id] = labels_model
    
    states = states[:len(labels_model)]
    states_hand[expt_id] = states
    print('done')
    

    print('states shape: ', states.shape)
    print('pred shape: ', states_0[expt_id].shape)
    state_overlaps_0[expt_id] = confusion_matrix(
        [states[states > 0]], [states_0[expt_id][states > 0]], num_states=len(label_names))
    


In [None]:
states_all_hand = []
states_all_0 = []
# states_all_1 = []

# for expt_id in sess_ids_test:
#     s = states_hand[expt_id]
#     states_all_hand.append(states_hand[expt_id])
#     states_all_0.append(states_0[expt_id])
# #     states_all_1.append(states_1[expt_id][s > 0])

# state_overlaps_all_0 = confusion_matrix(
#     states_all_hand, states_all_0, num_states=len(label_names))


for expt_id in sess_ids_test:
    s = states_hand[expt_id]
    states_all_hand.append(states_hand[expt_id])#[s > 0])
    states_all_0.append(states_0[expt_id])#[s > 0])
#     states_all_1.append(states_1[expt_id][s > 0])

state_overlaps_all_0 = confusion_matrix(
    states_all_hand, states_all_0, num_states=len(label_names))

# state_overlaps_all_1 = confusion_matrix(
#     states_all_hand, states_all_1, num_states=len(label_names))

In [None]:
fig = plt.figure(figsize=(5, 5))
im = plt.imshow(state_overlaps_all_0[:, :], vmin=0, vmax=1, cmap='Reds')#'Greys_r')
plt.title('Model SSM Data')
plt.yticks(np.arange(len(label_names)), label_names[:])
plt.xticks(np.arange(len(label_names)), label_names[:], rotation=45, ha='right')
plt.colorbar()
# plt.savefig('/home/mattw/Dropbox/research-text/papers/2021-daart/figs/state_compare_w_wo_ab/model_0.pdf')
plt.show()

In [None]:
y, z, x = model_0.sampler(10000)
import os
import pickle
import copy

import autograd.numpy as np
import autograd.numpy.random as npr
npr.seed(12345)

import matplotlib.pyplot as plt
from matplotlib import gridspec
%matplotlib inline

import seaborn as sns
color_names = ["windows blue", "red", "amber", "faded green"]
colors = sns.xkcd_palette(color_names)
sns.set_style("white")
sns.set_context("talk")

def plot_trajectory(z, x, ax=None, ls="-"):
    zcps = np.concatenate(([0], np.where(np.diff(z))[0] + 1, [z.size]))
    if ax is None:
        fig = plt.figure(figsize=(4, 4))
        ax = fig.gca()
    for start, stop in zip(zcps[:-1], zcps[1:]):
        
        ax.plot(x[start:stop + 1, 0],
                x[start:stop + 1, 1],
                lw=1, ls=ls,
                color=colors[z[start] % len(colors)],
                alpha=1.0)
    return ax



new_y = []
for ny in y:
    m = torch.argmax(ny).item()
    new_y.append(m)
#print(new_y)
new_y = np.array(new_y)

z_new = []
for zn in z:
    m = zn.detach().numpy()
    z_new.append(m)
z_new = np.array(z_new)

new_x = []
for xn in x:
    m = xn.detach().numpy()
    new_x.append(m)
new_x = np.array(new_x)

fig = plt.figure(figsize=(15, 6)) 
gs = gridspec.GridSpec(1, 2, width_ratios=[2, 3]) 
ax0 = plt.subplot(gs[0])
plot_trajectory(new_y, z_new, ax=ax0)
plt.title("True Trajectory")


print('yn', new_y.shape, new_y[:15])
print('zn', z_new.shape, z_new[:15])

In [None]:
markers_path_latents = '/Users/blau/Projects/daart/daart_utils/data/ssm/latents.npy'
true_latents = np.load(markers_path_latents)


path_labels = '/Users/blau/Projects/daart/daart_utils/data/ssm/labels.npy'
true_labels= np.load(path_labels)

In [None]:
# perform transforms

from sklearn.linear_model import LinearRegression

z = true_latents
zhat = inferred_latents

# find the linear mapping from z to zhat; important to not fit the intercept!
lr = LinearRegression(fit_intercept=False)
lr.fit(zhat, z)
R = lr.coef_  # might need to reshape this into a matrix
print(R.shape, R)
# # compute the updated dynamics matrix for each state; assume the learned matrices are stored in a list called As for simplicity
# As_rotated = [None for _ in range(num_states)]
# for k in range(num_states):
#     As_rotated[k] = R @ As[k] @ np.linalg.inv(R)

lr = LinearRegression(fit_intercept=False)
lr.fit(yhat, true_labels)
YM = lr.coef_  # might need to reshape this into a matrix
print(YM.shape, YM)
    
# z_t = R A R^{-1} z_{t-1}


In [None]:
print(z_new.shape)
z_R = np.zeros_like(z_new)
for i, zn in enumerate(z_new):
    zr = np.matmul(R, zn)
    z_R[i] = zr

In [None]:
fig = plt.figure(figsize=(15, 6)) 
gs = gridspec.GridSpec(1, 2, width_ratios=[2, 3]) 
ax0 = plt.subplot(gs[0])
plot_trajectory(new_y, z_R, ax=ax0)
plt.title("Sample Trajectory from Model Weights")

In [None]:

def plot_most_likely_dynamics(model,
    xlim=(-4, 4), ylim=(-3, 3), nxpts=20, nypts=20,
    alpha=0.8, ax=None, figsize=(3, 3), K=4):
    
    #assert model.D == 2
    x = np.linspace(*xlim, nxpts)
    y = np.linspace(*ylim, nypts)
    X, Y = np.meshgrid(x, y)
    xy = np.column_stack((X.ravel(), Y.ravel()))

    # Get the probability of each state at each xy location
    print('Rs true grapch', model.transitions.Rs.shape)
    print('rs true g', model.transitions.r.shape)
    z = np.argmax(xy.dot(model.transitions.Rs.T) + model.transitions.r, axis=1)
    print('z shape?', z.shape, z)
    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)

    for k, (A, b) in enumerate(zip(model.dynamics.As, model.dynamics.bs)):
        dxydt_m = xy.dot(A.T) + b - xy

        zk = z == k
        if zk.sum(0) > 0:
            ax.quiver(xy[zk, 0], xy[zk, 1],
                      dxydt_m[zk, 0], dxydt_m[zk, 1],
                      color=colors[k % len(colors)], alpha=alpha)

    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')

    plt.tight_layout()

    return ax


def plot_inf(model,
    xlim=(-4, 4), ylim=(-3, 3), nxpts=20, nypts=20,
    alpha=0.8, ax=None, figsize=(3, 3), K=4):
    
    x = np.linspace(*xlim, nxpts)
    y = np.linspace(*ylim, nypts)
    X, Y = np.meshgrid(x, y)
    xy = np.column_stack((X.ravel(), Y.ravel()))

    # Get the probability of each state at each xy location
    print('Rs true grapch', model.transitions.Rs.shape)
    print('rs true g', model.transitions.r.shape)
    
    z = [0]
    
    for i, xy_iter in enumerate(xy[1:]):
    
        z.append(np.argmax(xy_iter.dot(model.transitions.Rs[z[i-1]].T) + model.transitions.r[z[i-1]]))
    print('z shape?', np.array(z).shape, z)
    z = np.array(z)
    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)

    for k, (A, b) in enumerate(zip(model.dynamics.As, model.dynamics.bs)):
        dxydt_m = xy.dot(A.T) + b - xy

        zk = z == k
        if zk.sum(0) > 0:
            ax.quiver(xy[zk, 0], xy[zk, 1],
                      dxydt_m[zk, 0], dxydt_m[zk, 1],
                      color=colors[k % len(colors)], alpha=alpha)

    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')

    plt.tight_layout()

    return ax



In [None]:
# build true rslds
T = 10000
K = 4
D_obs = 10
D_latent = 2

def make_true_model():
    As = [random_rotation(D_latent, np.pi/24.),
      random_rotation(D_latent, np.pi/48.)]

    # Set the center points for each system
    centers = [np.array([+2.0, 0.]),
           np.array([-2.0, 0.])]
    bs = [-(A - np.eye(D_latent)).dot(center) for A, center in zip(As, centers)]

    # Add a "right" state
    As.append(np.eye(D_latent))
    bs.append(np.array([+0.1, 0.]))

    # Add a "right" state
    As.append(np.eye(D_latent))
    bs.append(np.array([-0.25, 0.]))

    # Construct multinomial regression to divvy up the space
    w1, b1 = np.array([+1.0, 0.0]), np.array([-2.0])   # x + b > 0 -> x > -b
    w2, b2 = np.array([-1.0, 0.0]), np.array([-2.0])   # -x + b > 0 -> x < b
    w3, b3 = np.array([0.0, +1.0]), np.array([0.0])    # y > 0
    w4, b4 = np.array([0.0, -1.0]), np.array([0.0])    # y < 0
    Rs = np.row_stack((100*w1, 100*w2, 10*w3,10*w4))
    r = np.concatenate((100*b1, 100*b2, 10*b3, 10*b4))
    
    true_rslds = ssm.SLDS(D_obs, K, D_latent, 
                      transitions="recurrent_only",
                      dynamics="diagonal_gaussian",
                      emissions="gaussian_orthog",
                      single_subspace=True)
    true_rslds.dynamics.mu_init = np.tile(np.array([[0, 1]]), (K, 1))
    true_rslds.dynamics.sigmasq_init = 1e-4 * np.ones((K, D_latent))
    true_rslds.dynamics.As = np.array(As)
    true_rslds.dynamics.bs = np.array(bs)
    print('As true', true_rslds.dynamics.As.shape)
    print('bs true', true_rslds.dynamics.bs.shape)
    true_rslds.dynamics.sigmasq = 1e-4 * np.ones((K, D_latent))
    
    
     
    true_rslds.transitions.Rs = Rs
    true_rslds.transitions.r = r
    print('Rs true', true_rslds.transitions.Rs.shape)
    print('rs true', true_rslds.transitions.r.shape)
    
    true_rslds.emissions.inv_etas = np.log(1e-2) * np.ones((1, D_obs))
    return true_rslds


true_rslds = make_true_model()


# build infered rslds
from numpy.linalg import inv
def make_inf_model():
    As = np.array([[ 0.9268,  0.0551],
        [-0.2080,  0.9323],
        [ 0.9952,  0.0399],
        [-0.1017,  0.9670],
        [ 0.9807,  0.0068],
        [ 0.0152,  0.9635],
        [ 0.9846,  0.0100],
        [ 0.0219,  0.9676]])
    As = np.reshape(As, (4,2,2))
    
#     for i, A in enumerate(As):
#         As[i] = np.matmul(np.matmul(R, A), inv(R)) 
        

    # Set the center points for each system
    
    bs = np.array([-0.1113, -0.0556, -0.0226,  0.1760, -0.0124, -0.0799,  0.0792,  0.1404])
    bs = np.reshape(bs, (4,2))

    
    # Construct multinomial regression to divvy up the space
    Rs = np.array([[-1.2110, -1.1055],
        [ 1.3356, -0.2874],
        [ 2.8323,  0.1543],
        [ 1.1719,  0.8615],
        [-0.5288, -0.3797],
        [ 1.4611,  0.7638],
        [ 0.1845, -1.2972],
        [-0.5560, -0.6553],
        [-0.9565, -0.7659],
        [-0.7451,  1.1582],
        [ 1.4352, -0.9208],
        [-1.0619,  0.8751],
        [ 0.0094, -0.9992],
        [ 1.5631,  0.0360],
        [ 0.3146, -1.5100],
        [-0.3560,  0.7558]])
    
    Rs = np.reshape(Rs, (4,4,2))
    
    r = np.array([ 1.7372, -0.9438, -1.5308, -0.4251, -0.4980,  0.3336, -0.1542, -0.7011,
        -0.9896, -1.5549,  0.4939, -1.6408, -1.2595, -0.6701, -0.3222,  0.6581])
    r = np.reshape(r, (4,4))
    
    true_rslds = ssm.SLDS(D_obs, K, D_latent, 
                      transitions="recurrent_only",
                      dynamics="diagonal_gaussian",
                      emissions="gaussian_orthog",
                      single_subspace=True)
    true_rslds.dynamics.mu_init = np.tile(np.array([[0, 1]]), (K, 1))
    true_rslds.dynamics.sigmasq_init = 1e-4 * np.ones((K, D_latent))
    true_rslds.dynamics.As = np.array(As)
    true_rslds.dynamics.bs = np.array(bs)
    true_rslds.dynamics.sigmasq = 1e-4 * np.ones((K, D_latent))
    
#     for i, rr in enumerate(Rs):
#         Rs[i] = np.matmul(R, rr)
#         #Rs[i] = np.matmul(np.matmul(R, A), inv(R)) 
        
    true_rslds.transitions.Rs = Rs
    true_rslds.transitions.r = r
    print('Rs inf', true_rslds.transitions.Rs.shape)
    print('rs inf', true_rslds.transitions.r.shape)
    
    true_rslds.emissions.inv_etas = np.log(1e-2) * np.ones((1, D_obs))
    return true_rslds

inf_rslds = make_inf_model()

plt.figure(figsize=(6,4))
ax = plt.subplot(111)
lim = abs(z).max(axis=0) + 1
plot_most_likely_dynamics(true_rslds, xlim=(-lim[0], lim[0]), ylim=(-lim[1], lim[1]), ax=ax)
plt.title("True Dynamics")

plt.figure(figsize=(6,4))
ax = plt.subplot(111)
lim = 4, 3
plot_inf(inf_rslds, xlim=(-lim[0], lim[0]), ylim=(-lim[1], lim[1]), ax=ax)
plt.title("Inferred Dynamics, rSLDS")

In [None]:
# plot x, y
#y, z, x = model_0.sampler(10000)
def plot_observations(z, y, ax=None, ls="-", lw=1):

    zcps = np.concatenate(([0], np.where(np.diff(z))[0] + 1, [z.size]))
    if ax is None:
        fig = plt.figure(figsize=(4, 4))
        ax = fig.gca()
    T, N = y.shape
    t = np.arange(T)
    for n in range(N):
        for start, stop in zip(zcps[:-1], zcps[1:]):
            ax.plot(t[start:stop + 1], y[start:stop + 1, n],
                    lw=lw, ls=ls,
                    color=colors[z[start] % len(colors)],
                    alpha=1.0)
    return ax
fig = plt.figure(figsize=(15, 6)) 
gs = gridspec.GridSpec(1, 2, width_ratios=[2, 3]) 


ax1 = plt.subplot(gs[0])
plot_observations(new_y[:1000], new_x[:1000,:1], ax=ax1)
plt.title("Observations for first 1000 time steps")
plt.tight_layout()

In [None]:
print(model_0)