In [1]:
%pylab inline
import numpy as np
import torch
import os

from torch import nn
from torch import optim
from torch.nn import functional as F
from torch import autograd
from torch.autograd import Variable
import nibabel as nib
from torch.utils.data.dataset import Dataset
from torch.utils.data import dataloader
from skimage.transform import resize
from nilearn import plotting
from ADNI_dataset import *
from BRATS_dataset import *
from ATLAS_dataset import *
from Model_VAEGAN import *
from Model_alphaWGAN import Discriminator as alpha_D
from utils import *
from utils import sinkhorn_pointcloud as sp

Populating the interactive namespace from numpy and matplotlib


# Configuration

In [2]:
BATCH_SIZE=4
max_epoch = 100
gpu = True
workers = 4

reg = 5e-10

gamma = 20
beta = 10

Use_BRATS= True
Use_ATLAS = False

#setting latent variable sizes
latent_dim = 1000

gpu_0 = 2
torch_seed = 4
r_g = torch.manual_seed(torch_seed)

In [3]:
trainset = ADNIdataset(augmentation=True)
train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,
                                          shuffle=True,num_workers=workers)
if Use_BRATS:
    #'flair' or 't2' or 't1ce'
    trainset = BRATSdataset(imgtype='flair')
    train_loader = torch.utils.data.DataLoader(trainset,batch_size = BATCH_SIZE, shuffle=True,
                                               num_workers=workers)
if Use_ATLAS:
    trainset = ATLASdataset(augmentation=True)
    train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,
                                          shuffle=True,num_workers=workers)

In [4]:
G = Generator(noise = latent_dim)
D = Discriminator()
E = Encoder(gpu_ind=gpu_0)
# E = alpha_D(out_class = latent_dim, is_dis=False)

G.cuda(gpu_0)
D.cuda(gpu_0)
E.cuda(gpu_0)

