# 2.0.1 Creating Pytorch Datastructures

Turning the Schapiro graph and walking algorithms into code amenable to pytorch.

## Jupyter Extensions

Load [watermark](https://github.com/rasbt/watermark) to see the state of the machine and environment that's running the notebook. To make sense of the options, take a look at the [usage](https://github.com/rasbt/watermark#usage) section of the readme.

In [50]:
# Load `watermark` extension
%load_ext watermark
# Display the status of the machine and other non-code related info
%watermark -n -m -g -b -t -h

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Tue Aug 18 2020 22:00:18 

compiler   : GCC 7.3.0
system     : Linux
release    : 5.4.0-42-generic
machine    : x86_64
processor  : x86_64
CPU cores  : 4
interpreter: 64bit
host name  : apra-x3
Git hash   : 9dcef66145f8f506529041f0eb7a84a740a26a16
Git branch : master


Load [autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) which will always reload modules marked with `%aimport`.

This behavior can be inverted by running `autoreload 2` which will set everything to be auto-reloaded *except* for modules marked with `%aimport`.

In [3]:
# Load `autoreload` extension
%load_ext autoreload
# Set autoreload behavior
%autoreload 1

Load `matplotlib` in one of the more `jupyter`-friendly [rich-output modes](https://ipython.readthedocs.io/en/stable/interactive/plotting.html). Some options (that may or may not have worked) are `inline`, `notebook`, and `gtk`.

In [4]:
# Set the matplotlib mode
%matplotlib inline

## Imports

In [98]:
import logging
from pathlib import Path

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch
from PIL import Image, ImageOps
from torch.utils.data import Dataset, IterableDataset, DataLoader

%aimport prevseg.index
import prevseg.index as index
%aimport prevseg.schapiro
from prevseg.schapiro import walk, graph

# Keep track of versions of everything
%watermark -v -iv

networkx  2.4
logging   0.5.1.2
prevseg   0+untagged.30.g179141d.dirty
PIL.Image 7.2.0
torch     1.6.0
numpy     1.19.1
CPython 3.8.5
IPython 7.16.1


## The Fractal Dataloader Class

In [100]:
max_steps = 3
class ShapiroFractalsDataset(IterableDataset):
    def __init__(self, batch_size=2, n_pentagons=3, max_steps=max_steps):
        self.batch_size = batch_size
        self.n_pentagons = n_pentagons
        self.max_steps = max_steps
        
        self.G = graph.schapiro_graph(n_pentagons=n_pentagons)
        
        self.load_node_stimuli()
        
        self.mapping = {node : path.stem for node, path in zip(range(len(self.G.nodes)),
                                                               self.paths_data)}
        print(f'Created mapping as follows:\n{self.mapping}')
        
    def load_node_stimuli(self):
        # Load the fractal images into memory
        assert index.DIR_SCH_FRACTALS.exists()
        paths_data = list(pes.index.DIR_SCH_FRACTALS.iterdir())
        np.random.shuffle(paths_data)
        self.paths_data = paths_data[:5*self.n_pentagons]
        self.array_data = np.array(
            [np.array(ImageOps.grayscale(Image.open(str(path))))
             for path in self.paths_data])
        
    def iter_single_sample(self):
        for sample in walk.walk_random(self.G, steps=self.max_steps):
            #yield self.array_data[sample[0]]
            yield sample[0]
        
    def iter_batch_dataset(self):
        batch = zip(*[self.iter_single_sample() for _ in range(self.batch_size)])
        try:
            while True:
                yield np.array(next(batch))
        except StopIteration:
            return
        
    def __iter__(self):
        return self.iter_batch_dataset()
    
iter_ds = ShapiroFractalsDataset()
loader = DataLoader(iter_ds, batch_size=None)
epochs = 3
for _ in range(epochs):
    for i, batch in enumerate(loader):
        #print(i, batch.shape)
        print(i, batch)
        if i > max_steps+5:
            print('bad')
            break

Created mapping as follows:
{0: '90', 1: '99', 2: '70', 3: '74', 4: '35', 5: '47', 6: '62', 7: '29', 8: '49', 9: '89', 10: '81', 11: '41', 12: '32', 13: '1', 14: '13'}
0 tensor([1, 2])
1 tensor([4, 4])
2 tensor([2, 5])
0 tensor([ 2, 10])
1 tensor([ 0, 13])
2 tensor([14, 10])
0 tensor([9, 0])
1 tensor([7, 1])
2 tensor([6, 2])


## Embeddings Dataloader

In [57]:
class ShapiroResnetEmbeddingDataset(ShapiroFractalsDataset):
    def load_node_stimuli(self):
        # Load the fractal images into memory
        assert index.DIR_SCH_FRACTALS.exists()
        paths_data = list(pes.index.DIR_SCH_FRACTALS_EMB.iterdir())
        np.random.shuffle(paths_data)
        self.paths_data = paths_data[:5*self.n_pentagons]
        self.array_data = np.array(
            [np.array(np.load(str(path)))
             for path in self.paths_data])

iter_ds = ShapiroResnetEmbeddingDataset()
loader = DataLoader(iter_ds, batch_size=None)
for batch in loader:
    print(batch.shape)
    break

torch.Size([8, 2048])


## Trying Import

In [103]:
%aimport prevseg.dataloaders.schapiro
import prevseg.dataloaders.schapiro as schapiro

max_steps=3

iter_ds = ShapiroFractalsDataset(batch_size=2, max_steps=max_steps)
loader = DataLoader(iter_ds, batch_size=None)

epochs = 3
for _ in range(epochs):
    for i, batch in enumerate(loader):
        print(i, batch.shape)
        if i > max_steps+5:
            print('bad')
            break

Created mapping as follows:
{0: '40', 1: '46', 2: '54', 3: '3', 4: '94', 5: '18', 6: '4', 7: '13', 8: '82', 9: '19', 10: '7', 11: '11', 12: '64', 13: '38', 14: '41'}
0 torch.Size([2])
1 torch.Size([2])
2 torch.Size([2])
0 torch.Size([2])
1 torch.Size([2])
2 torch.Size([2])
0 torch.Size([2])
1 torch.Size([2])
2 torch.Size([2])
