In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torchvision import datasets, transforms

import matplotlib

import matplotlib.pyplot as plt

import os
import random
import math

from torch.utils.data.dataset import Dataset

INPUT_PATH = "./output/nodule_npy/"

In [2]:
randnum = random.randint(0,100)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 25
nz = 100
n_epochs = 100
ngf = 64
ngpu = 1
LEVELS_NUM = 2

cuda = True if torch.cuda.is_available() else False
print(device)

cuda:0


In [3]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def chunks(arr, m):
    nchunk = int(math.ceil(len(arr) / float(m)))
    return [arr[i:i + nchunk] for i in range(0, len(arr), nchunk)]

def five_folder(arr, number):
    training_set = []
    test_set = []
    for j in range(len(arr)):
        if number == j:
            test_set.extend(arr[j])
        else:
            training_set.extend(arr[j])
    return training_set, test_set


class MyDataset(Dataset):
    def __init__(self, images,labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, index):#返回的是tensor
        img = self.images[index]
        label = self.labels[index]
        return img, label

    def __len__(self):
        return len(self.images)

In [4]:
# Generator
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        
        self.label_emb = nn.Embedding(LEVELS_NUM, 100)

        self.project = nn.Sequential(
          nn.Linear(nz, 1 * 5 * 5 * ngf * 4, bias=False)
        )
        self.deconv = nn.Sequential(
          # input is Z, going into a deconvolution
          # state size. BATCH_SIZE x (ngf*4) x 1 x 5 x 5
          nn.ConvTranspose3d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm3d(ngf * 2),
          nn.ReLU(True),
          # state size. BATCH_SIZE x (ngf*2) x 2 x 10 x 10
          nn.ConvTranspose3d(ngf * 2, ngf, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm3d(ngf),
          nn.ReLU(True),
          # state size. BATCH_SIZE x (ngf) x 4 x 20 x 20
          nn.ConvTranspose3d(ngf, 1, kernel_size=4, stride=2, padding=1, bias=False),
          nn.Tanh()
          # state size. BATCH_SIZE x 1 x 8 x 40 x 40
        )

    def forward(self, noise, labels):
        gen_input = torch.mul(self.label_emb(labels), noise)
        
        x = self.project(gen_input)
        # Conv3d的规定输入数据格式为(batch, channel, Depth, Height, Width)
        x = x.view(-1, ngf * 4, 1, 5, 5)
        
        if noise.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.deconv, x, range(self.ngpu))
        else:
            output = self.deconv(x)

        return output


# Discriminator
def softmax(input, dim=1):
    transposed_input = input.transpose(dim, len(input.size()) - 1)
    softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1)
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1)

class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=(2,9,9)):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv3d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                stride=1
                              )

    def forward(self, x):
        return F.relu(self.conv(x))


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=(2,9,9), num_routes=32 * 12 * 12 * 6):
        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
          nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=(1,2,2), padding=0) 
                      for _ in range(num_capsules)])
  
    def forward(self, x):
        u = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
        u = torch.cat(u, dim=-1)
        return self.squash(u)
  
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor



class DigitCaps(nn.Module):
    def __init__(self, num_capsules=1, num_routes=32 * 12 * 12 * 6, in_channels=8, out_channels=16, num_iterations=3):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.num_iterations = num_iterations
        self.route_weights = nn.Parameter(torch.randn(num_capsules, num_routes, in_channels, out_channels)).to(device)

    def forward(self, x):
        # 矩阵相乘
        # x.size(): [1, batch_size, in_capsules, 1, dim_in_capsule]
        # weight.size(): [num_capsules, 1, num_route, in_channels, out_channels]
        priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
        logits = Variable(torch.zeros(*priors.size())).to(device)

        for i in range(self.num_iterations):
            probs = softmax(logits, dim=2)
            outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))

            if i != self.num_routes - 1:
                delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
                logits = logits + delta_logits
        
        outputs = outputs.squeeze()
#         print("OUTPUT:", outputs.shape)
        return outputs
  
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

