In [1]:
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
from PIL import Image
import numpy as np
import glob
import random
from models import Encoder, Tnet, _F_, Destructor_domain, Destructor_shape
import torch.nn as nn
import warnings
warnings.filterwarnings('ignore')

In [2]:
class XDataset(data.Dataset):
    def __init__(self, image_root, point_cloud_root, transform=None, use_2048 = True):
        self.image_root = image_root
        self.point_cloud_root = point_cloud_root
        self.transform = transform
        self.use_2048 = use_2048
        self.pc_list = glob.glob(point_cloud_root+"*/*/*2048.npy")
        self.syn_list = glob.glob(image_root+"*/*/rendering/*.png")
        self.N = len(self.syn_list)
        self.l = glob.glob('pix3d/*/*.png')
        self.n = len(self.l)
    

    def __getitem__(self, index):
        
        image = Image.open(self.syn_list[random.randint(0, self.N-1)]).convert('RGB')
        real_image = Image.open(self.l[random.randint(0, self.n-1)]).convert('RGB')
        
        if self.transform is not None:
            image = self.transform(image)
            real_image = self.transform(real_image)   
            
        point_cloud = np.load(self.pc_list[index])
        
        return image, point_cloud, real_image

    def __len__(self):
        return len(self.pc_list)

def get_loader(image_root, point_cloud_root, use_2048, transform, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom X dataset."""
    
    xdataset = XDataset(image_root, point_cloud_root, transform = transform, use_2048 = use_2048)
    
    data_loader = torch.utils.data.DataLoader(dataset=xdataset, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers)
    return data_loader


In [3]:
# Training Configuration
image_root = "/datasets/cs253-wi20-public/ShapeNetRendering/"
point_cloud_root = "/datasets/cs253-wi20-public/ShapeNet_pointclouds/"

batch_size = 16
shuffle = True
num_workers = 8
use_2048 = True
img_size = 256
transform = transforms.Compose([transforms.Resize(img_size,interpolation=2),transforms.CenterCrop(img_size),transforms.ToTensor()])
data_loader = get_loader(image_root, point_cloud_root, use_2048, transform, batch_size, shuffle, num_workers)
def init_weights(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform(m.weight.data)

In [None]:
criterion = nn.BCELoss()
num_epochs = 1000
learning_rate = 1e-5
f = _F_().apply(init_weights).cuda()
d_domain = Destructor_domain().apply(init_weights).cuda()
d_shape = Destructor_shape().apply(init_weights).cuda()
optimizerf = torch.optim.Adam(f.parameters(), lr=learning_rate)
optimizer_domain = torch.optim.Adam(d_domain.parameters(), lr=learning_rate)
optimizer_shape = torch.optim.Adam(d_shape.parameters(), lr=learning_rate)

enc = torch.load('encoder-Copy1').cuda()

for epoch in range(num_epochs):
    loss_f = 0.0
    loss_domain = 0.0
    loss_shape = 0.0
    
    for (i_syn, pc_syn, i_real) in data_loader:
        b_size = i_syn.shape[0]
        if torch.cuda.is_available():
            i_syn = i_syn.cuda()
            i_real = i_real.cuda()
            pc_syn = pc_syn.transpose(1,2).cuda()
        
        optimizer_domain.zero_grad()
        #real =0; syn =1
        enc_syn = f(i_syn)
        output = d_domain(enc_syn)
        errD_syn = criterion(output, torch.full((b_size,), 1, device=i_syn.device))
        enc_real = f(i_real)
        output1 = d_domain(enc_real)
        errD_real = criterion(output1, torch.full((b_size,), 0, device=i_syn.device))
        errD = errD_syn + errD_real
        errD.backward()
        optimizer_domain.step()
        
        optimizerf.zero_grad()
        enc_syn = f(i_syn)
        output = d_domain(enc_syn)
        errG_syn = criterion(output, torch.full((b_size,), 0, device=i_syn.device))
        enc_real = f(i_real)
        output1 = d_domain(enc_real)
        errG_real = criterion(output1, torch.full((b_size,), 1, device=i_syn.device))
        errG = 0.01*(errG_syn + errG_real)
        errG.backward()
        optimizerf.step()
        
        optimizer_shape.zero_grad()
        enc_syn = f(i_syn)
        enc_real = f(i_real)
        enc_pc = enc(pc_syn.float())
        output = d_shape(enc_pc)
        errD_pc = criterion(output, torch.full((b_size,), 1, device=i_syn.device))
        output1 = d_shape(enc_syn)
        errD_syn1 = criterion(output1, torch.full((b_size,), 0, device=i_syn.device))
        output2 = d_shape(enc_real)
        errD_real1 = criterion(output2, torch.full((b_size,), 0, device=i_syn.device))
        lossD = errD_pc + errD_syn1 + errD_real1
        lossD.backward()
        optimizer_shape.step()
        
        optimizerf.zero_grad()
        enc_syn = f(i_syn)
        enc_real = f(i_real)
        enc_pc = enc(pc_syn.float())
        output = d_shape(enc_pc)
        output1 = d_shape(enc_syn)
        output2 = d_shape(enc_real)
        errG_pc = criterion(output, torch.full((b_size,), 0, device=i_syn.device))
        errG_syn1 = criterion(output1, torch.full((b_size,), 1, device=i_syn.device))
        errG_real1 = criterion(output2, torch.full((b_size,), 1, device=i_syn.device))
        lossG = 0.01*(errG_pc + errG_syn1 + errG_real1)
        lossG.backward()
        optimizerf.step()
        
        loss_f+=errG.data.detach() +lossG.data.detach() 
        loss_domain+=errD.data.detach() 
        loss_shape +=lossD.data.detach() 
        
    print('epoch [{}/{}], F_loss:{:.4f}, Domain_loss:{:.4f}, Shape_loss:{:.4f}'.format(epoch + 1, num_epochs, loss_f/b_size, loss_domain/b_size, loss_shape/b_size ))
        
                 