In [1]:
import torch
from torch import optim
from torch import  nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
from torchvision import transforms, utils

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import h5py as h5
from collections import OrderedDict
from utils import make_hyparam_string, save_new_pickle, read_pickle, Save_Voxels, generateZ,plotVoxelVisdom
import os
import time
import numpy as np
from data_utils import plot_3d_mesh, VertDataset, ResizeTo
import pathlib

from utils import var_or_cuda, plot_losess
from model import _G, _D
from lr_sh import  MultiStepLR
from visdom import Visdom 

In [2]:
import torch


class _G(torch.nn.Module):
    def __init__(self):
        super(_G, self).__init__()
        self.cube_len = 64
        self.z_size = 100
        self.bias = False


        padd = (0, 0, 0)
        self.layer1 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.z_size + 2, self.cube_len*8, kernel_size=4, stride=2, bias=self.bias, padding=padd),
            torch.nn.BatchNorm3d(self.cube_len*8),
            torch.nn.ReLU(inplace = True)
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.cube_len*8, self.cube_len*4, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            torch.nn.BatchNorm3d(self.cube_len*4),
            torch.nn.ReLU(inplace = True)
            #torch.nn.LeakyReLU(self.self.leak_value)
        )
                
        self.layer3 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.cube_len*4, self.cube_len*2, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            torch.nn.BatchNorm3d(self.cube_len*2),
            torch.nn.ReLU()
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.cube_len*2, self.cube_len, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            torch.nn.BatchNorm3d(self.cube_len),
            torch.nn.ReLU()
        )
        self.layer5 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.cube_len*1, 1, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            #torch.nn.BatchNorm3d(1), #lets try opening this later 
            #torch.nn.Linear(-1,1),
            torch.nn.ReLU(),
            #torch.nn.LeakyReLU(self.self.leak_value),
            torch.nn.Tanh() ###
            #torch.nn.Sigmoid()
        )
      

    def forward(self, x, one_hot_labels):
        # shape (batch size,2, 1,1,1)
        one_hot_labels = one_hot_labels.unsqueeze(2).unsqueeze(3).unsqueeze(4)

        out = x.view(-1, self.z_size, 1, 1, 1)
        #print("start",out.size()) # torch.Size([64, 100, 1, 1, 1])

        # concatenate x (z noise vector) with the one hot labels
        out = torch.cat([out, one_hot_labels], dim=1)
        print("Concatination of noise vector and one hot labels", out.size()) # torch.size([64,102,1,1,1])

        out = self.layer1(out)
        #print("G:L1",out.size())  #torch.Size([64, 512, 4, 4, 4])


        out = self.layer2(out)
        #print("G:L2",out.size()) # torch.Size([64, 256, 8, 8, 8])
        out = self.layer3(out)
        #print("G:L3",out.size()) # torch.Size([64, 128, 16, 16, 16])
        out = self.layer4(out)
        #print("G:L4",out.size()) #  torch.Size([64, 64, 32, 32, 32])
        out = self.layer5(out)
        #print("G:L5",out.size()) #  torch.Size([64, 1, 64, 64, 64])
        out = out.view(-1, self.cube_len*self.cube_len*self.cube_len)
        #print("G:Final",out.size()) # torch.Size([64, 32768])

        return out


