In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from researchlib.single_import import *

In [None]:
train_loader = VisionDataset(vision.MNIST, batch_size=512, train=True)
test_loader = VisionDataset(vision.MNIST, batch_size=512, train=False)

In [None]:
encoder = builder([
    nn.Conv2d(1, 256, 9, stride=1, bias=True),
    nn.BatchNorm2d(256),
    nn.SELU(inplace=True),
    layer.PrimaryCapsules(256, 256, 8, kernel_size=9),
    layer.RoutingCapsules(8, 1152, 10, 16, 3)
])

decoder = builder([
    layer.CapsuleMasked(),
    nn.Linear(160, 512),
    nn.BatchNorm1d(512),
    nn.SELU(inplace=True),
    nn.Linear(512, 1024),
    nn.BatchNorm1d(1024),
    nn.SELU(inplace=True),
    nn.Linear(1024, 784),
    nn.Sigmoid(),
    Reg(layer.Reshape((-1, 1, 28, 28)), 'rc', get='out'),
])

model = builder([
    Reg(Identical(), 'rc', get='out'),
    encoder,
    Auxiliary(decoder),
    layer.Norm()
])

In [None]:
runner = Runner(model, train_loader, test_loader, 'adam', 'margin', fp16=False, multigpu=True, reg_fn={'rc': 'mse'}, reg_weights={'rc': 0.1})

In [None]:
runner.init_model()

In [None]:
runner.fit(1, 1e-3)

In [None]:
import torchvision.utils as vutils

x = next(iter(test_loader))[0]

out = encoder(x.cuda())
res = decoder(out)

t1 = vutils.make_grid(x).cpu().numpy().transpose(1,2,0)
t2 = vutils.make_grid(res.detach()).cpu().numpy().transpose(1,2,0)

fig, arr = plt.subplots(1,2,figsize=(20,20))
arr[0].imshow(t1)
arr[1].imshow(t2)
plt.show()