In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torchvision import datasets, transforms

import matplotlib

import matplotlib.pyplot as plt

import os
import random
import math

from torch.utils.data.dataset import Dataset

INPUT_PATH = "./output/nodule_npy/"

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
nz = 100
n_epochs = 50
ngf = 64
ngpu = 1

print(device)

cuda:0


In [3]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def chunks(arr, m):
    nchunk = int(math.ceil(len(arr) / float(m)))
    return [arr[i:i + nchunk] for i in range(0, len(arr), nchunk)]

def five_folder(arr, number):
    training_set = []
    test_set = []
    for j in range(len(arr)):
        if number == j:
            test_set.extend(arr[j])
        else:
            training_set.extend(arr[j])
    return training_set, test_set


class MyDataset(Dataset):
    def __init__(self, images):
        self.images = images

    def __getitem__(self, index):#返回的是tensor
        img = self.images[index]
        return img

    def __len__(self):
        return len(self.images)

In [4]:
# Generator
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu

        self.project = nn.Sequential(
          nn.Linear(nz, 1 * 5 * 5 * ngf * 4, bias=False)
        )
        self.deconv = nn.Sequential(
          # input is Z, going into a deconvolution
          # state size. BATCH_SIZE x (ngf*8) x 4 x 4 x 4
#           nn.ConvTranspose3d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
#           nn.BatchNorm3d(ngf * 4),
#           nn.ReLU(True),
          # state size. BATCH_SIZE x (ngf*4) x 1 x 5 x 5
          nn.ConvTranspose3d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm3d(ngf * 2),
          nn.ReLU(True),
          # state size. BATCH_SIZE x (ngf*2) x 2 x 10 x 10
          nn.ConvTranspose3d(ngf * 2, ngf, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm3d(ngf),
          nn.ReLU(True),
          # state size. BATCH_SIZE x (ngf) x 4 x 20 x 20
          nn.ConvTranspose3d(ngf, 1, kernel_size=4, stride=2, padding=1, bias=False),
          nn.Tanh()
          # state size. BATCH_SIZE x 1 x 8 x 40 x 40
        )

    def forward(self, input):
        x = self.project(input)
        # Conv3d的规定输入数据格式为(batch, channel, Depth, Height, Width)
#         print(x.shape)
        x = x.view(-1, ngf * 4, 1, 5, 5)
#         x = self.deconv(x)
        
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.deconv, x, range(self.ngpu))
        else:
            output = self.deconv(x)

        return output


# Discriminator
def softmax(input, dim=1):
    transposed_input = input.transpose(dim, len(input.size()) - 1)
    softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1)
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1)

class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=(2,9,9)):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv3d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                stride=1
                              )

    def forward(self, x):
        return F.relu(self.conv(x))


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=(2,9,9), num_routes=32 * 12 * 12 * 6):
        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
          nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=(1,2,2), padding=0) 
                      for _ in range(num_capsules)])
  
    def forward(self, x):
        u = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
        u = torch.cat(u, dim=-1)
        return self.squash(u)
  
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor



class DigitCaps(nn.Module):
    def __init__(self, num_capsules=1, num_routes=32 * 12 * 12 * 6, in_channels=8, out_channels=16, num_iterations=3):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.num_iterations = num_iterations
        self.route_weights = nn.Parameter(torch.randn(num_capsules, num_routes, in_channels, out_channels)).to(device)

    def forward(self, x):
        # 矩阵相乘
        # x.size(): [1, batch_size, in_capsules, 1, dim_in_capsule]
        # weight.size(): [num_capsules, 1, num_route, in_channels, out_channels]
        priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
#         priors = priors.to(device)
#         print()
#         print(x[None, :, :, None, :].size())
#         print(self.route_weights[:, None, :, :, :].size())
#         print(priors.size())
#         print()

        logits = Variable(torch.zeros(*priors.size())).to(device)
