In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torchvision import datasets, transforms

import matplotlib

import matplotlib.pyplot as plt

import os
import random
import math

from torch.utils.data.dataset import Dataset

INPUT_PATH = "./output/nodule_npy/"

In [2]:
randnum = random.randint(0,100)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 10
nz = 100
n_epochs = 100
ngf = 64
ngpu = 1
LEVELS_NUM = 2

cuda = True if torch.cuda.is_available() else False
print(device)

cuda:0


In [3]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def chunks(arr, m):
    nchunk = int(math.ceil(len(arr) / float(m)))
    return [arr[i:i + nchunk] for i in range(0, len(arr), nchunk)]

def ten_folder(arr, number):
    training_set = []
    test_set = []
    for j in range(len(arr)):
        if number == j:
            test_set.extend(arr[j])
        else:
            training_set.extend(arr[j])
    return training_set, test_set


class MyDataset(Dataset):
    def __init__(self, images,labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, index):#返回的是tensor
        img = self.images[index]
        label = self.labels[index]
        return img, label

    def __len__(self):
        return len(self.images)

In [4]:
# Generator
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        
        self.label_emb = nn.Embedding(LEVELS_NUM, 100)

        self.project = nn.Sequential(
          nn.Linear(nz, 1 * 5 * 5 * ngf * 4, bias=False)
        )
        self.deconv = nn.Sequential(
          # input is Z, going into a deconvolution
          # state size. BATCH_SIZE x (ngf*4) x 1 x 5 x 5
          nn.ConvTranspose3d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm3d(ngf * 2),
          nn.ReLU(True),
          # state size. BATCH_SIZE x (ngf*2) x 2 x 10 x 10
          nn.ConvTranspose3d(ngf * 2, ngf, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm3d(ngf),
          nn.ReLU(True),
          # state size. BATCH_SIZE x (ngf) x 4 x 20 x 20
          nn.ConvTranspose3d(ngf, 1, kernel_size=4, stride=2, padding=1, bias=False),
          nn.Tanh()
          # state size. BATCH_SIZE x 1 x 8 x 40 x 40
        )

    def forward(self, noise, labels):
        gen_input = torch.mul(self.label_emb(labels), noise)
        
        x = self.project(gen_input)
        # Conv3d的规定输入数据格式为(batch, channel, Depth, Height, Width)
        x = x.view(-1, ngf * 4, 1, 5, 5)
        
        if noise.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.deconv, x, range(self.ngpu))
        else:
            output = self.deconv(x)

        return output


# Discriminator
def softmax(input, dim=1):
    transposed_input = input.transpose(dim, len(input.size()) - 1)
    softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1)
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1)

class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=(2,9,9)):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv3d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                stride=1
                              )

    def forward(self, x):
        return F.relu(self.conv(x))


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=(2,9,9), num_routes=32 * 12 * 12 * 6):
        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
          nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=(1,2,2), padding=0) 
                      for _ in range(num_capsules)])
  
    def forward(self, x):
        u = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
        u = torch.cat(u, dim=-1)
        return self.squash(u)
  
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor



class DigitCaps(nn.Module):
    def __init__(self, num_capsules=1, num_routes=32 * 12 * 12 * 6, in_channels=8, out_channels=16, num_iterations=3):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.num_iterations = num_iterations
        self.route_weights = nn.Parameter(torch.randn(num_capsules, num_routes, in_channels, out_channels)).to(device)

    def forward(self, x):
        # 矩阵相乘
        # x.size(): [1, batch_size, in_capsules, 1, dim_in_capsule]
        # weight.size(): [num_capsules, 1, num_route, in_channels, out_channels]
        priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
        logits = Variable(torch.zeros(*priors.size())).to(device)

        for i in range(self.num_iterations):
            probs = softmax(logits, dim=2)
            outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))

            if i != self.num_routes - 1:
                delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
                logits = logits + delta_logits
        
        outputs = outputs.squeeze()
#         print("OUTPUT:", outputs.shape)
        return outputs
  
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

