In [1]:
import argparse
import os
from tqdm import tqdm
import numpy as np
import random
# from PIL import Image
import torch
import torch.nn as nn
# import os.path as osp
import torchio as tio
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from torch.utils.data import DataLoader
from torchio.transforms.augmentation.intensity.random_bias_field import \
    RandomBiasField
from torchio.transforms.augmentation.intensity.random_noise import RandomNoise
from torchvision import transforms

import models.models as models
import utils.confusion as confusion
import utils.my_trainer as trainer
import utils.train_result as train_result
from datasets.dataset import load_data
from utils.data_class import BrainDataset

In [2]:
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = True #この行をFalseにすると再現性はとれるが、速度が落ちる
    torch.backends.cudnn.deterministic = True
    return
fix_seed(0)

In [3]:
CLASS_MAP = {"CN": 0, "AD": 1}
SEED_VALUE = 0

In [4]:
data = load_data(kinds=["ADNI2"], classes=["CN"], unique=False, blacklist=True)

                                                                                                                                                                                                                                   

In [5]:
pids = []
voxels = np.zeros((len(data), 80, 96, 80))
labels = np.zeros(len(data))
for i in tqdm(range(len(data))):
    pids.append(data[i]["pid"])
    voxels[i] = data[i]["voxel"]
    labels[i] = CLASS_MAP[data[i]["label"]]
pids = np.array(pids)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 359.95it/s]


In [6]:
import torch
from torch.utils.data import Dataset
import numpy as np
from datasets.dataset import CLASS_MAP
class BrainDataset(Dataset):
    def __init__(self, voxels, labels, transform=None):
        self.voxels = voxels
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.voxels)
    def __getitem__(self, index):
        voxel = self.voxels[index]
        label = self.labels[index]
        if self.transform:
            voxel = self.transform(voxel, self.phase)
        voxel = self._preprocess(voxel)
        return voxel, label
    def _preprocess(self, voxel):
        cut_range = 4
        voxel = np.clip(voxel, 0, cut_range * np.std(voxel))
        voxel = normalize(voxel, np.min(voxel), np.max(voxel))
        voxel = voxel[np.newaxis, ]
        return voxel.astype('f')
    def __call__(self, index):
        return self.__getitem__(index)

def normalize(voxel: np.ndarray, floor: int, ceil: int) -> np.ndarray:
    return (voxel - floor) / (ceil - floor)

In [7]:
gss = GroupShuffleSplit(test_size=0.2, random_state=42)
tid, vid = list(gss.split(voxels, groups=pids))[0]
train_voxels = voxels[tid]
val_voxels = voxels[vid]
train_labels = labels[tid]
val_labels = labels[vid]

train_dataset = BrainDataset(train_voxels, train_labels)
val_dataset = BrainDataset(val_voxels, val_labels)

In [8]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

num_workers = 2
batch_size = 16
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    worker_init_fn=seed_worker,
    generator=g,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    worker_init_fn=seed_worker,
    generator=g,
)

