# 2.1.2 Stacked LSTMs on Fractals

PredNet representations match the ones of human fMRI. How about LSTMS?

## Jupyter Extensions

Load [watermark](https://github.com/rasbt/watermark) to see the state of the machine and environment that's running the notebook. To make sense of the options, take a look at the [usage](https://github.com/rasbt/watermark#usage) section of the readme.

In [1]:
# Load `watermark` extension
%load_ext watermark
# Display the status of the machine and other non-code related info
%watermark -n -m -g -b -t -h

Sun Sep 06 2020 17:31:06 

compiler   : GCC 7.3.0
system     : Linux
release    : 4.15.0-112-generic
machine    : x86_64
processor  : x86_64
CPU cores  : 16
interpreter: 64bit
host name  : serrep5
Git hash   : 559e267c9765a31f726bd422b7937cc37ebde795
Git branch : master


Load [autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) which will always reload modules marked with `%aimport`.

This behavior can be inverted by running `autoreload 2` which will set everything to be auto-reloaded *except* for modules marked with `%aimport`.

In [2]:
# Load `autoreload` extension
%load_ext autoreload
# Set autoreload behavior
%autoreload 1

Load `matplotlib` in one of the more `jupyter`-friendly [rich-output modes](https://ipython.readthedocs.io/en/stable/interactive/plotting.html). Some options (that may or may not have worked) are `inline`, `notebook`, and `gtk`.

In [3]:
# Set the matplotlib mode
%matplotlib inline

## Imports

In [4]:
import gc
import logging
from argparse import Namespace
from pathlib import Path

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl
from PIL import Image, ImageOps
from torch.utils.data import IterableDataset, DataLoader

%aimport prevseg.constants
import prevseg.constants as const
%aimport prevseg.index
import prevseg.index as index
%aimport prevseg.dataloaders.schapiro
import prevseg.dataloaders.schapiro as sch
%aimport prevseg.schapiro
from prevseg.schapiro import walk, graph
%aimport prevseg.models.prednet
import prevseg.models.prednet as prednet

# Keep track of versions of everything
%watermark -v -iv

logging           0.5.1.2
numpy             1.19.1
pytorch_lightning 0.8.5
torch             1.6.0
PIL.Image         7.2.0
prevseg           0+untagged.75.g559e267.dirty
networkx          2.4
CPython 3.8.5
IPython 7.16.1


## Testing the Embedding Dataset


In [7]:
max_steps = 2
epochs = 2
batch_size = 2
n_paths = 2

iter_ds = sch.ShapiroResnetEmbeddingDataset(
    batch_size=batch_size, 
    max_steps=max_steps, 
    n_paths=n_paths,
)
loader = DataLoader(iter_ds, batch_size=None)

for _ in range(epochs):
    for i, batch in enumerate(loader):
        print(i, batch[0].shape, batch[1])
        if i == n_paths:
            print('bad')
            break

Created mapping as follows:
{0: '64', 1: '89', 2: '86', 3: '51', 4: '56', 5: '83', 6: '12', 7: '38', 8: '8', 9: '92', 10: '65', 11: '70', 12: '90', 13: '44', 14: '74'}
0 torch.Size([2, 2, 2048]) [[9, 0], [10, 3]]
1 torch.Size([2, 2, 2048]) [[1, 1], [0, 2]]
0 torch.Size([2, 2, 2048]) [[13, 10], [11, 9]]
1 torch.Size([2, 2, 2048]) [[7, 12], [8, 11]]


## Defining the Stacked LSTM

In [11]:
class LSTMStacked(pn.PredNetTrackedSchapiro):
    pass

## Loading Saved Weights

In [None]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
ckpt = None
train_errors, val_errors = None, None
res = None
gc.collect()
torch.cuda.empty_cache()

ModelClass = LSTMStacked
hparams = const.DEFAULT_HPARAMS
hparams.n_layers = 2
hparams.batch_size = 256 + 128
hparams.max_steps = 128
hparams.n_paths = 16
hparams.n_pentagons = 3
hparams.time_steps = hparams.max_steps
hparams.exp_name = 'schapiro_test'
hparams.name = f'{ModelClass.name}_{hparams.exp_name}'
hparams.debug = False
hparams.n_workers = 2
hparams.lr = 0.001

log_dir = Path(hparams.dir_logs) / f'{hparams.name}'
if not log_dir.exists():
    log_dir.mkdir(parents=True)
logger = pl.loggers.TensorBoardLogger(str(log_dir.parent), name=hparams.name)

ckpt_dir = Path(hparams.dir_checkpoints) / f'{hparams.name}_v{logger.version}'
if not ckpt_dir.exists():
    ckpt_dir.mkdir(parents=True)
    
ckpt = pl.callbacks.ModelCheckpoint(
    filepath=str(ckpt_dir / (hparams.exp_name+'_{global_step:05d}_{epoch:03d}_{val_loss:.3f}')),
    verbose=True,
    save_top_k=1,
)

trainer = pl.Trainer(checkpoint_callback=ckpt,
                     max_epochs=20,
                     logger=logger,
                     gpus=1
                     )

model = ModelClass(hparams)
model.ds = None

In [None]:
iter_ds = ShapiroResnetEmbeddingDataset(
    batch_size=1, 
    max_steps=hparams.max_steps, 
    n_paths=1,
    mapping=model.ds.mapping,
    mode='euclidean')
loader = DataLoader(iter_ds, batch_size=None)

for data, nodes in loader:
    pass

In [None]:
data.shape

In [None]:
data_all = torch.cat((data, torch.flip(data, (0,1))[:,1:,:]), 1)
data_all.shape

In [None]:
outs = model.forward(data_all, output_mode='eval', run_num='fwd_rev', 
                     tb_labels=['nodes'])

In [None]:
nodes = np.array(nodes).reshape(30)
nodes

In [None]:
nodes_all = np.concatenate((nodes, np.flip(nodes)[1:]))
nodes_all.shape

In [None]:
for i, val in enumerate(nodes_all):
    print(i, val)

In [None]:
borders = [9, 19, 29, 30, 40, 50]

In [None]:
G = graph.schapiro_graph(n_pentagons=3)
nx.draw(G, with_labels=True, font_weight='bold')
plt.show()

### Prediction Errors

In [None]:
outs_pe = model.forward(data_all, output_mode='error', run_num='fwd_rev', 
                        tb_labels=['nodes'])

In [None]:
outs_pe.shape

In [None]:
outs_array = outs_pe[0,:,:].cpu().detach().numpy()
outs_array.shape

In [None]:
fig = plt.figure()
ax_large = fig.add_subplot(111)

for i, out in enumerate(outs_array):
    ax = fig.add_subplot(11 + i + len(outs_array)*100)
    ax.plot(out)
    ax.set_ylabel(f'Layer {i+1}')
    [ax.axes.axvline(b, ls=':') for b in borders]
    if i == len(outs_array)-1:
        ax.set_xlabel('Step')
        
ax_large.axes.xaxis.set_ticks([])
ax_large.axes.yaxis.set_ticks([])
gcf = plt.gcf()
gcf.set_size_inches(16,9)

### Prediction Error Differences

In [None]:
fig = plt.figure()
ax_large = fig.add_subplot(111)

for i, out in enumerate(outs['error_diff']):
    ax = fig.add_subplot(11 + i + len(outs['error_diff'])*100)
    ax.plot(np.array(out.cpu()).reshape(59))
    ax.set_ylabel(f'Layer {i+1}')
    [ax.axes.axvline(b, ls=':') for b in borders]
    if i == len(outs['error_diff'])-1:
        ax.set_xlabel('Step')
        
ax_large.axes.xaxis.set_ticks([])
ax_large.axes.yaxis.set_ticks([])
gcf = plt.gcf()
gcf.set_size_inches(16,9)

### Hidden State Differences

In [None]:
fig = plt.figure()
ax_large = fig.add_subplot(111)

for i, out in enumerate(outs['hidden_diff']):
    ax = fig.add_subplot(11 + i + len(outs['hidden_diff'])*100)
    ax.plot(np.array(out.cpu()).reshape(59)[1:])
    ax.set_ylabel(f'Layer {i+1}')
    [ax.axes.axvline(b, ls=':') for b in borders]
    if i == len(outs['hidden_diff'])-1:
        ax.set_xlabel('Step')
        
ax_large.axes.xaxis.set_ticks([])
ax_large.axes.yaxis.set_ticks([])
gcf = plt.gcf()
gcf.set_size_inches(16,9)

### Alternating Within vs Between Communities

In [None]:
test_nodes = [6,8,9,
              10,9,10,
              13,12,14,
              0,14,0,
              1,2,4,
              5,4,5]
test_data = np.array([iter_ds.array_data[n] 
                      for n in test_nodes]).reshape((1,len(test_nodes),2048))

In [None]:
border_outs = model.forward(torch.Tensor(test_data), 
                            output_mode='eval', 
                            run_num='border_walk_3', 
                            tb_labels=['nodes'])

In [None]:
fig = plt.figure()
ax_large = fig.add_subplot(111)

for i, out in enumerate(border_outs['hidden_diff']):
    ax = fig.add_subplot(11 + i + len(border_outs['hidden_diff'])*100)
    ax.plot(np.array(out.cpu()).reshape(len(test_nodes))[1:])
    ax.set_ylabel(f'Layer {i+1}')
    if i == len(border_outs['hidden_diff'])-1:
        ax.set_xlabel('Step')
        
ax_large.axes.xaxis.set_ticks([])
ax_large.axes.yaxis.set_ticks([])
gcf = plt.gcf()
gcf.set_size_inches(16,9)