class CapsNet(nn.Module):
    def __init__(self, ngpu):
        super(CapsNet, self).__init__()
        self.ngpu = ngpu
        
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.D_digit_capsules = DigitCaps()
        self.C_digit_capsules = DigitCaps(num_capsules=LEVELS_NUM)

    def forward(self, data):
        if data.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.conv_layer, data, range(self.ngpu))
            output = nn.parallel.data_parallel(self.primary_capsules, output, range(self.ngpu))
            D_output = nn.parallel.data_parallel(self.D_digit_capsules, output, range(self.ngpu))
            C_output = nn.parallel.data_parallel(self.C_digit_capsules, output, range(self.ngpu))
            C_output = C_output.transpose(0,1)
        else:
            output = self.primary_capsules(self.conv_layer(data))
            D_output = self.D_digit_capsules(output)
            C_output = self.C_digit_capsules(output).transpose(0,1)
            

        D_classes = (D_output ** 2).sum(dim=-1) ** 0.5
        C_classes = (C_output ** 2).sum(dim=-1) ** 0.5
#         print(D_classes, C_classes)
        return D_classes, C_classes


class CapsuleLoss(nn.Module):
    def __init__(self):
        super(CapsuleLoss, self).__init__()

    def forward(self, classes, labels):
        left = F.relu(0.9 - classes, inplace=True) ** 2
        right = F.relu(classes - 0.1, inplace=True) ** 2

        margin_loss = labels * left + 0.5 * (1. - labels) * right
        margin_loss = margin_loss.sum()

        return margin_loss / classes.size(0)

In [5]:
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
discriminator = CapsNet(ngpu).to(device)
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

generator = Generator(ngpu).to(device)
generator.apply(weights_init)

optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

d_losses = []
g_losses = []
train_acc = []
test_acc = []

def train(train_loader, test_set, test_level):
    test_loader = torch.utils.data.DataLoader(MyDataset(test_set, test_level), batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    epoch_num = 0
    for epoch in range(n_epochs):
        print("-----THE [{}/{}] epoch start-----".format(epoch + 1, n_epochs))
        for j, (data,labels) in enumerate(train_loader, 0):
#                     print(j, data.shape) #torch.Size([32, 1, 8, 40, 40])

#                     plt.figure()
#                     for lenc in range(data.shape[0]):
#                         for len_img in range(8):
#                             plt.subplot(2, 4, len_img + 1)
#                             pixel_array = data[lenc][0][len_img]
#                             print(pixel_array)
#                             plt.imshow(pixel_array, cmap="gray")
#                         plt.show()


#                     pixel_array = data[0][0][0][20]
#                     print(pixel_array)

#                     data = data / 255.0
#                     pixel_array = data[0][0][0][20]
#                     print(pixel_array)

            
            real_imgs = data.to(device, dtype=torch.float)
            labels = labels.to(device)
            
            # Adversarial ground truths
            valid = Variable(FloatTensor(BATCH_SIZE, 1).fill_(1.0), requires_grad=False)
            fake = Variable(FloatTensor(BATCH_SIZE, 1).fill_(0.0), requires_grad=False)
            
            
            z = Variable(FloatTensor(np.random.normal(0, 1, (BATCH_SIZE, 100))))
            gen_labels = Variable(LongTensor(np.random.randint(0, LEVELS_NUM, BATCH_SIZE)))
    
            # Generate a batch of images
            gen_imgs = generator(z, gen_labels)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            
            discriminator.train()
            discriminator.zero_grad()
            
            # Loss for real images
            real_pred, real_aux = discriminator(real_imgs)
            d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

            # Loss for fake images
            fake_pred, fake_aux = discriminator(gen_imgs.detach())
            d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2

            # Calculate discriminator accuracy
            pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
            gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
            
            pred = np.argmax(pred, axis=1)
            
            d_acc = np.mean(pred == gt)
            
#             print(pred)
#             print(gt)

            d_loss.backward()
            optimizer_D.step()
        
            
            # -----------------
            #  Train Generator
            # -----------------
            
            generator.zero_grad()
            
            # Loss measures generator's ability to fool the discriminator
            validity, pred_label = discriminator(gen_imgs)
            
            g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))

            g_loss.backward()
            optimizer_G.step()
            

            print('[BATCH %d/%d] Loss_D: %.4f Loss_G: %.4f acc: %.1f%%'
                  % (j + 1, len(train_loader), d_loss.item(), g_loss.item(), 100 * d_acc))
            
            d_losses.append(d_loss)
            g_losses.append(g_loss)
            train_acc.append(d_acc)
            
            epoch_num += 1
            d_acc_total = 0
            if epoch_num % 50 == 0:
                discriminator.eval()
                for j, (data,labels) in enumerate(test_loader, 0):
                    real_imgs = data.to(device, dtype=torch.float)
                    labels = labels.to(device)
                    real_pred, real_aux = discriminator(real_imgs)
                    pred = real_aux.data.cpu().numpy()
                    pred = np.argmax(pred, axis=1)
                    gt = labels.data.cpu().numpy()
                    d_acc = np.mean(pred == gt)
                    d_acc_total += d_acc
                single_test_acc = d_acc_total / len(test_loader)
                print("[EPOCH %d] TEST ACC is : %.1f%%" % (epoch_num, 100 * single_test_acc))
                test_acc.append(single_test_acc)
                

        print("-----THE [{}/{}] epoch end-----".format(epoch + 1, n_epochs))