In [9]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.AvgPool3d(2, 2)
        self.conv1 = nn.Conv3d(1, 12, 3, padding=1)
        self.conv2 = nn.Conv3d(12, 12, 3, padding=1)
        self.conv3 = nn.Conv3d(12, 12, 3, padding=1)
        self.conv4 = nn.Conv3d(12, 24, 3, padding=1)
        self.conv5 = nn.Conv3d(24, 24, 3, padding=1)
        self.conv6 = nn.Conv3d(24, 32, 3, padding=1)
        self.conv7 = nn.Conv3d(32, 32, 3, padding=1)
        self.conv8 = nn.Conv3d(32, 32, 3, padding=1)
        self.conv9 = nn.Conv3d(32, 32, 3, padding=1)
        self.conv10 = nn.Conv3d(32, 48, 3, padding=1)
        self.conv11 = nn.Conv3d(48, 48, 3, padding=1)
        self.conv12 = nn.Conv3d(48, 48, 3, padding=1)
        self.conv13 = nn.Conv3d(48, 48, 3, padding=1)
        self.mu = nn.Conv3d(48, 1 , 1, padding=0)
        self.logvar = nn.Conv3d(48, 1 , 1, padding=0)
    
        self.batchnorm3d1 = nn.BatchNorm3d(12)
        self.batchnorm3d2 = nn.BatchNorm3d(12)
        self.batchnorm3d3 = nn.BatchNorm3d(12)
        self.batchnorm3d4 = nn.BatchNorm3d(24)
        self.batchnorm3d5 = nn.BatchNorm3d(24)
        self.batchnorm3d6 = nn.BatchNorm3d(32)
        self.batchnorm3d7 = nn.BatchNorm3d(32)
        self.batchnorm3d8 = nn.BatchNorm3d(32)
        self.batchnorm3d9 = nn.BatchNorm3d(32)
        self.batchnorm3d10 = nn.BatchNorm3d(48)
        self.batchnorm3d11 = nn.BatchNorm3d(48)
        self.batchnorm3d12 = nn.BatchNorm3d(48)
        self.batchnorm3d13 = nn.BatchNorm3d(48)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.batchnorm3d1(self.conv1(x)))
        x = self.relu(self.batchnorm3d2(self.conv2(x)))
        x = self.pool(x)
        x = self.relu(self.batchnorm3d3(self.conv3(x)))
        x = self.relu(self.batchnorm3d4(self.conv4(x)))
        x = self.pool(x)
        x = self.relu(self.batchnorm3d5(self.conv5(x)))
        x = self.relu(self.batchnorm3d6(self.conv6(x)))
        x = self.pool(x)
        h = self.relu(self.batchnorm3d7(self.conv7(x)))
        x = self.relu(self.batchnorm3d8(self.conv8(h)))
        x = self.batchnorm3d9(self.conv9(x))
        x = self.relu(x+h)
        x = self.relu(self.batchnorm3d10(self.conv10(x)))
        x = self.pool(x)
        h = self.relu(self.batchnorm3d11(self.conv11(x)))
        x = self.relu(self.batchnorm3d12(self.conv12(h)))
        x = self.batchnorm3d13(self.conv13(x))
        x = self.relu(x+h)
        mu = self.mu(x)
        logvar = self.logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.upsamp = nn.Upsample(scale_factor=2, mode="nearest")
        self.deconv1 = nn.Conv3d(1, 48, 1, padding=0)
        self.deconv2 = nn.Conv3d(48, 48, 3, padding=1)
        self.deconv3 = nn.Conv3d(48, 48, 3, padding=1)
        self.deconv4 = nn.Conv3d(48, 32, 3, padding=1)
        self.deconv5 = nn.Conv3d(32, 32, 3, padding=1)
        self.deconv6 = nn.Conv3d(32, 32, 3, padding=1)
        self.deconv7 = nn.Conv3d(32, 32, 3, padding=1)
        self.deconv8 = nn.Conv3d(32, 24, 3, padding=1)
        self.deconv9 = nn.Conv3d(24, 24, 3, padding=1)
        self.deconv10 = nn.Conv3d(24, 24, 3, padding=1)
        self.deconv11 = nn.Conv3d(24, 12, 3, padding=1)
        self.deconv12 = nn.Conv3d(12, 12, 3, padding=1)
        self.deconv13 = nn.Conv3d(12, 12, 3, padding=1)
        self.deconv14 = nn.Conv3d(12, 1, 3, padding=1) 
        

        self.batchnorm_d3d1 = nn.BatchNorm3d(48)
        self.batchnorm_d3d2 = nn.BatchNorm3d(48)
        self.batchnorm_d3d3 = nn.BatchNorm3d(48)
        self.batchnorm_d3d4 = nn.BatchNorm3d(32)
        self.batchnorm_d3d5 = nn.BatchNorm3d(32)
        self.batchnorm_d3d6 = nn.BatchNorm3d(32)
        self.batchnorm_d3d7 = nn.BatchNorm3d(32)
        self.batchnorm_d3d8 = nn.BatchNorm3d(24)
        self.batchnorm_d3d9 = nn.BatchNorm3d(24)
        self.batchnorm_d3d10 = nn.BatchNorm3d(24)
        self.batchnorm_d3d11 = nn.BatchNorm3d(12)
        self.batchnorm_d3d12 = nn.BatchNorm3d(12)
        self.batchnorm_d3d13 = nn.BatchNorm3d(12)

        self.relu = nn.ReLU()
        
    def forward(self, x):
        h = self.relu(self.batchnorm_d3d1(self.deconv1(x)))
        x = self.relu(self.batchnorm_d3d2(self.deconv2(h)))
        x = self.batchnorm_d3d3(self.deconv3(x))
        x = self.relu(x+h)
        x = self.relu(self.batchnorm_d3d4(self.deconv4(x)))
        x = self.upsamp(x)
        h = self.relu(self.batchnorm_d3d5(self.deconv5(x)))
        x = self.relu(self.batchnorm_d3d6(self.deconv6(h)))
        x = self.batchnorm_d3d7(self.deconv7(x))
        x = self.relu(x+h)
        x = self.relu(self.batchnorm_d3d8(self.deconv8(x)))
        x = self.upsamp(x)
        x = self.relu(self.batchnorm_d3d9(self.deconv9(x)))
        x = self.upsamp(x)
        x = self.relu(self.batchnorm_d3d10(self.deconv10(x)))
        x = self.relu(self.batchnorm_d3d11(self.deconv11(x)))
        x = self.upsamp(x)
        x = self.relu(self.batchnorm_d3d12(self.deconv12(x)))
        x = self.relu(self.batchnorm_d3d13(self.deconv13(x)))
        x = self.relu(self.deconv14(x))
        return x