class CapsNet(nn.Module):
    def __init__(self, ngpu):
        super(CapsNet, self).__init__()
        self.ngpu = ngpu
        
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.D_digit_capsules = DigitCaps()
        self.C_digit_capsules = DigitCaps(num_capsules=LEVELS_NUM)

    def forward(self, data):
        if data.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.conv_layer, data, range(self.ngpu))
            output = nn.parallel.data_parallel(self.primary_capsules, output, range(self.ngpu))
            output = nn.parallel.data_parallel(self.digit_capsules, output, range(self.ngpu))
        else:
            output = self.primary_capsules(self.conv_layer(data))
            D_output = self.D_digit_capsules(output)
            C_output = self.C_digit_capsules(output).transpose(0,1)
            

        D_classes = (D_output ** 2).sum(dim=-1) ** 0.5
        C_classes = (C_output ** 2).sum(dim=-1) ** 0.5
#         print(D_classes, C_classes)
        return D_classes, C_classes


class CapsuleLoss(nn.Module):
    def __init__(self):
        super(CapsuleLoss, self).__init__()

    def forward(self, classes, labels):
        left = F.relu(0.9 - classes, inplace=True) ** 2
        right = F.relu(classes - 0.1, inplace=True) ** 2

        margin_loss = labels * left + 0.5 * (1. - labels) * right
        margin_loss = margin_loss.sum()

        return margin_loss / classes.size(0)

In [5]:
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
discriminator = CapsNet(ngpu).to(device)
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

generator = Generator(ngpu).to(device)
generator.apply(weights_init)

optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

def train(train_loader, test_set, test_level):
    test_loader = torch.utils.data.DataLoader(MyDataset(test_set, test_level), batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    epoch_num = 0
    for epoch in range(n_epochs):
        print("-----THE [{}/{}] epoch start-----".format(epoch + 1, n_epochs))
        for j, (data,labels) in enumerate(train_loader, 0):
#                     print(j, data.shape) #torch.Size([32, 1, 8, 40, 40])

#                     plt.figure()
#                     for lenc in range(data.shape[0]):
#                         for len_img in range(8):
#                             plt.subplot(2, 4, len_img + 1)
#                             pixel_array = data[lenc][0][len_img]
#                             print(pixel_array)
#                             plt.imshow(pixel_array, cmap="gray")
#                         plt.show()


#                     pixel_array = data[0][0][0][20]
#                     print(pixel_array)

#                     data = data / 255.0
#                     pixel_array = data[0][0][0][20]
#                     print(pixel_array)

            
            real_imgs = data.to(device, dtype=torch.float)
            labels = labels.to(device)
            
            # Adversarial ground truths
            valid = Variable(FloatTensor(BATCH_SIZE, 1).fill_(1.0), requires_grad=False)
            fake = Variable(FloatTensor(BATCH_SIZE, 1).fill_(0.0), requires_grad=False)
            
            
            z = Variable(FloatTensor(np.random.normal(0, 1, (BATCH_SIZE, 100))))
            gen_labels = Variable(LongTensor(np.random.randint(0, LEVELS_NUM, BATCH_SIZE)))
    
            # Generate a batch of images
            gen_imgs = generator(z, gen_labels)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            
            discriminator.zero_grad()
            
            # Loss for real images
            real_pred, real_aux = discriminator(real_imgs)
            d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

            # Loss for fake images
            fake_pred, fake_aux = discriminator(gen_imgs.detach())
            d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2

            # Calculate discriminator accuracy
            pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
            gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
            
            pred = np.argmax(pred, axis=1)
            
            d_acc = np.mean(pred == gt)
            
#             print(pred)
#             print(gt)

            d_loss.backward()
            optimizer_D.step()
        
            
            # -----------------
            #  Train Generator
            # -----------------
            
            generator.zero_grad()
            
            # Loss measures generator's ability to fool the discriminator
            validity, pred_label = discriminator(gen_imgs)
            
            g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))

            g_loss.backward()
            optimizer_G.step()
            

            print('[BATCH %d/%d] Loss_D: %.4f Loss_G: %.4f acc: %d%%'
                  % (j + 1, len(train_loader), d_loss.item(), g_loss.item(), 100 * d_acc))
            
            epoch_num += 1
            d_acc_total = 0
            if epoch_num % 50 == 0:
                for j, (data,labels) in enumerate(test_loader, 0):
                    real_imgs = data.to(device, dtype=torch.float)
                    labels = labels.to(device)
                    pred = real_aux.data.cpu().numpy()
                    gt = labels.data.cpu().numpy()
                    pred = np.argmax(pred, axis=1)
                    d_acc = np.mean(pred == gt)
                    d_acc_total += d_acc
                print("[EPOCH %d] TEST ACC is : %.4f%%" % (epoch_num, d_acc_total * 100 / len(test_loader)))

        print("-----THE [{}/{}] epoch end-----".format(epoch + 1, n_epochs))