In [None]:
input_data = os.listdir(INPUT_PATH)
npy_list = []
npy_level_list = []

for level_num in range(5): #遍历patient文件夹——study指代每一个study文件夹
    if level_num == 2:
        continue
    npy_file_path = os.path.join(INPUT_PATH, "malignancy_" + str(level_num + 1))
    npy_files = os.listdir(npy_file_path)
    for i in npy_files:
        npy_path = os.path.join(npy_file_path, i)
        single_npy = np.load(npy_path)
        single_fliplr_npy = np.fliplr(single_npy)
        single_npy = (single_npy - 127.5) / 127.5
        npy_list.append(single_npy)
        if level_num < 2:
            npy_level_list.append(0)
        else:
            npy_level_list.append(1)
        if level_num > 2:
            single_fliplr_npy = (single_fliplr_npy - 127.5) / 127.5
            npy_list.append(single_fliplr_npy)
            npy_level_list.append(1)

random.seed(randnum)
random.shuffle(npy_list)
random.seed(randnum)
random.shuffle(npy_level_list)

npy_chunks = chunks(npy_list, 10)
npy_level_chunks = chunks(npy_level_list, 10)

print("NOW the training STARTS:")
training_set, test_set = ten_folder(npy_chunks, 9)
training_level, test_level = ten_folder(npy_level_chunks, 9)

