In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torchvision import datasets, transforms

import os

INPUT_PATH = "./output/nodule_npy/"

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 25
nz = 100

print(device)

cuda:0


In [7]:
# Discriminator
def softmax(input, dim=1):
    transposed_input = input.transpose(dim, len(input.size()) - 1)
    softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1)
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1)

class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=(9,9,2)):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv3d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                stride=1
                              )

    def forward(self, x):
        return F.relu(self.conv(x))


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=(9,9,2), num_routes=32 * 12 * 12 * 6):
        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
          nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=(2,2,1), padding=0) 
                      for _ in range(num_capsules)])
  
    def forward(self, x):
        u = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
        u = torch.cat(u, dim=-1)
        return self.squash(u)
  
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor



class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 12 * 12 * 6, in_channels=8, out_channels=16, num_iterations=3):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.num_iterations = num_iterations
        self.route_weights = nn.Parameter(torch.randn(num_capsules, num_routes, in_channels, out_channels))

    def forward(self, x):
        # 矩阵相乘
        # x.size(): [1, batch_size, in_capsules, 1, dim_in_capsule]
        # weight.size(): [num_capsules, 1, num_route, in_channels, out_channels]
        priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]

#         print()
#         print(x[None, :, :, None, :].size())
#         print(self.route_weights[:, None, :, :, :].size())
#         print(priors.size())
#         print()

        # logits = Variable(torch.zeros(*priors.size())).cuda()
        logits = Variable(torch.zeros(*priors.size())).cuda()
        for i in range(self.num_iterations):
            probs = softmax(logits, dim=2)
            outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))

            if i != self.num_routes - 1:
                delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
                logits = logits + delta_logits
    
        return outputs.squeeze().transpose(0, 1)
  
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
    
        self.reconstraction_layers = nn.Sequential(
          nn.Linear(16 * 10, 512),
          nn.ReLU(inplace=True),
          nn.Linear(512, 1024),
          nn.ReLU(inplace=True),
          nn.Linear(1024, 784),
          nn.Sigmoid()
        )
      
    def forward(self, x, data):
        classes = (x ** 2).sum(dim=-1) ** 0.5
        classes = F.softmax(classes, dim=-1)

        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.sparse.torch.eye(10))
        if USE_CUDA:
              masked = masked.cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
    
        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        reconstructions = reconstructions.view(-1, 1, 28, 28)

        return reconstructions, masked


class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        # self.decoder = Decoder()

        # self.mse_loss = nn.MSELoss()
      
    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        return output

        #reconstructions, masked = self.decoder(output, data)
        #return output, reconstructions, masked

    # def loss(self, data, x, target, reconstructions):
      #   return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)

    # def margin_loss(self, x, labels, size_average=True):
      #   batch_size = x.size(0)

      #   v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

      #   left = F.relu(0.9 - v_c).view(batch_size, -1)
      #   right = F.relu(v_c - 0.1).view(batch_size, -1)

      #   loss = labels * left + 0.5 * (1.0 - labels) * right
      #   loss = loss.sum(dim=1).mean()

      #   return loss

    # def reconstruction_loss(self, data, reconstructions):
      #   loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
      #   return loss * 0.0005


class CapsuleLoss(nn.Module):
    def __init__(self):
        super(CapsuleLoss, self).__init__()
        # self.reconstruction_loss = nn.MSELoss(size_average=False)

    # def forward(self, images, labels, classes, reconstructions):
    def forward(self, classes, labels):
        left = F.relu(0.9 - classes, inplace=True) ** 2
        right = F.relu(classes - 0.1, inplace=True) ** 2

        margin_loss = labels * left + 0.5 * (1. - labels) * right
        margin_loss = margin_loss.sum()

        return margin_loss

        # assert torch.numel(images) == torch.numel(reconstructions)
        # images = images.view(reconstructions.size()[0], -1)
        # reconstruction_loss = self.reconstruction_loss(reconstructions, images)

        # return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)


# n = torch.randn(BATCH_SIZE, 1, 40, 40, 8)
# conv = ConvLayer().cuda()
# conv_result = conv(n)
# print(conv_result.size()) # torch.Size([25, 256, 20, 20, 20])

# primary_capsules = PrimaryCaps()
# prim_result = primary_capsules(conv_result)
# print(prim_result.size()) # torch.Size([25, 6912, 8])

# digit_capsules = DigitCaps()
# digit_result = digit_capsules(prim_result)
# print(digit_result.size()) # torch.Size([10, 25, 1, 1, 16])
# digit_result = digit_result.squeeze()
# print(digit_result.size()) # torch.Size([10, 25, 16])
# digit_result = digit_result.transpose(0, 1)
# print(digit_result.size()) # torch.Size([25, 10, 16])


# input_data = os.listdir(INPUT_PATH)
netD = CapsNet()
netD = netD.to(device)

# for radiologist in input_data: #遍历patient文件夹——study指代每一个study文件夹
#     if not radiologist.startswith('.'): #忽略.DS文件
#         npy_file_path = os.path.join(INPUT_PATH, radiologist)
#         npy_files = os.listdir(npy_file_path)
#         print(npy_file_path, len(npy_files))
#         netD_result = netD(npy_files[0])
#         print(netD_result.size())
#         print("------------")

n = torch.randn(BATCH_SIZE, 1, 40, 40, 8)
n = n.to(device)
netD_result = netD(n)
print(netD_result.shape)

torch.Size([5, 10, 16])
