In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from researchlib.single_import import *

In [3]:
train_loader = VisionDataset(vision.MNIST, batch_size=512, train=True)
test_loader = VisionDataset(vision.MNIST, batch_size=512, train=False)

In [4]:
encoder = builder([
    nn.Conv2d(1, 256, 9, stride=1, bias=True),
    nn.BatchNorm2d(256),
    nn.SELU(inplace=True),
    layer.PrimaryCapsules(256, 256, 8, kernel_size=9),
    layer.RoutingCapsules(8, 1152, 10, 16, 3)
])

decoder = builder([
    layer.CapsuleMasked(),
    nn.Linear(160, 512),
    nn.BatchNorm1d(512),
    nn.SELU(inplace=True),
    nn.Linear(512, 1024),
    nn.BatchNorm1d(1024),
    nn.SELU(inplace=True),
    nn.Linear(1024, 784),
    nn.Sigmoid(),
    Reg(layer.Reshape((-1, 1, 28, 28)), 'rc', get='out'),
])

model = builder([
    Reg(Identical(), 'rc', get='out'),
    encoder,
    Auxiliary(decoder),
    layer.Norm()
])

In [5]:
runner = Runner(model, train_loader, test_loader, 'adam', 'margin', fp16=False, multigpu=True, reg_fn={'rc': 'mse'}, reg_weights={'rc': 0.1})

In [6]:
runner.init_model()

Init xavier_normal: Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
Init xavier_normal: Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
Init xavier_normal: _PrimaryCapsules(
  (conv): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
)
Init xavier_normal: _RoutingCapsules(
  (0): CapsuleLinear(8, 16)
  (1): Routing(num_routing=3)
)
Init xavier_normal: Sequential(
  (0): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
  (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): SELU(inplace)
  (3): _PrimaryCapsules(
    (conv): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
  )
  (4): _RoutingCapsules(
    (0): CapsuleLinear(8, 16)
    (1): Routing(num_routing=3)
  )
)
Init xavier_normal: Sequential(
  (0): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
  (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): SELU(inplace)
  (3): _PrimaryCapsules(
    (conv): Conv2d(256, 256, kernel_size=(9, 9), st

In [7]:
runner.fit(1, 1e-3)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=118), HTML(value='')))

Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/c

KeyboardInterrupt: 

In [None]:
import torchvision.utils as vutils

x = next(iter(test_loader))[0]

out = encoder(x.cuda())
res = decoder(out)

t1 = vutils.make_grid(x).cpu().numpy().transpose(1,2,0)
t2 = vutils.make_grid(res.detach()).cpu().numpy().transpose(1,2,0)

fig, arr = plt.subplots(1,2,figsize=(20,20))
arr[0].imshow(t1)
arr[1].imshow(t2)
plt.show()