In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from researchlib.single_import import *

In [3]:
train_loader = FromPublic('mnist', 'train', batch_size=128, normalize=False, shuffle=True, pin_memory=True, num_workers=4)
test_loader = FromPublic('mnist', 'test', batch_size=128, normalize=False, shuffle=True, pin_memory=True, num_workers=4)

In [12]:
encoder = builder([
    nn.Conv2d(1, 256, 9, stride=1, bias=True),
    nn.SELU(inplace=True),
    PrimaryCapsules(256, 256, 8, kernel_size=9),
    RoutingCapsules(8, 1152, 10, 16, 3)
])

decoder = builder([    
    nn.Linear(160, 512),
    nn.SELU(inplace=True),
    nn.Linear(512, 1024),
    nn.SELU(inplace=True),
    nn.Linear(1024, 784),
    nn.Sigmoid(),
    Reshape((-1, 1, 28, 28))
])

model = builder([
    encoder,
    CapReconstructRegularized(decoder)
])

In [13]:
runner = Runner(train_loader=train_loader,
               test_loader=test_loader,
               model=model,
               loss_fn=CapsuleLoss(),
               optimizer='adam')

In [14]:
runner.fit(20, 1e-3, metrics=[CapsuleAcc()])

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

0.94475
123.92405063291139

Test set: Average loss: 0.0510


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

0.9774
124.62025316455696

Test set: Average loss: 0.0515


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

KeyboardInterrupt: 

In [23]:
runner.validate(metrics=[CapsuleAcc()])

torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size([128])
torch.Size([128]) torch.Size

In [None]:
x = next(iter(test_loader))[0]

In [None]:
import torchvision.utils as vutils

In [None]:
out = encoder(x.cuda())
out_m = torch.norm(out, dim=-1).cuda()
_, v_max_index = out_m.max(dim=1)
v_max_index = v_max_index.data

y = torch.eye(10).cuda()
y = y.index_select(dim=0, index=v_max_index).unsqueeze(2)
masked = (out*y).view(out.size(0), -1)

res = decoder(masked)

t1 = vutils.make_grid(x).cpu().numpy().transpose(1,2,0)
t2 = vutils.make_grid(res.detach()).cpu().numpy().transpose(1,2,0)

fig, arr = plt.subplots(1,2,figsize=(20,20))
arr[0].imshow(t1)
arr[1].imshow(t2)
plt.show()

In [None]:
import copy

single_x = x[10]
single_x = single_x[None, :, :]
out = encoder(single_x.cuda())
out_m = torch.norm(out, dim=-1).cuda()
_, v_max_index = out_m.max(dim=1)
v_max_index = v_max_index.data

y = torch.eye(10).cuda()
y = y.index_select(dim=0, index=v_max_index).unsqueeze(2)
masked = (out*y)
masked = masked.detach()

imgs = np.zeros((320, 28, 28, 3))

count = 0
for dim in range(16):
    for n in np.arange(-.1, .1, 0.01):
        mod_masked = copy.deepcopy(masked.cpu().numpy())
        mod_masked[:, :, dim] += n

        mod_masked = torch.from_numpy(mod_masked).cuda()
        mod_masked = mod_masked.view(out.size(0), -1)

        res = decoder(mod_masked)
        t2 = vutils.make_grid(res.detach()).cpu().numpy().transpose(1,2,0)
        imgs[count] = t2
        count += 1

plot_montage(imgs, 16, 20)
    