#         logits = Variable(torch.zeros(*priors.size()))
        for i in range(self.num_iterations):
            probs = softmax(logits, dim=2)
            outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))

            if i != self.num_routes - 1:
                delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
                logits = logits + delta_logits
        
#         return outputs.squeeze().transpose(0, 1)
        return outputs.squeeze()
  
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
    
        self.reconstraction_layers = nn.Sequential(
          nn.Linear(16 * 10, 512),
          nn.ReLU(inplace=True),
          nn.Linear(512, 1024),
          nn.ReLU(inplace=True),
          nn.Linear(1024, 784),
          nn.Sigmoid()
        )
      
    def forward(self, x, data):
        classes = (x ** 2).sum(dim=-1) ** 0.5
        classes = F.softmax(classes, dim=-1)

        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.sparse.torch.eye(10))
        if USE_CUDA:
              masked = masked.cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
    
        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        reconstructions = reconstructions.view(-1, 1, 28, 28)

        return reconstructions, masked


class CapsNet(nn.Module):
    def __init__(self, ngpu):
        super(CapsNet, self).__init__()
        self.ngpu = ngpu
        
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        # self.decoder = Decoder()

        # self.mse_loss = nn.MSELoss()
      
    def forward(self, data):
#         output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        
        if data.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.conv_layer, data, range(self.ngpu))
            output = nn.parallel.data_parallel(self.primary_capsules, output, range(self.ngpu))
            output = nn.parallel.data_parallel(self.digit_capsules, output, range(self.ngpu))
        else:
            output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))

#         print("OUTPUTS: ",output.shape)
        classes = (output ** 2).sum(dim=-1) ** 0.5
#         classes = F.softmax(classes, dim=-1)
        
        return classes, output

        #reconstructions, masked = self.decoder(output, data)
        #return output, reconstructions, masked

    # def loss(self, data, x, target, reconstructions):
      #   return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)

    # def margin_loss(self, x, labels, size_average=True):
      #   batch_size = x.size(0)

      #   v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

      #   left = F.relu(0.9 - v_c).view(batch_size, -1)
      #   right = F.relu(v_c - 0.1).view(batch_size, -1)

      #   loss = labels * left + 0.5 * (1.0 - labels) * right
      #   loss = loss.sum(dim=1).mean()

      #   return loss

    # def reconstruction_loss(self, data, reconstructions):
      #   loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
      #   return loss * 0.0005


class CapsuleLoss(nn.Module):
    def __init__(self):
        super(CapsuleLoss, self).__init__()
        # self.reconstruction_loss = nn.MSELoss(size_average=False)

    # def forward(self, images, labels, classes, reconstructions):
    def forward(self, classes, labels):
        left = F.relu(0.9 - classes, inplace=True) ** 2
        right = F.relu(classes - 0.1, inplace=True) ** 2

        margin_loss = labels * left + 0.5 * (1. - labels) * right
        margin_loss = margin_loss.sum()

        return margin_loss

        # assert torch.numel(images) == torch.numel(reconstructions)
        # images = images.view(reconstructions.size()[0], -1)
        # reconstruction_loss = self.reconstruction_loss(reconstructions, images)

        # return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)

In [5]:
def train(cross_part, train_loader, single_test_set):
    netD = CapsNet(ngpu).to(device)
    criterion = CapsuleLoss()

    netG = Generator(ngpu).to(device)
    netG.apply(weights_init)

    fixed_noise = torch.randn(BATCH_SIZE, nz, device=device)
    real_label = 1
    fake_label = 0

    optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    
    for epoch in range(n_epochs):
        print("-----THE [{}/{}] epoch start-----".format(epoch + 1, n_epochs))
        for j, data in enumerate(train_loader, 0):
#                     print(j, data.shape) #torch.Size([32, 1, 8, 40, 40])

