# 2.0.1 Creating Pytorch Datastructures

Turning the Schapiro graph and walking algorithms into code amenable to pytorch.

## Jupyter Extensions

Load [watermark](https://github.com/rasbt/watermark) to see the state of the machine and environment that's running the notebook. To make sense of the options, take a look at the [usage](https://github.com/rasbt/watermark#usage) section of the readme.

In [2]:
# Load `watermark` extension
%load_ext watermark
# Display the status of the machine and packages. Add more as necessary.
%watermark -v -n -m -g -b -t -p numpy,matplotlib,seaborn,networkx,torch,pytorch_lightning,prevseg

Tue Aug 18 2020 21:30:26 

CPython 3.8.5
IPython 7.16.1

numpy 1.19.1
matplotlib 3.2.2
seaborn 0.10.1
networkx 2.4
torch 1.6.0
pytorch_lightning 0.8.5
prevseg 0+untagged.30.g179141d.dirty

compiler   : GCC 7.3.0
system     : Linux
release    : 5.4.0-42-generic
machine    : x86_64
processor  : x86_64
CPU cores  : 4
interpreter: 64bit
Git hash   : 179141d76e24ec9801ed1013bf2bfd33fc31cbb3
Git branch : master


Load [autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) which will always reload modules marked with `%aimport`.

This behavior can be inverted by running `autoreload 2` which will set everything to be auto-reloaded *except* for modules marked with `%aimport`.

In [3]:
# Load `autoreload` extension
%load_ext autoreload
# Set autoreload behavior
%autoreload 1

Load `matplotlib` in one of the more `jupyter`-friendly [rich-output modes](https://ipython.readthedocs.io/en/stable/interactive/plotting.html). Some options (that may or may not have worked) are `inline`, `notebook`, and `gtk`.

In [4]:
# Set the matplotlib mode
%matplotlib inline

## Imports

In [8]:
from pathlib import Path

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch
from PIL import Image, ImageOps
from torch.utils.data import Dataset, IterableDataset, DataLoader

%aimport prevseg.index
import prevseg.index as index
%aimport prevseg.schapiro
from prevseg.schapiro import walk, graph

## The Dataloader Class

In [39]:
class ShapiroFractalsDataset(IterableDataset):
    def __init__(self, batch_size=8, n_pentagons=3, max_steps=None):
        self.batch_size = batch_size
        self.n_pentagons = n_pentagons
        self.max_steps = max_steps
        
        self.G = graph.schapiro_graph(n_pentagons=n_pentagons)
        
        # Load the fractal images into memory
        assert index.DIR_SCH_FRACTALS.exists()
        paths_fractals = list(pes.index.DIR_SCH_FRACTALS.iterdir())
        np.random.shuffle(paths_fractals)
        self.paths_fractals = paths_fractals[:5*self.n_pentagons]
        self.array_fractals = np.array(
            [np.array(ImageOps.grayscale(Image.open(str(path))))
             for path in self.paths_fractals])
        
    def iter_single_sample(self):
        for sample in walk.walk_random(self.G, steps=self.max_steps):
            yield self.array_fractals[sample[0]]
        
    def iter_batch(self):
        batch = zip(*[self.iter_single_sample() for _ in range(self.batch_size)])
        yield np.array(next(batch))
        
    def __iter__(self):
        return self.iter_batch()
    
iter_ds = ShapiroFractalsDataset()
loader = DataLoader(iter_ds, batch_size=None)
for batch in loader:
    print(batch.shape)

torch.Size([8, 128, 128])
