In [1]:
import h5py

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# models
discrete_encoder_path = 'discrete_encoder.pth'
discrete_decoder_path = 'discrete_decoder.pth'
mapping_path = 'mapping.pth'
unet_path = 'con_unet_full.pth'
continuous_model_path = 'continuous_model.pth'

hidden_dim_discrete = 128

# data preparation
data_file = '/home/ankbzpx/datasets/ShapeNet/ShapeNetRenderingh5_v1/03001627/sdf_train_core.h5'
sample_size = 2048
batch_size = 32
split_ratio = 0.9
depth_size = 256
num_of_workers = 12
# training
num_epoch = 40

In [4]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import torch.nn as nn
import torchvision.transforms.functional as TF
from torch.autograd import Variable
from PIL import Image
import sys
import time
import matplotlib.pyplot as plt
from torchsummary import summary

In [5]:
# reproducible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)
np.random.seed(0)

In [6]:
device

device(type='cuda')

In [7]:
class ChairSDFDataset(Dataset):
     
    def __init__(self, h5_file):
        
        self.file_path = h5_file
        self.dataset = None
        
        with h5py.File(self.file_path, 'r') as file:
            self.dataset_len = len(file)
            self.keys = list(file.keys())
            
        self.to_tensor = transforms.ToTensor()
     
    def __len__(self):
        return self.dataset_len
 
    def __getitem__(self, idx):
        
        #start_time = time.time()
        
        if self.dataset is None:
            self.dataset = h5py.File(self.file_path, 'r')
         
        group = self.dataset[self.keys[idx]]
        
        depth_img = self.to_tensor(Image.fromarray(np.array(group['depth_img'])))
        
        #print("--- depth preprocessing %s seconds ---" % (time.time() - start_time))
        
        sample_pt_np = np.array(group['sample_pt']).reshape(-1, 3)
        sample_sdf_np = np.array(group['sample_sdf']).reshape(-1, 1)
        
        # check size correctness and fix incorrect data
        if sample_pt_np.shape[0] != 2048:
            sample_pt_np = np.pad(sample_pt_np, ((0, 2048 - sample_pt_np.shape[0]), (0, 0)), 'reflect')
        if sample_sdf_np.shape[0] != 2048:
            sample_sdf_np = np.pad(sample_sdf_np, ((0, 2048 - sample_sdf_np.shape[0]), (0, 0)), 'reflect')
            
        
        sample_pt = torch.from_numpy(sample_pt_np).float()
        sample_sdf = torch.from_numpy(sample_sdf_np).float()
        # scale sdf
        sample_sdf = torch.sign(sample_sdf)*torch.pow(torch.abs(sample_sdf), 0.25)
        
        #print("--- subsampling %s seconds ---" % (time.time() - start_time))
        
        target_vox = torch.from_numpy(np.array(group['target_vox'])).float()
        
        sample = { 'depth_img': depth_img,
                   'sample_pt':sample_pt,
                   'sample_sdf':sample_sdf,
                   'target_vox':target_vox,
                  }
        
        return sample

In [8]:
train_sdf_dataset = ChairSDFDataset(data_file)

train_sdf_dataloader = DataLoader(train_sdf_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)

In [9]:
len(train_sdf_dataloader)

2634

In [10]:
start_epoch = 0
latent_dim = 256

from models import Discrete_encoder, Mapping, Discrete_decoder, Conditional_UNET

####################
# Discrete Encoder #
####################

discrete_encoder = Discrete_encoder(256).to(device)
discrete_encoder.load_state_dict(torch.load(discrete_encoder_path))
discrete_encoder.eval()

for child in discrete_encoder.children():
    for param in child.parameters():
        param.requires_grad = False


####################
# Mapping #
####################

mapping = Mapping().to(device)
mapping.load_state_dict(torch.load(mapping_path))
mapping.eval()

for child in mapping.children():
    for param in child.parameters():
        param.requires_grad = False
        
# ####################
# # Discrete Decoder #
# ####################
        
discrete_decoder = Discrete_decoder(256).to(device)
discrete_decoder.load_state_dict(torch.load(discrete_decoder_path))
discrete_decoder.eval()

for child in discrete_decoder.children():
    for param in child.parameters():
        param.requires_grad = False
        
########
# UNET #
########

# pre-trained model is loaded within the model
unet = Conditional_UNET(unet_path).to(device)
unet.eval()

for child in unet.children():
    for param in child.parameters():
        param.requires_grad = False

In [11]:
class BottleNect1D(nn.Module):
    def __init__(self, input_dim, expand = 5):
        super(BottleNect1D, self).__init__()
        
        self.block = nn.Sequential(
            nn.Linear(input_dim, expand*input_dim),
            nn.BatchNorm1d(expand*input_dim),
            nn.ReLU(),
            nn.Linear(expand*input_dim, input_dim),
            nn.BatchNorm1d(input_dim),
        )
    
    def forward(self, x):
        return x + self.block(x)

