# DeepCaps
Deep Capsule networks based on [DeepCaps: Going Deeper with Capsule Networks](https://arxiv.org/pdf/1904.09546.pdf).
Implementation based on the original [DeepCaps](https://github.com/brjathu/deepcaps) and [DeepCaps-PyTorch](https://github.com/HopefulRational/DeepCaps-PyTorch) implementation.

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from keras.utils import to_categorical
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from sklearn import preprocessing
import math


Using TensorFlow backend.


In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Helper Ftions

In [0]:
def squash(s, dim=-1):
  eps=1e-8
  norm = torch.norm(s, dim=dim, keepdim=True)
  return (norm /(1 + norm**2 + eps)) * s

In [0]:
def softmax3D(x, dim):
  return (torch.exp(x) / torch.sum(torch.sum(torch.sum(torch.exp(x), dim=dim[0], keepdim=True), dim=dim[1], keepdim=True), dim=dim[2], keepdim=True))

In [0]:
def one_hot(tensor, num_classes=10):
    return torch.eye(num_classes).cuda().index_select(dim=0, index=tensor.cuda()) # One-hot encode

### Layers

Layer Modules required for DeepCaps network.

In [0]:
class ConvertToCaps(nn.Module):
  def __init__(self):
    '''
    Convert ConvLayer Outputs into Capsules.
    ConvLayer Out: NxCxHxW
    Capsule Shape: NxNCxDxHxW
    NC: Number of capsules.
    D: Capsule Dimension.
    '''
    super().__init__()

  def forward(self, x):
    return x.unsqueeze(2)

In [0]:
class FlattenCaps(nn.Module):
  def __init__(self):
    '''
    Transposes and Flattens capsules.
    Input Shape: NxNCxDxHxW
    Output Shape: Nx[NCxHxW]xD
    '''
    super().__init__()
  def forward(self, x):
    n, nc, d, h, w = x.shape
    x = x.permute(0,3,4,1,2).contiguous()
    return x.view(n,nc*h*w,d)
    

In [0]:
class CapsToScalars(nn.Module):
  def __init__(self):
    '''
    Returns norm of capsule taken along the capsule dimension dim.
    Norm/Length of capsules is the probablity that the 
    object detected by that capsule exists.
    '''
    super().__init__()
  
  def forward(self, x):
    return torch.norm(x,dim=2)
    

In [0]:
class ConvCaps2D(nn.Module):
  def __init__(self, nc_i, dim_i, nc_j, dim_j, kernel_size=3, stride=1,padding=1, r_num=1):
    '''
    2D Convolutional Capsule Layer. 
    Conv2DCaps is similar to a convolutional layer,
    except that its outputs will be squashed 4D tensors.
    i --> current layer.
    j --> next layer.
    Arguments
    ---
    `nc_i`: number of capsules in layer i.
    `dim_i`: dimensions of capsules in layer i.
    `nc_j`: number of capsules in layer j.
    `dim_j`: dimensions of capsules in layer j.
    `kernel_size`: Convolutional filter size.
    `stride`: convolution stride.
    `padding`: convolution padding.
    `r_num`: number of routing iterations.
    '''
    super().__init__()
    self.nc_i = nc_i
    self.dim_i = dim_i
    self.nc_j = nc_j
    self.dim_j = dim_j
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding=padding
    self.r_num=r_num

    in_channels = self.nc_i * self.dim_i
    out_channels = self.nc_j * self.dim_j

    self.conv = nn.Conv2d(in_channels,out_channels,
                          self.kernel_size,self.stride, self.padding)

  def forward(self, x):
    # x.shape: NxNCxDxHxW
    n, nc, dim, h, w = x.shape
    # Reshape x from NxNCxDxHxW --> Nx[NCxD]XHXW
    x = x.view(n,nc*dim, h, w)

    x = self.conv(x)
    h_j, w_j = x.shape[-2:]

    #reshape back to NxNCxDxHxW
    x = x.view(n,self.nc_j,self.dim_j,h_j,w_j)
    
    # Squash and return x.
    return squash(x)


In [0]:
class ConvCaps3D(nn.Module):
  def __init__(self, nc_i, dim_i, nc_j, dim_j,r_num=3, kernel_size=3, padding = (0,1,1)):
    '''
    3D Convolutional Capsule Layer.
    ConvCaps3D uses 3D convolutions with Dynamic Routing
    when num_routings is set greater than 1.
    i --> current layer.
    j --> next layer.
    Arguments
    ---
    `nc_i`: number of capsules in layer i.
    `dim_i`: dimensions of capsules in layer i.
    `nc_j`: number of capsules in layer j.
    `dim_j`: dimensions of capsules in layer j.
    `kernel_size`: Convolutional filter size.
    `stride`: convolution stride.
    `padding`: convolution padding.
    `r_num`: number of routing iterations.
    '''
    super().__init__()
    self.nc_i = nc_i
    self.dim_i = dim_i
    self.nc_j = nc_j
    self.dim_j = dim_j
    self.kernel_size = kernel_size
    self.r_num=r_num


    self.stride = (dim_i,1,1)
    self.padding= padding

    in_channels = 1
    out_channels = self.nc_j * self.dim_j

    self.conv3d = nn.Conv3d(in_channels,out_channels,
                            self.kernel_size,self.stride,self.padding)
    
  def forward(self, x):
    # x.shape = NxNCxDxHxW
    n, nc, dim, h, w = x.shape

    x = x.view(n,nc*dim, h, w)

    x = x.unsqueeze(1)
    x = self.conv3d(x)

    h_j, w_j = x.shape[-2:]

    x = x.view(n, self.nc_i, self.nc_j, self.dim_j ,h_j, w_j)

    # Transpose to NxHxWxDjxNCjxNCi for routing updates.

    x = x.permute(0,4,5,3,2,1)

    # B matrix for routing coefficients.
    # B.shape: NxHxWx1xNCjxNCi
    self.B = x.new(x.shape[0],h_j,w_j,1,self.nc_j,self.nc_i).to(device)

    x = self.update_routing(x, self.r_num)
    
    return x
  
  def update_routing(self, x, num_r=3):
    #x.shape = NxHxWxDjxNCjxNCi
    for ix in range(num_r):
      k = softmax3D(self.B,(1,2,3))
      s = (k * x).sum(dim=-1,keepdim=True)
      s_hat  = squash(s)

      if ix < num_r-1:
        agreements = (s_hat * x).sum(dim=3, keepdim=True)
        self.B = self.B = agreements

    s_hat = s_hat.squeeze(-1)
    batch, h_j, w_j, d_j, n_j  = s_hat.shape

    return s_hat.reshape(batch,n_j,d_j,h_j,w_j)

In [0]:
class MaskCID(nn.Module):
  def __init__(self):
    super().__init__()
  
  def forward(self, x, target=None):
    if target is None:
      #Inference mode
      classes = torch.norm(x,dim=2)
      pred_class = classes.max(dim=1)[1].squeeze()
    else:
      pred_class = target.max(dim=1)[1]
    
    increasing = torch.arange(start=0, end = x.shape[0]).to(device)

    m = torch.stack([increasing,pred_class], dim=1)

    masked = torch.zeros((x.shape[0],1)+x.shape[2:])
    # import pdb; pdb.set_trace()
    for i in increasing:
      masked[i] = x[m[i][0],m[i][1],:].unsqueeze(0)

    return masked.squeeze(-1), pred_class

In [0]:
class DenseCaps_v1(nn.Module):
  def __init__(self, nc=10, num_routes=640, in_dim=8, out_dim=16, routing_iters=3):
      '''
      Dense Capsule Layer.
      '''
      super().__init__()
      self.nc = nc
      self.num_routes = num_routes
      self.r_it = routing_iters

      self.W = nn.Parameter(torch.randn(1,num_routes, nc, out_dim, in_dim) * 0.01).to(device)
      self.bias = nn.Parameter(torch.rand(1, 1, nc, out_dim) * 0.01)

  def forward(self, x):
    x = x.unsqueeze(2).unsqueeze(4)

    u_hat = torch.matmul(self.W,x).squeeze()

    b_ij = x.new(x.shape[0], self.num_routes, self.nc, 1).zero_()

    for ix in range(self.r_it):
      c_ij = F.softmax(b_ij, dim=2)
      s_j = (c_ij * u_hat).sum(dim=-1, keepdim=True) + self.bias
      v_j = squash(s_j, dim=-1)

      if ix<self.r_it-1:
        a_ij = (u_hat * v_j).sum(dim=-1, keepdim=True)
        b_ij = b_ij + a_ij
    v_j = v_j.squeeze()
    return v_j

In [0]:
class DenseCaps_v2(nn.Module):
  def __init__(self, nc=10, num_routes=640, in_dim=8, out_dim=16, routing_iters=3):
      '''
      Dense Capsule Layer.
      '''
      super().__init__()
      self.nc = nc
      self.num_routes = num_routes
      self.out_dim=out_dim
      self.r_it = routing_iters

      self.W = nn.Parameter(torch.Tensor(num_routes, in_dim, nc * out_dim))
      self.bias = nn.Parameter(torch.rand(1, 1, nc, out_dim) * 0.01)
      self.b = nn.Parameter(torch.zeros(num_routes,nc))
      self.reset_params()
  def reset_params(self):
    stdv = 1/math.sqrt(self.num_routes)
    self.W.data.uniform_(-stdv, stdv)

  def forward(self, x):
    x = x.unsqueeze(2)#.unsqueeze(4)

    u_hat = torch.matmul(x,self.W)
    u_hat = u_hat.view(u_hat.size(0),self.num_routes, self.nc, self.out_dim)

    c = F.softmax(self.b)
    s = (c.unsqueeze(2) * u_hat).sum(dim=1)
    v = squash(s)

    if self.r_it > 0:
      bBatch = self.b.expand((u_hat.shape[0],self.num_routes,self.nc))
      for r in range(self.r_it):
        v = v.unsqueeze(1)
        bBatch = bBatch + (u_hat * v).sum(-1)

        c = F.softmax(bBatch.view(-1,self.nc)).view(-1,self.num_routes,self.nc,1)
        s = (c * u_hat).sum(dim=1)
        v = squash(s)
      
    return v

### Networks
Networks that make up the DenseCaps Network.

In [0]:
class Decoder(nn.Module):
  def __init__(self, caps_dim=16, num_caps=1, img_size=28, out_channels=1):
    super().__init__()

    self.num_caps = num_caps
    self.out_channels=1
    self.img_size=img_size

    # self.dense = nn.Linear(caps_dim*num_caps, 7*7*16)
    self.fc = nn.Linear(caps_dim*num_caps,7*7*16).to(device)
    self.relu = nn.ReLU(inplace=True)

    self.reconst_layers1 = nn.Sequential(nn.BatchNorm2d(num_features=16, momentum=0.8),
                                            nn.ConvTranspose2d(in_channels=16, out_channels=64, 
                                                               kernel_size=3, stride=1, padding=1)
                                            )
    self.reconst_layers2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, 
                                                  kernel_size=3, stride=2, padding=1
                                                 )
    self.reconst_layers3 = nn.ConvTranspose2d(in_channels=32, out_channels=16, 
                                                  kernel_size=3, stride=2, padding=1
                                                 )
                                            
    self.reconst_layers4 = nn.ConvTranspose2d(in_channels=16, out_channels=1, 
                                              kernel_size=3, stride=1, padding=1
                                              )
                                        
    self.reconst_layers5 = nn.ReLU()

  def forward(self, x):
    
    batch = x.shape[0]

    x = x.float().to(device)

    x = self.fc(x)
    x=self.relu(x)
    x = x.reshape(-1,16,7,7)

    x = self.reconst_layers1(x)
    x = self.reconst_layers2(x)

    p2d = (1,0,1,0)
    x = F.pad(x, p2d, 'constant')
    x = self.reconst_layers3(x)

    p2d = (1,0,1,0)
    x=F.pad(x,p2d)
    x = self.reconst_layers4(x)

    x = self.reconst_layers5(x)
    x = x.reshape(batch,1,self.img_size,self.img_size)

    return x

In [0]:
class DeepCaps(nn.Module):
  def __init__(self):
    super().__init__()
    
    self.conv2d = nn.Conv2d(1,128,3,1,1)
    self.bn = nn.BatchNorm2d(128,1e-8,momentum=0.99)
    self.toCaps = ConvertToCaps()

    #inSize 28
    self.convcaps1_1 = ConvCaps2D(128,1,32,4,3,2,1,1) 
    self.convcaps1_2 = ConvCaps2D(32,4,32,4,3,1,1,1) #inSize 14
    self.convcaps1_3 = ConvCaps2D(32,4,32,4,3,1,1,1)
    self.convcaps1_4 = ConvCaps2D(32,4,32,4,3,1,1,1)

    #inSize 14
    self.convcaps2_1 = ConvCaps2D(32,4,32,8,3,2,1,1) 
    self.convcaps2_2 = ConvCaps2D(32,8,32,8,3,1,1,1) #inSize 7
    self.convcaps2_3 = ConvCaps2D(32,8,32,8,3,1,1,1)
    self.convcaps2_4 = ConvCaps2D(32,8,32,8,3,1,1,1)

    #inSize 7
    self.convcaps3_1 = ConvCaps2D(32,8,32,8,3,2,1,1) 
    self.convcaps3_2 = ConvCaps2D(32,8,32,8,3,1,1,1) #inSize 4
    self.convcaps3_3 = ConvCaps2D(32,8,32,8,3,1,1,1)
    self.convcaps3_4 = ConvCaps2D(32,8,32,8,3,1,1,1)

    #inSize 4
    self.convcaps4_1 = ConvCaps2D(32,8,32,8,3,2,1,1) 
    self.convcaps3d4 = ConvCaps3D(32,8,32,8,3,3) #inSize 2
    self.convcaps4_3 = ConvCaps2D(32,8,32,8,3,1,1,1)
    self.convcaps4_4 = ConvCaps2D(32,8,32,8,3,1,1,1)

    self.flat_caps = FlattenCaps()
    #numCaps for MNIST: 640
    #self, nc=10, num_routes=640, in_dim=8, out_dim=16, routing_iters=3
    self.digitCaps = DenseCaps_v2()
    
    self.reconNet = Decoder(16,1,28,1)

    self.caps_score = CapsToScalars()
    self.mask = MaskCID()

    self.mse_loss = nn.MSELoss(reduction='none')

  def forward(self, x, target=None):
    x = self.conv2d(x)
    x = self.bn(x)
    x = self.toCaps(x)

    # print(x.shape)

    #Block 1
    x =self.convcaps1_1(x)
    x_skip = self.convcaps1_2(x)
    x = self.convcaps1_3(x)
    x = self.convcaps1_4(x)
    x = x+x_skip

    #Block 2
    x =self.convcaps2_1(x)
    x_skip = self.convcaps2_2(x)
    x = self.convcaps2_3(x)
    x = self.convcaps2_4(x)
    x = x+x_skip

    #Block 3
    x =self.convcaps3_1(x)
    x_skip = self.convcaps3_2(x)
    x = self.convcaps3_3(x)
    x = self.convcaps3_4(x)
    x = x+x_skip
    x1 = x

    #Block 1
    x =self.convcaps4_1(x)
    x_skip = self.convcaps3d4(x)
    x = self.convcaps3_3(x)
    x = self.convcaps3_4(x)
    x = x+x_skip
    x2 = x

    xa = self.flat_caps(x1) # 512 Capsules
    xb = self.flat_caps(x2) # 128 Capsules

    x = torch.cat([xa,xb],dim=-2)

    class_caps = self.digitCaps(x)
    x = self.caps_score(class_caps)
    masked, indices = self.mask(class_caps, target)
    decoded = self.reconNet(masked)

    return class_caps, masked, decoded, indices

  def margin_loss(self, x, labels, lamda, m_plus, m_minus):
    v_c = torch.norm(x, dim=2, keepdim=True)
    tmp1 = F.relu(m_plus - v_c).view(x.shape[0], -1) ** 2
    tmp2 = F.relu(v_c - m_minus).view(x.shape[0], -1) ** 2
    loss = labels*tmp1 + lamda*(1-labels)*tmp2
    loss = loss.sum(dim=1)
    return loss

  def reconst_loss(self, recnstrcted, data):
    loss = self.mse_loss(recnstrcted.view(recnstrcted.shape[0], -1), data.view(recnstrcted.shape[0], -1))
    return 0.4 * loss.sum(dim=1)

  def loss(self, x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1):
    loss = self.margin_loss(x, labels, lamda, m_plus, m_minus) + self.reconst_loss(recnstrcted, data)
    return loss.mean()



### Dataloader


In [0]:
class MNISTData(Dataset):
  def __init__(self,mode='train'):
    super().__init__()
    if 'test' in mode.lower():
      fname = 'mnist_test.csv'
    else:
      fname = 'mnist_train_small.csv'
      
    dataset = pd.read_csv('sample_data/'+fname).values
    # if 'test' in mode:
    #   maskIdx = [3,7]
    # else:
    #   maskIdx = [0,1,2,4,5,6,8,9]
    
    self.xData = dataset[:,1:]/255
    self.yData = dataset[:,0]
    
    # mask = np.isin(self.yData,maskIdx)

    # self.xData = self.xData[mask]
    # self.yData = self.yData[mask]
    self.xData = self.xData[:128]
    self.yData = self.yData[:128]

  def __len__(self):
    return len(self.xData)
  
  def __getitem__(self, idx):
    return np.reshape(self.xData[idx],[1,28,28]),self.yData[idx]#to_categorical(self.yData[idx],num_classes=10)

### Training

#### Plot Helpers

In [0]:
def plot_img(imgs,idx=42, title='Number'):
  plt.figure(figsize=[7,7])
  plt.imshow(imgs[idx],cmap='gray')
  plt.axis('off')
  plt.title(title)

#### Loss Functions

In [0]:
mse_loss = nn.MSELoss(reduction='none')

def margin_loss(x, labels, lamda=0.5, m_plus=0.9, m_minus=0.1):
  v_c = torch.norm(x, dim=2, keepdim=True)
  tmp1 = F.relu(m_plus - v_c).view(x.shape[0], -1) ** 2
  tmp2 = F.relu(v_c - m_minus).view(x.shape[0], -1) ** 2
  loss_ = labels*tmp1 + lamda*(1-labels)*tmp2
  loss_ = loss_.sum(dim=1)
  return loss_
    
def reconst_loss(recnstrcted, data):
  loss = mse_loss(recnstrcted.view(recnstrcted.shape[0], -1), data.view(recnstrcted.shape[0], -1))
  return 0.4 * loss.sum(dim=1)
    
def loss(x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1):
  # print(type(recnstrcted))
  loss_ = margin_loss(x, labels, lamda, m_plus, m_minus) + reconst_loss(recnstrcted, data)
  return loss_.mean()

def accuracy(indices, labels):
  correct = 0.0
  for i in range(indices.shape[0]):
      if float(indices[i]) == labels[i]:
          correct += 1
  return correct

#### Training

In [0]:
model = DeepCaps().to(device)

In [0]:
batch_size = 64
num_epochs = 100
lamda = 0.5
m_plus = 0.9
m_minus = 0.1

In [0]:
# train_data = MNISTData('train')
# test_data = MNISTData('test')
# train_loader = DataLoader(train_data,batch_size,True)
# test_loader = DataLoader(test_data,batch_size,True)

# MNIST
# train_loader = torch.utils.data.DataLoader(
#         datasets.MNIST('../data', train=True, download=True,
#                        transform=transforms.Compose([
#                            transforms.Pad(2), transforms.RandomCrop(28),
#                            transforms.ToTensor()
#                        ])),
#         batch_size=batch_size, shuffle=True)
# test_loader = torch.utils.data.DataLoader(
#         datasets.MNIST('../data', train=False, download=True,
#                        transform=transforms.Compose([
#                            transforms.Pad(2), transforms.RandomCrop(28),
#                            transforms.ToTensor()
#                        ])),
#         batch_size=batch_size, shuffle=True)

# FASHION-MNIST
train_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Pad(2), transforms.RandomCrop(28),
                           transforms.ToTensor()
                       ])),
        batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST('../data', train=False, download=True,
                       transform=transforms.Compose([
                           transforms.Pad(2), transforms.RandomCrop(28),
                           transforms.ToTensor()
                       ])),
        batch_size=batch_size, shuffle=True)

  0%|          | 0/26421880 [00:00<?, ?it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz


26427392it [00:00, 76096183.18it/s]                              


Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw


32768it [00:00, 443357.18it/s]
  5%|▍         | 212992/4422102 [00:00<00:02, 1861086.58it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


4423680it [00:00, 23152133.52it/s]                           
8192it [00:00, 161233.84it/s]


Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw
Processing...
Done!


In [0]:
def train(train_loader, model, num_epochs, lr=0.001, batch_size=64, lamda=0.5, m_plus=0.9,  m_minus=0.1):
    optimizer = torch.optim.Adam(model.parameters(), lr)
    lambda1 = lambda epoch: 0.5**(epoch // 10)
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

    for epoch in tqdm(range(num_epochs)):
      loss_track=0.
      for batch_idx, (data, label_) in enumerate(train_loader):
        data = data.float().to(device)
        labels = one_hot(label_.to(device))
        optimizer.zero_grad()
        outputs, masked, recnstrcted, indices = model(data, labels)
        loss_val = loss(outputs, recnstrcted, data, labels, lamda, m_plus, m_minus)
        loss_track += loss_val.item()
        # print(recnstrcted.shape)
        # plot_img(recnstrcted.squeeze().cpu().detach().numpy())
        loss_val.backward()
        optimizer.step()

      print(f'EP {epoch}, Loss: {loss_track}, Accuracy: {accuracy(indices, label_.cpu())/indices.shape[0]}')
      loss_track=0.
      lr_scheduler.step()

In [0]:
train(train_loader,model,30)

  3%|▎         | 1/30 [01:46<51:36, 106.76s/it]

EP 0, Loss: 59756.566802978516, Accuracy: 1.0


  7%|▋         | 2/30 [03:32<49:43, 106.55s/it]

EP 1, Loss: 59699.96783065796, Accuracy: 1.0


 10%|█         | 3/30 [05:18<47:52, 106.38s/it]

EP 2, Loss: 59672.4952545166, Accuracy: 1.0


KeyboardInterrupt: ignored

In [0]:
x_,y_ = iter(train_loader).next()

In [0]:
y_[0]

In [0]:
x = x_.float().to(device)
y = one_hot(y_)

outputs, masked, recnstrcted, indices = model(x)

In [0]:
pred = torch.argmax(torch.norm(outputs,dim=-1),dim=1)

In [0]:
y_

In [0]:
pred

In [0]:
plot_img(x_.squeeze(),3)

In [0]:
plot_img(recnstrcted.squeeze().cpu().detach().numpy(),3)

#### Testing

In [0]:
def test(model, test_loader, loss, batch_size, lamda=0.5, m_plus=0.9, m_minus=0.1):
  test_loss = 0.0
  correct = 0.0
  for batch_idx, (data, label) in enumerate(test_loader):
    data, labels = data.cuda(), one_hot(label.cuda())
    outputs, masked_output, recnstrcted, indices = model(data)
    
    loss_test = model.loss(outputs, recnstrcted, data, labels, lamda, m_plus, m_minus)
    test_loss += loss_test.data
    indices_cpu, labels_cpu = indices.cpu(), label.cpu()

    correct += accuracy(indices_cpu, labels_cpu)

  print("\nTest Loss: ", test_loss/len(test_loader.dataset), "; Test Accuracy: " , correct/len(test_loader.dataset) * 100,'\n')

In [0]:
 test(model, test_loader, loss, batch_size, lamda, m_plus, m_minus)