In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from prnn.utils.predictiveNet import PredictiveNet
from prnn.utils.agent import RandomActionAgent
from prnn.analysis.SpatialTuningAnalysis import SpatialTuningAnalysis
from prnn.analysis.representationalGeometryAnalysis import representationalGeometryAnalysis
from prnn.analysis.OfflineTrajectoryAnalysis import OfflineTrajectoryAnalysis
from prnn.analysis.TuningCurveAnalysis import TuningCurveAnalysis



In [None]:
savefolder = 'BasicAnalysisFigs'

## Load in your trained net

In [None]:
%cd ..

In [None]:
#Example Net
netname = 'Masked'
netfolder = '/examplenet/'
exseed = 8
predictiveNet = PredictiveNet.loadNet(netfolder+netname+'--s'+str(exseed))

## Prediction of observation and global location

In [None]:
env = predictiveNet.EnvLibrary[0]
agentname = 'RandomActionAgent'
action_probability = np.array([0.15,0.15,0.6,0.1,0,0,0])
agent = RandomActionAgent(env.action_space,action_probability)
# returns place_fields? , SI?, and a linear decoder trained to decode position and observation from pRNN activity?
place_fields, SI, decoder = predictiveNet.calculateSpatialRepresentation(env,agent,
                                             trainDecoder=True)

Top row: shows distribution of decoder error compared to random shuffled data. Should show low decoder error for the actual data.

State: shows actual action sequences across 6 sequential time steps. 

Observation: shows the egocentric view of the agent for the same 6 sequential timestemps. 

Predicted: the predicted observation for that timestep from the linear decoder returned in the above cell

Bottom row: predicted global location for that timestemp from the linear decoder 

In [None]:
predictiveNet.calculateDecodingPerformance(env,agent,decoder,
                                            savename=netname, savefolder=savefolder,
                                          trajectoryWindow=5,
                                          timesteps=1000)

# Spatial Tuning Analysis

In [None]:
STA = SpatialTuningAnalysis(predictiveNet,inputControl=True, untrainedControl=True)


In [None]:
STA.TCExamplesFigure(netname,savefolder)

# Representational Geometry Analysis

In [None]:
sleepnoise = 0.03
isomap_neighbors = 15
RGA = representationalGeometryAnalysis(predictiveNet, noisestd=sleepnoise,
                                       withIsomap=True, n_neighbors = isomap_neighbors)

In [None]:
RGA.WakeSleepFigure(netname,savefolder)

# Offline Trajectory Analysis

In [None]:

b_adapt = 1
tau_adapt=100
OTA_adapt = OfflineTrajectoryAnalysis(predictiveNet, noisestd=sleepnoise,
                                   withIsomap=False, decoder=decoder, 
                                      withAdapt=True, b_adapt = b_adapt, tau_adapt=tau_adapt,
                                       compareWake=True)

In [None]:
OTA_adapt.SpontTrajectoryFigure('adaptation',savefolder, trajRange=(150,250))


In [None]:
OTA_query = OfflineTrajectoryAnalysis(predictiveNet, noisemag = 0, noisestd=sleepnoise,
                               withIsomap=False, decoder=decoder,
                                     actionAgent=True,
                               compareWake=True)

In [None]:
OTA_query.SpontTrajectoryFigure('actionquery',savefolder, trajRange=(110,150))


# Tuning Curve Analysis

In [None]:
tuning_curve_analysis = TuningCurveAnalysis(predictiveNet)
tuning_curve_analysis.cellClassificationFigure()

Upper left shows you the percentage of the network per class type. Middle shows a PCA embedded of the various features used to classify each unit, colored by their group ID. Right shows EVs versus SI for each unit, again colored by cell type. Below are example tuning curves for each cell class. The cell types include: 
- untuned
- HD_cells: head direction cells, units with a preference for head direction but no spatial preference 
- single_field: a canonical "place cell" with a centralized symetric place field
- border_cells: units that fire preferentially along environmental boundaries
- spatial_HD: some combination of spatial and head direction preferences
- complex_cells: units that have high SIs but cannot be categorized into the above types 
- dead: units that have no activity 