#                     plt.figure()
#                     for lenc in range(data.shape[0]):
#                         for len_img in range(8):
#                             plt.subplot(2, 4, len_img + 1)
#                             pixel_array = data[lenc][0][len_img]
#                             print(pixel_array)
#                             plt.imshow(pixel_array, cmap="gray")
#                         plt.show()


#                     pixel_array = data[0][0][0][20]
#                     print(pixel_array)

#                     data = data / 255.0
#                     pixel_array = data[0][0][0][20]
#                     print(pixel_array)
            ############################
            # (1) Update D network: maximize Lm(D(x), T = 0) + Lm(D(G(z)), T = 1)
            ###########################
            # train with real
            netD.zero_grad()
            real_cpu = data.to(device, dtype=torch.float)
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size,), real_label, device=device)   

            classes_d_real, output_d_real = netD(real_cpu)
#                     print("ERR_D_REAL:",classes_d_real)
            errD_real = criterion(classes_d_real, label)
            errD_real.backward()
            D_x = output_d_real.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            classes_d_fake, output_d_fake = netD(fake.detach())

#                     print("ERR_D_FAKE:",classes_d_fake)
            errD_fake = criterion(classes_d_fake, label)
            errD_fake.backward()
            D_G_z1 = output_d_fake.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize Lm(D(G(z)), T=0)
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            netD.eval()

            classes_g, output_g = netD(fake)
#                     print("ERRG:",classes_g)
            errG = criterion(classes_g, label)
            errG.backward()
            D_G_z2 = output_g.mean().item()
            optimizerG.step()

