# Remapping analysis

In [1]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from experiment import Experiment
from utils import print_stats
from itertools import product
from scipy.stats import pearsonr


data_path = "../data"
save_path = "../figs"

In [2]:
class Analysis:
    def __init__(self, exp, immediate_pc=False):
        self.exp = exp
        self.find_active_cells()
        if immediate_pc:
            self.find_place_cells()
    
    def find_active_cells(self):
        self.active_per_env = {env: set(pfs.get_active_cells().cpu().numpy())
                               for env, pfs in self.exp.pfs_per_env.items()}
    
    def find_place_cells(self):
        self.place_cells_per_env = dict()
        for env in self.exp.pfs_per_env.keys():
            exp.compile_grid_cells(env)
            exp.load_pfs()
            self.place_cells_per_env[env] = set(exp.pfs.get_place_cells().cpu().numpy())
    
    def get_remapping(self, env1, env2):
        pc1, pc2 = self.place_cells_per_env[env1], self.place_cells_per_env[env2]
        pairs = np.asarray(list(product(pc1, pc2)))
        pairs = pairs[pairs[:,0] != pairs[:,1]]

        with torch.no_grad():
            d1 = exp.pfs_per_env[env1].pairwise_distances(pairs).cpu()
            d2 = exp.pfs_per_env[env2].pairwise_distances(pairs).cpu()
            return 1 - max(0, pearsonr(d1, d2).statistic)
    
    def place_cell_stats(self, env1=1, env2=2):
        pc1, pc2 = self.place_cells_per_env[env1], self.place_cells_per_env[env2]

        print("Len env1:", len(pc1))
        print("Len env2:", len(pc2))
        print("Intersection:", len(pc1.intersection(pc2)))
        print("Union:", len(pc1.union(pc2)))
        print("env1 - env2:", len(pc1 - pc2))
        print("env2 - env1:", len(pc2 - pc1))
        print("Remapping:", self.get_remapping(env1, env2))

In [3]:
hiddens = [f'hidden{n}' for n in [20, 50, 200, 500]]
hiddens.insert(2, 'baseline')

for h in hiddens:
    exp = Experiment.load_experiment(data_path, h)
    anl = Analysis(exp, immediate_pc=True)
    
    print(h + ':')
    anl.place_cell_stats()
    print()

hidden20:
Len env1: 5
Len env2: 10
Intersection: 2
Union: 13
env1 - env2: 3
env2 - env1: 8
Remapping: 1

hidden50:
Len env1: 21
Len env2: 26
Intersection: 10
Union: 37
env1 - env2: 11
env2 - env1: 16
Remapping: 0.9539285451173782

baseline:
Len env1: 57
Len env2: 64
Intersection: 38
Union: 83
env1 - env2: 19
env2 - env1: 26
Remapping: 0.8008994460105896

hidden200:
Len env1: 63
Len env2: 81
Intersection: 27
Union: 117
env1 - env2: 36
env2 - env1: 54
Remapping: 0.8698728382587433



ValueError: NumPy boolean array indexing assignment cannot assign 91 input values to the 211 output values where the mask is true

In [4]:
for h in hiddens:
    exp = Experiment.load_experiment(data_path, h)
    anl = Analysis(exp)
    act1, act2 = *anl.active_per_env.values(),
    print(h + ':')
    print('Len env1:', len(act1), '| Len env2:', len(act2))
    print("Intersection:", len(act1.intersection(act2)))
    print("Union:", len(act1.union(act2)))
    print("IoU:", len(act1.intersection(act2)) / len(act1.union(act2)))
    print()

hidden20:
Len env1: 18 | Len env2: 18
Intersection: 18
Union: 18
IoU: 1.0

hidden50:
Len env1: 39 | Len env2: 42
Intersection: 38
Union: 43
IoU: 0.8837209302325582

baseline:
Len env1: 89 | Len env2: 96
Intersection: 85
Union: 100
IoU: 0.85

hidden200:
Len env1: 98 | Len env2: 92
Intersection: 67
Union: 123
IoU: 0.5447154471544715

hidden500:
Len env1: 166 | Len env2: 171
Intersection: 97
Union: 240
IoU: 0.4041666666666667