class Continuous(nn.Module):
    def __init__(self, pt_dim = 3, con_dim = 32, latent_dim = 256):
        super(Continuous, self).__init__()
        
        self.de_pt =  nn.Sequential(
            nn.Linear(pt_dim + con_dim + latent_dim, latent_dim),
            nn.BatchNorm1d(latent_dim),
            BottleNect1D(latent_dim),
        )
        
        self.de_1 = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.BatchNorm1d(latent_dim),
            BottleNect1D(latent_dim),
        )
        
        self.de_2 = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.BatchNorm1d(latent_dim),
            BottleNect1D(latent_dim),
            nn.Linear(latent_dim, 1),
            nn.Tanh(),
        )
        
    def forward(self, pt, con, z):
        
        fea = self.de_pt(torch.cat((torch.cat((pt, con), 1), z), 1))
        out = self.de_1(fea) + fea
        out = self.de_2(out)
        
        return out

In [12]:
continuous = Continuous().to(device)
continuous.load_state_dict(torch.load(continuous_model_path))

<All keys matched successfully>

In [13]:
summary(continuous, [(3, ), (32, ), (256, )])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 256]          74,752
       BatchNorm1d-2                  [-1, 256]             512
            Linear-3                 [-1, 1280]         328,960
       BatchNorm1d-4                 [-1, 1280]           2,560
              ReLU-5                 [-1, 1280]               0
            Linear-6                  [-1, 256]         327,936
       BatchNorm1d-7                  [-1, 256]             512
      BottleNect1D-8                  [-1, 256]               0
            Linear-9                  [-1, 256]          65,792
      BatchNorm1d-10                  [-1, 256]             512
           Linear-11                 [-1, 1280]         328,960
      BatchNorm1d-12                 [-1, 1280]           2,560
             ReLU-13                 [-1, 1280]               0
           Linear-14                  [

In [14]:
l1loss = nn.L1Loss()

In [15]:
lr = 1e-3

model = continuous

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [16]:
def save_checkpoint(state):
    torch.save(state, 'continuous_model_checkpoint.pth.tar')

In [17]:
checkpoint = torch.load('continuous_model_checkpoint.pth.tar')
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['continuous_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

In [18]:
# advanced indexing 2x2x2 context from voxel
def getContext(sample_pt_query, vox):
    
    # sample_pt bxmx3
    # vox bxcxdimxdimxdim
    
    channel_size = vox.shape[1]
    batch_size, sample_size, _ = sample_pt_query.shape
    meshgrid_base = torch.Tensor(np.meshgrid(np.arange(0, batch_size), np.arange(0, channel_size), np.arange(0, 2), np.arange(0, 2), np.arange(0, 2))).int()
    context = torch.empty((batch_size, sample_size, channel_size, 2, 2, 2))
    
    for j in range(context.shape[1]):
        context[:, j, :, :, :, :] = vox[
                    meshgrid_base[0].long(),
                    meshgrid_base[1].long(),
                    (meshgrid_base[2] + sample_pt_query[:, j, 0].reshape(1, -1, 1, 1, 1)).long(), 
                    (meshgrid_base[3] + sample_pt_query[:, j, 1].reshape(1, -1, 1, 1, 1)).long(), 
                    (meshgrid_base[4] + sample_pt_query[:, j, 2].reshape(1, -1, 1, 1, 1)).long()
                ].transpose(0, 1)
    
    # b x c x m x 2 x 2 x 2
    return context.transpose(1, 2)

In [19]:
def trilinearInterpolation(context, dx, dy, dz):
    
    v0 = context[:, :, :, 0, 0, 0]*(1-dx)*(1-dy)*(1-dz)
    v1 = context[:, :, :, 1, 0, 0]*dx*(1-dy)*(1-dz)
    v2 = context[:, :, :, 0, 1, 0]*(1-dx)*dy*(1-dz)
    v3 = context[:, :, :, 1, 1, 0]*dx*dy*(1-dz)
    v4 = context[:, :, :, 0, 0, 1]*(1-dx)*(1-dy)*dz
    v5 = context[:, :, :, 1, 0, 1]*dx*(1-dy)*dz
    v6 = context[:, :, :, 0, 1, 1]*(1-dx)*dy*dz
    v7 = context[:, :, :, 1, 1, 1]*dx*dy*dz
    
    # b x c x m 1
    return v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7

In [20]:
#########################################
#test sub module

import time

vox_size = 32
latent_dim = 256
con_dim = 32

batch_len = len(train_sdf_dataloader)

print("Starting Training Loop...")

start_time = time.time()

for epoch in range(start_epoch, num_epoch):
    
    loss_list = []
    loss_batch = []
    
    print("Epoch: ", epoch)
    
    count = 0
    
    for i, data in enumerate(train_sdf_dataloader):
        
        ####################
        # Data preparation #
        ####################
        
        # b x 1 x 256 x 256
        depth_img = data['depth_img'].to(device)
        # b x 128 x 1 x 1
        z = discrete_encoder(depth_img)
        # b x n x 3
        # DO NOT scale by np.sqrt(3)
        sample_pt = data['sample_pt']
        # b x n x 1
        sample_sdf = data['sample_sdf']
        
        # b x 16 x 64 x 64 x 64
        target_vox = data['target_vox'].to(device)
        vox_feature = unet(target_vox, z)
        #vox_feature = unet(torch.sigmoid(discrete_decoder(mapping(z))))
        
        ####################
        # indexing context #
        ####################
        
        # stay with cpu for v-ram efficiency
        sample_pt_normalized = sample_pt + torch.tensor([0.5, 0.5, 0.5])
        # (0, vox_size-1)
        sample_pt_scale = torch.clamp(sample_pt_normalized* (vox_size-1), 0, (vox_size-1)-1e-5)
        # (0, vox_size-2)
        sample_pt_query = torch.clamp((sample_pt_scale).int(), 0, (vox_size-2))
        sample_pt_distance = sample_pt_scale - sample_pt_query
        
        context = getContext(sample_pt_query, vox_feature)
        
        dx = sample_pt_distance[:, :, 0].unsqueeze(1)
        dy = sample_pt_distance[:, :, 1].unsqueeze(1)
        dz = sample_pt_distance[:, :, 2].unsqueeze(1)
        # local feature
        con = trilinearInterpolation(context, dx, dy, dz)
        
        ################################
        # Reshape input & forward pass #
        ################################
        
        sample_pt = sample_pt.transpose(-1, -2).to(device)
        con = con.to(device)
        z = z.squeeze(-1).squeeze(-1).repeat(1, 1, sample_size)
        sample_sdf = sample_sdf.transpose(-1, -2).to(device)
        
        
        sample_pt = sample_pt.transpose(-1, -2).reshape(-1, 3)
        con = con.transpose(-1, -2).reshape(-1, con_dim)
        z = z.transpose(-1, -2).reshape(-1, latent_dim)
        sample_sdf = sample_sdf.transpose(-1, -2).reshape(-1, 1)
        
        
        optimizer.zero_grad()
        
        pred_sdf = model(sample_pt, con, z)
        
        loss_l1 = l1loss(pred_sdf, sample_sdf)
        
        loss = loss_l1
        
        loss.backward()
        
        loss_list.append(loss_l1.item())
        loss_batch.append(loss_l1.item())
        
        optimizer.step()
        
        #scheduler.step()
        
        if count != 0 and count % 10 == 0:
            loss_batch_avg = np.average(loss_batch)
            
            print("Batch: ", count, ", l1 Loss: ", loss_batch_avg, ", Time: %s s" % (time.time() - start_time))
            
            if count % 500 == 0:
                torch.save(model.state_dict(), continuous_model_path)
                
            loss_batch.clear()
            
        count += 1
        
    print("Epoch: ", epoch, ', l1 loss: ', np.average(loss_list))
    
    loss_list.clear()
    
    torch.save(model.state_dict(), continuous_model_path)
    
    save_checkpoint({
        'epoch': epoch + 1,
        'continuous_state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    })
    
print("Training finished")

Starting Training Loop...
Epoch:  25
Batch:  10 , l1 Loss:  0.09254280274564569 , Time: 17.20931100845337 s
Batch:  20 , l1 Loss:  0.09253654256463051 , Time: 30.998764753341675 s
Batch:  30 , l1 Loss:  0.09230871349573136 , Time: 45.400079011917114 s
Batch:  40 , l1 Loss:  0.09249885976314545 , Time: 58.660292625427246 s
Batch:  50 , l1 Loss:  0.09258602485060692 , Time: 73.28973507881165 s
Batch:  60 , l1 Loss:  0.09254356026649475 , Time: 87.38306188583374 s
Batch:  70 , l1 Loss:  0.09247088059782982 , Time: 100.73967361450195 s
Batch:  80 , l1 Loss:  0.09226948395371437 , Time: 116.43030500411987 s
Batch:  90 , l1 Loss:  0.09188556298613548 , Time: 130.42106103897095 s
Batch:  100 , l1 Loss:  0.0919592835009098 , Time: 144.8918879032135 s
Batch:  110 , l1 Loss:  0.09223691299557686 , Time: 158.51891708374023 s
Batch:  120 , l1 Loss:  0.09238492324948311 , Time: 172.48464608192444 s
Batch:  130 , l1 Loss:  0.09232599958777428 , Time: 187.02431297302246 s
Batch:  140 , l1 Loss:  0.09

KeyboardInterrupt: 