#                     print("The batch data shape is {}".format(data.shape))

            print('[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                  % (j + 1, len(train_loader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))


        print("-----THE [{}/{}] epoch end-----".format(epoch + 1, n_epochs))
    
    ###TEST###
    test_loader = torch.utils.data.DataLoader(MyDataset(single_test_set), batch_size=BATCH_SIZE, shuffle=True)
    test_sum = 0
    for i, test_data in enumerate(test_loader, 0):
        test_data = test_data.to(device, dtype=torch.float)
        test_classes, test_outputs = netD(test_data)
        
        test_classes = test_classes.cpu()
        test_classes = test_classes.detach().numpy()
        thresh = 0.5

        test_classes[test_classes > thresh] = 1
        test_classes[test_classes < thresh] = 0
        
        single_test_sum = test_classes.sum()
        test_sum += single_test_sum
        
    print("{}/{} is discriminated for right. The ACCURACY is {}%.".format(test_sum, len(single_test_set), 100 * test_sum / len(test_loader)))
    
    return test_sum
    

In [6]:
input_data = os.listdir(INPUT_PATH)

for radiologist in input_data: #遍历patient文件夹——study指代每一个study文件夹
    if not radiologist.startswith('.'): #忽略.DS文件        
        npy_file_path = os.path.join(INPUT_PATH, radiologist)
        npy_files = os.listdir(npy_file_path)
        npy_list = []
        for i in npy_files:
            npy_path = os.path.join(npy_file_path, i)
            single_npy = np.load(npy_path)
            #print(single_npy.shape)
            npy_list.append(single_npy)

#         print(npy_file_path, np.array(npy_list).shape)
        random.shuffle(npy_list)
        npy_chunks = chunks(npy_list, 5)

        print("THIS IS the data from {}. AND NOW the 5-folder cross-valiation start:".format(npy_file_path))
        
        acc_total = 0
        for i in range(5):
            #training
            print("-------------THE {} part(as the test set) cross-valiation start----------------------".format(i + 1))
            single_training_set, single_test_set = five_folder(npy_chunks, i)
#             print(i, np.array(single_training_set).shape, np.array(single_test_set).shape)
            train_loader = torch.utils.data.DataLoader(MyDataset(single_training_set), batch_size=BATCH_SIZE, shuffle=True)
            
            #print(len(train_loader), type(train_loader))
            single_test_acc = train(i, train_loader,single_test_set)
            acc_total += single_test_acc
        
        print("THIS IS the data from {}. {}/{} is discriminated for right. The TOTAL ACCURACY is {}%.".format(npy_file_path, acc_total, len(npy_list), 100 * acc_total / len(npy_list)))

THIS IS the data from ./output/nodule_npy/Radiologist_1. AND NOW the 5-folder cross-valiation start:
-------------THE 1 part(as the test set) cross-valiation start----------------------
-----THE [1/50] epoch start-----
[1/10] Loss_D: 25.8905 Loss_G: 25.8261 D(x): 0.0000 D(G(z)): -0.0000 / 0.0001
[2/10] Loss_D: 23.3455 Loss_G: 25.8196 D(x): 0.0031 D(G(z)): 0.0001 / 0.0001
[3/10] Loss_D: 21.2638 Loss_G: 25.8824 D(x): 0.0058 D(G(z)): 0.0001 / 0.0000
[4/10] Loss_D: 19.2242 Loss_G: 25.8095 D(x): 0.0093 D(G(z)): 0.0001 / 0.0001
[5/10] Loss_D: 17.7059 Loss_G: 25.6931 D(x): 0.0119 D(G(z)): 0.0002 / 0.0003
[6/10] Loss_D: 18.0797 Loss_G: 25.6661 D(x): 0.0117 D(G(z)): 0.0004 / 0.0003
[7/10] Loss_D: 18.2215 Loss_G: 25.6250 D(x): 0.0117 D(G(z)): 0.0004 / 0.0004
[8/10] Loss_D: 14.8814 Loss_G: 25.6188 D(x): 0.0178 D(G(z)): 0.0005 / 0.0004
[9/10] Loss_D: 16.6658 Loss_G: 25.4610 D(x): 0.0149 D(G(z)): 0.0005 / 0.0006
[10/10] Loss_D: 11.5287 Loss_G: 19.1062 D(x): 0.0171 D(G(z)): 0.0007 / 0.0006
-----THE 

[8/10] Loss_D: 8.2138 Loss_G: 21.1244 D(x): 0.0359 D(G(z)): 0.0079 / 0.0073
[9/10] Loss_D: 9.1252 Loss_G: 20.1058 D(x): 0.0350 D(G(z)): 0.0075 / 0.0084
[10/10] Loss_D: 6.5256 Loss_G: 15.2907 D(x): 0.0347 D(G(z)): 0.0085 / 0.0081
-----THE [10/50] epoch end-----
-----THE [11/50] epoch start-----
[1/10] Loss_D: 9.0028 Loss_G: 20.9683 D(x): 0.0341 D(G(z)): 0.0083 / 0.0073
[2/10] Loss_D: 7.8301 Loss_G: 20.6153 D(x): 0.0379 D(G(z)): 0.0073 / 0.0076
[3/10] Loss_D: 8.9047 Loss_G: 20.8001 D(x): 0.0341 D(G(z)): 0.0078 / 0.0075
[4/10] Loss_D: 9.1514 Loss_G: 20.8980 D(x): 0.0357 D(G(z)): 0.0076 / 0.0073
[5/10] Loss_D: 7.3957 Loss_G: 20.1204 D(x): 0.0388 D(G(z)): 0.0074 / 0.0085
[6/10] Loss_D: 10.0494 Loss_G: 20.7337 D(x): 0.0325 D(G(z)): 0.0086 / 0.0076
[7/10] Loss_D: 8.3849 Loss_G: 20.0096 D(x): 0.0344 D(G(z)): 0.0077 / 0.0088
[8/10] Loss_D: 13.6858 Loss_G: 20.8117 D(x): 0.0249 D(G(z)): 0.0088 / 0.0075
[9/10] Loss_D: 10.7940 Loss_G: 19.5862 D(x): 0.0293 D(G(z)): 0.0076 / 0.0093
[10/10] Loss_D: 9.

KeyboardInterrupt: 