train_loader = torch.utils.data.DataLoader(MyDataset(training_set, training_level), batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
train(train_loader, test_set, test_level)


NOW the training STARTS:
-----THE [1/100] epoch start-----


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


[BATCH 1/478] Loss_D: 2.9057 Loss_G: 3.3267 acc: 60.0%
[BATCH 2/478] Loss_D: 1.4589 Loss_G: 2.8863 acc: 55.0%
[BATCH 3/478] Loss_D: 1.4474 Loss_G: 2.7142 acc: 40.0%
[BATCH 4/478] Loss_D: 1.4095 Loss_G: 2.6243 acc: 30.0%
[BATCH 5/478] Loss_D: 1.3702 Loss_G: 2.5671 acc: 55.0%
[BATCH 6/478] Loss_D: 1.4049 Loss_G: 2.5381 acc: 65.0%
[BATCH 7/478] Loss_D: 1.3326 Loss_G: 2.5156 acc: 45.0%
[BATCH 8/478] Loss_D: 1.3203 Loss_G: 2.5047 acc: 60.0%
[BATCH 9/478] Loss_D: 1.3479 Loss_G: 2.4930 acc: 60.0%
[BATCH 10/478] Loss_D: 1.2950 Loss_G: 2.4890 acc: 55.0%
[BATCH 11/478] Loss_D: 1.2729 Loss_G: 2.4768 acc: 55.0%
[BATCH 12/478] Loss_D: 1.3355 Loss_G: 2.4704 acc: 50.0%
[BATCH 13/478] Loss_D: 1.3289 Loss_G: 2.4564 acc: 70.0%
[BATCH 14/478] Loss_D: 1.3103 Loss_G: 2.4628 acc: 40.0%
[BATCH 15/478] Loss_D: 1.2667 Loss_G: 2.4612 acc: 80.0%
[BATCH 16/478] Loss_D: 1.3279 Loss_G: 2.4525 acc: 75.0%
[BATCH 17/478] Loss_D: 1.2601 Loss_G: 2.4490 acc: 45.0%
[BATCH 18/478] Loss_D: 1.2357 Loss_G: 2.4478 acc: 55.0%
[

[BATCH 146/478] Loss_D: 1.1644 Loss_G: 1.1179 acc: 40.0%
[BATCH 147/478] Loss_D: 1.0387 Loss_G: 1.1323 acc: 50.0%
[BATCH 148/478] Loss_D: 0.9704 Loss_G: 1.1203 acc: 60.0%
[BATCH 149/478] Loss_D: 1.0130 Loss_G: 1.1126 acc: 55.0%
[BATCH 150/478] Loss_D: 1.0028 Loss_G: 1.1157 acc: 50.0%
[EPOCH 150] TEST ACC is : 55.6%
[BATCH 151/478] Loss_D: 0.9375 Loss_G: 1.0987 acc: 55.0%
[BATCH 152/478] Loss_D: 1.0086 Loss_G: 1.1027 acc: 65.0%
[BATCH 153/478] Loss_D: 0.9701 Loss_G: 1.0999 acc: 50.0%
[BATCH 154/478] Loss_D: 0.9438 Loss_G: 1.0927 acc: 50.0%
[BATCH 155/478] Loss_D: 1.0773 Loss_G: 1.1114 acc: 60.0%
[BATCH 156/478] Loss_D: 0.9628 Loss_G: 1.1163 acc: 50.0%
[BATCH 157/478] Loss_D: 0.9394 Loss_G: 1.1006 acc: 45.0%
[BATCH 158/478] Loss_D: 1.0947 Loss_G: 1.1058 acc: 60.0%
[BATCH 159/478] Loss_D: 0.9822 Loss_G: 1.1026 acc: 40.0%
[BATCH 160/478] Loss_D: 0.9754 Loss_G: 1.0986 acc: 40.0%
[BATCH 161/478] Loss_D: 1.0029 Loss_G: 1.0990 acc: 70.0%
[BATCH 162/478] Loss_D: 0.9227 Loss_G: 1.0993 acc: 45.0%

[BATCH 289/478] Loss_D: 0.8947 Loss_G: 1.0500 acc: 25.0%
[BATCH 290/478] Loss_D: 0.9398 Loss_G: 1.0424 acc: 60.0%
[BATCH 291/478] Loss_D: 0.8334 Loss_G: 1.0408 acc: 50.0%
[BATCH 292/478] Loss_D: 0.9273 Loss_G: 1.0373 acc: 55.0%
[BATCH 293/478] Loss_D: 0.8495 Loss_G: 1.0341 acc: 45.0%
[BATCH 294/478] Loss_D: 0.9596 Loss_G: 1.0383 acc: 35.0%
[BATCH 295/478] Loss_D: 0.9674 Loss_G: 1.0382 acc: 40.0%
[BATCH 296/478] Loss_D: 0.9763 Loss_G: 1.0442 acc: 55.0%
[BATCH 297/478] Loss_D: 0.8682 Loss_G: 1.0502 acc: 55.0%
[BATCH 298/478] Loss_D: 0.8437 Loss_G: 1.0425 acc: 50.0%
[BATCH 299/478] Loss_D: 0.9604 Loss_G: 1.0461 acc: 50.0%
[BATCH 300/478] Loss_D: 0.8653 Loss_G: 1.0466 acc: 65.0%
[EPOCH 300] TEST ACC is : 56.5%
[BATCH 301/478] Loss_D: 0.9112 Loss_G: 1.0529 acc: 65.0%
[BATCH 302/478] Loss_D: 0.9130 Loss_G: 1.0601 acc: 55.0%
[BATCH 303/478] Loss_D: 0.9334 Loss_G: 1.0623 acc: 50.0%
[BATCH 304/478] Loss_D: 0.8941 Loss_G: 1.0597 acc: 50.0%
[BATCH 305/478] Loss_D: 0.9818 Loss_G: 1.0682 acc: 55.0%