In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import prod

In [2]:
class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first

    def forward(self, x):

        if len(x.size()) <= 2:
            return self.module(x)

        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(-1, x.size(-3), x.size(-2), x.size(-1))  # (samples * timesteps, input_size)

        y, reconstruction = self.module(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-3),  y.size(-2),  y.size(-1))  # (samples, timesteps, output_size)
        else:
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)
            reconstruction = reconstruction.view(-1, x.size(1), reconstruction.size(-3),  reconstruction.size(-2),  reconstruction.size(-1))
        return y, reconstruction

In [3]:
def squash(s, dim=-1):
    '''
    "Squashing" non-linearity that shrunks short vectors to almost zero length and long vectors to a length slightly below 1
    Eq. (1): v_j = ||s_j||^2 / (1 + ||s_j||^2) * s_j / ||s_j||
    Args:
        s: Vector before activation
        dim: Dimension along which to calculate the norm
    Returns:
        Squashed vector
    '''
    squared_norm = torch.sum(s**2, dim=dim, keepdim=True)
    return squared_norm / (1 + squared_norm) * s / (torch.sqrt(squared_norm) + 1e-8)


class PrimaryCapsules(nn.Module):
    def __init__(self, in_channels, out_channels, dim_caps,
    kernel_size=9, stride=2, padding=0):
        """
        Initialize the layer.

        Args:
            in_channels: Number of input channels.
            out_channels: Number of output channels.
            dim_caps: Dimensionality, i.e. length, of the output capsule vector.
        """
        super(PrimaryCapsules, self).__init__()
        self.dim_caps = dim_caps
        self._caps_channel = int(out_channels / dim_caps)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size(0), self._caps_channel, out.size(2), out.size(3), self.dim_caps)
        out = out.view(out.size(0), -1, self.dim_caps)
        return squash(out)


class RoutingCapsules(nn.Module):
    def __init__(self, in_dim, in_caps, num_caps, dim_caps, num_routing, device: torch.device):
        """
        Initialize the layer.

        Args:
            in_dim: Dimensionality (i.e. length) of each capsule vector.
            in_caps: Number of input capsules if digits layer.
            num_caps: Number of capsules in the capsule layer
            dim_caps: Dimensionality, i.e. length, of the output capsule vector.
            num_routing: Number of iterations during routing algorithm
        """
        super(RoutingCapsules, self).__init__()
        self.in_dim = in_dim
        self.in_caps = in_caps
        self.num_caps = num_caps
        self.dim_caps = dim_caps
        self.num_routing = num_routing
        self.device = device

        self.W = nn.Parameter( 0.01 * torch.randn(1, num_caps, in_caps, dim_caps, in_dim ) )

    def __repr__(self):
        tab = '  '
        line = '\n'
        next = ' -> '
        res = self.__class__.__name__ + '('
        res = res + line + tab + '(' + str(0) + '): ' + 'CapsuleLinear('
        res = res + str(self.in_dim) + ', ' + str(self.dim_caps) + ')'
        res = res + line + tab + '(' + str(1) + '): ' + 'Routing('
        res = res + 'num_routing=' + str(self.num_routing) + ')'
        res = res + line + ')'
        return res

    def forward(self, x):
        batch_size = x.size(0)
        # (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1)
        x = x.unsqueeze(1).unsqueeze(4)
        #
        # W @ x =
        # (1, num_caps, in_caps, dim_caps, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) =
        # (batch_size, num_caps, in_caps, dim_caps, 1)
        u_hat = torch.matmul(self.W, x)
        # (batch_size, num_caps, in_caps, dim_caps)
        u_hat = u_hat.squeeze(-1)
        # detach u_hat during routing iterations to prevent gradients from flowing
        temp_u_hat = u_hat.detach()

        '''
        Procedure 1: Routing algorithm
        '''
        b = torch.zeros(batch_size, self.num_caps, self.in_caps, 1).to(self.device)

        for route_iter in range(self.num_routing-1):
            # (batch_size, num_caps, in_caps, 1) -> Softmax along num_caps
            c = F.softmax(b, dim=1)

            # element-wise multiplication
            # (batch_size, num_caps, in_caps, 1) * (batch_size, in_caps, num_caps, dim_caps) ->
            # (batch_size, num_caps, in_caps, dim_caps) sum across in_caps ->
            # (batch_size, num_caps, dim_caps)
            s = (c * temp_u_hat).sum(dim=2)
            # apply "squashing" non-linearity along dim_caps
            v = squash(s)
            # dot product agreement between the current output vj and the prediction uj|i
            # (batch_size, num_caps, in_caps, dim_caps) @ (batch_size, num_caps, dim_caps, 1)
            # -> (batch_size, num_caps, in_caps, 1)
            uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
            b += uv

        # last iteration is done on the original u_hat, without the routing weights update
        c = F.softmax(b, dim=1)
        s = (c * u_hat).sum(dim=2)
        # apply "squashing" non-linearity along dim_caps
        v = squash(s)

        return v

In [4]:
class CapsuleNetwork(nn.Module):
    def __init__(self, img_shape, channels, primary_dim, num_classes, out_dim, num_routing, device: torch.device, kernel_size=9):
        super(CapsuleNetwork, self).__init__()
        self.img_shape = img_shape
        self.num_classes = num_classes
        self.device = device

        self.conv1 = nn.Conv2d(img_shape[0], channels, kernel_size, stride=1, bias=True)
        self.relu = nn.ReLU(inplace=True)

        self.primary = PrimaryCapsules(channels, channels, primary_dim, kernel_size)
        
        primary_caps = int(channels / primary_dim * ( img_shape[1] - 2*(kernel_size-1) ) * ( img_shape[2] - 2*(kernel_size-1) ) / 4)
        self.digits = RoutingCapsules(primary_dim, primary_caps, num_classes, out_dim, num_routing, device=self.device)

        self.decoder = nn.Sequential(
            nn.Linear(out_dim * num_classes, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, int(prod(img_shape)) )
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out = self.primary(out)
        out = self.digits(out)
        preds = torch.norm(out, dim=-1)

        # Reconstruct the *predicted* image
        _, max_length_idx = preds.max(dim=1)	
        y = torch.eye(self.num_classes).to(self.device)
        y = y.index_select(dim=0, index=max_length_idx).unsqueeze(2)

        reconstructions = self.decoder( (out*y).view(out.size(0), -1) )
        reconstructions = reconstructions.view(-1, *self.img_shape)

        return preds, reconstructions

In [5]:
class CapLSTM(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.Encoder = TimeDistributed(CapsuleNetwork(img_shape=(11, 64, 64), channels=256, primary_dim=8, num_classes=100, out_dim=16, num_routing=3, device=torch.device("cpu")))
        self.RNN = nn.LSTM(100, 100, 4)
        self.Output = nn.Linear(100, 1)
        
#     seq_len, batch, input_size
    def forward(self, x):
        out, reconstruction = self.Encoder(x)
        out = self.RNN(out)
        out = self.Output(out[0])
        return out[-1], reconstruction[-1]

In [6]:
net = CapLSTM()

In [7]:
test_input = torch.randn((2, 2, 11, 64, 64))

In [8]:
test_output, test_reconstruction = net(test_input)

In [9]:
test_output.shape, test_reconstruction.shape

(torch.Size([2, 1]), torch.Size([2, 11, 64, 64]))