# 2.0.3 Running PredNet on Fractal Data

Getting the fractal data to work with the implemented prednet.

## Jupyter Extensions

Load [watermark](https://github.com/rasbt/watermark) to see the state of the machine and environment that's running the notebook. To make sense of the options, take a look at the [usage](https://github.com/rasbt/watermark#usage) section of the readme.

In [1]:
# Load `watermark` extension
%load_ext watermark
# Display the status of the machine and other non-code related info
%watermark -n -m -g -b -t -h

Thu Aug 20 2020 10:58:46 

compiler   : GCC 7.3.0
system     : Linux
release    : 5.4.0-42-generic
machine    : x86_64
processor  : x86_64
CPU cores  : 4
interpreter: 64bit
host name  : apra-x3
Git hash   : 45cc25f141a3ec65725320f89cf953e8381ed860
Git branch : master


Load [autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) which will always reload modules marked with `%aimport`.

This behavior can be inverted by running `autoreload 2` which will set everything to be auto-reloaded *except* for modules marked with `%aimport`.

In [2]:
# Load `autoreload` extension
%load_ext autoreload
# Set autoreload behavior
%autoreload 1

Load `matplotlib` in one of the more `jupyter`-friendly [rich-output modes](https://ipython.readthedocs.io/en/stable/interactive/plotting.html). Some options (that may or may not have worked) are `inline`, `notebook`, and `gtk`.

In [3]:
# Set the matplotlib mode
%matplotlib inline

## Imports

In [4]:
import gc
import logging
from argparse import Namespace
from pathlib import Path

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl
from PIL import Image, ImageOps
from torch.utils.data import IterableDataset, DataLoader

%aimport prevseg.constants
import prevseg.constants as const
%aimport prevseg.index
import prevseg.index as index
%aimport prevseg.dataloaders.schapiro
import prevseg.dataloaders.schapiro as sch
%aimport prevseg.schapiro
from prevseg.schapiro import walk, graph
%aimport prevseg.models.prednet
import prevseg.models.prednet as prednet

# Keep track of versions of everything
%watermark -v -iv

prevseg           0+untagged.40.g45cc25f.dirty
pytorch_lightning 0.8.5
PIL.Image         7.2.0
torch             1.6.0
networkx          2.4
logging           0.5.1.2
numpy             1.19.1
CPython 3.8.5
IPython 7.16.1


## Adding Fixed Path-Lengths

In [5]:
max_steps = 2
epochs = 2
batch_size = 2

iter_ds = sch.ShapiroFractalsDataset(batch_size=batch_size, max_steps=max_steps)
loader = DataLoader(iter_ds, batch_size=None)
for _ in range(epochs):
    for i, batch in enumerate(loader):
        print(i, batch.shape)
        if i > max_steps+5:
            print('bad')
            break

Created mapping as follows:
{0: '94', 1: '100', 2: '71', 3: '37', 4: '28', 5: '67', 6: '57', 7: '27', 8: '2', 9: '36', 10: '78', 11: '42', 12: '41', 13: '13', 14: '77'}
0 torch.Size([2, 128, 128])
1 torch.Size([2, 128, 128])
0 torch.Size([2, 128, 128])
1 torch.Size([2, 128, 128])


In [6]:
class ShapiroFractalsDataset(IterableDataset):
    def __init__(self, batch_size=32, n_pentagons=3, max_steps=128, n_paths=128, mapping=None):
        self.batch_size = batch_size
        self.n_pentagons = n_pentagons
        self.max_steps = max_steps
        self.n_paths = n_paths
        self.mapping = mapping
        
        self.G = graph.schapiro_graph(n_pentagons=n_pentagons)
        
        self.load_node_stimuli()
        
        self.mapping = {node : path.stem
                        for node, path in zip(range(len(self.G.nodes)),
                                              self.paths_data)}
        print(f'Created mapping as follows:\n{self.mapping}')
        
    def load_node_stimuli(self):
        # Load the fractal images into memory
        assert index.DIR_SCH_FRACTALS.exists()
        if self.mapping:
            self.paths_data = [index.DIR_SCH_FRACTALS / (name+'.tiff')
                               for name in self.mapping.values()]
        else:
            paths_data = list(index.DIR_SCH_FRACTALS.iterdir())
            np.random.shuffle(paths_data)
            self.paths_data = paths_data[:5*self.n_pentagons]
        self.array_data = np.array(
            [np.array(ImageOps.grayscale(Image.open(str(path))))
             for path in self.paths_data])
        
    def iter_batch_sample(self):
        def iter_single_sample():
            for sample in walk.walk_random(self.G, steps=self.max_steps):
                yield self.array_data[sample[0]]
                #yield sample[0]
        yield from zip(*[iter_single_sample() for _ in range(self.batch_size)])
        
    def iter_batch_dataset(self):       
        for _ in range(self.n_paths):
            yield np.moveaxis(np.array(list(self.iter_batch_sample())), 0, 1)
        
    def __iter__(self):
        return self.iter_batch_dataset()

    
max_steps = 10
epochs = 2
batch_size = 2
n_paths = 5

iter_ds = ShapiroFractalsDataset(batch_size=batch_size, max_steps=max_steps, n_paths=n_paths)
loader = DataLoader(iter_ds, batch_size=None)
for _ in range(epochs):
    for i, batch in enumerate(loader):
        print(i, batch.shape)
        if i > max_steps+5:
            print('bad')
            break
            
mapping = iter_ds.mapping

Created mapping as follows:
{0: '90', 1: '94', 2: '6', 3: '35', 4: '95', 5: '25', 6: '45', 7: '62', 8: '37', 9: '19', 10: '13', 11: '85', 12: '58', 13: '18', 14: '29'}
0 torch.Size([2, 10, 128, 128])
1 torch.Size([2, 10, 128, 128])
2 torch.Size([2, 10, 128, 128])
3 torch.Size([2, 10, 128, 128])
4 torch.Size([2, 10, 128, 128])
0 torch.Size([2, 10, 128, 128])
1 torch.Size([2, 10, 128, 128])
2 torch.Size([2, 10, 128, 128])
3 torch.Size([2, 10, 128, 128])
4 torch.Size([2, 10, 128, 128])


In [7]:
max_steps = 10
epochs = 2
batch_size = 2
n_paths = 5

class ShapiroResnetEmbeddingDataset(ShapiroFractalsDataset):
    def load_node_stimuli(self):
        # Load the fractal images into memory
        assert index.DIR_SCH_FRACTALS.exists()
        if self.mapping:
            self.paths_data = [index.DIR_SCH_FRACTALS_EMB / (name+'.npy')
                               for name in self.mapping.values()]
        else:
            paths_data = list(index.DIR_SCH_FRACTALS_EMB.iterdir())
            np.random.shuffle(paths_data)
            self.paths_data = paths_data[:5*self.n_pentagons]
        self.array_data = np.array(
            [np.array(np.load(str(path)))
             for path in self.paths_data])    

iter_ds = ShapiroResnetEmbeddingDataset(
    batch_size=batch_size, 
    max_steps=max_steps, 
    n_paths=n_paths,
    mapping=mapping,
)
loader = DataLoader(iter_ds, batch_size=None)
for _ in range(epochs):
    for i, batch in enumerate(loader):
        print(i, batch.shape)
        if i > max_steps+5:
            print('bad')
            break

Created mapping as follows:
{0: '90', 1: '94', 2: '6', 3: '35', 4: '95', 5: '25', 6: '45', 7: '62', 8: '37', 9: '19', 10: '13', 11: '85', 12: '58', 13: '18', 14: '29'}
0 torch.Size([2, 10, 2048])
1 torch.Size([2, 10, 2048])
2 torch.Size([2, 10, 2048])
3 torch.Size([2, 10, 2048])
4 torch.Size([2, 10, 2048])
0 torch.Size([2, 10, 2048])
1 torch.Size([2, 10, 2048])
2 torch.Size([2, 10, 2048])
3 torch.Size([2, 10, 2048])
4 torch.Size([2, 10, 2048])


## PredNet

In [25]:
class PredNetTracked(prednet.PredNetTracked):
    @prednet.PredNet.timeit
    def prepare_data(self):
        self.ds = self.ds or ShapiroResnetEmbeddingDataset(
            batch_size=self.batch_size, 
            n_pentagons=self.hparams.n_pentagons, 
            max_steps=self.hparams.max_steps, 
            n_paths=self.hparams.n_paths)
        self.ds_val = ShapiroResnetEmbeddingDataset(
            batch_size=self.batch_size, 
            n_pentagons=self.hparams.n_pentagons, 
            max_steps=self.hparams.max_steps, 
            n_paths=1,
            mapping=self.ds.mapping)
        
#         n_test, n_val = self.hparams.n_test, self.hparams.n_val
        
    def train_dataloader(self):
        return DataLoader(self.ds_val, 
                          batch_size=None,
                          num_workers=self.hparams.n_workers)
    
    def val_dataloader(self):
        return DataLoader(self.ds, 
                          batch_size=None,
                          num_workers=self.hparams.n_workers)
    
    def _common_step(self, batch, batch_idx, mode):
        data = batch
        errors = self.forward(data) # batch x n_layers x nt

        loc_batch = errors.size(0)
        errors = torch.mm(errors.view(-1, self.time_steps), 
                          self.time_loss_weights) # batch*n_layers x 1
        errors = torch.mm(errors.view(loc_batch, -1), 
                          self.layer_loss_weights)
        errors = torch.mean(errors, axis=0)
        
        if mode == 'train':
            prefix = ''
        else:
            prefix = mode + '_'
            
        self.logger.experiment.add_scalar(f'{prefix}loss', 
                                          errors, self.global_step)
        return {f'{prefix}loss' : errors}

In [26]:
model, trainer = None, None
train_dataloader, val_dataloader = None, None
errors, optimizer = None, None
ckpt = None
train_errors, val_errors = None, None
res = None
gc.collect()
torch.cuda.empty_cache()

ModelClass = PredNetTracked
hparams = const.DEFAULT_HPARAMS
hparams.n_layers = 2
hparams.batch_size = 256
hparams.max_steps = 2
hparams.n_paths = 2
hparams.n_pentagons = 3
hparams.time_steps = hparams.max_steps
hparams.exp_name = 'schapiro_test'
hparams.name = f'{ModelClass.name}_{hparams.exp_name}'

log_dir = Path(hparams.dir_logs) / f'{hparams.name}'
if not log_dir.exists():
    log_dir.mkdir(parents=True)
logger = pl.loggers.TensorBoardLogger(str(log_dir.parent), name=hparams.name)

ckpt_dir = Path(hparams.dir_checkpoints) / f'{hparams.name}_v{logger.version}'
if not ckpt_dir.exists():
    ckpt_dir.mkdir(parents=True)
    
ckpt = pl.callbacks.ModelCheckpoint(
    filepath=str(ckpt_dir / (hparams.exp_name+'_{global_step:05d}_{epoch:03d}_{val_loss:.3f}')),
    verbose=True,
    save_top_k=1,
)

trainer = pl.Trainer(checkpoint_callback=ckpt,
                     max_epochs=2,
                     logger=logger,
                     gpus=1
                     )

model = ModelClass(hparams)
model.ds = None

                                                                      

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]




In [27]:
trainer.fit(model)


  | Name                 | Type       | Params
----------------------------------------------------
0 | predcell_0_recurrent | LSTM       | 67 M  
1 | predcell_0_dense     | Sequential | 4 M   
2 | predcell_0_update_a  | Sequential | 4 M   
3 | predcell_1_recurrent | LSTM       | 12 M  
4 | predcell_1_dense     | Sequential | 1 M   


Created mapping as follows:
{0: '1', 1: '60', 2: '95', 3: '100', 4: '14', 5: '2', 6: '63', 7: '58', 8: '96', 9: '55', 10: '99', 11: '50', 12: '7', 13: '89', 14: '12'}
Created mapping as follows:
{0: '1', 1: '60', 2: '95', 3: '100', 4: '14', 5: '2', 6: '63', 7: '58', 8: '96', 9: '55', 10: '99', 11: '50', 12: '7', 13: '89', 14: '12'}
Validation sanity check:  50%|█████     | 1/2 [00:00<00:00,  9.85it/s]<class 'torch.Tensor'>
Epoch 1: : 4it [00:00,  8.74it/s, loss=0.203, v_num=24]               
Validating: 0it [00:00, ?it/s][A
Epoch 1: : 5it [00:00,  8.68it/s, loss=0.203, v_num=24]
Epoch 1: : 9it [00:00, 13.29it/s, loss=0.203, v_num=24]


Epoch 00000: val_loss reached 0.18061 (best 0.18061), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v24/schapiro_test_global_step=00003_epoch=000_val_loss=0.181.ckpt as top 1


<class 'torch.Tensor'>
Epoch 1: : 12it [00:03,  3.46it/s, loss=0.203, v_num=24]
Epoch 2: : 4it [00:00,  9.06it/s, loss=0.185, v_num=24] 
Validating: 0it [00:00, ?it/s][A
Epoch 2: : 8it [00:00, 12.85it/s, loss=0.185, v_num=24]
Validating: 5it [00:00, 12.81it/s][A


Epoch 00001: val_loss reached 0.15046 (best 0.15046), saving model to /home/apra/work/predictive-event-segmentation/models/checkpoints/prednet_tracked_schapiro_test_v24/schapiro_test_global_step=00007_epoch=001_val_loss=0.150.ckpt as top 1


<class 'torch.Tensor'>
Epoch 2: : 12it [00:03,  3.34it/s, loss=0.185, v_num=24]
Epoch 2: : 12it [00:03,  3.33it/s, loss=0.185, v_num=24]


1