In [10]:
import torch.nn.functional as F
class SoftintroVAE(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def get_mu_var(self, x):
        mu, logvar =self.encoder(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.rand_like(std)
        return mu + eps * std

    def decode(self, x):
        x = self.decoder(x)
        return x

    def encode(self, x):
        mu, logvar = self.get_mu_var(x)
        return self.reparameterize(mu, logvar)

    # def loss(self, x_re, x, mu, logvar):
    #     re_err = nn.MSEloss(x_re, x)
    #     kld = -0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp())
    #     return re_err + kld
    
    def sample(self, z):
        y = self.decode(z)
        return y
    
    def sample_with_noise(self, num_samples=1):
        z=torch.randn(num_samples,150)
        return self.decode(z)
    
    def forward(self, x):
        mu, logvar = self.get_mu_var(x)
        z = self.reparameterize(mu, logvar)
        x_re = self.decode(z)
        return  x_re, z, mu, logvar
    


In [11]:
def calc_kl(logvar, mu, mu_o=0.0, logvar_o=0.0,reduce='sum'):
    if not isinstance(mu_o, torch.Tensor):
        mu_o = torch.tensor(mu_o).to(mu.device)
    if not isinstance(logvar_o, torch.Tensor):
        logvar_o = torch.tensor(logvar_o).to(mu.device)
    kl = -0.5 * (1 + logvar - logvar_o - logvar.exp() / torch.exp(logvar_o) - (mu - mu_o).pow(2) / torch.exp(
        logvar_o)).sum(1)
    if reduce == 'sum':
        kl = torch.sum(kl)
    elif reduce == 'mean':
        kl = torch.mean(kl)
    return kl

def calc_reconstruction_loss(x,rec_x):
    x = x.view(x.size(0),-1)
    rec_x = rec_x.view(rec_x.size(0),-1)
    
    rec_err = F.mse_loss(rec_x,x,reduction = 'none')
    rec_err = rec_err.sum(1)
    rec_err = rec_err.mean()
    return rec_err

In [12]:
import os
import time

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda:5" if torch.cuda.is_available() and True else "cpu")
model = SoftintroVAE()
log_path = "./logs/" + "output" + "_softintrovae/"
# model = torch.nn.DataParallel(model, device_ids=[5, 6])
model.to(device)
optimizer_e = optim.Adam(model.encoder.parameters(), 2e-4)
optimizer_d = optim.Adam(model.decoder.parameters(), 2e-4)

e_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_e, milestones=(350,), gamma=0.1)
d_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_d, milestones=(350,), gamma=0.1)

