# 4.1.3 Lighting Prednet

In an effort to make things more streamlined, this notebook goes through the process of moving and handling the model using `pytorch-lightning` as shown [here](https://github.com/PyTorchLightning/pytorch-lightning). 

## Jupyter Extensions

Load [watermark](https://github.com/rasbt/watermark) to see the state of the machine and environment that's running the notebook. To make sense of the options, take a look at the [usage](https://github.com/rasbt/watermark#usage) section of the readme.

In [1]:
# Load `watermark` extension
%load_ext watermark
# Display the status of the machine and packages. Add more as necessary.
%watermark -v -n -m -g -b -t -p torch,torchvision,pytorch_lightning,jupyterlab,lab

Tue Mar 10 2020 10:50:35 

CPython 3.8.2
IPython 7.13.0

torch 1.4.0
torchvision 0.5.0
pytorch_lightning 0.7.1
jupyterlab 2.0.1
lab 0+untagged.46.gd571ca0.dirty

compiler   : GCC 7.3.0
system     : Linux
release    : 4.15.0-88-generic
machine    : x86_64
processor  : x86_64
CPU cores  : 16
interpreter: 64bit
Git hash   : d571ca0b446908408dcf53afb9389b2207d3a0dd
Git branch : master


Load [autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) which will always reload modules marked with `%aimport`.

This behavior can be inverted by running `autoreload 2` which will set everything to be auto-reloaded *except* for modules marked with `%aimport`.

In [2]:
# Load `autoreload` extension
%load_ext autoreload
# Set autoreload behavior
%autoreload 1

## Set the GPU

Make sure we aren't greedy.

In [3]:
!nvidia-smi

Tue Mar 10 10:51:26 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.87.00    Driver Version: 418.87.00    CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  TITAN Xp            Off  | 00000000:04:00.0 Off |                  N/A |
| 23%   26C    P8     7W / 250W |   1656MiB / 12196MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  TITAN Xp            Off  | 00000000:05:00.0 Off |                  N/A |
| 23%   29C    P8     8W / 250W |     10MiB / 12196MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  TITAN Xp            Off  | 00000000:08:00.0 Off |                  N/A |
| 42%   

In [4]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


## Installing Pytorch-Lightning

This was done in an earlier instantiation of the notebook and is kept for book-keeping.

In [7]:
!pip install pytorch-lightning

Collecting pytorch-lightning
  Using cached pytorch-lightning-0.7.1.tar.gz (6.0 MB)
Collecting tensorboard>=1.14
  Using cached tensorboard-2.1.1-py3-none-any.whl (3.8 MB)
Collecting future>=0.17.1
  Using cached future-0.18.2.tar.gz (829 kB)
Collecting werkzeug>=0.11.15
  Using cached Werkzeug-1.0.0-py2.py3-none-any.whl (298 kB)
Collecting grpcio>=1.24.3
  Downloading grpcio-1.27.2-cp38-cp38-manylinux2010_x86_64.whl (2.7 MB)
[K     |████████████████████████████████| 2.7 MB 12.9 MB/s eta 0:00:01
[?25hCollecting requests<3,>=2.21.0
  Using cached requests-2.23.0-py2.py3-none-any.whl (58 kB)
Collecting google-auth<2,>=1.6.3
  Using cached google_auth-1.11.2-py2.py3-none-any.whl (76 kB)
Collecting google-auth-oauthlib<0.5,>=0.4.1
  Using cached google_auth_oauthlib-0.4.1-py2.py3-none-any.whl (18 kB)
Collecting protobuf>=3.6.0
  Downloading protobuf-3.11.3-cp38-cp38-manylinux1_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 39.1 MB/s eta 0:00:01
Collecting markdown>

## Imports

In [5]:
from pathlib import Path
import gc

import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from tqdm import tqdm
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

Local imports that may or may not be autoreloaded. This section contains things that will likely have to be re-imported multiple times, and have additions or subtractions made throughout the project.

In [6]:
# Constants to be used throughout the package
%aimport lab
import lab
%aimport lab.index
from lab import index
%aimport lab.breakfast
import lab.breakfast as bk
%aimport lab.breakfast.constants
from lab.breakfast.constants import SEED
# Import the data subdirectories
%aimport lab.breakfast.index
from lab.breakfast.index import (DIR_BREAKFAST, 
                                 DIR_BREAKFAST_DATA, 
                                 DIR_COARSE_SEG, 
                                 DIR_FINE_SEG,
                                 DIR_BK_WEIGHTS,
                                 DIR_BK_CHECKPOINTS,
                                 DIR_BK_LOGS_TB,
                                )
%aimport lab.breakfast.prednet
from lab.breakfast.prednet import PredNet
%aimport lab.breakfast.dataloader
from lab.breakfast.dataloader import Breakfast64DimFVDataset, BreakfastI3DFVDataset

## Previous Pytorch Code

See `wb-4.1.2` for the outputs of the code cells below.

### DataLoader et al

Loading the Dataloader which now has all the I3D data.

```
%%time
ds = BreakfastI3DFVDataset()
```

```
np.random.seed(SEED)

ds_length = len(ds)
indices = list(range(ds_length))
batch_size = 256
n_test = np.maximum(batch_size, 128)

np.random.shuffle(indices)
train_indices, test_indices = indices[n_test:], indices[:n_test]

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = DataLoader(ds, batch_size=batch_size, sampler=train_sampler)
test_loader = DataLoader(ds, batch_size=batch_size, sampler=test_sampler)
```

### Running the Model

```
%%time

num_epochs = 50
n_layers = 4
input_size = 2048
nt = 64 # num of time steps
A_channels = tuple(input_size // (2**i) for i in range(n_layers))
R_channels = tuple(input_size // (2**i) for i in range(n_layers))
lr = 0.000333 # if epoch < 75 else 0.0001

path_checkpoint = DIR_BK_CHECKPOINTS / 'i3d_checkpoint.tar'
path_weights = DIR_BK_WEIGHTS / 'i3d_training.pt'

layer_loss_weights = Variable(torch.FloatTensor([[1.]] + [[0.]]*(n_layers-1)).cuda())
time_loss_weights = 1./(nt - 1) * torch.ones(nt, 1)
time_loss_weights[0] = 0
time_loss_weights = Variable(time_loss_weights.cuda())

model = PredNet(R_channels, A_channels, output_mode='error')
print(model)
if torch.cuda.is_available():
    print('Using GPU.')
    model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

def lr_scheduler(optimizer, epoch):
    if epoch < num_epochs // 2:
        return optimizer
    else:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.0001
        return optimizer
    
train_errors = []

print(f'Running with batch size {batch_size} ({ds_length//batch_size} iterations / epoch)')
    
for epoch in range(num_epochs):
    optimizer = lr_scheduler(optimizer, epoch)
    for batch_idx, (data, path) in enumerate(train_loader):
        data = Variable(data)
        errors = model(data) # batch x n_layers x nt
        loc_batch = errors.size(0)
        errors = torch.mm(errors.view(-1, nt), time_loss_weights) # batch*n_layers x 1
        errors = torch.mm(errors.view(loc_batch, -1), layer_loss_weights)
        errors = torch.mean(errors, axis=0)
        train_errors.append(errors.item())

        optimizer.zero_grad()
        errors.backward()
        optimizer.step()

    if epoch % 2 == 0:
        test_errors = []
        for data, path in test_loader:
            data = Variable(data)
            errors = model(data) # batch x n_layers x nt
            loc_batch = errors.size(0)
            errors = torch.mm(errors.view(-1, nt), time_loss_weights) # batch*n_layers x 1
            errors = torch.mm(errors.view(loc_batch, -1), layer_loss_weights)
            test_errors.append(torch.mean(errors, axis=0).item())
            
        test_error = np.mean(test_errors)
        train_error = np.mean(train_errors)
        print(f'Epoch: {epoch}/{num_epochs}, train_error: {train_error}, '
              f'test error: {test_error}')
        train_errors = []
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'errors': errors,
        }, str(path_checkpoint))
```

```
torch.save(model.state_dict(), str(path_weights))
```

## Pytorch-Lightning

The code below will follow the setup shown on the [intro](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html) documentation page for `pytorch-lightning`. Additionally, the actual model code is lifted from `lab.breakfast.prednet.PredNet` at commit `ab1302580aa631970eac815e335d124a8029ae41`.

In [7]:
# import argparse

# parser = argparse.ArgumentParser()

# parametrize the network
# parser.add_argument('--n_layers', type=int, default=4)
# parser.add_argument('--input_size', type=int, default=2048)
# parser.add_argument('--time_steps', type=int, default=64)
# parser.add_argument('--path_checkpoints', type=str, default=str(DIR_BK_CHECKPOINTS / 'i3d_checkpoint.tar'))
# parser.add_argument('--path_weights', type=str, default=str(DIR_BK_WEIGHTS / 'i3d_training.pt'))
# parser.add_argument('--lr', type=float, default=0.000333)
# parser.add_argument('--output_mode', type=str, default='error')
# parser.add_argument('--n_val', type=int, default=256)
# parser.add_argument('--device', type=str, default='cuda')
# parser.add_argument('--seed', type=int, default=117)
# parser.add_argument('--batch_size', type=int, default=256)

# # add all the available options to the trainer
# parser = pl.Trainer.add_argparse_args(parser)

# args = parser.parse_args()

hps = {
    'n_layers' : 4,
    'input_size' : 2048,
    'time_steps' : 64,
    'path_checkpoints' : DIR_BK_CHECKPOINTS / 'i3d_checkpoint.tar',
    'path_weights' : DIR_BK_WEIGHTS / 'i3d_training.pt',
    'lr' : 0.000333,
    'output_mode' : 'error',
    'device' : 'cuda',
    'n_val' : 256,
    'seed' : 117,
    'batch_size' : 256,
    'n_epochs' : 10,
    'n_workers' : 4,
}

In [8]:
%%time
ds = BreakfastI3DFVDataset()

CPU times: user 10.6 s, sys: 1min 54s, total: 2min 4s
Wall time: 8min 8s


### The Model

In [9]:
import time
import pytorch_lightning as pl
import torch.nn as nn
from torch.nn import functional as F
from functools import wraps
from lab.utils import flatten
from lab.torch.lstm import LSTM
from lab.torch.activations import SatLU


class PredCell(object):
    def __init__(self, parent, layer_num, hps, a_channels, r_channels):
        super().__init__()
        self.parent = parent
        self.layer_num = layer_num
        self.hps = hps
        self.a_channels = a_channels
        self.r_channels = r_channels
        
        # Reccurent
        self.recurrent = LSTM(2 * self.a_channels[self.layer_num],
                              self.r_channels[self.layer_num])
        self.recurrent.reset_parameters()
        
        # Dense
        self.dense = nn.Sequential(
            nn.Linear(self.r_channels[self.layer_num],
                      self.a_channels[self.layer_num]),
            nn.ReLU())
        if self.layer_num == 0:
            self.dense.add_module('satlu', SatLU())
            
        # Update
        if self.layer_num < self.hps['n_layers'] - 1:
            self.update_a = nn.Sequential(
                nn.Linear(
                    2 * self.a_channels[self.layer_num],
                    self.a_channels[self.layer_num + 1]),
                nn.ReLU())
        
        # Build E, R, and H
        self.reset()
        
        # Book keeping
        self.modules = {'recurrent' : self.recurrent, 'dense' : self.dense}
        if hasattr(self, 'update_a'):
            self.modules['update_a'] = self.update_a
        # Hack to appease the pytorch-gods
        for name, module in self.modules.items():
            setattr(self.parent, f'predcell_{self.layer_num}_{name}', module)
            
    def reset(self, batch_size=None):
        batch_size = batch_size or self.hps['batch_size']
        # E, R, and H variables
        self.E = Variable(torch.zeros(
            1,                  # Single time step
            batch_size,
            2 * self.a_channels[self.layer_num])).cuda()
        self.R = Variable(torch.zeros(
            1,                  # Single time step
            batch_size,
            self.r_channels[self.layer_num])).cuda()
        self.H = None
        
class LitPredNet(pl.LightningModule):
    def __init__(self, hps, ds=None):
        super().__init__()
        # Attribute definitions
        self.hps = hps
        self.ds = ds
        self.n_layers = self.hps['n_layers']
        self.output_mode = self.hps['output_mode']
        self.input_size = self.hps['input_size']
        self.time_steps = self.hps['time_steps']
        self.batch_size = self.hps['batch_size']
        
        # Channel sizes
        self.r_channels = [self.input_size // (2**i) 
                           for i in range(self.n_layers)] + [0,] # Convenience
        self.a_channels = [self.input_size // (2**i) 
                           for i in range(self.n_layers)]
        
        # Make sure everything checks out
        default_output_modes = ['prediction', 'error']
        assert self.output_mode in default_output_modes, \
            'Invalid output_mode: ' + str(output_mode)

        # Make all the pred cells
        self.predcells = [PredCell(self,
                                   layer_num,
                                   self.hps,
                                   self.a_channels,
                                   self.r_channels)
                          for layer_num in range(self.n_layers)]
        
        #nn.ParameterList([param for predcell in self.predcells for param in predcell.parameters])

        # How to weight the errors
        # 1 followed by zeros means just minimize error at lowest layer
        self.layer_loss_weights = Variable(torch.FloatTensor(
            [[1.]] + [[0.]]*(self.n_layers-1)).cuda())
        # How much to weight errors at each timestep
        self.time_loss_weights = 1. / (self.time_steps - 1) \
                                 * torch.ones(self.time_steps, 1)
        # Dont count first time step
        self.time_loss_weights[0] = 0
        self.time_loss_weights = Variable(self.time_loss_weights.cuda())
        
        if self.hps['device'] == 'cuda' and torch.cuda.is_available():
            print('Using GPU', flush=True)
            self.cuda()

    def forward(self, input):
        total_error = []
        # Set the expected batch size
        for cell in self.predcells:
            cell.reset(input.size(0))

        for t in range(self.time_steps):
            A = input[:,t,:].unsqueeze(0)
            A = A.type(torch.cuda.FloatTensor)

            # Loop backwards
            for cell in reversed(self.predcells):
                E, R = cell.E, cell.R
                # First time step
                if t == 0:
                    hx = (R, R)
                else:
                    hx = cell.H

                cell.R, cell.H = cell.recurrent(E, hx)

            for cell in self.predcells:
                # Go from R to A_hat
                A_hat = cell.dense(cell.R)

                # Convenience
                if cell.layer_num == 0:
                    frame_prediction = A_hat

                # Split to 2 Es
                pos = F.relu(A_hat - A)
                neg = F.relu(A - A_hat)
                E = torch.cat([pos, neg], 2)
                cell.E = E

                # If not last layer, update stored A
                if cell.layer_num < self.n_layers - 1:
                    A = cell.update_a(E)

            if self.output_mode == 'error':
                mean_error = torch.cat(
                    [torch.mean(cell.E.view(cell.E.size(1), -1),
                                1, keepdim=True)
                     for cell in self.predcells], 1)

                # batch x n_layers
                total_error.append(mean_error)
        
        if self.output_mode == 'error':
            return torch.stack(total_error, 2) # batch x n_layers x nt
        elif self.output_mode == 'prediction':
            return frame_prediction

    def timeit(method):
        """Combination of https://stackoverflow.com/questions/51503672/decorator-for-timeit-timeit-method/51503837#51503837,
        and https://www.geeksforgeeks.org/python-program-to-convert-seconds-into-hours-minutes-and-seconds/"""
        @wraps(method)
        def _time_it(self, *args, **kwargs):
            start = int(round(time.time() * 1000))
            try:
                return method(self, *args, **kwargs)
            finally:
                end_ = int(round(time.time() * 1000)) - start
                if end_ > 1000:
                    time_str = time.strftime("%H:%M:%S", time.gmtime(end_ // 1000))
                    print(f"Total execution time: {time_str}", flush=True)
                
        return _time_it

    @timeit
    def prepare_data(self):
        if self.ds is None:
            print('Loading the i3d data from disk. This can take '
                  'several minutes...', flush=True)
        self.ds = self.ds or BreakfastI3DFVDataset()
        self.ds_length = len(self.ds)
        np.random.seed(self.hps['seed'])
        self.indices = list(range(self.ds_length))
        self.train_sampler = SubsetRandomSampler(self.indices[self.hps['n_val']:])
        self.val_sampler = SubsetRandomSampler(self.indices[:self.hps['n_val']])
        
    def train_dataloader(self):
        return DataLoader(self.ds, 
                          batch_size=self.batch_size, 
                          sampler=self.train_sampler,
                          num_workers=self.hps['n_workers'])
    
    def val_dataloader(self):
        return DataLoader(self.ds, 
                          batch_size=self.batch_size, 
                          sampler=self.val_sampler,
                          num_workers=self.hps['n_workers'])
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hps['lr'])
    
    def _common_step(self, batch, batch_idx, mode):
        data, path = batch
        data = Variable(data)
        errors = self.forward(data) # batch x n_layers x nt
        loc_batch = errors.size(0)
        errors = torch.mm(errors.view(-1, self.time_steps), 
                          self.time_loss_weights) # batch*n_layers x 1
        errors = torch.mm(errors.view(loc_batch, -1), 
                          self.layer_loss_weights)
        errors = torch.mean(errors, axis=0)
        
        if mode == 'train':
            prefix = ''
        else:
            prefix = mode + '_'
            
        logs = {f'{prefix}loss' : errors}
        return {f'{prefix}loss' : errors, 'log' : logs}
    
    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, 'train')
    
    def validation_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, 'val')


### Running the Old-Fashioned Way

In [12]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
batch = None
train_errors, val_errors = None, None
gc.collect()
torch.cuda.empty_cache()

In [13]:
model = LitPredNet(hps, ds=ds)
model.prepare_data()
hps['n_epochs'] = 1
num_epochs = hps['n_epochs']

train_errors, val_errors = [], []
n_iters_per_epoch = model.ds_length // model.batch_size

print(f'Running with batch size {model.batch_size} ({n_iters_per_epoch} iterations / epoch)')

optimizer = model.configure_optimizers()

train_dataloader = model.train_dataloader()
val_dataloader = model.val_dataloader()

for epoch in range(num_epochs):    
    for batch_idx, batch in enumerate(train_dataloader):
        errors = model.training_step(batch, batch_idx)
        train_errors.append(errors['loss'].item())
        optimizer.zero_grad()
        errors['loss'].backward()
        optimizer.step()
        
        if batch_idx % (n_iters_per_epoch // 3) == 0:
            train_error = np.mean(train_errors)
            print(f'Epoch: {epoch+1}, Iteration: {batch_idx}, train_error: {train_error}')
            train_errors = []

    for batch_idx, batch in enumerate(val_dataloader):
        errors = model.val_step(batch, batch_idx)
        val_errors.append(errors['val_loss'].item())

    val_error = np.mean(val_errors)n_epochs
    train_error = np.mean(train_errors)
    print(f'Epoch: {epoch+1}/{num_epochs}, train_error: {train_error}, '
          f'val error: {val_error}')
    train_errors = []
    val_errors = []

Using GPU
Running with batch size 256 (168 iterations / epoch)
Epoch: 1, Iteration: 0, train_error: 5.090212821960449
Epoch: 1, Iteration: 56, train_error: 3.1931983871119365
Epoch: 1, Iteration: 112, train_error: 1.4628551240478243
Epoch: 1/1, train_error: 0.9116474173285745, val error: 0.8152002096176147


Model is backwards compatible.

### Lighting API

In [13]:
model, trainer = NonModelCheckpointne
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
batch = None
train_errors, val_errors = None, None
gc.collect()
torch.cuda.empty_cache()

hps['n_epochs'] = 1

model = LitPredNet(hps, ds=ds)
trainer = pl.Trainer(max_epochs=hps['n_epochs'])
trainer.fit(model)

Using GPU


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…






1

## Adding Intermediate Functionality

### Tensorboard

In [39]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
batch = None
train_errors, val_errors = None, None
gc.collect()
torch.cuda.empty_cache()

logger = pl.loggers.TensorBoardLogger(str(DIR_BK_LOGS_TB), name='litprednet')
model = LitPredNet(hps, ds=ds)
trainer = pl.Trainer(logger=logger,
                     max_epochs=1, 
                   )
trainer.fit(model)

Using GPU


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [53]:
trainer = pl.Trainer(logger=logger,
                     max_epochs=3, 
                   )
trainer.fit(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [11]:
class TBLitPredNet(LitPredNet):
    def _common_step(self, batch, batch_idx, mode):
        data, path = batch
        data = Variable(data)
        errors = self.forward(data) # batch x n_layers x nt
        loc_batch = errors.size(0)
        errors = torch.mm(errors.view(-1, self.time_steps), 
                          self.time_loss_weights) # batch*n_layers x 1
        errors = torch.mm(errors.view(loc_batch, -1), 
                          self.layer_loss_weights)
        errors = torch.mean(errors, axis=0)
        
        if mode == 'train':
            prefix = ''
        else:
            prefix = mode + '_'
            
        self.logger.experiment.add_scalar(f'{prefix.replace("_", "/")}loss', 
                                          errors, self.global_step)
        return {f'{prefix}loss' : errors}

In [77]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
batch = None
train_errors, val_errors = None, None
gc.collect()
torch.cuda.empty_cache()

logger = pl.loggers.TensorBoardLogger(str(DIR_BK_LOGS_TB), name='tblitprednet')
model = TBLitPredNet(hps, ds=ds)

trainer = pl.Trainer(max_epochs=5,
                     train_percent_check=.1, 
                     logger=logger, 
                    )
trainer.fit(model)

Using GPU


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…




1

### Checkpointing

In [79]:
class CPLitPredNet(LitPredNet):
    def _common_step(self, batch, batch_idx, mode):
        data, path = batch
        data = Variable(data)
        errors = self.forward(data) # batch x n_layers x nt
        loc_batch = errors.size(0)
        errors = torch.mm(errors.view(-1, self.time_steps), 
                          self.time_loss_weights) # batch*n_layers x 1
        errors = torch.mm(errors.view(loc_batch, -1), 
                          self.layer_loss_weights)
        errors = torch.mean(errors, axis=0)
        
        if mode == 'train':
            prefix = ''
        else:
            prefix = mode + '_'
            
        self.logger.experiment.add_scalar(f'{prefix}loss', 
                                          errors, self.global_step)
        return {f'{prefix}loss' : errors}

    def validation_epoch_end(self, output):
        out = output[0]
        out['global_step'] = self.global_step
        return out
    
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
ckpt = None
train_errors, val_errors = None, None
gc.collect()
torch.cuda.empty_cache()

model = CPLitPredNet(hps, ds=ds)
model.name = 'ckpt_test_litprednet'
logger = pl.loggers.TensorBoardLogger(str(DIR_BK_LOGS_TB), name=model.name)

Using GPU


In [None]:
ckpt_dir = DIR_BK_CHECKPOINTS / f'{model.name}_v{logger.version}'
ckpt = pl.callbacks.ModelCheckpoint(
    filepath=str(ckpt_dir / 'bk_i3d_{global_step:05d}_{epoch:03d}_{val_loss:.3f}'),
    verbose=True,
    save_top_k=10,
    period=1
)

In [83]:
trainer = pl.Trainer(default_save_path=str(DIR_BK_CHECKPOINTS),
                     checkpoint_callback=ckpt,
                     epochs=5,
                     logger=logger,
                     )

In [None]:
trainer.fit(model)

Need to remember to use`max_epochs` not `epochs`.

### Hparams + Checkpointing

In [141]:
hps = {
    'model_name' : 'hparams_litprednet',
    'n_layers' : 4,
    'input_size' : 2048,
    'time_steps' : 64,
    'dir_checkpoints' : str(DIR_BK_CHECKPOINTS),
    'dir_weights' : str(DIR_BK_WEIGHTS),
    'dir_logs' : str(DIR_BK_LOGS_TB),
    'lr' : 0.000333,
    'output_mode' : 'error',
    'device' : 'cuda',
    'n_val' : 256,
    'seed' : 117,
    'batch_size' : 256,
    'n_epochs' : 10,
    'n_workers' : 4,
}

In [151]:
import time
import pytorch_lightning as pl
import torch.nn as nn
from torch.nn import functional as F
from functools import wraps
from lab.utils import flatten
from lab.torch.lstm import LSTM
from lab.torch.activations import SatLU


class PredCell(object):
    def __init__(self, parent, layer_num, hparams, a_channels, r_channels):
        super().__init__()
        self.parent = parent
        self.layer_num = layer_num        
        if isinstance(hparams, dict):
            self.hparams = Namespace(**hparams)
        else:
            self.hparams = hparams
        self.a_channels = a_channels
        self.r_channels = r_channels
        
        # Reccurent
        self.recurrent = LSTM(2 * self.a_channels[self.layer_num],
                              self.r_channels[self.layer_num])
        self.recurrent.reset_parameters()
        
        # Dense
        self.dense = nn.Sequential(
            nn.Linear(self.r_channels[self.layer_num],
                      self.a_channels[self.layer_num]),
            nn.ReLU())
        if self.layer_num == 0:
            self.dense.add_module('satlu', SatLU())
            
        # Update
        if self.layer_num < self.hparams.n_layers - 1:
            self.update_a = nn.Sequential(
                nn.Linear(
                    2 * self.a_channels[self.layer_num],
                    self.a_channels[self.layer_num + 1]),
                nn.ReLU())
        
        # Build E, R, and H
        self.reset()
        
        # Book keeping
        self.modules = {'recurrent' : self.recurrent, 'dense' : self.dense}
        if hasattr(self, 'update_a'):
            self.modules['update_a'] = self.update_a
        # Hack to appease the pytorch-gods
        for name, module in self.modules.items():
            setattr(self.parent, f'predcell_{self.layer_num}_{name}', module)
            
    def reset(self, batch_size=None):
        batch_size = batch_size or self.hparams.batch_size
        # E, R, and H variables
        self.E = Variable(torch.zeros(
            1,                  # Single time step
            batch_size,
            2 * self.a_channels[self.layer_num])).cuda()
        self.R = Variable(torch.zeros(
            1,                  # Single time step
            batch_size,
            self.r_channels[self.layer_num])).cuda()
        self.H = None
        
class LitPredNet(pl.LightningModule):
    def __init__(self, hparams, ds=None):
        super().__init__()
        # Attribute definitions
        self.hparams = hparams
        self.n_layers = self.hparams.n_layers
        self.output_mode = self.hparams.output_mode
        self.input_size = self.hparams.input_size
        self.time_steps = self.hparams.time_steps
        self.batch_size = self.hparams.batch_size
        self.ds = ds
        
        # Channel sizes
        self.r_channels = [self.input_size // (2**i) 
                           for i in range(self.n_layers)] + [0,] # Convenience
        self.a_channels = [self.input_size // (2**i) 
                           for i in range(self.n_layers)]
        
        # Make sure everything checks out
        default_output_modes = ['prediction', 'error']
        assert self.output_mode in default_output_modes, \
            'Invalid output_mode: ' + str(output_mode)

        # Make all the pred cells
        self.predcells = [PredCell(self,
                                   layer_num,
                                   self.hparams,
                                   self.a_channels,
                                   self.r_channels)
                          for layer_num in range(self.n_layers)]
        
        #nn.ParameterList([param for predcell in self.predcells for param in predcell.parameters])

        # How to weight the errors
        # 1 followed by zeros means just minimize error at lowest layer
        self.layer_loss_weights = Variable(torch.FloatTensor(
            [[1.]] + [[0.]]*(self.n_layers-1)).cuda())
        # How much to weight errors at each timestep
        self.time_loss_weights = 1. / (self.time_steps - 1) \
                                 * torch.ones(self.time_steps, 1)
        # Dont count first time step
        self.time_loss_weights[0] = 0
        self.time_loss_weights = Variable(self.time_loss_weights.cuda())
        
        if self.hparams.device == 'cuda' and torch.cuda.is_available():
            print('Using GPU', flush=True)
            self.cuda()

    def forward(self, input):
        total_error = []
        # Set the expected batch size
        for cell in self.predcells:
            cell.reset(input.size(0))

        for t in range(self.time_steps):
            A = input[:,t,:].unsqueeze(0)
            A = A.type(torch.cuda.FloatTensor)

            # Loop backwards
            for cell in reversed(self.predcells):
                E, R = cell.E, cell.R
                # First time step
                if t == 0:
                    hx = (R, R)
                else:
                    hx = cell.H

                cell.R, cell.H = cell.recurrent(E, hx)

            for cell in self.predcells:
                # Go from R to A_hat
                A_hat = cell.dense(cell.R)

                # Convenience
                if cell.layer_num == 0:
                    frame_prediction = A_hat

                # Split to 2 Es
                pos = F.relu(A_hat - A)
                neg = F.relu(A - A_hat)
                E = torch.cat([pos, neg], 2)
                cell.E = E

                # If not last layer, update stored A
                if cell.layer_num < self.n_layers - 1:
                    A = cell.update_a(E)

            if self.output_mode == 'error':
                mean_error = torch.cat(
                    [torch.mean(cell.E.view(cell.E.size(1), -1),
                                1, keepdim=True)
                     for cell in self.predcells], 1)

                # batch x n_layers
                total_error.append(mean_error)
        
        if self.output_mode == 'error':
            return torch.stack(total_error, 2) # batch x n_layers x nt
        elif self.output_mode == 'prediction':
            return frame_prediction

    def timeit(method):
        """Combination of https://stackoverflow.com/questions/51503672/decorator-for-timeit-timeit-method/51503837#51503837,
        and https://www.geeksforgeeks.org/python-program-to-convert-seconds-into-hours-minutes-and-seconds/"""
        @wraps(method)
        def _time_it(self, *args, **kwargs):
            start = int(round(time.time() * 1000))
            try:
                return method(self, *args, **kwargs)
            finally:
                end_ = int(round(time.time() * 1000)) - start
                if end_ > 1000:
                    time_str = time.strftime("%H:%M:%S", time.gmtime(end_ // 1000))
                    print(f"Total execution time: {time_str}", flush=True)
                
        return _time_it

    @timeit
    def prepare_data(self):
        if self.ds is None:
            print('Loading the i3d data from disk. This can take '
                  'several minutes...', flush=True)
        self.ds = self.ds or BreakfastI3DFVDataset()
        self.ds_length = len(self.ds)
        np.random.seed(self.hparams.seed)
        self.indices = list(range(self.ds_length))
        self.train_sampler = SubsetRandomSampler(self.indices[self.hparams.n_val:])
        self.val_sampler = SubsetRandomSampler(self.indices[:self.hparams.n_val])
        
    def train_dataloader(self):
        return DataLoader(self.ds, 
                          batch_size=self.batch_size, 
                          sampler=self.train_sampler,
                          num_workers=self.hparams.n_workers)
    
    def val_dataloader(self):
        return DataLoader(self.ds, 
                          batch_size=self.batch_size, 
                          sampler=self.val_sampler,
                          num_workers=self.hparams.n_workers)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
    
    def _common_step(self, batch, batch_idx, mode):
        data, path = batch
        data = Variable(data)
        errors = self.forward(data) # batch x n_layers x nt
        loc_batch = errors.size(0)
        errors = torch.mm(errors.view(-1, self.time_steps), 
                          self.time_loss_weights) # batch*n_layers x 1
        errors = torch.mm(errors.view(loc_batch, -1), 
                          self.layer_loss_weights)
        errors = torch.mean(errors, axis=0)
        
        if mode == 'train':
            prefix = ''
        else:
            prefix = mode + '_'
            
        self.logger.experiment.add_scalar(f'{prefix}loss', 
                                          errors, self.global_step)
        return {f'{prefix}loss' : errors}

    def validation_epoch_end(self, output):
        out_dict = {}
        out_dict['val_loss'] = np.mean([out['val_loss'].item() for out in output])
        out_dict['global_step'] = self.global_step
        return out_dict
    
    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, 'train')
    
    def validation_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, 'val')

In [152]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
ckpt = None
train_errors, val_errors = None, None
gc.collect()
torch.cuda.empty_cache()

hparams = Namespace(**hps)
hparams.name = 'hparams_litprednet'

log_dir = Path(hparams.dir_logs) / f'{hparams.name}'
if not log_dir.exists():
    log_dir.mkdir(parents=True)
logger = pl.loggers.TensorBoardLogger(str(log_dir.parent), name=hparams.name)

ckpt_dir = Path(hparams.dir_checkpoints) / f'{hparams.name}_v{logger.version}'
if not ckpt_dir.exists():
    ckpt_dir.mkdir(parents=True)
    
ckpt = pl.callbacks.ModelCheckpoint(
    filepath=str(ckpt_dir / 'bk_i3d_{global_step:05d}_{epoch:03d}_{val_loss:.3f}'),
    verbose=True,
    save_top_k=2,
    period=.25
)

trainer = pl.Trainer(default_save_path=str(DIR_BK_CHECKPOINTS),
                     checkpoint_callback=ckpt,
                     max_epochs=1,
                     logger=logger,
                     val_check_interval=0.25,
                     )

model = LitPredNet(hparams)
model.ds = ds

Using GPU


In [153]:
trainer.fit(model)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…




1

In [154]:
model = None
gc.collect()
torch.cuda.empty_cache()

model = LitPredNet.load_from_checkpoint(
    str(ckpt_dir / 'bk_i3d_global_step=00167_epoch=000_val_loss=0.872.ckpt'))
model.ds = ds

Using GPU


In [155]:
trainer.fit(model)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…




1

In [156]:
trainer.max_epochs = 2

In [157]:
trainer.fit(model)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…




1

In [159]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
ckpt = None
train_errors, val_errors = None, None
gc.collect()
torch.cuda.empty_cache()

hparams = Namespace(**hps)
hparams.name = 'hparams_litprednet'

log_dir = Path(hparams.dir_logs) / f'{hparams.name}'
if not log_dir.exists():
    log_dir.mkdir(parents=True)
logger = pl.loggers.TensorBoardLogger(str(log_dir.parent), name=hparams.name, version=2)

ckpt_dir = Path(hparams.dir_checkpoints) / f'{hparams.name}_v{logger.version}'
if not ckpt_dir.exists():
    ckpt_dir.mkdir(parents=True)
    
ckpt = pl.callbacks.ModelCheckpoint(
    filepath=str(ckpt_dir / 'bk_i3d_{global_step:05d}_{epoch:03d}_{val_loss:.3f}'),
    verbose=True,
    save_top_k=2,
    period=.25
)

trainer = pl.Trainer(default_save_path=str(DIR_BK_CHECKPOINTS),
                     checkpoint_callback=ckpt,
                     max_epochs=1,
                     logger=logger,
                     val_check_interval=0.25,
                     )

model = LitPredNet.load_from_checkpoint(
    str(ckpt_dir / 'bk_i3d_global_step=00671_epoch=001_val_loss=0.329.ckpt'))
model.ds = ds

Using GPU


In [160]:
trainer.fit(model)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…




1

In [165]:
trainer.fit(model)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…




1

In [170]:
model.global_step

335

In [174]:
trainer.current_epoch = 0
trainer.fit(model)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1.0, style=Prog…




1