Encoder(
  (conv1): Conv3d(1, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
  (conv2): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
  (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(128, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
  (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv3d(256, 512, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
  (bn4): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (mean): Sequential(
    (0): Linear(in_features=32768, out_features=2048, bias=True)
    (1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=2048, out_features=1000, bias=True)
  )
  (logvar): Sequential(
    (0): Linear(in_features=32768, out_features=2048, bias=True)
    (1): BatchNorm1d(2

In [5]:
g_optimizer = optim.Adam(G.parameters(), lr=0.0001)
d_optimizer = optim.Adam(D.parameters(), lr=0.0001)
e_optimizer = optim.Adam(E.parameters(), lr=0.0001)

# Training

In [6]:
N_EPOCH = 300

criterion_bce = nn.BCELoss()
criterion_l1 = nn.L1Loss()
d_iter = 1

pth = 'vae_checkpoint'
# load the highest savepoints of all models
df = load_loss(path=pth)
iteration = load_checkpoint(G, D, E, None, '_iter', path=pth)

real_y = torch.ones((BATCH_SIZE, 1)).cuda()
fake_y = torch.zeros((BATCH_SIZE, 1)).cuda()

FileNotFoundError: [Errno 2] No such file or directory: './vae_checkpoint/D_iter0.pth'

In [None]:
for epoch in range(init_epoch, N_EPOCH):
    step = 0
    while step < len(train_loader):
        
        real_images = gen_load.__next__().cuda(gpu_0)
        _batch_size = real_images.size(0)
        z_rand = torch.randn((_batch_size, latent_dim), requires_grad=False).cuda(gpu_0)
        ###############################################
        # Train D 
        ###############################################
        for i in range(d_iter):
            d_optimizer.zero_grad()
            mean,logvar,code = E(real_images)
#             code = E(real_images).view(_batch_size,-1).cuda(gpu_0)
            d_real_loss = -D(real_images).mean()
            x_hat = G(code)
            x_rand = G(z_rand)
            
#             gradient_penalty_h = calc_gradient_penalty(D,real_images, x_hat, cuda_ind=gpu_0)
#             gradient_penalty_r = calc_gradient_penalty(D,real_images, x_rand, cuda_ind=gpu_0)
            
            d_recon_loss = D(x_hat).mean()
            d_fake_loss = D(x_rand).mean()

            dis_loss = d_recon_loss+d_real_loss + d_fake_loss# + gradient_penalty_r + gradient_penalty_h
            dis_loss.backward(retain_graph=True)
            d_optimizer.step()
        
        ###############################################
        # Train G
        ###############################################
        g_optimizer.zero_grad()
        
        gen_img_loss = -(d_real_loss + d_recon_loss+ d_fake_loss)
        rec_loss = ((G(code) - real_images)**2).mean()
        
        err_dec = gamma* rec_loss + gen_img_loss
        
        err_dec.backward(retain_graph=True)
        g_optimizer.step()
        ###############################################
        # Train E KL
        ###############################################
        mean,logvar,code = E(real_images)
        prior_loss = 1+logvar-mean.pow(2) - logvar.exp()
        prior_loss = (-0.5*torch.sum(prior_loss))/torch.numel(mean.data)
        err_enc = prior_loss + beta*rec_loss
        ## Train E WL ##############################
#         z_hat = E(real_images).view(_batch_size,-1).cuda(gpu_0)
#         code = z_hat
#         w_dist = sp.sinkhorn_loss(torch.transpose(z_rand, 0, 1), torch.transpose(z_hat, 0, 1), 0.1, 1000, 100, gpu=gpu_0)
#         err_enc = w_dist + beta*rec_loss
        ###############################################
        
        e_optimizer.zero_grad()
        err_enc.backward()
        e_optimizer.step()
        ###############################################
        # Visualization
        ###############################################
        if step % 400 == 0:
            print('[{}/{}]'.format(epoch,N_EPOCH),
                  'D: {:<8.3}'.format(dis_loss.item()), 
                  'En: {:<8.3}'.format(err_enc.item()),
                  'De: {:<8.3}'.format(err_dec.item()) 
                  )
            
            featmask = np.squeeze((0.5*real_images[0]+0.5).data.cpu().numpy())
            featmask = nib.Nifti1Image(featmask,affine = np.eye(4))
            plotting.plot_img(featmask,title="X_Real")
            plotting.show()
            
            featmask = np.squeeze((0.5*G(code)[0]+0.5).data.cpu().numpy())
            featmask = nib.Nifti1Image(featmask,affine = np.eye(4))
            plotting.plot_img(featmask,title="X_DEC")
            plotting.show()
            
            featmask = np.squeeze((0.5*G(z_rand)[0]+0.5).data.cpu().numpy())
            featmask = nib.Nifti1Image(featmask,affine = np.eye(4))
            plotting.plot_img(featmask,title="X_rand")
            plotting.show()

        loss_dict = {
            'd_real_loss': [d_real_loss.item()],
            'd_recon': [d_recon_loss.item()],
            'd_fake': [d_fake_loss.item()],
            'err_enc': [err_enc.item()],
        }
        step += 1
        
    df = add_loss(df, loss_dict)
    write_loss(df, path=pth)
    print(f'G loss: {err_dec.item()}')
    
    viz_pca_tsne([E], trainset, is_tsne=True, latent_size=latent_dim, index=iteration, is_cd=True, gpu_ind=gpu_1)
    viz_pca_tsne([G], trainset, is_tsne=True, latent_size=latent_dim, index=iteration, gpu_ind=gpu_1)
    torch.save(G.state_dict(),f'./{pth}/G_VG_ep_'+str(epoch)+'.pth')
    torch.save(D.state_dict(),f'./{pth}/D_VG_ep_'+str(epoch)+'.pth')
    torch.save(E.state_dict(),f'./{pth}/E_VG_ep_'+str(epoch)+'.pth')
    
    

In [None]:
%debug

In [None]:
torch.unsqueeze(torch.randn((64, 64, 64)), 1).shape