# Remapping analysis

In [1]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from experiment import Experiment
from utils import print_stats
from itertools import product


data_path = "../data"
save_path = "../figs"

In [2]:
exp = Experiment.load_experiment('data', 'baseline')

In [68]:
def get_sets(exp):
    exp.compile_grid_cells(1)
    exp.load_pfs()
    fit1 = exp.pfs.calc_fitness()
    idx1 = exp.pfs.place_cell_idx()
    exp.compile_grid_cells(2)
    exp.load_pfs()
    fit2 = exp.pfs.calc_fitness()
    idx2 = exp.pfs.place_cell_idx()
    return set(idx1.cpu().numpy()), set(idx2.cpu().numpy()), fit1, fit2

def get_active(exp, threshold=0.001):
    return exp.pfs_per_env[1].scales >= threshold, exp.pfs_per_env[2].scales >= threshold

def print_sets(exp):
    s1, s2, *_ = get_sets(exp)
    pcs = np.asarray(list(product(s1, s2)))
    pcs = pcs[pcs[:,0] != pcs[:,1]]

    with torch.no_grad():
        d1 = torch.pow(exp.pfs_per_env[1].means[pcs[:,0]] - exp.pfs_per_env[1].means[pcs[:,1]], 2).sum(-1).sqrt()
        d2 = torch.pow(exp.pfs_per_env[2].means[pcs[:,0]] - exp.pfs_per_env[2].means[pcs[:,1]], 2).sum(-1).sqrt()
        remap = 1 - max(0, pearsonr(d1.cpu(), d2.cpu()).statistic)
    
    print("len env1:", len(s1))
    print("len env2:", len(s2))
    print("intersection:", len(s1.intersection(s2)))
    print("union:", len(s1.union(s2)))
    print("s1 - s2:", len(s1 - s2))
    print("s2 - s1:", len(s2 - s1))
    print("Remapping:", remap)

In [69]:
hiddens = [f'hidden{n}' for n in [20, 50, 200, 500]]
hiddens.insert(2, 'baseline')

for h in hiddens:
    exp = Experiment.load_experiment('data', h)
    
    print(h + ':')
    print_sets(exp)
    print()

hidden20:
len env1: 5
len env2: 10
intersection: 2
union: 13
s1 - s2: 3
s2 - s1: 8
Remapping: 1.3242740631103516

hidden50:
len env1: 21
len env2: 26
intersection: 10
union: 37
s1 - s2: 11
s2 - s1: 16
Remapping: 0.9539285451173782

baseline:
len env1: 57
len env2: 64
intersection: 38
union: 83
s1 - s2: 19
s2 - s1: 26
Remapping: 0.8008994460105896

hidden200:
len env1: 63
len env2: 81
intersection: 27
union: 117
s1 - s2: 36
s2 - s1: 54
Remapping: 0.8698728382587433

hidden500:


ValueError: NumPy boolean array indexing assignment cannot assign 91 input values to the 211 output values where the mask is true

In [11]:
for h in hiddens:
    exp = Experiment.load_experiment('data', h)
    act1, act2 = get_active(exp)
    print(h + ':')
    print('env1:', act1.sum().cpu().item(), '| env2:', act2.sum().cpu().item())
    print("intersection of active", torch.logical_and(act1, act2).sum().cpu().item())
    print("union of active", torch.logical_or(act1, act2).sum().cpu().item())
    print()

hidden20:
env1: 18 | env2: 18
intersection of active 18
union of active 18

hidden50:
env1: 39 | env2: 42
intersection of active 38
union of active 43

baseline:
env1: 89 | env2: 96
intersection of active 85
union of active 100

hidden200:
env1: 98 | env2: 92
intersection of active 67
union of active 123

hidden500:
env1: 166 | env2: 171
intersection of active 97
union of active 240



In [26]:
exp = Experiment.load_experiment('data', 'baseline')

In [27]:
s1, s2, f1, f2 = get_sets(exp)

In [18]:
act1, act2 = get_active(exp)

In [64]:
pcs = np.asarray(list(product(s1, s2)))
pcs = pcs[pcs[:,0] != pcs[:,1]]

with torch.no_grad():
    d1 = torch.pow(exp.pfs_per_env[1].means[pcs[:,0]] - exp.pfs_per_env[1].means[pcs[:,1]], 2).sum(-1).sqrt()
    d2 = torch.pow(exp.pfs_per_env[2].means[pcs[:,0]] - exp.pfs_per_env[2].means[pcs[:,1]], 2).sum(-1).sqrt()

    remap = 1 - pearsonr(d1.cpu(), d2.cpu()).statistic