## 胶囊网络

### 基本结构

In [5]:
import torch
from torch import nn

![](./capsnet.jpg)

**结构解读**

- 普通卷积层 $Conv1$：基本的卷积层，感受野较大，达到了 $9*9$

- 预胶囊层 $PrimaryCaps$：为胶囊层准备，运算为卷积运算，最终输出为[batch,caps_num,caps_length]的三维数据：batch为批大小

- 胶囊层 $DigitCaps$：胶囊层，目的是代替最后一层全连接层，输出为 $10$ 个胶囊

    - $caps\_num$ 为胶囊的数量 图中为 $10$

    - $caps\_length$ 为每个胶囊的长度（每个胶囊为一个向量，该向量包括 $caps\_length$ 个分量） 图中为 $16$
       

### 胶囊网络组件的代码实现

**胶囊网络激活函数**

![](./squash.png)

In [3]:
def squash(inputs, axis=-1):
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
    return scale * inputs

### 预胶囊层

In [6]:
class PrimaryCapsule(nn.Module):
    """
    Apply Conv2D with `out_channels` and then reshape to get capsules
    :param in_channels: input channels
    :param out_channels: output channels
    :param dim_caps: dimension of capsule
    :param kernel_size: kernel size
    :return: output tensor, size=[batch, num_caps, dim_caps]
    """
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):
        super(PrimaryCapsule, self).__init__()
        self.dim_caps = dim_caps  # 每个胶囊的长度
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    def forward(self, x):
        outputs = self.conv2d(x)
        outputs = outputs.view(x.size(0), -1, self.dim_caps) # [batch,caps_num,caps_length]
        return squash(outputs) # 激活函数

### 胶囊层

**动态路由算法**

![](./动态路由.png)

In [7]:
class DenseCapsule(nn.Module):
    """
    in_num_caps：输入胶囊的数量
    in_dim_caps：输入胶囊的长度（维数）
    out_num_caps：输出胶囊的数量
    out_dim_caps：输出胶囊的长度（维数）
    routings：动态路由迭代的次数
    """
    def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):
        super(DenseCapsule, self).__init__()
        self.in_num_caps = in_num_caps
        self.in_dim_caps = in_dim_caps
        self.out_num_caps = out_num_caps
        self.out_dim_caps = out_dim_caps
        self.routings = routings
        # 定义了权值weight，尺寸为[out_num_caps, in_num_caps, out_dim_caps, in_dim_caps]，即每个输出和每个输出胶囊都有连接
        self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

    def forward(self, x):
        '''
        第一部分： 输入映射部分
        - x[:, None, :, :, None]将数据维度从[batch, in_num_caps, in_dim_caps]扩展到[batch, 1,in_num_caps, in_dim_caps,1]
        
        - torch.matmul()将weight和扩展后的输入相乘，weight的尺寸是[out_num_caps, in_num_caps, out_dim_caps, in_dim_caps]，
          相乘后结果尺寸为[batch, out_num_caps, in_num_caps,out_dim_caps, 1]
        
        - torch.squeeze去除多余维度 [batch, out_num_caps, in_num_caps,out_dim_caps]
        '''
        x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
        x_hat_detached = x_hat.detach() # 截断梯度，禁止反向传播
        
        '''
        第二部分： 动态路由算法
        
        - 第一部分是softmax函数，使用c = F.softmax(b, dim=1)实现，该步骤不改变b的尺寸
        
        - 第二部分是计算路由结果：outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
          c[:, :, :, None]扩展c的维度，以便按位置相乘时广播维度
         torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True)计算出每个胶囊与对应权值的积，即算法中的sj，
         同时在倒数第二维上求和，则该步输出的结果尺寸为[batch, out_num_caps, 1,out_dim_caps] 通过激活函数squash()
        '''
        # The prior for coupling coefficient, initialized as zeros.
        # b.size = [batch, out_num_caps, in_num_caps]
        b = torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)

        assert self.routings > 0, 'The \'routings\' should be > 0.'
        for i in range(self.routings):
            # c.size = [batch, out_num_caps, in_num_caps]
            c = F.softmax(b, dim=1)

            # At last iteration, use `x_hat` to compute `outputs` in order to backpropagate gradient
            if i == self.routings - 1:
                # c.size expanded to [batch, out_num_caps, in_num_caps, 1           ]
                # x_hat.size     =   [batch, out_num_caps, in_num_caps, out_dim_caps]
                # => outputs.size=   [batch, out_num_caps, 1,           out_dim_caps]
                outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
                # outputs = squash(torch.matmul(c[:, :, None, :], x_hat))  # alternative way
            else:  # Otherwise, use `x_hat_detached` to update `b`. No gradients flow on this path.
                outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True))
                # outputs = squash(torch.matmul(c[:, :, None, :], x_hat_detached))  # alternative way

                # outputs.size       =[batch, out_num_caps, 1,           out_dim_caps]
                # x_hat_detached.size=[batch, out_num_caps, in_num_caps, out_dim_caps]
                # => b.size          =[batch, out_num_caps, in_num_caps]
                b = b + torch.sum(outputs * x_hat_detached, dim=-1)

        return torch.squeeze(outputs, dim=-2)

In [8]:
in_num_caps = 32
in_dim_caps = 6
out_num_caps = 10
out_dim_caps = 16
weight = torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps)

In [9]:
weight.shape

torch.Size([10, 32, 16, 6])

In [14]:
x = torch.randn(1,in_num_caps, in_dim_caps)
x = x[:, None, :, :, None]
x.shape

torch.Size([1, 1, 32, 6, 1])

In [15]:
res = torch.matmul(weight, x)

In [16]:
res.shape

torch.Size([1, 10, 32, 16, 1])

### 胶囊网络整体结构

In [17]:
class CapsuleNet(nn.Module):
    """
    A Capsule Network on MNIST.
    :param input_size: data size = [channels, width, height]
    :param classes: number of classes
    :param routings: number of routing iterations
    Shape:
        - Input: (batch, channels, width, height), optional (batch, classes) .
        - Output:((batch, classes), (batch, channels, width, height))
    """
    def __init__(self, input_size, classes, routings):
        super(CapsuleNet, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.routings = routings

        # Layer 1: Just a conventional Conv2D layer
        self.conv1 = nn.Conv2d(input_size[0],
                               256,
                               kernel_size=9,
                               stride=1,
                               padding=0)

        # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_caps, dim_caps]
        self.primarycaps = PrimaryCapsule(256,
                                          256,
                                          8,
                                          kernel_size=9,
                                          stride=2,
                                          padding=0)

        # Layer 3: Capsule layer. Routing algorithm works here.
        self.digitcaps = DenseCapsule(in_num_caps=32 * 6 * 6,
                                      in_dim_caps=8,
                                      out_num_caps=classes,
                                      out_dim_caps=16,
                                      routings=routings)

        self.relu = nn.ReLU()

    def forward(self, x, y=None):
        x = self.relu(self.conv1(x))
        x = self.primarycaps(x)
        x = self.digitcaps(x)
        length = x.norm(dim=-1)
        return length