In [None]:
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
from PIL import Image
import numpy as np
import glob
import random
from models import Encoder, Tnet, _F_, Destructor_domain, Destructor_shape
import torch
import torch.nn as nn

In [None]:
class XDataset(data.Dataset):
    """Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self, image_root, point_cloud_root, transform=None, use_2048 = True):
        
        """
        Args:
            path: image directory.
            
        """
        self.image_root = image_root
        self.point_cloud_root = point_cloud_root
        self.transform = transform
        self.use_2048 = use_2048
        # Init ids
        self.ids = []
        # Store ids
        image_types = os.listdir(image_root)
        file_names = []
        
        for image_type in image_types:
            if image_type.endswith('.tgz'):
                    continue
            
            specific_types = os.listdir(os.path.join(image_root, image_type))
            
            for specific_type in specific_types:
                if specific_type.endswith('.tgz'):
                    continue
                
                path1 = os.path.join(os.path.join(os.path.join(image_root, image_type), specific_type), 'rendering')
                temp = os.listdir(path1)
                
                for file_name in temp:
                    
                    if file_name.endswith('.png'):
                        path = os.path.join(path1, file_name)
                        
                        self.ids.append((image_type, specific_type, path))
            
        self.l = glob.glob('pix3d/*/*.png')
        self.n = len(self.l)
    

    def __getitem__(self, index):
        """Returns one data pair (actual image and point cloud)."""
        image_root = self.image_root
        point_cloud_root = self.point_cloud_root
        
        image_type, specific_type, image_path = self.ids[index]
        
        # Load image and point cloud and return
        image = Image.open(image_path).convert('RGB')
        real_image = Image.open(self.l[random.randint(0, self.n)]).convert('RGB')
        
        if self.transform is not None:
            image = self.transform(image)
            real_image = self.transform(real_image)   
        
        if self.use_2048:
            point_cloud_path = os.path.join(os.path.join(os.path.join(point_cloud_root, image_type), 
                                                     specific_type), 'pointcloud_2048.npy')
        else:
            point_cloud_path = os.path.join(os.path.join(os.path.join(point_cloud_root, image_type), 
                                                     specific_type), 'pointcloud_1024.npy')
            
        point_cloud = np.load(point_cloud_path)
        
        return image, point_cloud, real_image

    def __len__(self):
        return len(self.ids)

def get_loader(image_root, point_cloud_root, use_2048, transform, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom X dataset."""
    
    xdataset = XDataset(image_root, point_cloud_root, transform = transform, use_2048 = use_2048)
    
    data_loader = torch.utils.data.DataLoader(dataset=xdataset, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers)
    return data_loader


In [None]:
# Training Configuration
image_root = "./../../../datasets/cs253-wi20-public/ShapeNetRendering/"
point_cloud_root = "./../../../datasets/cs253-wi20-public/ShapeNet_pointclouds/"

batch_size = 4
shuffle = True
num_workers = 8
use_2048 = True
img_size = 256
transform = transforms.Compose([transforms.Resize(img_size,interpolation=2),transforms.CenterCrop(img_size),transforms.ToTensor()])
data_loader = get_loader(image_root, point_cloud_root, use_2048, transform, batch_size, shuffle, num_workers)

In [None]:
criterion = nn.BCELoss()
num_epochs = 1000
learning_rate = 1e-3
f = _F_().cuda()
d_domain = Destructor_domain().cuda()
d_shape = Destructor_shape().cuda()
optimizerf = torch.optim.Adam(f.parameters(), lr=learning_rate)
optimizer_domain = torch.optim.Adam(d_domain.parameters(), lr=learning_rate)
optimizer_shape = torch.optim.Adam(d_shape.parameters(), lr=learning_rate)

enc = torch.load('encoder').cuda()

for epoch in range(num_epochs):
    loss_f = 0.0
    loss_domain = 0.0
    loss_shape = 0.0
    
    for (i_syn, pc_syn, i_real) in data_loader:
        if torch.cuda.is_available():
            i_syn = i_syn.cuda()
            i_real = i_real.cuda()
            pc_syn = pc_syn.transpose(1,2).cuda()
            
        #real =1; syn =0
        enc_syn = f(i_syn)
        output = d_domain(enc_syn)
        errD_syn = criterion(output, torch.full((batch_size,), 1, device=i_syn.device))
        errG_syn = criterion(output, torch.full((batch_size,), 0, device=i_syn.device))
        enc_real = f(i_real)
        output = d_domain(enc_real)
        errD_real = criterion(output, torch.full((batch_size,), 0, device=i_syn.device))
        errG_real = criterion(output, torch.full((batch_size,), 1, device=i_syn.device))
        errD = errD_syn + errD_real
        errD.backward(retain_graph=True)
        optimizer_domain.step()
        errG = errG_syn + errG_real
        errG.backward()
        optimizerf.step()
        
        enc_pc = enc(pc_syn.float())
        output = d_shape(enc_pc)
        errD_pc = criterion(output, torch.full((batch_size,), 1, device=i_syn.device))
        errG_pc = criterion(output, torch.full((batch_size,), 0, device=i_syn.device))
        output = d_shape(enc_syn)
        errD_syn1 = criterion(output, torch.full((batch_size,), 0, device=i_syn.device))
        errG_syn1 = criterion(output, torch.full((batch_size,), 1, device=i_syn.device))
        output = d_shape(enc_real)
        errD_real1 = criterion(output, torch.full((batch_size,), 0, device=i_syn.device))
        errG_real1 = criterion(output, torch.full((batch_size,), 1, device=i_syn.device))
        lossD = errD_pc + errD_syn1 + errD_real1
        lossG = errG_pc + errG_syn1 + errG_real1
        lossD.backward(retain_graph=True)
        optimizer_shape.step()
        lossG.backward()
        optimizerf.step()
        
        loss_f+=errG.data.detach() +lossG.data.detach() 
        loss_domain+=errD.data.detach() 
        loss_shape +=lossD.data.detach() 
        
    print('epoch [{}/{}], F_loss:{:.4f}, Domain_loss:{:.4f}, Shape_loss:{:.4f}'.format(epoch + 1, num_epochs, loss_f/i_syn.shape[0], loss_shape/i_syn.shape[0], ))
        
                 