In [None]:
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
from PIL import Image
import numpy as np

class XDataset(data.Dataset):
    """Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self, image_root, point_cloud_root, transform=None, use_2048 = True):
        
        """
        Args:
            path: image directory.
            
        """
        self.image_root = image_root
        self.point_cloud_root = point_cloud_root
        self.transform = transform
        self.use_2048 = use_2048
        
        
        # Init ids
        self.ids = []
        
        # Store ids
        
        image_types = os.listdir(image_root)
        file_names = []
        
        for image_type in image_types:
            if image_type.endswith('.tgz'):
                    continue
            
            specific_types = os.listdir(os.path.join(image_root, image_type))
            
            for specific_type in specific_types:
                if specific_type.endswith('.tgz'):
                    continue
                
                path1 = os.path.join(os.path.join(os.path.join(image_root, image_type), specific_type), 'rendering')
                temp = os.listdir(path1)
                
                for file_name in temp:
                    
                    if file_name.endswith('.png'):
                        path = os.path.join(path1, file_name)
                        
                        self.ids.append((image_type, specific_type, path))
            

    

    def __getitem__(self, index):
        """Returns one data pair (actual image and point cloud)."""
        image_root = self.image_root
        point_cloud_root = self.point_cloud_root
        
        image_type, specific_type, image_path = self.ids[index]
        
        # Load image and point cloud and return
        image = Image.open(image_path).convert('RGB')
        
        if self.transform is not None:
            image = self.transform(image)
            
        image = np.asarray(image)
            
            
        
        if self.use_2048:
            point_cloud_path = os.path.join(os.path.join(os.path.join(point_cloud_root, image_type), 
                                                     specific_type), 'pointcloud_2048.npy')
        else:
            point_cloud_path = os.path.join(os.path.join(os.path.join(point_cloud_root, image_type), 
                                                     specific_type), 'pointcloud_1024.npy')
            
        point_cloud = np.load(point_cloud_path)
        
        return image, point_cloud

    def __len__(self):
        return len(self.ids)

def get_loader(image_root, point_cloud_root, use_2048, transform, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom X dataset."""
    
    xdataset = XDataset(image_root, point_cloud_root, transform = transform, use_2048 = use_2048)
    
    data_loader = torch.utils.data.DataLoader(dataset=xdataset, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers)
    return data_loader

In [None]:
ndf = 4
class _F_(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 256 x 256
            nn.Conv2d(3, int(ndf/2) , 4, stride=2, padding=1, bias=False), 
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf/2) x 128 x 128
            nn.Conv2d(int(ndf/2), ndf, 4, stride=2, padding=1, bias=False), 
            nn.BatchNorm2d(ndf),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 64 x 64
            nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 32 x 32
            nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 16 x 16 
            nn.Conv2d(ndf * 4, ndf * 8, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 8 x 8
            nn.Conv2d(ndf * 8, ndf * 16, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*16) x 4 x 4
            nn.Conv2d(ndf * 16, 1, 4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
            # state size. 1
        )
    def forward(self, input):
        return self.main(input)

    
class Destructor_domain(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128,1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)
        
    def forward(self, input):
        input = input.view(input.size(0), -1)
        out = F.relu(self.bn1(self.fc1(input)))
        out = F.relu(self.bn2(self.fc2(out)))
        out = F.relu(self.bn3(self.fc3(out)))
        return F.sigmoid(self.fc4(out))
    
    
class Destructor_shape(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128,1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)
        
    def forward(self, input):
        input = input.view(input.size(0), -1)
        out = F.relu(self.bn1(self.fc1(input)))
        out = F.relu(self.bn2(self.fc2(out)))
        out = F.relu(self.bn3(self.fc3(out)))
        return F.sigmoid(self.fc4(out))

In [None]:
criterion = nn.BCELoss()
num_epochs = 1000
learning_rate = 1e-3
f = _F_().cuda()
d_domain = Destructor_domain().cuda()
d_shape = Destructor_shape().cuda()
optimizerf = torch.optim.Adam(f.parameters(), lr=learning_rate)
optimizer_domain = torch.optim.Adam(d_domain.parameters(), lr=learning_rate)
optimizer_shape = torch.optim.Adam(d_shape.parameters(), lr=learning_rate)

enc = torch.load('encoder').cuda()

for epoch in range(num_epochs):
    running_loss = 0.0
    
    for (i_syn, i_real, pc_syn) in dataloader:
        if torch.cuda.is_available():
            i_syn = i_syn.cuda()
            i_real = i_real.cuda()
            pc_syn = pc_syn.cuda()
            
        #real =1; syn =0
        enc_syn = f(i_syn)
        output = d_domain(enc_syn)
        errD_syn = criterion(output, torch.full((b_size,), 1, device=device))
        errG_syn = criterion(output, torch.full((b_size,), 0, device=device))
        enc_real = f(i_real)
        output = d_domain(enc_real)
        errD_real = criterion(output, torch.full((b_size,), 0, device=device))
        errG_real = criterion(output, torch.full((b_size,), 1, device=device))
        errD = errD_syn + errD_real
        errD.backward()
        optimizer_domain.step()
        errG = errG_syn + errG_real
        errG.backward()
        optimizerf.step()
        
        enc_pc = enc(pc_syn)
        output = d_shape(enc_pc)
        errD_pc = criterion(output, torch.full((b_size,), 1, device=device))
        errG_pc = criterion(output, torch.full((b_size,), 0, device=device))
        output = d_shape(enc_syn)
        errD_syn1 = criterion(output, torch.full((b_size,), 0, device=device))
        errG_syn1 = criterion(output, torch.full((b_size,), 1, device=device))
        output = d_shape(enc_real)
        errD_real1 = criterion(output, torch.full((b_size,), 0, device=device))
        errG_real1 = criterion(output, torch.full((b_size,), 1, device=device))
        lossD = errD_pc + errD_syn1 + errD_real1
        lossG = errG_pc + errG_syn1 + errG_real1
        lossD.backward()
        optimizer_shape.step()
        lossG.backward()
        optimizerf.step()
        
        
        
         