In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from researchlib.single_import import *

In [3]:
train_loader = FromDataset(MNIST(True), batch_size=256)
test_loader = FromDataset(MNIST(False), batch_size=256)

In [4]:
encoder = builder([
    nn.Conv2d(1, 256, 9, stride=1, bias=True),
    nn.SELU(inplace=True),
    PrimaryCapsules(256, 256, 8, kernel_size=9),
    RoutingCapsules(8, 1152, 10, 16, 3)
])

decoder = builder([
    CapsuleMasked(),
    nn.Linear(160, 512),
    nn.SELU(inplace=True),
    nn.Linear(512, 1024),
    nn.SELU(inplace=True),
    nn.Linear(1024, 784),
    nn.Sigmoid(),
    Reshape((-1, 1, 28, 28))
])

model = builder([
    Reg(Identical(), 'rc', get='out'),
    encoder,
    Reg(decoder, 'rc', get='out', out_through=True),
    Norm()
])

In [5]:
runner = Runner(train_loader=train_loader,
               test_loader=test_loader,
               model=model,
               loss_fn=CapsuleLoss(),
               optimizer='adam',
               fp16=False,
               reg_fn={'rc': F.mse_loss})

In [6]:
runner.fit(20, 1e-3, metrics=[Acc()])

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

KeyboardInterrupt: 

In [None]:
runner.validate(metrics=[Acc()])

In [None]:
x = next(iter(test_loader))[0]

In [None]:
import torchvision.utils as vutils

In [None]:
out = encoder(x.cuda())
res = decoder(out)

t1 = vutils.make_grid(x).cpu().numpy().transpose(1,2,0)
t2 = vutils.make_grid(res.detach()).cpu().numpy().transpose(1,2,0)

fig, arr = plt.subplots(1,2,figsize=(20,20))
arr[0].imshow(t1)
arr[1].imshow(t2)
plt.show()