In [6]:
input_data = os.listdir(INPUT_PATH)
npy_list = []
npy_level_list = []

for level_num in range(5): #遍历patient文件夹——study指代每一个study文件夹
    if level_num == 2:
        continue
    npy_file_path = os.path.join(INPUT_PATH, "malignancy_" + str(level_num + 1))
    npy_files = os.listdir(npy_file_path)
    for i in npy_files:
        npy_path = os.path.join(npy_file_path, i)
        single_npy = np.load(npy_path)
        single_fliplr_npy = np.fliplr(single_npy)
        single_npy = (single_npy - 127.5) / 127.5
        npy_list.append(single_npy)
        if level_num < 2:
            npy_level_list.append(0)
        else:
            npy_level_list.append(1)
        if level_num > 2:
            single_fliplr_npy = (single_fliplr_npy - 127.5) / 127.5
            npy_list.append(single_fliplr_npy)
            npy_level_list.append(1)
        

#         print(npy_file_path, np.array(npy_list).shape)
random.seed(randnum)
random.shuffle(npy_list)
random.seed(randnum)
random.shuffle(npy_level_list)

npy_chunks = chunks(npy_list, 5)
npy_level_chunks = chunks(npy_level_list, 5)

print("NOW the training STARTS:")
training_set, test_set = five_folder(npy_chunks, 4)
training_level, test_level = five_folder(npy_level_chunks, 4)

