# 2.1.0 Running PredNet on Fractal Data

Getting the fractal data to work with the implemented prednet.

## Jupyter Extensions

Load [watermark](https://github.com/rasbt/watermark) to see the state of the machine and environment that's running the notebook. To make sense of the options, take a look at the [usage](https://github.com/rasbt/watermark#usage) section of the readme.

In [1]:
# Load `watermark` extension
%load_ext watermark
# Display the status of the machine and other non-code related info
%watermark -n -m -g -b -t -h

Fri Aug 21 2020 19:42:35 

compiler   : GCC 7.3.0
system     : Linux
release    : 5.4.0-42-generic
machine    : x86_64
processor  : x86_64
CPU cores  : 4
interpreter: 64bit
host name  : apra-x3
Git hash   : 3a3dff83da3f7b5527060c5ee1d9569065b84bea
Git branch : master


Load [autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) which will always reload modules marked with `%aimport`.

This behavior can be inverted by running `autoreload 2` which will set everything to be auto-reloaded *except* for modules marked with `%aimport`.

In [2]:
# Load `autoreload` extension
%load_ext autoreload
# Set autoreload behavior
%autoreload 1

Load `matplotlib` in one of the more `jupyter`-friendly [rich-output modes](https://ipython.readthedocs.io/en/stable/interactive/plotting.html). Some options (that may or may not have worked) are `inline`, `notebook`, and `gtk`.

In [3]:
# Set the matplotlib mode
%matplotlib inline

## Imports

In [4]:
import gc
import logging
from argparse import Namespace
from pathlib import Path

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl
from PIL import Image, ImageOps
from torch.utils.data import IterableDataset, DataLoader

%aimport prevseg.constants
import prevseg.constants as const
%aimport prevseg.index
import prevseg.index as index
%aimport prevseg.dataloaders.schapiro
import prevseg.dataloaders.schapiro as sch
%aimport prevseg.schapiro
from prevseg.schapiro import walk, graph
%aimport prevseg.models.prednet
import prevseg.models.prednet as prednet

# Keep track of versions of everything
%watermark -v -iv

networkx          2.4
torch             1.6.0
numpy             1.19.1
prevseg           0+untagged.47.g3a3dff8.dirty
pytorch_lightning 0.8.5
logging           0.5.1.2
PIL.Image         7.2.0
CPython 3.8.5
IPython 7.16.1


## Adding Fixed Path-Lengths

In [5]:
max_steps = 2
epochs = 2
batch_size = 2

iter_ds = sch.ShapiroFractalsDataset(batch_size=batch_size, max_steps=max_steps)
loader = DataLoader(iter_ds, batch_size=None)
for _ in range(epochs):
    for i, batch in enumerate(loader):
        print(i, batch.shape)
        if i > max_steps+5:
            print('bad')
            break

Created mapping as follows:
{0: '39', 1: '78', 2: '91', 3: '36', 4: '9', 5: '55', 6: '23', 7: '101', 8: '52', 9: '35', 10: '49', 11: '100', 12: '46', 13: '10', 14: '59'}
0 torch.Size([2, 2, 128, 128])
1 torch.Size([2, 2, 128, 128])
2 torch.Size([2, 2, 128, 128])
3 torch.Size([2, 2, 128, 128])
4 torch.Size([2, 2, 128, 128])
5 torch.Size([2, 2, 128, 128])
6 torch.Size([2, 2, 128, 128])
7 torch.Size([2, 2, 128, 128])
8 torch.Size([2, 2, 128, 128])
bad
0 torch.Size([2, 2, 128, 128])
1 torch.Size([2, 2, 128, 128])
2 torch.Size([2, 2, 128, 128])
3 torch.Size([2, 2, 128, 128])
4 torch.Size([2, 2, 128, 128])
5 torch.Size([2, 2, 128, 128])
6 torch.Size([2, 2, 128, 128])
7 torch.Size([2, 2, 128, 128])
8 torch.Size([2, 2, 128, 128])
bad


In [6]:
class ShapiroFractalsDataset(IterableDataset):
    modes = set(('random', 'euclidean', 'hamiltonian'))
    def __init__(self, batch_size=32, n_pentagons=3, max_steps=128, n_paths=128, mapping=None, 
                 mode='random', debug=False):
        self.batch_size = batch_size
        self.n_pentagons = n_pentagons
        self.max_steps = max_steps
        self.n_paths = n_paths
        self.mapping = mapping
        self.mode = mode
        self.debug = debug
        assert self.mode in self.modes
        
        self.G = graph.schapiro_graph(n_pentagons=n_pentagons)
        
        self.load_node_stimuli()
        
        self.mapping = {node : path.stem
                        for node, path in zip(range(len(self.G.nodes)),
                                              self.paths_data)}
        print(f'Created mapping as follows:\n{self.mapping}')
        
        if self.debug:
            self.sample_transform = lambda sample : sample
        else:
            self.sample_transform = lambda sample : self.array_data[sample]            
        
    def load_node_stimuli(self):
        # Load the fractal images into memory
        assert index.DIR_SCH_FRACTALS.exists()
        if self.mapping:
            self.paths_data = [index.DIR_SCH_FRACTALS / (name+'.tiff')
                               for name in self.mapping.values()]
        else:
            paths_data = list(index.DIR_SCH_FRACTALS.iterdir())
            np.random.shuffle(paths_data)
            self.paths_data = paths_data[:5*self.n_pentagons]
        self.array_data = np.array(
            [np.array(ImageOps.grayscale(Image.open(str(path))))
             for path in self.paths_data])
        
    def iter_single_sample(self): 
        if self.mode == 'random':
            iter_walk = walk.walk_random(self.G, steps=self.max_steps)
        elif self.mode == 'euclidean':
            iter_walk = walk.walk_euclidean(self.G)
        elif self.mode == 'hamiltonian':
            iter_walk = walk.walk_hamiltonian(self.G)
        
        for sample in iter_walk:
            yield self.sample_transform(sample[0])
        
    def iter_batch_sample(self):
        yield from zip(*[self.iter_single_sample() for _ in range(self.batch_size)])
        
    def iter_batch_dataset(self):       
        for _ in range(self.n_paths):
            yield np.moveaxis(np.array(list(self.iter_batch_sample())), 0, 1)
        
    def __iter__(self):
        return self.iter_batch_dataset()

    
max_steps = 10
epochs = 2
batch_size = 2
n_paths = 5

iter_ds = ShapiroFractalsDataset(batch_size=batch_size, max_steps=max_steps, n_paths=n_paths)
loader = DataLoader(iter_ds, batch_size=None)
for _ in range(epochs):
    for i, batch in enumerate(loader):
        print(i, batch.shape)
        if i > max_steps+5:
            print('bad')
            break
            
mapping = iter_ds.mapping

Created mapping as follows:
{0: '53', 1: '98', 2: '69', 3: '65', 4: '85', 5: '33', 6: '75', 7: '20', 8: '16', 9: '28', 10: '64', 11: '72', 12: '76', 13: '99', 14: '54'}
0 torch.Size([2, 10, 128, 128])
1 torch.Size([2, 10, 128, 128])
2 torch.Size([2, 10, 128, 128])
3 torch.Size([2, 10, 128, 128])
4 torch.Size([2, 10, 128, 128])
0 torch.Size([2, 10, 128, 128])
1 torch.Size([2, 10, 128, 128])
2 torch.Size([2, 10, 128, 128])
3 torch.Size([2, 10, 128, 128])
4 torch.Size([2, 10, 128, 128])


In [7]:
max_steps = 10
epochs = 2
batch_size = 2
n_paths = 5

class ShapiroResnetEmbeddingDataset(ShapiroFractalsDataset):
    def load_node_stimuli(self):
        # Load the fractal images into memory
        assert index.DIR_SCH_FRACTALS.exists()
        if self.mapping:
            self.paths_data = [index.DIR_SCH_FRACTALS_EMB / (name+'.npy')
                               for name in self.mapping.values()]
        else:
            paths_data = list(index.DIR_SCH_FRACTALS_EMB.iterdir())
            np.random.shuffle(paths_data)
            self.paths_data = paths_data[:5*self.n_pentagons]
        self.array_data = np.array(
            [np.array(np.load(str(path)))
             for path in self.paths_data])    

iter_ds = ShapiroResnetEmbeddingDataset(
    batch_size=batch_size, 
    max_steps=max_steps, 
    n_paths=n_paths,
    mapping=mapping,
)
loader = DataLoader(iter_ds, batch_size=None)
for _ in range(epochs):
    for i, batch in enumerate(loader):
        print(i, batch.shape)
        if i > max_steps+5:
            print('bad')
            break

Created mapping as follows:
{0: '53', 1: '98', 2: '69', 3: '65', 4: '85', 5: '33', 6: '75', 7: '20', 8: '16', 9: '28', 10: '64', 11: '72', 12: '76', 13: '99', 14: '54'}
0 torch.Size([2, 10, 2048])
1 torch.Size([2, 10, 2048])
2 torch.Size([2, 10, 2048])
3 torch.Size([2, 10, 2048])
4 torch.Size([2, 10, 2048])
0 torch.Size([2, 10, 2048])
1 torch.Size([2, 10, 2048])
2 torch.Size([2, 10, 2048])
3 torch.Size([2, 10, 2048])
4 torch.Size([2, 10, 2048])


## PredNet

In [8]:
class PredNetTracked(prednet.PredNetTracked):
    @prednet.PredNet.timeit
    def prepare_data(self):
        self.ds = self.ds or ShapiroResnetEmbeddingDataset(
            batch_size=self.batch_size, 
            n_pentagons=self.hparams.n_pentagons, 
            max_steps=self.hparams.max_steps, 
            n_paths=self.hparams.n_paths,
            debug=self.hparams.debug)
        self.ds_val = ShapiroResnetEmbeddingDataset(
            batch_size=self.batch_size, 
            n_pentagons=self.hparams.n_pentagons,
            n_paths=1,
            mapping=self.ds.mapping,
            mode='euclidean',
            debug=self.hparams.debug)
        
#         n_test, n_val = self.hparams.n_test, self.hparams.n_val
        
    def train_dataloader(self):
        return DataLoader(self.ds, 
                          batch_size=None,
                          num_workers=self.hparams.n_workers)
    
    def val_dataloader(self):
        return DataLoader(self.ds_val, 
                          batch_size=None,
                          num_workers=self.hparams.n_workers)
    
    def _common_step(self, batch, batch_idx, mode):
        data = batch
        errors = self.forward(data) # batch x n_layers x nt

        loc_batch = errors.size(0)
        errors = torch.mm(errors.view(-1, self.time_steps), 
                          self.time_loss_weights) # batch*n_layers x 1
        errors = torch.mm(errors.view(loc_batch, -1), 
                          self.layer_loss_weights)
        errors = torch.mean(errors, axis=0)
        
        if mode == 'train':
            prefix = ''
        else:
            prefix = mode + '_'
            
        self.logger.experiment.add_scalar(f'{prefix}loss', 
                                          errors, self.global_step)
        return {f'{prefix}loss' : errors}

In [64]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
ckpt = None
train_errors, val_errors = None, None
res = None
gc.collect()
torch.cuda.empty_cache()

ModelClass = PredNetTracked
hparams = const.DEFAULT_HPARAMS
hparams.n_layers = 2
hparams.batch_size = 256 + 128
hparams.max_steps = 128
hparams.n_paths = 16
hparams.n_pentagons = 3
hparams.time_steps = hparams.max_steps
hparams.exp_name = 'schapiro_test'
hparams.name = f'{ModelClass.name}_{hparams.exp_name}'
hparams.debug = False
hparams.n_workers = 2
hparams.lr = 0.001

log_dir = Path(hparams.dir_logs) / f'{hparams.name}'
if not log_dir.exists():
    log_dir.mkdir(parents=True)
logger = pl.loggers.TensorBoardLogger(str(log_dir.parent), name=hparams.name)

ckpt_dir = Path(hparams.dir_checkpoints) / f'{hparams.name}_v{logger.version}'
if not ckpt_dir.exists():
    ckpt_dir.mkdir(parents=True)
    
ckpt = pl.callbacks.ModelCheckpoint(
    filepath=str(ckpt_dir / (hparams.exp_name+'_{global_step:05d}_{epoch:03d}_{val_loss:.3f}')),
    verbose=True,
    save_top_k=1,
)

trainer = pl.Trainer(checkpoint_callback=ckpt,
                     max_epochs=20,
                     logger=logger,
                     gpus=1
                     )

model = ModelClass(hparams)
model.ds = None

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [65]:
trainer.fit(model)


  | Name                 | Type       | Params
----------------------------------------------------
0 | predcell_0_recurrent | LSTM       | 67 M  
1 | predcell_0_dense     | Sequential | 4 M   
2 | predcell_0_update_a  | Sequential | 4 M   
3 | predcell_1_recurrent | LSTM       | 12 M  
4 | predcell_1_dense     | Sequential | 1 M   


Created mapping as follows:
{0: '1', 1: '60', 2: '95', 3: '100', 4: '14', 5: '2', 6: '63', 7: '58', 8: '96', 9: '55', 10: '99', 11: '50', 12: '7', 13: '89', 14: '12'}
Created mapping as follows:
{0: '1', 1: '60', 2: '95', 3: '100', 4: '14', 5: '2', 6: '63', 7: '58', 8: '96', 9: '55', 10: '99', 11: '50', 12: '7', 13: '89', 14: '12'}
Epoch 1: : 32it [03:50,  7.19s/it, loss=0.157, v_num=47]              
Validating: 0it [00:00, ?it/s][A
Epoch 1: : 33it [03:50,  7.00s/it, loss=0.157, v_num=47]
Epoch 1: : 34it [03:51,  6.81s/it, loss=0.157, v_num=47]


Epoch 00000: val_loss reached 0.15527 (best 0.15527), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v47/schapiro_test_global_step=00031_epoch=000_val_loss=0.155.ckpt as top 1


Epoch 1: : 34it [03:53,  6.88s/it, loss=0.157, v_num=47]
Epoch 2: : 32it [03:49,  7.17s/it, loss=0.152, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 2: : 33it [03:50,  6.97s/it, loss=0.152, v_num=47]
Epoch 2: : 34it [03:50,  6.79s/it, loss=0.152, v_num=47]


Epoch 00001: val_loss reached 0.15085 (best 0.15085), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v47/schapiro_test_global_step=00063_epoch=001_val_loss=0.151.ckpt as top 1


Epoch 2: : 34it [03:53,  6.86s/it, loss=0.152, v_num=47]
Epoch 3: : 32it [03:48,  7.15s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 3: : 33it [03:49,  6.96s/it, loss=0.149, v_num=47]
Epoch 3: : 34it [03:50,  6.77s/it, loss=0.149, v_num=47]


Epoch 00002: val_loss reached 0.14926 (best 0.14926), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v47/schapiro_test_global_step=00095_epoch=002_val_loss=0.149.ckpt as top 1


Epoch 3: : 34it [03:52,  6.84s/it, loss=0.149, v_num=47]
Epoch 4: : 32it [03:48,  7.14s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 4: : 33it [03:49,  6.95s/it, loss=0.149, v_num=47]
Epoch 4: : 34it [03:49,  6.76s/it, loss=0.149, v_num=47]


Epoch 00003: val_loss reached 0.14920 (best 0.14920), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v47/schapiro_test_global_step=00127_epoch=003_val_loss=0.149.ckpt as top 1


Epoch 4: : 34it [03:52,  6.83s/it, loss=0.149, v_num=47]
Epoch 5: : 32it [03:49,  7.16s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 5: : 33it [03:49,  6.96s/it, loss=0.149, v_num=47]
Epoch 5: : 34it [03:50,  6.77s/it, loss=0.149, v_num=47]


Epoch 00004: val_loss  was not in top 1


Epoch 5: : 34it [03:50,  6.77s/it, loss=0.149, v_num=47]
Epoch 6: : 32it [03:49,  7.16s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 6: : 33it [03:49,  6.97s/it, loss=0.149, v_num=47]
Epoch 6: : 34it [03:50,  6.78s/it, loss=0.149, v_num=47]


Epoch 00005: val_loss  was not in top 1


Epoch 6: : 34it [03:50,  6.78s/it, loss=0.149, v_num=47]
Epoch 7: : 32it [03:48,  7.15s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 7: : 33it [03:49,  6.96s/it, loss=0.149, v_num=47]
Epoch 7: : 34it [03:50,  6.77s/it, loss=0.149, v_num=47]


Epoch 00006: val_loss  was not in top 1


Epoch 7: : 34it [03:50,  6.77s/it, loss=0.149, v_num=47]
Epoch 8: : 32it [03:48,  7.15s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 8: : 33it [03:49,  6.96s/it, loss=0.149, v_num=47]
Epoch 8: : 34it [03:50,  6.77s/it, loss=0.149, v_num=47]


Epoch 00007: val_loss  was not in top 1


Epoch 8: : 34it [03:50,  6.77s/it, loss=0.149, v_num=47]
Epoch 9: : 32it [03:48,  7.15s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 9: : 33it [03:49,  6.95s/it, loss=0.149, v_num=47]
Epoch 9: : 34it [03:49,  6.76s/it, loss=0.149, v_num=47]


Epoch 00008: val_loss reached 0.14918 (best 0.14918), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v47/schapiro_test_global_step=00287_epoch=008_val_loss=0.149.ckpt as top 1


Epoch 9: : 34it [03:52,  6.84s/it, loss=0.149, v_num=47]
Epoch 10: : 32it [03:47,  7.12s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 10: : 33it [03:48,  6.92s/it, loss=0.149, v_num=47]
Epoch 10: : 34it [03:48,  6.74s/it, loss=0.149, v_num=47]


Epoch 00009: val_loss  was not in top 1


Epoch 10: : 34it [03:49,  6.74s/it, loss=0.149, v_num=47]
Epoch 11: : 32it [03:48,  7.13s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 11: : 33it [03:49,  6.94s/it, loss=0.149, v_num=47]
Epoch 11: : 34it [03:49,  6.75s/it, loss=0.149, v_num=47]


Epoch 00010: val_loss  was not in top 1


Epoch 11: : 34it [03:49,  6.75s/it, loss=0.149, v_num=47]
Epoch 12: : 32it [03:48,  7.15s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 12: : 33it [03:49,  6.96s/it, loss=0.149, v_num=47]
Epoch 12: : 34it [03:50,  6.77s/it, loss=0.149, v_num=47]


Epoch 00011: val_loss reached 0.14911 (best 0.14911), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v47/schapiro_test_global_step=00383_epoch=011_val_loss=0.149.ckpt as top 1


Epoch 12: : 34it [03:52,  6.84s/it, loss=0.149, v_num=47]
Epoch 13: : 32it [03:48,  7.14s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 13: : 33it [03:49,  6.95s/it, loss=0.149, v_num=47]
Epoch 13: : 34it [03:49,  6.76s/it, loss=0.149, v_num=47]


Epoch 00012: val_loss reached 0.14901 (best 0.14901), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v47/schapiro_test_global_step=00415_epoch=012_val_loss=0.149.ckpt as top 1


Epoch 13: : 34it [03:52,  6.83s/it, loss=0.149, v_num=47]
Epoch 14: : 32it [03:48,  7.13s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 14: : 33it [03:49,  6.94s/it, loss=0.149, v_num=47]
Epoch 14: : 34it [03:49,  6.75s/it, loss=0.149, v_num=47]


Epoch 00013: val_loss reached 0.14899 (best 0.14899), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v47/schapiro_test_global_step=00447_epoch=013_val_loss=0.149.ckpt as top 1


Epoch 14: : 34it [03:52,  6.83s/it, loss=0.149, v_num=47]
Epoch 15: : 32it [03:48,  7.13s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 15: : 33it [03:48,  6.94s/it, loss=0.149, v_num=47]
Epoch 15: : 34it [03:49,  6.75s/it, loss=0.149, v_num=47]


Epoch 00014: val_loss  was not in top 1


Epoch 15: : 34it [03:49,  6.75s/it, loss=0.149, v_num=47]
Epoch 16: : 32it [03:48,  7.13s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 16: : 33it [03:48,  6.93s/it, loss=0.149, v_num=47]
Epoch 16: : 34it [03:49,  6.74s/it, loss=0.149, v_num=47]


Epoch 00015: val_loss reached 0.14899 (best 0.14899), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v47/schapiro_test_global_step=00511_epoch=015_val_loss=0.149.ckpt as top 1


Epoch 16: : 34it [03:51,  6.82s/it, loss=0.149, v_num=47]
Epoch 17: : 32it [03:48,  7.13s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 17: : 33it [03:48,  6.93s/it, loss=0.149, v_num=47]
Epoch 17: : 34it [03:49,  6.74s/it, loss=0.149, v_num=47]


Epoch 00016: val_loss  was not in top 1


Epoch 17: : 34it [03:49,  6.75s/it, loss=0.149, v_num=47]
Epoch 18: : 32it [03:48,  7.14s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 18: : 33it [03:49,  6.95s/it, loss=0.149, v_num=47]
Epoch 18: : 34it [03:49,  6.76s/it, loss=0.149, v_num=47]


Epoch 00017: val_loss  was not in top 1


Epoch 18: : 34it [03:49,  6.76s/it, loss=0.149, v_num=47]
Epoch 19: : 32it [03:49,  7.16s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 19: : 33it [03:49,  6.97s/it, loss=0.149, v_num=47]
Epoch 19: : 34it [03:50,  6.78s/it, loss=0.149, v_num=47]


Epoch 00018: val_loss  was not in top 1


Epoch 19: : 34it [03:50,  6.78s/it, loss=0.149, v_num=47]
Epoch 20: : 32it [03:48,  7.14s/it, loss=0.149, v_num=47]
Validating: 0it [00:00, ?it/s][A
Epoch 20: : 33it [03:49,  6.95s/it, loss=0.149, v_num=47]
Epoch 20: : 34it [03:49,  6.76s/it, loss=0.149, v_num=47]


Epoch 00019: val_loss  was not in top 1


Epoch 20: : 34it [03:49,  6.76s/it, loss=0.149, v_num=47]
Epoch 20: : 34it [03:49,  6.76s/it, loss=0.149, v_num=47]


1

Best LR was 0.000333

## Hierarchical PredNet

In [20]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
ckpt = None
train_errors, val_errors = None, None
res = None
gc.collect()
torch.cuda.empty_cache()

ModelClass = PredNetTracked
hparams = const.DEFAULT_HPARAMS
hparams.n_layers = 3
hparams.batch_size = 256 + 32
hparams.max_steps = 128
hparams.n_paths = 16
hparams.n_pentagons = 3
hparams.time_steps = hparams.max_steps
hparams.exp_name = 'schapiro_test'
hparams.name = f'{ModelClass.name}_{hparams.exp_name}'
hparams.debug = False
hparams.n_workers = 2
hparams.lr = 0.000333

log_dir = Path(hparams.dir_logs) / f'{hparams.name}'
if not log_dir.exists():
    log_dir.mkdir(parents=True)
logger = pl.loggers.TensorBoardLogger(str(log_dir.parent), name=hparams.name)

ckpt_dir = Path(hparams.dir_checkpoints) / f'{hparams.name}_v{logger.version}'
if not ckpt_dir.exists():
    ckpt_dir.mkdir(parents=True)
    
ckpt = pl.callbacks.ModelCheckpoint(
    filepath=str(ckpt_dir / (hparams.exp_name+'_{global_step:05d}_{epoch:03d}_{val_loss:.3f}')),
    verbose=True,
    save_top_k=1,
)

trainer = pl.Trainer(checkpoint_callback=ckpt,
                     max_epochs=25,
                     logger=logger,
                     gpus=1
                     )

model = ModelClass(hparams)
model.ds = None

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [21]:
trainer.fit(model)


  | Name                 | Type       | Params
----------------------------------------------------
0 | predcell_0_recurrent | LSTM       | 67 M  
1 | predcell_0_dense     | Sequential | 4 M   
2 | predcell_0_update_a  | Sequential | 4 M   
3 | predcell_1_recurrent | LSTM       | 16 M  
4 | predcell_1_dense     | Sequential | 1 M   
5 | predcell_1_update_a  | Sequential | 1 M   
6 | predcell_2_recurrent | LSTM       | 3 M   
7 | predcell_2_dense     | Sequential | 262 K 


Created mapping as follows:
{0: '1', 1: '60', 2: '95', 3: '100', 4: '14', 5: '2', 6: '63', 7: '58', 8: '96', 9: '55', 10: '99', 11: '50', 12: '7', 13: '89', 14: '12'}
Created mapping as follows:
{0: '1', 1: '60', 2: '95', 3: '100', 4: '14', 5: '2', 6: '63', 7: '58', 8: '96', 9: '55', 10: '99', 11: '50', 12: '7', 13: '89', 14: '12'}
Epoch 1: : 32it [03:48,  7.13s/it, loss=0.116, v_num=52]              
Validating: 0it [00:00, ?it/s][A
Epoch 1: : 33it [03:48,  6.94s/it, loss=0.116, v_num=52]
Epoch 1: : 34it [03:49,  6.75s/it, loss=0.116, v_num=52]


Epoch 00000: val_loss reached 0.11156 (best 0.11156), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v52/schapiro_test_global_step=00031_epoch=000_val_loss=0.112.ckpt as top 1


Epoch 1: : 34it [03:52,  6.83s/it, loss=0.116, v_num=52]
Epoch 2: : 32it [03:50,  7.20s/it, loss=0.107, v_num=52]
Validating: 0it [00:00, ?it/s][A
Epoch 2: : 33it [03:51,  7.01s/it, loss=0.107, v_num=52]
Epoch 2: : 34it [03:51,  6.82s/it, loss=0.107, v_num=52]


Epoch 00001: val_loss reached 0.10622 (best 0.10622), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v52/schapiro_test_global_step=00063_epoch=001_val_loss=0.106.ckpt as top 1


Epoch 2: : 34it [03:54,  6.90s/it, loss=0.107, v_num=52]
Epoch 3: : 32it [03:50,  7.20s/it, loss=0.106, v_num=52]
Validating: 0it [00:00, ?it/s][A
Epoch 3: : 33it [03:51,  7.00s/it, loss=0.106, v_num=52]
Epoch 3: : 34it [03:51,  6.82s/it, loss=0.106, v_num=52]


Epoch 00002: val_loss reached 0.10588 (best 0.10588), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v52/schapiro_test_global_step=00095_epoch=002_val_loss=0.106.ckpt as top 1


Epoch 3: : 34it [03:54,  6.90s/it, loss=0.106, v_num=52]
Epoch 4: : 32it [03:51,  7.24s/it, loss=0.106, v_num=52]
Validating: 0it [00:00, ?it/s][A
Epoch 4: : 33it [03:52,  7.05s/it, loss=0.106, v_num=52]
Epoch 4: : 34it [03:53,  6.85s/it, loss=0.106, v_num=52]


Epoch 00003: val_loss  was not in top 1


Epoch 4: : 34it [03:53,  6.86s/it, loss=0.106, v_num=52]
Epoch 5: : 32it [03:52,  7.26s/it, loss=0.106, v_num=52]
Validating: 0it [00:00, ?it/s][A
Epoch 5: : 33it [03:53,  7.07s/it, loss=0.106, v_num=52]
Epoch 5: : 34it [03:53,  6.87s/it, loss=0.106, v_num=52]


Epoch 00004: val_loss reached 0.10575 (best 0.10575), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v52/schapiro_test_global_step=00159_epoch=004_val_loss=0.106.ckpt as top 1


Epoch 5: : 34it [03:56,  6.96s/it, loss=0.106, v_num=52]
Epoch 6: : 32it [03:51,  7.24s/it, loss=0.106, v_num=52]
Validating: 0it [00:00, ?it/s][A
Epoch 6: : 33it [03:52,  7.05s/it, loss=0.106, v_num=52]
Epoch 6: : 34it [03:53,  6.85s/it, loss=0.106, v_num=52]


Epoch 00005: val_loss reached 0.10571 (best 0.10571), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v52/schapiro_test_global_step=00191_epoch=005_val_loss=0.106.ckpt as top 1


Epoch 6: : 34it [03:55,  6.94s/it, loss=0.106, v_num=52]
Epoch 7: : 32it [03:51,  7.22s/it, loss=0.106, v_num=52]
Validating: 0it [00:00, ?it/s][A
Epoch 7: : 33it [03:51,  7.03s/it, loss=0.106, v_num=52]
Epoch 7: : 34it [03:52,  6.84s/it, loss=0.106, v_num=52]


Epoch 00006: val_loss reached 0.10567 (best 0.10567), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v52/schapiro_test_global_step=00223_epoch=006_val_loss=0.106.ckpt as top 1


Epoch 7: : 34it [03:55,  6.92s/it, loss=0.106, v_num=52]
Epoch 8: : 32it [03:51,  7.22s/it, loss=0.106, v_num=52]
Validating: 0it [00:00, ?it/s][A
Epoch 8: : 33it [03:51,  7.02s/it, loss=0.106, v_num=52]
Epoch 8: : 34it [03:52,  6.83s/it, loss=0.106, v_num=52]


Epoch 00007: val_loss reached 0.10558 (best 0.10558), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v52/schapiro_test_global_step=00255_epoch=007_val_loss=0.106.ckpt as top 1


Epoch 8: : 34it [03:55,  6.92s/it, loss=0.106, v_num=52]
Epoch 9: : 24it [03:01,  7.55s/it, loss=0.106, v_num=52]

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/apra/miniconda3/envs/g2/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/apra/miniconda3/envs/g2/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/apra/miniconda3/envs/g2/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/home/apra/miniconda3/envs/g2/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/home/apra/miniconda3/envs/g2/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/apra/miniconda3/envs/g2/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/apra/miniconda3/envs/g2/lib/python3.8/multiprocessing/co




1