In [2]:
import torch
import torch.nn as nn
import json
import numpy as np
import pandas as pd
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import trimesh
import scipy.io
from torchvision.datasets import ImageFolder
import sys
sys.path.append("../models")
from gan import Discriminator, Generator

In [2]:
mesh = trimesh.load('../../model/bed/IKEA_BEDDINGE/model.obj')
mesh_v = list(mesh.geometry.values())[0]
vertices = np.array(mesh_v.vertices)
faces = np.array(mesh_v.faces)
print(vertices)
print(faces)

[[-0.43324038 -0.07641411  0.10654314]
 [-0.43324038 -0.07641411  0.10654314]
 [-0.43324038 -0.07641411  0.10654314]
 [ 0.43324038 -0.07641411 -0.22230652]
 [ 0.43324038 -0.07641411  0.10654314]
 [ 0.43324038 -0.07641411  0.10654314]
 [ 0.43324038 -0.07641411  0.10654314]
 [-0.43324038 -0.07641411 -0.22230652]
 [-0.43324038 -0.07641411 -0.22230652]
 [-0.43324038 -0.07641411 -0.22230652]
 [-0.43324038 -0.0459173  -0.22230652]
 [-0.43324038 -0.0459173  -0.22230652]
 [-0.43324038 -0.0459173  -0.22230652]
 [-0.43324038 -0.04256963 -0.22186802]
 [-0.43324038 -0.04256963 -0.22186802]
 [-0.43324038 -0.04256963 -0.22186802]
 [-0.43324038 -0.039453   -0.2205761 ]
 [-0.43324038 -0.039453   -0.2205761 ]
 [-0.43324038 -0.039453   -0.2205761 ]
 [-0.43324038 -0.03677486 -0.21852035]
 [-0.43324038 -0.03677486 -0.21852035]
 [-0.43324038 -0.03677486 -0.21852035]
 [-0.43324038 -0.03471911 -0.21584221]
 [-0.43324038 -0.03471911 -0.21584221]
 [-0.43324038 -0.03471911 -0.21584221]
 [-0.43324038 -0.03342719

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
json_dir = "../../pix3d.json"
with open(json_dir, 'r') as f:
    data = json.load(f)
df = pd.DataFrame(data)

In [5]:
class pix3d_dataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.transform = transform
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        img_path = "../../" + self.dataframe.iloc[idx]['img']
        mask_path = "../../" + self.dataframe.iloc[idx]['mask']
        voxel_path = "../../" + self.dataframe.iloc[idx]['voxel']

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        voxel = scipy.io.loadmat(voxel_path)['voxel']

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        voxel = torch.tensor(voxel, dtype=torch.float32).unsqueeze(0)
        sample = {
            'image': image,
            'mask': mask,
            'voxel': voxel
        }
        return sample


In [6]:
latent_dim = 100
hidden_dim = 64
lr = 0.0002
batch_size = 32
num_epochs = 100

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(), 
])
dataset = pix3d_dataset(df, transform = transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

generator = Generator(img_dim=3, hidden_dim=hidden_dim, latent_dim=latent_dim, output_dim=128).to(device)
discriminator = Discriminator(voxel_dim=1, img_dim=3, hidden_dim=hidden_dim).to(device)

optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [8]:
for i, data in enumerate(dataloader):
    real_images = data['image'].to(device)
    real_voxels = data['voxel'].to(device)
    print("Image shape:", real_images.shape)
    print("Voxel shape:", real_voxels.shape)
    break

Image shape: torch.Size([4, 3, 128, 128])
Voxel shape: torch.Size([4, 1, 128, 128, 128])


In [7]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        real_images = data['image'].to(device)
        real_voxels = data['voxel'].to(device)

        print(f"Real images shape: {real_images.shape}")
        print(f"Real voxels shape: {real_voxels.shape}")

        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        #trainign Discriminator
        optimizer_d.zero_grad()
        
        #real data
        outputs = discriminator(real_voxels, real_images)
        d_loss_real = criterion(outputs, real_labels)
        
        #fake data
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_voxels = generator(real_images, z)
        
        print(f"Fake voxels shape: {fake_voxels.shape}")
        
        outputs = discriminator(fake_voxels.detach(), real_images)
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        optimizer_g.zero_grad()
        
        fake_voxels = generator(real_images, z)
        outputs = discriminator(fake_voxels, real_images)
        g_loss = criterion(outputs, real_labels)
        
        g_loss.backward()
        optimizer_g.step()

        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

    # Optional: Save models after each epoch
    torch.save(generator.state_dict(), f'generator_epoch_{epoch+1}.pth')
    torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch+1}.pth')

Real images shape: torch.Size([32, 3, 128, 128])
Real voxels shape: torch.Size([32, 1, 128, 128, 128])


: 