# Capsule

核心思想是： 每个capsule代表一个特征。

具体的解释，请看 [揭开迷雾，来一顿美味的Capsule盛宴 By 苏剑林](https://kexue.fm/archives/4819)


In [1]:
import torch
import torchvision
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn

In [8]:
class CapsuleLayer(nn.Module):
    def __init__(self, in_cap_num, in_dim, out_cap_num, out_dim, routings):
        super(CapsuleLayer, self).__init__()
        self.out_cap_num = out_cap_num # 下一层 capsule 的个数
        self.in_cap_num = in_cap_num
        self.routings = routings
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        # 变换矩阵
        self.W = nn.Parameter(torch.randn(in_cap_num, out_cap_num, out_dim, in_dim)) # 
        
    def forward(self, u_vecs):
        """
            考虑简单情况，每次都是一个样本，也就是说，我们的输入 u_vecs 是 (in_capsule_num, capsule_dim)
        """        
        # 完成变换矩阵
        
        b = Variable(torch.zeros(self.out_cap_num, self.in_cap_num))
        u_hat = Variable(torch.zeros((self.out_cap_num, self.in_cap_num, self.out_dim)))
        
        for j in range(self.out_cap_num):
            for i in range(self.in_cap_num):
                u_hat[j, i] = torch.mm(self.W[i, j], u_vecs[i].view(-1,1))
        # dynamic routing
        for i in range(self.routings):
            c = F.softmax(b, dim=1) # num_capsule * input_capsule_num (表示概率)
            s = torch.matmul(c.unsqueeze_(1), u_hat).squeeze_(1) # out_cap_num*out_dim
            v = self.squash(s)
            b = b + torch.matmul(u_hat, v.unsqueeze(2)).squeeze_(2)
        return v

    # 定义 squash 函数
    @staticmethod
    def squash(x, p=2, dim=1, keepdim=True):
        """
            params: x (num*feature), p: 几范数, dim: 对哪个维度求范数, keepdim: 保持维度一致
            return: squash_x (b*m)
        """
        squash_norm = torch.norm(x, p, dim, keepdim)
        scale = torch.sqrt(squash_norm) / (1 + squash_norm)
        return scale * x
    

In [9]:
# u[5*128] v[4*16]
cap = CapsuleLayer(5, 128, 4, 16, 3)

u_vecs = Variable(torch.randn((5, 128)))

vttt = cap(u_vecs)

In [None]:
# 考虑例子中的情况
# u[5*128] v[4*16]

u = Variable(torch.randn((5, 128)))

W = Variable(torch.randn((5, 4, 16, 128)))

b = Variable(torch.zeros((4, 5)))

u_hat = Variable(torch.zeros((4, 5, 16)))

for j in range(4):
    for i in range(5):
        u_hat[j, i] = torch.mm(W[i, j], u[i].view(-1,1))

for r in range(3):
    c = F.softmax(b, dim=1)
    s = torch.matmul(c.unsqueeze_(1), u_hat).squeeze_(1) # 4*16
    v = squash(s) # 4*16
    b = b + torch.matmul(u_hat, v.unsqueeze(2)).squeeze_(2)

print(v)