class _D(torch.nn.Module):
    def __init__(self):
        super(_D, self).__init__()
        self.cube_len = 64
        self.bias = False
        self.leak_value = 0.2

        padd = (0,0,0)
        if self.cube_len == 32:
            padd = (1,1,1)
        # self.process_labels = torch.nn.Sequential(
        #     torch.nn.Conv3d(2, self.cube_len, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1) ),
        #     torch.nn.LeakyReLU(0.2)
        # )
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv3d(1+2, self.cube_len, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            torch.nn.LayerNorm([64, 32, 32, 32]),
            torch.nn.LeakyReLU(self.leak_value)
        )

        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv3d(self.cube_len, self.cube_len*2, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
	    torch.nn.LayerNorm([128, 16, 16, 16]),
            torch.nn.LeakyReLU(self.leak_value)
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv3d(self.cube_len*2, self.cube_len*4, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            torch.nn.LayerNorm([256, 8, 8, 8]),
            torch.nn.LeakyReLU(self.leak_value)
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.Conv3d(self.cube_len*4, self.cube_len*8, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            torch.nn.LayerNorm([512, 4, 4, 4]),
            torch.nn.LeakyReLU(self.leak_value)
        )
        self.layer5 = torch.nn.Sequential(
            torch.nn.Conv3d(self.cube_len*8, 1, kernel_size=4, stride=2, bias=self.bias, padding=(0,0,0)),
            #torch.nn.LeakyReLU(self.leak_value)
            #torch.nn.Sigmoid(),
        )

    def forward(self, x, one_hot_labels):
        # shape (batch, 2, 1,1,1)
        one_hot_labels = one_hot_labels.unsqueeze(2).unsqueeze(3).unsqueeze(4)
        one_hot_labels = one_hot_labels.repeat(1,1,64,64,64)
        #print("One hot labels in D ",one_hot_labels.shape)

        out = x.view(-1, 1, self.cube_len, self.cube_len, self.cube_len)
        #print("start D",out.size()) #  torch.Size([64, 1, 64, 64, 64])

        # concatenate processed labels and x
        out = torch.cat([out, one_hot_labels], dim=1)

        out = self.layer1(out)
        #print("D:L1",out.size()) #  torch.Size([64, 64, 32, 32, 32])



        out = self.layer2(out)
        #print("D:L2",out.size()) #  torch.Size([64, 128, 16, 16, 16])
        out = self.layer3(out)
        #print("D:L3",out.size()) #  torch.Size([64, 256, 8, 8, 8])
        out = self.layer4(out)
        #print("D:L4",out.size()) #  torch.Size([64, 512, 4, 4, 4])
        out = self.layer5(out)
        #print("D:L5",out.size()) #  torch.Size([64, 1, 1, 1, 1])
        out = out.view(-1,1) 
        #print("final",out.size()) # torch.Size([64, 1])
        return out

In [3]:
#  Build the model
D = _D()
G = _G()
d_lr = 1e-2
g_lr = 0.00025
beta = (0.9, 0.99)
#Create the solvers
D_solver = optim.Adam(D.parameters(), lr=d_lr, betas=beta)
G_solver = optim.Adam(G.parameters(), lr=g_lr, betas=beta)


if torch.cuda.is_available():
    print("using cuda")
    D.cuda()
    G.cuda()

pickle_path = "./pickled_model"
#Load checkpoint if available
read_pickle(pickle_path, G, G_solver, D, D_solver)

using cuda
905 ./pickled_model
Done loading G 905
Done loading G_Opti 905
Done loading D 905
Done loading D_Optim 905


In [14]:
z1 = generateZ(1,z_size=100)
# label 1 ([0,1]) lumbar, 0 ([1,0]) thoracic 
label = torch.Tensor(np.asarray([1,0])).cuda().unsqueeze(0) 
print(label.shape)
gout1 = G(z1,label)
samples = gout1.cpu().data.squeeze().numpy()
samples = samples.reshape(-1,64,64,64)
Save_Voxels(samples, './ls', 1)

torch.Size([1, 2])
Concatination of noise vector and one hot labels torch.Size([1, 102, 1, 1, 1])


# Vis

In [15]:
from visdom import Visdom
import pickle
import skimage.measure as sk
from skimage.measure import marching_cubes_lewiner
import h5py as h5
import numpy as np 
from data_utils import plot_3d_mesh
import matplotlib.pyplot as plt



In [16]:
def getVFByMarchingCubes(voxels, threshold=0.5):
    """Voxel 로 부터 Vertices, faces 리턴 하는 함수"""
    #v, f = sk.marching_cubes_classic(voxels) #, level=threshold)
    #voxels = np.pad(voxels, pad_width=((1, 1), (1, 1), (1, 1)), mode='constant')
    v, f, _, _  = marching_cubes_lewiner(voxels, level= threshold)#
    return v, f


def plotVoxelVisdom(voxels, visdom, title):
    v, f = getVFByMarchingCubes(voxels)
    visdom.mesh(X=v, Y=f, opts=dict(opacity=0.5, title=title))


In [17]:
DEFAULT_PORT = 8097
DEFAULT_HOSTNAME = "http://localhost"
viz = Visdom(DEFAULT_HOSTNAME,DEFAULT_PORT, ipv6=False)

In [13]:
number= '001'
filename="./ls/" +str(number) + ".pkl"
with open(filename, "rb") as f:
    voxels = pickle.load(f).squeeze()
    #plot_3d_mesh(voxels)
    plotVoxelVisdom(voxels, viz, "0-Thoracic")
    