# Remapping analysis

In [5]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from experiment import Experiment
from utils import print_stats
from itertools import product
from scipy.stats import pearsonr
from gc import collect


data_path = "../data"
save_path = "../figs"

In [2]:
exp = Experiment.load_experiment('data', 'baseline')

In [68]:
def get_sets(exp):
    exp.compile_grid_cells(1)
    exp.load_pfs()
    fit1 = exp.pfs.calc_fitness()
    idx1 = exp.pfs.place_cell_idx()
    exp.compile_grid_cells(2)
    exp.load_pfs()
    fit2 = exp.pfs.calc_fitness()
    idx2 = exp.pfs.place_cell_idx()
    return set(idx1.cpu().numpy()), set(idx2.cpu().numpy()), fit1, fit2

def get_active(exp, threshold=0.001):
    return exp.pfs_per_env[1].scales >= threshold, exp.pfs_per_env[2].scales >= threshold

def print_sets(exp):
    s1, s2, *_ = get_sets(exp)
    pcs = np.asarray(list(product(s1, s2)))
    pcs = pcs[pcs[:,0] != pcs[:,1]]

    with torch.no_grad():
        d1 = torch.pow(exp.pfs_per_env[1].means[pcs[:,0]] - exp.pfs_per_env[1].means[pcs[:,1]], 2).sum(-1).sqrt()
        d2 = torch.pow(exp.pfs_per_env[2].means[pcs[:,0]] - exp.pfs_per_env[2].means[pcs[:,1]], 2).sum(-1).sqrt()
        remap = 1 - max(0, pearsonr(d1.cpu(), d2.cpu()).statistic)
    
    print("len env1:", len(s1))
    print("len env2:", len(s2))
    print("intersection:", len(s1.intersection(s2)))
    print("union:", len(s1.union(s2)))
    print("s1 - s2:", len(s1 - s2))
    print("s2 - s1:", len(s2 - s1))
    print("Remapping:", remap)

In [69]:
hiddens = [f'hidden{n}' for n in [20, 50, 200, 500]]
hiddens.insert(2, 'baseline')

for h in hiddens:
    exp = Experiment.load_experiment('data', h)
    
    print(h + ':')
    print_sets(exp)
    print()

hidden20:
len env1: 5
len env2: 10
intersection: 2
union: 13
s1 - s2: 3
s2 - s1: 8
Remapping: 1.3242740631103516

hidden50:
len env1: 21
len env2: 26
intersection: 10
union: 37
s1 - s2: 11
s2 - s1: 16
Remapping: 0.9539285451173782

baseline:
len env1: 57
len env2: 64
intersection: 38
union: 83
s1 - s2: 19
s2 - s1: 26
Remapping: 0.8008994460105896

hidden200:
len env1: 63
len env2: 81
intersection: 27
union: 117
s1 - s2: 36
s2 - s1: 54
Remapping: 0.8698728382587433

hidden500:


ValueError: NumPy boolean array indexing assignment cannot assign 91 input values to the 211 output values where the mask is true

In [11]:
for h in hiddens:
    exp = Experiment.load_experiment('data', h)
    act1, act2 = get_active(exp)
    print(h + ':')
    print('env1:', act1.sum().cpu().item(), '| env2:', act2.sum().cpu().item())
    print("intersection of active", torch.logical_and(act1, act2).sum().cpu().item())
    print("union of active", torch.logical_or(act1, act2).sum().cpu().item())
    print()

hidden20:
env1: 18 | env2: 18
intersection of active 18
union of active 18

hidden50:
env1: 39 | env2: 42
intersection of active 38
union of active 43

baseline:
env1: 89 | env2: 96
intersection of active 85
union of active 100

hidden200:
env1: 98 | env2: 92
intersection of active 67
union of active 123

hidden500:
env1: 166 | env2: 171
intersection of active 97
union of active 240



In [3]:
lr1s = [f'l1_{n}' for n in ['01', '001', '0001', '00001', '000001', '0']]
lr1s.insert(3, 'baseline')

In [8]:
for lr in lr1s:
    collect()
    exp = Experiment.load_experiment(data_path, lr)
    anl = Analysis(exp, immediate_pc=True)
    
    print(lr + ':')
    anl.place_cell_stats()
    print()
    del anl; del exp

l1_01:
Len env1: 62
Len env2: 55
Intersection: 33
Union: 84
env1 - env2: 29
env2 - env1: 22
Remapping: 0.890890508890152
Turnover: 0.8028039107175797

l1_001:
Len env1: 52
Len env2: 65
Intersection: 35
Union: 82
env1 - env2: 17
env2 - env1: 30
Remapping: 0.7286943793296814
Turnover: 0.7398388981122794

l1_0001:
Len env1: 34
Len env2: 45
Intersection: 20
Union: 59
env1 - env2: 14
env2 - env1: 25
Remapping: 0.6971185803413391
Turnover: 0.7140098691267969

baseline:
Len env1: 57
Len env2: 64
Intersection: 38
Union: 83
env1 - env2: 19
env2 - env1: 26
Remapping: 0.8008994460105896
Turnover: 0.7052341597796143

l1_00001:
Len env1: 46
Len env2: 59
Intersection: 28
Union: 77
env1 - env2: 18
env2 - env1: 31
Remapping: 0.6604875922203064
Turnover: 0.7911699779249448

l1_000001:
Len env1: 41
Len env2: 69
Intersection: 24
Union: 86
env1 - env2: 17
env2 - env1: 45
Remapping: 0.7910145670175552
Turnover: 0.988293897882939

l1_0:
Len env1: 46
Len env2: 36
Intersection: 17
Union: 65
env1 - env2: 29
en

In [9]:
for lr in lr1s:
    exp = Experiment.load_experiment(data_path, lr)
    anl = Analysis(exp)
    act1, act2 = *anl.active_per_env.values(),
    print(lr + ':')
    print('Len env1:', len(act1), '| Len env2:', len(act2))
    print("Intersection:", len(act1.intersection(act2)))
    print("Union:", len(act1.union(act2)))
    print("IoU:", len(act1.intersection(act2)) / len(act1.union(act2)))
    print("Turnover:", anl.get_turnover(1, 2, all_active=True))
    print()

l1_01:
Len env1: 84 | Len env2: 84
Intersection: 73
Union: 95
IoU: 0.7684210526315789
Turnover: 0.38095238095238093

l1_001:
Len env1: 84 | Len env2: 84
Intersection: 75
Union: 93
IoU: 0.8064516129032258
Turnover: 0.3116883116883117

l1_0001:
Len env1: 77 | Len env2: 74
Intersection: 61
Union: 90
IoU: 0.6777777777777778
Turnover: 0.4682434563229265

baseline:
Len env1: 89 | Len env2: 96
Intersection: 85
Union: 100
IoU: 0.85
Turnover: 0.2923486867148839

l1_00001:
Len env1: 79 | Len env2: 81
Intersection: 66
Union: 94
IoU: 0.7021276595744681
Turnover: 0.4666666666666668

l1_000001:
Len env1: 82 | Len env2: 87
Intersection: 74
Union: 95
IoU: 0.7789473684210526
Turnover: 0.3656396653744134

l1_0:
Len env1: 79 | Len env2: 61
Intersection: 52
Union: 88
IoU: 0.5909090909090909
Turnover: 0.5674876847290641

