# 2.1.2 Stacked LSTMs on Fractals

PredNet representations match the ones of human fMRI. How about LSTMS?

## Jupyter Extensions

Load [watermark](https://github.com/rasbt/watermark) to see the state of the machine and environment that's running the notebook. To make sense of the options, take a look at the [usage](https://github.com/rasbt/watermark#usage) section of the readme.

In [1]:
# Load `watermark` extension
%load_ext watermark
# Display the status of the machine and other non-code related info
%watermark -n -m -g -b -t -h

Mon Sep 07 2020 19:06:14 

compiler   : GCC 7.3.0
system     : Linux
release    : 5.4.0-45-generic
machine    : x86_64
processor  : x86_64
CPU cores  : 4
interpreter: 64bit
host name  : apra-x3
Git hash   : fd56f033dc8b98d92ac6d31c2ea68f0df697413a
Git branch : master


Load [autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) which will always reload modules marked with `%aimport`.

This behavior can be inverted by running `autoreload 2` which will set everything to be auto-reloaded *except* for modules marked with `%aimport`.

In [2]:
# Load `autoreload` extension
%load_ext autoreload
# Set autoreload behavior
%autoreload 1

Load `matplotlib` in one of the more `jupyter`-friendly [rich-output modes](https://ipython.readthedocs.io/en/stable/interactive/plotting.html). Some options (that may or may not have worked) are `inline`, `notebook`, and `gtk`.

In [3]:
# Set the matplotlib mode
%matplotlib inline

## Imports

In [4]:
import gc
import logging
from argparse import Namespace
from pathlib import Path

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import pytorch_lightning as pl
from PIL import Image, ImageOps
from torch.utils.data import IterableDataset, DataLoader

%aimport prevseg.constants
import prevseg.constants as const
%aimport prevseg.index
import prevseg.index as index
%aimport prevseg.dataloaders.schapiro
import prevseg.dataloaders.schapiro as sch
%aimport prevseg.schapiro
from prevseg.schapiro import walk, graph
%aimport prevseg.models.prednet
import prevseg.models.prednet as pn
%aimport prevseg.torch.lstm
import prevseg.torch.lstm as lstm
%aimport prevseg.torch.activations
import prevseg.torch.activations as act


# Keep track of versions of everything
%watermark -v -iv

numpy             1.19.1
logging           0.5.1.2
networkx          2.4
pytorch_lightning 0.8.5
prevseg           0+untagged.85.gfd56f03.dirty
torch             1.6.0
PIL.Image         7.2.0
CPython 3.8.5
IPython 7.16.1


## Defining the Stacked LSTM

In [5]:
class LSTMCell(pn.PredCellTracked):
    name = 'lstmcell'
    
    def __init__(self, parent, layer_num, hparams, a_channels, r_channels, *args, **kwargs):
        self.build_dense = lambda *args, **kwargs : None
        self.build_update = lambda *args, **kwargs : None
        super().__init__(parent, layer_num, hparams, a_channels, r_channels, *args, **kwargs)
        
    def build_recurrent(self):
        recurrent = self.RecurrentClass(
            self.a_channels[self.layer_num],
            #+ self.r_channels[self.layer_num+1],
            self.r_channels[self.layer_num])
        recurrent.reset_parameters()
        return recurrent
        
    def reset(self, batch_size=None):
        batch_size = batch_size or self.hparams.batch_size
        self.R = torch.zeros(1,                  # Single time step
                             batch_size,
                             self.r_channels[self.layer_num],
                             device=self.parent.dev)
        self.H = (torch.zeros(1,                  # Single time step
                              batch_size,
                              self.r_channels[self.layer_num],
                              device=self.parent.dev),
                  torch.zeros(1,                  # Single time step
                              batch_size,
                              self.r_channels[self.layer_num],
                              device=self.parent.dev))
        self.hidden_full_list = []
        self.hidden_diff_list = []
        self.representation_full_list = []
        self.representation_diff_list = []        
        
    def update_parent(self, module_names=('recurrent',)):
        return super().update_parent(module_names=module_names)

class LSTMStacked(pn.PredNetTrackedSchapiro):
    name = 'lstmstacked'
    def __init__(self, hparams, CellClass=LSTMCell, a_channels=None,
                 r_channels=None, *args, **kwargs):
        # Assertions for how it should be used
        assert hparams.layer_loss_mode is None
        
        if a_channels is None:
            a_channels = [hparams.input_size] * hparams.n_layers
        if r_channels is None:
            r_channels = list(a_channels) + [0,]
        # Run the init and cleanup
        super().__init__(hparams=hparams, CellClass=CellClass, r_channels=r_channels,
                         a_channels=a_channels, *args, **kwargs)
        # Add the last dense layer
        self.dense = nn.Sequential(
            nn.Linear(self.r_channels[hparams.n_layers - 1],
                      self.a_channels[0]),
            nn.ReLU())
        self.dense.add_module('satlu', act.SatLU())

    def forward(self, input):
        _, time_steps, *_ = self.check_input_shape(input)
        
        total_output = []

        for t in range(time_steps):
            self.frame = input[:,t,:].unsqueeze(0).to(self.dev, torch.float)
            A = self.frame
            for cell in self.cells:
                # First time step
                if t == 0:
                    hx = (cell.R, cell.R)
                else:
                    hx = cell.H
                    
                cell.R, cell.H = cell.recurrent(A, hx)
                # Optional tracking
                cell.track_hidden(self.output_mode, hx)
                cell.track_representation(self.output_mode, A)
                A = cell.R
                
            A_hat = self.dense(A)
            
            if self.output_mode == 'error':
                total_output.append(torch.abs(A_hat - self.frame))
            elif self.output_mode == 'eval':
                total_output.append(A_hat)
        
        if self.output_mode == 'prediction':
            return A_hat
        else:
            return torch.stack(total_output, 2)

## Training the Model

In [6]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
ckpt = None
train_errors, val_errors = None, None
res = None
gc.collect()
torch.cuda.empty_cache()

hparams = const.DEFAULT_HPARAMS

ModelClass = LSTMStacked
hparams.layer_loss_mode = None
hparams.n_layers = 2
hparams.batch_size = 256 + 128 + 64
hparams.max_steps = 128
hparams.n_paths = 16
hparams.n_pentagons = 3
hparams.time_steps = hparams.max_steps
hparams.exp_name = 'schapiro_test'
hparams.name = f'{ModelClass.name}_{hparams.exp_name}'
hparams.debug = False
hparams.n_workers = 4
hparams.lr = 0.001

log_dir = Path(hparams.dir_logs) / f'{hparams.name}'
if not log_dir.exists():
    log_dir.mkdir(parents=True)
logger = pl.loggers.TensorBoardLogger(str(log_dir.parent), name=hparams.name)

ckpt_dir = Path(hparams.dir_checkpoints) / f'{hparams.name}_v{logger.version}'
if not ckpt_dir.exists():
    ckpt_dir.mkdir(parents=True)
    
ckpt = pl.callbacks.ModelCheckpoint(
    filepath=str(ckpt_dir / (hparams.exp_name+'_{global_step:05d}_{epoch:03d}_{val_loss:.3f}')),
    verbose=True,
    save_top_k=1,
)

trainer = pl.Trainer(checkpoint_callback=ckpt,
                     max_epochs=20,
                     logger=logger,
                     gpus=1
                     )

model = ModelClass(hparams)
model.ds = None
model

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
    Found GPU1 GeForce GTX 670 which is of cuda capability 3.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability that we support is 3.5.
    
GeForce GTX 670 with CUDA capability sm_30 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_61 sm_70 sm_75 compute_37.
If you want to use the GeForce GTX 670 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



LSTMStacked(
  (lstmcell_0_recurrent): LSTM(
    (i2h): Linear(in_features=2048, out_features=8192, bias=True)
    (h2h): Linear(in_features=2048, out_features=8192, bias=True)
  )
  (lstmcell_1_recurrent): LSTM(
    (i2h): Linear(in_features=2048, out_features=8192, bias=True)
    (h2h): Linear(in_features=2048, out_features=8192, bias=True)
  )
  (dense): Sequential(
    (0): Linear(in_features=2048, out_features=2048, bias=True)
    (1): ReLU()
    (satlu): SatLU (min_val=0, max_val=255)
  )
)

In [None]:
trainer.fit(model)


  | Name                 | Type       | Params
----------------------------------------------------
0 | lstmcell_0_recurrent | LSTM       | 33 M  
1 | lstmcell_1_recurrent | LSTM       | 33 M  
2 | dense                | Sequential | 4 M   


Created mapping as follows:
{0: '1', 1: '60', 2: '95', 3: '100', 4: '14', 5: '2', 6: '63', 7: '58', 8: '96', 9: '55', 10: '99', 11: '50', 12: '7', 13: '89', 14: '12'}
Created mapping as follows:
{0: '1', 1: '60', 2: '95', 3: '100', 4: '14', 5: '2', 6: '63', 7: '58', 8: '96', 9: '55', 10: '99', 11: '50', 12: '7', 13: '89', 14: '12'}
Epoch 1: : 20it [02:28,  7.44s/it, loss=0.397, v_num=36]              

In [None]:
iter_ds = ShapiroResnetEmbeddingDataset(
    batch_size=1, 
    max_steps=hparams.max_steps, 
    n_paths=1,
    mapping=model.ds.mapping,
    mode='euclidean')
loader = DataLoader(iter_ds, batch_size=None)

for data, nodes in loader:
    pass

In [None]:
data.shape

In [None]:
data_all = torch.cat((data, torch.flip(data, (0,1))[:,1:,:]), 1)
data_all.shape

In [None]:
outs = model.forward(data_all, output_mode='eval', run_num='fwd_rev', 
                     tb_labels=['nodes'])

In [None]:
nodes = np.array(nodes).reshape(30)
nodes

In [None]:
nodes_all = np.concatenate((nodes, np.flip(nodes)[1:]))
nodes_all.shape

In [None]:
for i, val in enumerate(nodes_all):
    print(i, val)

In [None]:
borders = [9, 19, 29, 30, 40, 50]

In [None]:
G = graph.schapiro_graph(n_pentagons=3)
nx.draw(G, with_labels=True, font_weight='bold')
plt.show()

### Prediction Errors

In [None]:
outs_pe = model.forward(data_all, output_mode='error', run_num='fwd_rev', 
                        tb_labels=['nodes'])

In [None]:
outs_pe.shape

In [None]:
outs_array = outs_pe[0,:,:].cpu().detach().numpy()
outs_array.shape

In [None]:
fig = plt.figure()
ax_large = fig.add_subplot(111)

for i, out in enumerate(outs_array):
    ax = fig.add_subplot(11 + i + len(outs_array)*100)
    ax.plot(out)
    ax.set_ylabel(f'Layer {i+1}')
    [ax.axes.axvline(b, ls=':') for b in borders]
    if i == len(outs_array)-1:
        ax.set_xlabel('Step')
        
ax_large.axes.xaxis.set_ticks([])
ax_large.axes.yaxis.set_ticks([])
gcf = plt.gcf()
gcf.set_size_inches(16,9)

### Prediction Error Differences

In [None]:
fig = plt.figure()
ax_large = fig.add_subplot(111)

for i, out in enumerate(outs['error_diff']):
    ax = fig.add_subplot(11 + i + len(outs['error_diff'])*100)
    ax.plot(np.array(out.cpu()).reshape(59))
    ax.set_ylabel(f'Layer {i+1}')
    [ax.axes.axvline(b, ls=':') for b in borders]
    if i == len(outs['error_diff'])-1:
        ax.set_xlabel('Step')
        
ax_large.axes.xaxis.set_ticks([])
ax_large.axes.yaxis.set_ticks([])
gcf = plt.gcf()
gcf.set_size_inches(16,9)

### Hidden State Differences

In [None]:
fig = plt.figure()
ax_large = fig.add_subplot(111)

for i, out in enumerate(outs['hidden_diff']):
    ax = fig.add_subplot(11 + i + len(outs['hidden_diff'])*100)
    ax.plot(np.array(out.cpu()).reshape(59)[1:])
    ax.set_ylabel(f'Layer {i+1}')
    [ax.axes.axvline(b, ls=':') for b in borders]
    if i == len(outs['hidden_diff'])-1:
        ax.set_xlabel('Step')
        
ax_large.axes.xaxis.set_ticks([])
ax_large.axes.yaxis.set_ticks([])
gcf = plt.gcf()
gcf.set_size_inches(16,9)

### Alternating Within vs Between Communities

In [None]:
test_nodes = [6,8,9,
              10,9,10,
              13,12,14,
              0,14,0,
              1,2,4,
              5,4,5]
test_data = np.array([iter_ds.array_data[n] 
                      for n in test_nodes]).reshape((1,len(test_nodes),2048))

In [None]:
border_outs = model.forward(torch.Tensor(test_data), 
                            output_mode='eval', 
                            run_num='border_walk_3', 
                            tb_labels=['nodes'])

In [None]:
fig = plt.figure()
ax_large = fig.add_subplot(111)

for i, out in enumerate(border_outs['hidden_diff']):
    ax = fig.add_subplot(11 + i + len(border_outs['hidden_diff'])*100)
    ax.plot(np.array(out.cpu()).reshape(len(test_nodes))[1:])
    ax.set_ylabel(f'Layer {i+1}')
    if i == len(border_outs['hidden_diff'])-1:
        ax.set_xlabel('Step')
        
ax_large.axes.xaxis.set_ticks([])
ax_large.axes.yaxis.set_ticks([])
gcf = plt.gcf()
gcf.set_size_inches(16,9)