scale = 1 / (80 * 96 * 80)

In [13]:
start_time = time.time()
num_vae = 0
beta_rec=1.0
beta_kl=1.0
beta_neg=1.0
cur_iter = 0
gamma_r = 1e-8
kls_real = []
kls_fake = []
kls_rec = []
rec_errs = []
epochs = 1

for epoch in tqdm(range(epochs)):
    diff_kls = []
    model.train()
    batch_kls_real = []
    batch_kls_fake = []
    batch_kls_rec = []
    batch_rec_errs = []
    
    for batch,labels in train_dataloader:
        b_size = batch.size(0)

        noise_batch = torch.randn(size=(b_size,1,5,6,5)).to(device)
        real_batch = batch.to(device)

        # =============== Update E ================== 
        
        fake = model.decode(noise_batch)   

        real_mu, real_logvar = model.get_mu_var(real_batch)
        z = model.reparameterize(real_mu, real_logvar)
        rec = model.decode(z)

        loss_rec = calc_reconstruction_loss(real_batch, rec)
        lossE_real_kl = calc_kl(real_logvar, real_mu, reduce="mean")

        rec_rec, z_rec, rec_mu, rec_logvar     = model( rec.detach())
        rec_fake, z_fake, fake_mu, fake_logvar = model(fake.detach())

        fake_kl_e = calc_kl(fake_logvar, fake_mu, reduce="none")
        rec_kl_e = calc_kl(rec_logvar, rec_mu, reduce="none")

        print("fake：")
        print( fake.size() )
        print("rec_fake：")
        print( rec_fake.size() )
        
        
        loss_fake_rec = calc_reconstruction_loss(fake, rec_fake)
        loss_rec_rec = calc_reconstruction_loss(rec, rec_rec)
        
        
        print("loss_fake_rec：")
        print( loss_fake_rec )
        print(" fake_kl_e：" )
        print(fake_kl_e)        
        

        exp_elbo_fake = (-2 * scale * (beta_rec * loss_fake_rec + beta_neg * fake_kl_e)).exp().mean()
        exp_elbo_rec = (-2 * scale * (beta_rec * loss_rec_rec + beta_neg * rec_kl_e)).exp().mean()

        lossE = scale * (beta_rec * loss_rec + beta_kl * lossE_real_kl) + 0.25 * (exp_elbo_fake + exp_elbo_rec)

        optimizer_e.zero_grad()
        lossE.backward()
        optimizer_e.step()
        print("finish updateE")
        # ========= Update D ==================
        loss_rec = calc_reconstruction_loss(real_batch,rec.detach())
        rec_rec, z_rec, rec_mu, rec_logvar = model(rec.detach())
        rec_fake, z_fake, fake_mu, fake_logvar = model(fake.detach())

        rec_rec = model.decode(z_rec)
        rec_fake = model.decode(z_fake)

        
        loss_rec_rec = calc_reconstruction_loss(rec.detach(), rec_rec)
        loss_fake_rec = calc_reconstruction_loss(fake.detach(), rec_fake)

        rec_kl = calc_kl(rec_logvar, rec_mu, reduce="mean")
        fake_kl = calc_kl(fake_logvar, fake_mu, reduce="mean")

        lossD = scale * (loss_rec * beta_rec + (rec_kl + fake_kl) * 0.5 * beta_kl + \
                                     gamma_r * 0.5 * beta_rec * (loss_rec_rec + loss_fake_rec))

        optimizer_d.zero_grad()
        lossD.backward()
        optimizer_d.step()
        print("finish updateD")

#             # statistics for plotting later
#             diff_kls.append(-lossE_real_kl.data.cpu().item() + fake_kl.data.cpu().item())
#             batch_kls_real.append(lossE_real_kl.data.cpu().item())
#             batch_kls_fake.append(fake_kl.cpu().item())
#             batch_kls_rec.append(rec_kl.data.cpu().item())
#             batch_rec_errs.append(loss_rec.data.cpu().item())

#     e_scheduler.step()
#     d_scheduler.step()

