In [3]:
#export
from torch import nn, cat, Tensor

# Layers

In [4]:
#export
class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

In [5]:
#export
class Concat(nn.Module):
    def forward(self, x):
        if isinstance(x, Tensor):
            return x
        else:
            return cat(x, 1)

In [6]:
#export
class Add(nn.Module):
    def forward(self, x):
        if isinstance(x, Tensor):
            return x
        elif isinstance(x, list):
            return sum(x)

# Model generator

In [7]:
#export
class Generator(nn.Module):
    def __init__(self, ts):
        super().__init__()
        self.sequence = []
        self.ts = ts
        self.layers = nn.ModuleDict()
        
        index = 0
        for layer, input, output in ts:
            self.layers[str(index)] = layer
            self.sequence.append((str(index), input, output))
            index = index + 1
        
    def forward(self, x):
        value = {1: x}
        for layer, input, output in self.sequence:
            if isinstance(input, list):
                value[output] = self.layers[layer]([value[i] for i in input])
            else:
                value[output] = self.layers[layer](value[input])
        return value[2]

### Export

In [1]:
!python nb2py.py layer.ipynb

Converted layer.ipynb to exp/nb_layer.py
