In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install trimesh

Collecting trimesh
  Downloading trimesh-4.5.3-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.5.3-py3-none-any.whl (704 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/704.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m704.8/704.8 kB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.5.3


In [3]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [4]:
import torch
import torch.nn as nn
import json
import numpy as np
import pandas as pd
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import trimesh
import scipy.io
from torchvision.datasets import ImageFolder
import sys
sys.path.append('/content/drive/My Drive/pix3d/byop-2425/models')
from gan2 import CompactGenerator, CompactDiscriminator
# import sys
# sys.path.append("../models")
# from gan import Discriminator, Generator

In [None]:
mesh = trimesh.load('../../model/bed/IKEA_BEDDINGE/model.obj')
mesh_v = list(mesh.geometry.values())[0]
vertices = np.array(mesh_v.vertices)
faces = np.array(mesh_v.faces)
print(vertices)
print(faces)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
json_dir = "/content/drive/MyDrive/pix3d/pix3d.json"
with open(json_dir, 'r') as f:
    data = json.load(f)
df = pd.DataFrame(data)

In [16]:
sample_df = pd.read_csv("/content/sample_data.csv")

In [7]:
class pix3d_dataset(Dataset):
    def __init__(self, dataframe, transform=None, data_dir='/content/drive/MyDrive/pix3d/'):
        self.transform = transform
        self.dataframe = dataframe
        self.data_dir = data_dir

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.data_dir + self.dataframe.iloc[idx]['img']
        mask_path = self.data_dir + self.dataframe.iloc[idx]['mask']
        voxel_path = self.data_dir + self.dataframe.iloc[idx]['voxel']

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        voxel = scipy.io.loadmat(voxel_path)['voxel']

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        voxel = torch.tensor(voxel, dtype=torch.float32).unsqueeze(0)
        sample = {
            'image': image,
            'mask': mask,
            'voxel': voxel
        }
        return sample


In [17]:
# latent_dim = 50
# hidden_dim = 32
lr = 0.0002
batch_size = 1
num_epochs = 5

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])
dataset = pix3d_dataset(sample_df, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# generator = SmallGenerator(img_dim=3, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
# discriminator = SmallDiscriminator(voxel_dim=1, img_dim=3, hidden_dim=hidden_dim).to(device)

# optimizer_g = optim.SGD(generator.parameters(), lr=lr, momentum = 0.9)
# optimizer_d = optim.SGD(discriminator.parameters(), lr=lr, momentum = 0.9)

# criterion = nn.BCELoss()

In [20]:
for i, data in enumerate(dataloader):
    real_images = data['image'].to(device)
    real_voxels = data['voxel'].to(device)
    print("Image shape:", real_images.shape)
    print("Voxel shape:", real_voxels.shape)
    break

Image shape: torch.Size([1, 3, 128, 128])
Voxel shape: torch.Size([1, 1, 128, 128, 128])


In [18]:
import tqdm as tqdm

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F




def train_gan(generator, discriminator, dataloader, device, num_epochs=5, lr=0.0002, latent_dim=50):
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

    criterion = nn.BCELoss()

    for epoch in tqdm.tqdm(range(num_epochs)):
        for i, data in enumerate(dataloader):
            real_images = data['image'].to(device)
            real_voxels = data['voxel'].to(device)

            batch_size = real_images.size(0)
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            optimizer_d.zero_grad()

            #real data
            outputs = discriminator(real_voxels, real_images)
            d_loss_real = criterion(outputs, real_labels)

            #fake data
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_voxels = generator(real_images, z)

            outputs = discriminator(fake_voxels.detach(), real_images)
            d_loss_fake = criterion(outputs, fake_labels)

            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_d.step()

            optimizer_g.zero_grad()

            fake_voxels = generator(real_images, z)
            outputs = discriminator(fake_voxels, real_images)
            g_loss = criterion(outputs, real_labels)

            g_loss.backward()
            optimizer_g.step()

            if (i + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

        torch.save(generator.state_dict(), f'generator_epoch_{epoch+1}.pth')
        torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch+1}.pth')

# Usage
generator = CompactGenerator(img_dim=3, hidden_dim=16, latent_dim=50).to(device)
discriminator = CompactDiscriminator(voxel_dim=1, img_dim=3, hidden_dim=16).to(device)
train_gan(generator, discriminator, dataloader, device)

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch [1/5], Step [100/100], D Loss: 1.3964293003082275, G Loss: 0.7098605036735535


 20%|██        | 1/5 [07:32<30:08, 452.04s/it]

Epoch [2/5], Step [100/100], D Loss: 1.3865364789962769, G Loss: 0.6872125864028931


 40%|████      | 2/5 [13:17<19:28, 389.44s/it]

Epoch [3/5], Step [100/100], D Loss: 1.4212024211883545, G Loss: 0.6563372015953064


 60%|██████    | 3/5 [19:04<12:20, 370.11s/it]

Epoch [4/5], Step [100/100], D Loss: 1.428591012954712, G Loss: 0.6658902764320374


 80%|████████  | 4/5 [24:53<06:01, 361.82s/it]

Epoch [5/5], Step [100/100], D Loss: 1.3729350566864014, G Loss: 0.6639695763587952


100%|██████████| 5/5 [30:37<00:00, 367.59s/it]