#     if epoch > num_vae - 1:
#         kls_real.append(np.mean(batch_kls_real))
#         kls_fake.append(np.mean(batch_kls_fake))
#         kls_rec.append(np.mean(batch_kls_rec))
#         rec_errs.append(np.mean(batch_rec_errs))

print("finish")

  0%|                                                                                                                                                                                                        | 0/1 [00:00<?, ?it/s]

fake：
torch.Size([16, 1, 80, 96, 80])
rec_fake：
torch.Size([16, 1, 80, 96, 80])
loss_fake_rec：
tensor(130085.8594, device='cuda:5', grad_fn=<MeanBackward0>)
 fake_kl_e：
tensor([[[[2.2913e-02, 1.9557e-01, 1.3048e-02, 1.0788e-02, 1.0712e-02],
          [3.1289e-02, 1.0232e-01, 1.2547e-01, 1.7928e-02, 3.7994e-02],
          [1.7323e-02, 1.3314e-01, 7.7875e-02, 1.1606e-02, 9.4795e-02],
          [4.5933e-02, 7.6209e-03, 7.3626e-03, 8.2397e-02, 3.3868e-03],
          [1.7113e-01, 6.0803e-03, 1.9291e-03, 3.4593e-02, 3.5160e-03],
          [2.6783e-03, 8.8660e-03, 9.5907e-03, 2.4829e-03, 1.2732e-01]],

         [[3.2884e-02, 4.5563e-01, 1.8236e-01, 4.9763e-01, 6.0247e-02],
          [1.6455e-01, 1.8482e-01, 1.7157e-01, 1.3582e-01, 2.7969e-01],
          [4.8953e-02, 3.6124e-01, 1.8813e-01, 3.9489e-02, 1.0061e-01],
          [1.7998e-01, 4.6219e-03, 6.6463e-02, 9.6421e-03, 2.3054e-02],
          [2.1452e-01, 1.0886e-01, 8.7899e-03, 3.3141e-02, 1.4633e-01],
          [1.7059e-01, 2.3587e-02, 3.

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:29<00:00, 29.90s/it]

finish updateD
finish





In [14]:
# import matplotlib.pyplot as plt
# train_loss_list, val_loss_list = [], []
# start_time = time.time()
# for epoch in tqdm(range(1)):
#     net.train()
#     train_loss, val_loss = 0.0, 0.0
#     for inputs, labels in train_dataloader:
#         inputs = inputs.to(device)
#         optimizer_e.zero_grad()
#         optimizer_d.zero_grad()
#         x_re,mu,logvar = net.forward(inputs)
#         loss = net.module.loss(x_re, inputs, mu,logvar)
#         loss.backward()
#         optimizer_e.step()
#         optimizer_d.step()
#         train_loss += loss.item()
#     train_loss /= len(train_dataloader.dataset)
#     print(train_loss)
    
#     net.eval()
#     with torch.no_grad():
#         for inputs, labels in val_dataloader:
#             inputs = inputs.to(device)
#             x_re, mu, logvar = net.forward(inputs)
#             #labels = labels.to(device)
#             loss = net.module.loss(x_re, inputs, mu,logvar)

#             val_loss += loss.item()

#     val_loss /= len(val_dataloader)

#     elapsed_time = time.time()
#     print("Epoch [%3d], loss: %f, val_loss: %f, elapsed time %d秒"
#             % (
#                 epoch + 1,
#                 train_loss,
#                 val_loss,
#                 elapsed_time - start_time,
#             )
#     )
#     train_loss_list.append(train_loss)
#     val_loss_list.append(val_loss)

# # print("Finish")
# torch.save(net.state_dict(), log_path + "vae_weight.pth")
# os.makedirs(log_path + "/img", exist_ok=True)
# epoch = 1
# plt.rcParams["font.size"] = 18
# fig1, ax1 = plt.subplots(figsize=(10, 10))
# ax1.plot(range(1, epoch + 1), train_loss, label="train_loss")
# ax1.plot(range(1, epoch + 1), val_loss, label="val_loss")
# ax1.set_title("loss")
# ax1.set_xlabel("epoch")
# ax1.set_ylabel("loss")
# ax1.legend()
# fig1.savefig(log_path + "/img/loss.png")