train_loader = torch.utils.data.DataLoader(MyDataset(training_set, training_level), batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
train(train_loader, test_set, test_level)


NOW the training STARTS:
-----THE [1/100] epoch start-----


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


[2 3 1 1 3 3 1 0 3 3 0 1 4 1 3 2 1 1 3 1]
[2 2 2 0 3 1 1 4 2 2 4 1 2 0 4 3 3 3 0 4]
[BATCH 1/980] Loss_D: 3.1254 Loss_G: 6.9276 acc: 20%
[1 0 0 4 4 0 4 0 0 0 1 1 1 4 1 1 1 1 1 1]
[4 2 2 1 2 2 4 1 0 1 4 3 4 3 2 4 0 1 2 1]
[BATCH 2/980] Loss_D: 1.9911 Loss_G: 4.1289 acc: 20%
[4 4 4 2 4 4 4 4 4 4 4 1 4 1 1 4 1 1 1 1]
[1 2 1 4 2 2 1 1 0 2 2 4 4 1 1 4 4 2 3 1]
[BATCH 3/980] Loss_D: 1.9247 Loss_G: 3.5481 acc: 25%
[4 4 4 1 4 4 4 4 4 4 1 1 4 4 1 1 1 4 1 1]
[2 2 2 2 3 1 0 1 3 3 1 2 2 3 1 1 3 1 1 3]
[BATCH 4/980] Loss_D: 1.8817 Loss_G: 3.3577 acc: 20%
[4 4 4 4 4 4 4 4 4 4 4 4 1 4 4 4 4 4 1 4]
[3 0 2 0 3 2 4 1 2 2 0 4 2 2 0 1 0 2 2 0]
[BATCH 5/980] Loss_D: 1.8455 Loss_G: 3.2162 acc: 10%
[4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4]
[1 1 2 2 2 2 1 1 2 0 0 4 1 4 2 4 3 2 3 3]
[BATCH 6/980] Loss_D: 1.8607 Loss_G: 3.1473 acc: 15%
[4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4]
[1 1 2 3 1 1 0 0 2 3 4 3 1 4 3 1 4 0 2 3]
[BATCH 7/980] Loss_D: 1.8067 Loss_G: 3.1213 acc: 15%
[4 4 4 4 4 4 4 4 4 4 4 4 1 1 4 4 4 4 4 4]

[BATCH 60/980] Loss_D: 1.4490 Loss_G: 2.1568 acc: 15%
[4 4 4 4 4 4 4 4 4 4 1 4 4 4 4 3 3 4 1 4]
[3 3 3 2 3 2 2 1 1 1 0 0 3 0 1 1 3 2 1 0]
[BATCH 61/980] Loss_D: 1.3575 Loss_G: 2.2123 acc: 10%
[4 4 4 4 4 4 3 4 4 4 4 4 4 4 1 3 4 4 3 4]
[4 2 2 4 0 2 3 1 4 2 4 3 4 0 1 1 0 1 1 4]
[BATCH 62/980] Loss_D: 1.4545 Loss_G: 2.1239 acc: 40%
[4 4 4 4 4 4 1 4 4 4 3 4 3 4 4 1 4 3 3 4]
[1 2 0 2 2 2 4 2 1 2 4 2 3 4 1 3 1 3 4 2]
[BATCH 63/980] Loss_D: 1.4776 Loss_G: 2.2406 acc: 15%
[4 4 4 4 4 4 4 4 4 4 1 0 0 4 4 4 3 4 3 4]
[2 0 0 2 4 1 1 2 3 1 0 3 3 4 0 2 4 4 3 2]
[BATCH 64/980] Loss_D: 1.3452 Loss_G: 2.2061 acc: 20%
[4 4 3 4 4 4 4 4 4 4 1 4 4 4 4 4 0 1 4 4]
[2 1 2 1 0 0 3 1 2 1 2 3 2 2 0 1 0 1 1 3]
[BATCH 65/980] Loss_D: 1.4414 Loss_G: 2.1326 acc: 10%
[4 4 4 4 4 4 4 4 4 4 4 3 3 0 4 4 4 4 4 4]
[1 2 1 0 3 4 2 2 3 0 1 3 4 4 1 2 2 0 1 3]
[BATCH 66/980] Loss_D: 1.4397 Loss_G: 1.9826 acc: 10%
[4 4 4 4 4 4 4 4 4 4 4 1 4 1 4 0 4 3 4 4]
[2 2 3 2 1 0 2 1 3 1 4 3 1 3 3 0 3 3 4 2]
[BATCH 67/980] Loss_D: 1.4274 Loss

[4 4 4 4 4 4 4 4 4 3 4 4 4 4 4 4 4 4 4 0]
[3 2 2 2 1 1 2 2 2 2 3 2 4 0 3 2 0 4 4 3]
[BATCH 120/980] Loss_D: 1.4446 Loss_G: 1.5629 acc: 15%
[4 4 4 4 4 4 4 4 3 4 4 4 4 4 0 4 4 4 4 4]
[3 3 0 4 2 2 4 2 2 4 2 4 3 0 3 2 4 2 2 3]
[BATCH 121/980] Loss_D: 1.6778 Loss_G: 1.5858 acc: 25%
[0 4 4 0 4 4 4 4 4 4 4 4 4 4 0 4 4 4 4 4]
[1 2 3 0 3 1 2 2 2 0 3 1 4 3 4 4 0 4 4 2]
[BATCH 122/980] Loss_D: 1.4254 Loss_G: 1.5801 acc: 25%
[4 4 4 4 0 4 4 4 4 4 4 4 4 4 4 0 0 4 0 4]
[1 1 2 0 2 2 2 0 3 1 2 0 2 2 1 3 0 2 0 3]
[BATCH 123/980] Loss_D: 1.3381 Loss_G: 1.5706 acc: 10%
[4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4]
[3 1 2 0 2 1 3 2 3 2 3 3 1 2 2 3 1 2 1 2]
[BATCH 124/980] Loss_D: 1.4874 Loss_G: 1.5658 acc: 0%
[4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4]
[1 0 3 2 1 2 1 0 4 1 0 4 4 3 0 3 4 0 2 2]
[BATCH 125/980] Loss_D: 1.4020 Loss_G: 1.5539 acc: 20%
[4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 0 4 4]
[2 1 1 1 2 2 1 1 2 1 3 3 0 3 1 1 0 2 1 4]
[BATCH 126/980] Loss_D: 1.3510 Loss_G: 1.5614 acc: 5%
[4 4 4 4 4 4 4 4 4 4 4 4 4 4 

KeyboardInterrupt: 