# 1.2.1 Evaluating Hidden States

Adding functionality to view hidden state activity through a video.

## Jupyter Extensions

Load [watermark](https://github.com/rasbt/watermark) to see the state of the machine and environment that's running the notebook. To make sense of the options, take a look at the [usage](https://github.com/rasbt/watermark#usage) section of the readme.

In [2]:
# Load `watermark` extension
%load_ext watermark
# Display the status of the machine and packages. Add more as necessary.
%watermark -v -n -m -g -b -t -p torch,torchvision,pytorch_lightning,jupyterlab,prevseg

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Wed Mar 18 2020 17:13:03 

CPython 3.8.2
IPython 7.13.0

torch 1.4.0
torchvision 0.5.0
pytorch_lightning 0.7.1
jupyterlab 2.0.1
prevseg 0+untagged.11.g393f577.dirty

compiler   : GCC 7.3.0
system     : Linux
release    : 4.15.0-88-generic
machine    : x86_64
processor  : x86_64
CPU cores  : 16
interpreter: 64bit
Git hash   : 393f5775d28aeec38742df7b2b394f0ae6179d2a
Git branch : master


Load [autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) which will always reload modules marked with `%aimport`.

This behavior can be inverted by running `autoreload 2` which will set everything to be auto-reloaded *except* for modules marked with `%aimport`.

In [3]:
# Load `autoreload` extension
%load_ext autoreload
# Set autoreload behavior
%autoreload 1

## Imports

In [4]:
import gc
import time
from argparse import Namespace
from pathlib import Path
from functools import wraps

import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
import seaborn as sns
import matplotlib.pyplot as plt
from torch.nn import functional as F, GRU
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

Local imports that may or may not be autoreloaded. This section contains things that will likely have to be re-imported multiple times, and have additions or subtractions made throughout the project.

In [5]:
# Constants to be used throughout the package
%aimport prevseg
import prevseg as pes
%aimport prevseg.index
from prevseg import index
# Import the data subdirectories
%aimport prevseg.models.prednet
import prevseg.models.prednet as prednet
%aimport prevseg.dataloaders.breakfast
import prevseg.dataloaders.breakfast as bk
%aimport prevseg.constants
import prevseg.constants as const

from prevseg.torch.lstm import LSTM
from prevseg.torch.activations import SatLU

## Set the GPU

Make sure we aren't greedy.

In [6]:
!nvidia-smi

Wed Mar 18 17:13:35 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.87.00    Driver Version: 418.87.00    CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  TITAN Xp            Off  | 00000000:04:00.0 Off |                  N/A |
| 34%   57C    P2    84W / 250W |   9331MiB / 12196MiB |     14%      Default |
+-------------------------------+----------------------+----------------------+
|   1  TITAN Xp            Off  | 00000000:05:00.0 Off |                  N/A |
| 26%   38C    P8     9W / 250W |     10MiB / 12196MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  TITAN Xp            Off  | 00000000:08:00.0 Off |                  N/A |
| 54%   

In [7]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


## Loading the Data

In [10]:
%%time
ds = bk.BreakfastI3DFVDataset()

CPU times: user 14.1 s, sys: 8min 22s, total: 8min 36s
Wall time: 15min 6s


## Previous Implementation

See `wb-1.2.0` for the implementation at work, but below is the relevant portion

```

class PredCell(object):
    """Organizational class."""
    def __init__(self, parent, layer_num, hparams, a_channels, r_channels, 
                 RecurrentClass=LSTM):
        super().__init__()
        self.parent = parent
        self.layer_num = layer_num
        self.hparams = hparams
        self.a_channels = a_channels
        self.r_channels = r_channels
        self.RecurrentClass = RecurrentClass
        
        # Reccurent
        self.recurrent = self.build_recurrent()
        # Dense
        self.dense = self.build_dense()
        # Update
        self.update_a = self.build_update()
        # upsample - set at cell level for future
        self.upsample = nn.Upsample(scale_factor=2)
        
        # Build E, R, and H
        self.reset()
        # Book-keeping
        self.update_parent()
            
    def build_recurrent(self):
        recurrent = self.RecurrentClass(
            2 * (self.a_channels[self.layer_num] +
                 self.r_channels[self.layer_num+1]),
            #+ self.r_channels[self.layer_num+1],
            self.r_channels[self.layer_num])
        recurrent.reset_parameters()
        return recurrent
    
    def build_dense(self):
        dense = nn.Sequential(
            nn.Linear(self.r_channels[self.layer_num],
                      self.a_channels[self.layer_num]),
            nn.ReLU())
        if self.layer_num == 0:
            dense.add_module('satlu', SatLU())
        return dense
        
    def build_update(self):
        if self.layer_num < self.hparams.n_layers - 1:
            return nn.Sequential(
                nn.Linear(
                    2 * self.a_channels[self.layer_num],
                    self.a_channels[self.layer_num + 1]),
                nn.ReLU())
        else:
            return None
            
    def reset(self, batch_size=None):
        batch_size = batch_size or self.hparams.batch_size
        # E, R, and H variables
        self.E = torch.zeros(1,                  # Single time step
                             batch_size,
                             2*self.a_channels[self.layer_num],
                             device=self.parent.device)
        self.R = torch.zeros(1,                  # Single time step
                             batch_size,
                             self.r_channels[self.layer_num],
                             device=self.parent.device)
        self.H = None
        
    def update_parent(self):
        self.modules = {'recurrent' : self.recurrent, 'dense' : self.dense}
        if hasattr(self, 'update_a') and self.update_a is not None:
            self.modules['update_a'] = self.update_a
        # Hack to appease the pytorch-gods
        for name, module in self.modules.items():
            setattr(self.parent, f'predcell_{self.layer_num}_{name}', module)


class PredNet(pl.LightningModule):
    name = 'prednet'
    def __init__(self, hparams, ds=None, CellClass=PredCell):
        super().__init__()
        # Attribute definitions
        self.hparams = hparams
        self.n_layers = self.hparams.n_layers
        self.output_mode = self.hparams.output_mode
        self.input_size = self.hparams.input_size
        self.time_steps = self.hparams.time_steps
        self.batch_size = self.hparams.batch_size
        self.layer_loss_mode = self.hparams.layer_loss_mode
        self.ds = ds
        self.CellClass = CellClass
        
        if self.hparams.device == 'cuda' and torch.cuda.is_available():
            print('Using GPU', flush=True)
            self.device = torch.device('cuda')
        else:
            print('Using CPU', flush=True)
            self.device = torch.device('cpu')

        # Put together the model
        self.build_model()

    def build_model(self):        
        # Channel sizes
        self.r_channels = [self.input_size // (2**i) 
                           for i in range(self.n_layers)] + [0,] # Convenience
        self.a_channels = [self.input_size // (2**i) 
                           for i in range(self.n_layers)]
        
        # Make sure everything checks out
        default_output_modes = ['prediction', 'error']
        assert self.output_mode in default_output_modes, \
            'Invalid output_mode: ' + str(output_mode)

        # Make all the pred cells
        self.predcells = [self.CellClass(self,
                                         layer_num,
                                         self.hparams,
                                         self.a_channels,
                                         self.r_channels)
                          for layer_num in range(self.n_layers)]
        
        # How to weight the errors
        # 1 followed by zeros means just minimize error at lowest layer
        self.layer_loss_weights = self.build_layer_loss_weights(
            self.layer_loss_mode)
        # How much to weight errors at each timestep
        self.time_loss_weights = self.build_time_loss_weights()
        
    def build_layer_loss_weights(self, mode='first'):
        if mode == 'first':
            first = torch.zeros(self.n_layers, 1, device=self.device)
            first[0][0] = 1
            return first
        elif mode == 'all':
            return 1. / (self.n_layer-1) * torch.ones(self.n_layer, 1,
                                                      device=self.device)
        else:
            raise Exception(f'Invalid layer loss mode "{mode}".')
            
    def build_time_loss_weights(self, time_steps=None):
        time_steps = time_steps or self.time_steps
        # How much to weight errors at each timestep
        time_loss_weights = 1. / (time_steps-1) * torch.ones(time_steps, 1,
                                                             device=self.device)
        # Dont count first time step
        time_loss_weights[0] = 0
        return time_loss_weights
    
    def check_input_shape(self, input):
        batch_size, time_steps, *input_size = input.shape
        
        # Reset batch_size-dependent things
        if batch_size != self.batch_size:
            self.batch_size = batch_size
            for cell in self.predcells:
                cell.reset(self.batch_size)
                
        # Reset time_step-dependent things
        if time_steps != self.time_steps:
            self.time_steps = time_steps
            self.time_loss_weights = self.build_time_loss_weights(
                self.time_steps)
            
        return batch_size, time_steps, *input_size
    
    def top_down_pass(self, t):
        # Loop backwards
        for l, cell in reversed(list(enumerate(self.predcells))):
            E, R = cell.E, cell.R
            # First time step
            if t == 0:
                hx = (R, R)
            else:
                hx = cell.H

            # If not in the last layer, upsample R and
            if l < self.n_layers - 1:
                E = torch.cat((E,  cell.upsample(self.predcells[l+1].R)), 2)

            cell.R, cell.H = cell.recurrent(E, hx)
            
    def bottom_up_pass(self):
        for cell in self.predcells:
            # Go from R to A_hat
            A_hat = cell.dense(cell.R)

            # Convenience
            if self.output_mode == 'prediction' and cell.layer_num == 0:
                self.frame_prediction = A_hat

            # Split to 2 Es
            pos = F.relu(A_hat - self.A)
            neg = F.relu(self.A - A_hat)
            E = torch.cat([pos, neg], 2)
            cell.E = E

            # If not last layer, update stored A
            if cell.layer_num < self.n_layers - 1:
                self.A = cell.update_a(E)
            
    def forward(self, input):
        _, time_steps, *_ = self.check_input_shape(input)
        
        total_error = []

        for t in range(time_steps):
            self.A = input[:,t,:].unsqueeze(0).to(self.device, torch.float)
            
            # Loop from top layer to update R and H
            self.top_down_pass(t)
            # Loop bottom up to get E and A
            self.bottom_up_pass()
            
            if self.output_mode == 'error':
                mean_error = torch.cat(
                    [torch.mean(cell.E.view(cell.E.size(1), -1),
                                1, keepdim=True)
                     for cell in self.predcells], 1)
                # batch x n_layers
                total_error.append(mean_error)
        
        if self.output_mode == 'error':
            return torch.stack(total_error, 2) # batch x n_layers x nt
        elif self.output_mode == 'prediction':
            return self.frame_prediction

    def timeit(method):
        """Combination of https://stackoverflow.com/questions/51503672/decorator-for-timeit-timeit-method/51503837#51503837,
        and https://www.geeksforgeeks.org/python-program-to-convert-seconds-into-hours-minutes-and-seconds/"""
        @wraps(method)
        def _time_it(self, *args, **kwargs):
            start = int(round(time.time() * 1000))
            try:
                return method(self, *args, **kwargs)
            finally:
                end_ = int(round(time.time() * 1000)) - start
                if end_ > 1000:
                    time_str = time.strftime("%H:%M:%S",
                                             time.gmtime(end_ // 1000))
                    print(f"Total execution time: {time_str}", flush=True)
                
        return _time_it

    @timeit
    def prepare_data(self):
        if self.ds is None:
            print('Loading the i3d data from disk. This can take '
                  'several minutes...', flush=True)
        self.ds = self.ds or BreakfastI3DFVDataset()
        self.ds_length = len(self.ds)
        np.random.seed(self.hparams.seed)
        self.indices = list(range(self.ds_length))
        self.train_sampler = SubsetRandomSampler(
            self.indices[self.hparams.n_val:])
        self.val_sampler = SubsetRandomSampler(
            self.indices[:self.hparams.n_val])
        
    def train_dataloader(self):
        return DataLoader(self.ds, 
                          batch_size=self.batch_size, 
                          sampler=self.train_sampler,
                          num_workers=self.hparams.n_workers)
    
    def val_dataloader(self):
        return DataLoader(self.ds, 
                          batch_size=self.batch_size, 
                          sampler=self.val_sampler,
                          num_workers=self.hparams.n_workers)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
    
    def _common_step(self, batch, batch_idx, mode):
        data, path = batch
        errors = self.forward(data) # batch x n_layers x nt
        loc_batch = errors.size(0)
        errors = torch.mm(errors.view(-1, self.time_steps), 
                          self.time_loss_weights) # batch*n_layers x 1
        errors = torch.mm(errors.view(loc_batch, -1), 
                          self.layer_loss_weights)
        errors = torch.mean(errors, axis=0)
        
        if mode == 'train':
            prefix = ''
        else:
            prefix = mode + '_'
            
        self.logger.experiment.add_scalar(f'{prefix}loss', 
                                          errors, self.global_step)
        return {f'{prefix}loss' : errors}

    def validation_epoch_end(self, output):
        out_dict = {}
        out_dict['val_loss'] = np.mean([out['val_loss'].item()
                                        for out in output])
        out_dict['global_step'] = self.global_step
        return out_dict
    
    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, 'train')
    
    def validation_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, 'val')
```

## Adding Hidden State Tracking

In [53]:
class PredCellTracked(prednet.PredCell):
    """Organizational class."""
    def __init__(self, parent, layer_num, hparams, a_channels, r_channels,
                 *args, **kwargs):
        super().__init__(parent, layer_num, hparams, a_channels, r_channels,
                         *args, **kwargs)
        # Tracking
        self.hidden_full_list = []
        self.hidden_diff_list = []
        self.previous_hidden = None
        self.previous_error = None

    def track_hidden(self, output_mode, R):
        # Track hidden states if desired
        if 'hidden_full' in self.parent.track and self.parent.output_mode == 'eval':
            self.hidden_full_list.append(R.view(1, 0, 2))
        if 'hidden_diff' in self.parent.track and self.parent.output_mode == 'eval':
            diff = torch.mean(
                (R.view(1, 0, 2) - self.R.view(1, 0, 2))**2,
                2)
            self.hidden_diff_list.append(diff)

    def track_error(self, output_mode, E):
        # Track hidden states if desired
        if 'error_full' in self.parent.track and self.parent.output_mode == 'eval':
            self.error_full_list.append(E.view(1, 0, 2))
        if 'error_diff' in self.parent.track and self.parent.output_mode == 'eval':
            diff = torch.mean(
                (E.view(1, 0, 2) - self.E.view(1, 0, 2))**2,
                2)
            self.error_diff_list.append(diff)
            
class PredNetTracked(prednet.PredNet):
    name = 'prednet_tracked'
    def __init__(self, hparams, track=None, CellClass=PredCellTracked, *args,
                 **kwargs):
        self.track = track or ['hidden_diff', 'error_diff']
        super().__init__(hparams, CellClass=CellClass, *args, **kwargs)
        
    def top_down_pass(self):
        # Loop backwards
        for l, cell in reversed(list(enumerate(self.predcells))):
            # First time step
            if self.t == 0:
                hx = (cell.R, cell.R)
            else:
                hx = cell.H

            # If not in the last layer, upsample R and
            if l < self.n_layers - 1:
                cell.E = torch.cat((cell.E,  cell.upsample(
                    self.predcells[l+1].R)), 2)

            # Update the values of R and H
            R, H = cell.recurrent(cell.E, hx)

            # Optional tracking
            cell.track_hidden(self.output_mode, R)

            # Update cell state
            cell.R, cell.H = R, H
            
    def bottom_up_pass(self):
        for cell in self.predcells:
            # Go from R to A_hat
            A_hat = cell.dense(cell.R)

            # Convenience
            if self.output_mode == 'prediction' and cell.layer_num == 0:
                self.frame_prediction = A_hat

            # Split to 2 Es
            pos = F.relu(A_hat - self.A)
            neg = F.relu(self.A - A_hat)
            E = torch.cat([pos, neg], 2)
            
            # Optional Error tracking
            cell.track_error(self.output_mode, E)

            # Update cell error
            cell.E = E

            # If not last layer, update stored A
            if cell.layer_num < self.n_layers - 1:
                self.A = cell.update_a(E)
            
    def forward(self, input, output_mode=None, track=None):
        self.output_mode = output_mode or self.output_mode
        _, time_steps, *_ = self.check_input_shape(input)
        
        self.total_error = []
        
        for self.t in range(time_steps):
            self.A = input[:,self.t,:].unsqueeze(0).to(self.device, torch.float)
            # Loop from top layer to update R and H
            self.top_down_pass()
            # Loop bottom up to get E and A
            self.bottom_up_pass()
            # Track desired outputs
            self.track_outputs()
        
        return self.return_output()
        
    def track_outputs(self):
        if self.output_mode == 'error':
            mean_error = torch.cat(
                [torch.mean(cell.E.view(cell.E.size(1), -1),
                            1, keepdim=True)
                 for cell in self.predcells], 1)
            # batch x n_layers
            self.total_error.append(mean_error)
            
    def return_output(self):
        if self.output_mode == 'error':
            return torch.stack(self.total_error, 2) # batch x n_layers x nt
        elif self.output_mode == 'prediction':
            return self.frame_prediction
        elif self.output_mode == 'eval':
            return self.eval_outputs()

    def eval_outputs(self):
        outputs = {}
        for tracked in self.track:
            outputs[tracked] = [getattr(cell, tracked+'_list')
                                for cell in self.predcells]
        return outputs

In [46]:
# model, trainer = None, None
# gc.collect()
# torch.cuda.empty_cache()

model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
ckpt = None
train_errors, val_errors = None, None
res = None
gc.collect()
torch.cuda.empty_cache()

ModelClass = prednet.PredNetTracked
hparams = const.DEFAULT_HPARAMS
hparams.name = ModelClass.name

log_dir = Path(hparams.dir_logs) / f'{hparams.name}'
if not log_dir.exists():
    log_dir.mkdir(parents=True)
logger = pl.loggers.TensorBoardLogger(str(log_dir.parent), name=hparams.name)

ckpt_dir = Path(hparams.dir_checkpoints) / f'{hparams.name}_v{logger.version}'
if not ckpt_dir.exists():
    ckpt_dir.mkdir(parents=True)
    
ckpt = pl.callbacks.ModelCheckpoint(
    filepath=str(ckpt_dir / 'bk_i3d_{global_step:05d}_{epoch:03d}_{val_loss:.3f}'),
    verbose=True,
    save_top_k=3,
)

trainer = pl.Trainer(default_save_path=str(index.DIR_CHECKPOINTS),
                     checkpoint_callback=ckpt,
                     max_epochs=150,
                     logger=logger,
                     gpus=1
                     )

# model = ModelClass(hparams)
# model.ds = ds

# model, trainer = None, None
# train_dataloader, val_dataloader = None, None
# errors, optimizer = None, None
# ckpt = None
# train_errors, val_errors = None, None
# res = None
# gc.collect()
# torch.cuda.empty_cache()

# trainer = pl.Trainer(default_save_path=str(index.DIR_CHECKPOINTS),
#                      max_epochs=1,
#                      gpus=1,
#                      logger=logger,
#                     )

model = ModelClass(hparams)
model.ds = ds

Using GPU


In [None]:
trainer.fit(model)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

In [38]:
model.logger = logger

In [40]:
model.logger

<pytorch_lightning.loggers.tensorboard.TensorBoardLogger at 0x7fa43d878e20>

In [None]:
dataloader = model.train_dataloader()

In [21]:
for batch, path in dataloader:
    break

In [22]:
batch.shape

torch.Size([256, 64, 2048])

In [23]:
mini_batch = batch[:2,:,:]
mini_batch.shape

torch.Size([2, 64, 2048])

In [31]:
model.run_num = None

In [44]:
res = model.forward(mini_batch, 'eval')

In [27]:
res['error_diff'][0]

tensor([[1.8889e+02, 5.0149e-02, 3.7198e-03, 1.2581e-03, 1.8696e-03, 1.2821e-03,
         1.3875e-03, 2.8682e-03, 1.8658e-03, 2.4622e-03, 2.1080e-03, 1.1290e-03,
         2.6420e-03, 1.4306e-03, 2.0050e-03, 6.8406e-04, 1.1928e-03, 7.5844e-04,
         2.2347e-03, 1.2908e-03, 1.3033e-03, 2.2954e-03, 3.0578e-03, 1.8767e-03,
         5.8039e-03, 1.8915e-03, 1.8540e-03, 2.0353e-03, 2.6830e-03, 1.2019e-03,
         3.9745e-03, 2.0309e-03, 2.2514e-03, 1.8984e-03, 3.5779e-03, 1.6918e-03,
         1.6891e-03, 7.3595e-04, 1.9989e-03, 1.5965e-03, 1.5693e-03, 1.7259e-03,
         3.3160e-03, 1.2352e-03, 3.3798e-03, 1.2150e-03, 2.7234e-03, 1.1300e-03,
         1.6021e-03, 2.1845e-03, 1.5845e-03, 2.6526e-03, 1.9591e-03, 2.2918e-03,
         1.3289e-03, 1.3255e-03, 2.6525e-03, 1.4677e-03, 3.1222e-03, 1.9648e-03,
         4.3290e-03, 1.4332e-03, 3.8543e-03, 8.6919e-04],
        [3.3024e+02, 1.4147e-01, 8.1604e-02, 1.4242e-01, 2.3399e-01, 1.5287e-01,
         1.6023e-01, 2.8309e-01, 1.7085e-01, 2.1056

In [80]:
for cell in model.predcells:
    cell.error_diff_list = []