In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.checkpoint import checkpoint
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from dataset import VoxDataset, ToTensor
import matplotlib.pyplot as plt
from tqdm import tqdm

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
dataset = VoxDataset(transform=ToTensor())

In [3]:
train_ratio = 0.8  # 80% for training, 20% for testing
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
test_size = dataset_size - train_size

# Use random_split to create training and testing datasets
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [4]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [5]:
max(dataset[0][0])

tensor(0.5718, grad_fn=<UnbindBackward0>)

In [6]:
# dataiter = iter(dataloader)
# emb, mesh = next(dataiter)

In [7]:
# emb.dtype

In [8]:
# mesh.dtype

In [9]:
class AutoEncoder(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(384, 192),
            nn.Tanh()
        )

        self.decoder = nn.Sequential(
            # nn.Linear(192, 384),
            # nn.Tanh(),
            nn.Linear(384, 500),
            nn.Tanh(),
            nn.Linear(500, 768),
            nn.Tanh(),
            nn.Linear(768, 1000),
            nn.Tanh(),
            nn.Linear(1000, 1536),
            nn.Tanh(),
            nn.Linear(1536, 2000),
            nn.Tanh(),
            nn.Linear(2000, 3072),
            nn.Tanh(),
            nn.Linear(3072, 5000),
            nn.Tanh(),
            nn.Linear(5000, 6144),
            nn.Tanh(),
            nn.Linear(6144, 8000),
            nn.Tanh(),
            nn.Linear(8000, 10000),
            nn.Tanh(),
            nn.Linear(10000, 12288),
            nn.Tanh(),
            nn.Linear(12288, 15069),
            nn.Tanh()
            
        )

    def forward(self, x):
        #encoded = self.encoder(x)
        encoded = x
        decoded = self.decoder(encoded)
        return decoded

    


In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
model = AutoEncoder().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-5, 
weight_decay=1e-5
)

Device: cuda:0


In [11]:
def model_checkpoint(x):
    return checkpoint(model, x, use_reentrant=False)

In [12]:
num_epochs = 20
outputs = []
train_loss = []

for epoch in range(num_epochs):

    for emb, mesh in tqdm(train_dataloader):
        emb = emb.to(device)
        mesh = mesh.to(device)
        recon = model_checkpoint(emb)
        loss = criterion(recon, mesh)

        emb = emb.detach()
        mesh = mesh.detach()
        recon = recon.detach()

        train_loss.append(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch:{epoch+1}, Loss:{loss.item():.4f}")
    outputs.append((epoch,emb,recon))

100%|██████████| 1346/1346 [19:24<00:00,  1.16it/s]


Epoch:1, Loss:0.0105


100%|██████████| 1346/1346 [21:44<00:00,  1.03it/s]


Epoch:2, Loss:0.0104


100%|██████████| 1346/1346 [18:53<00:00,  1.19it/s]


Epoch:3, Loss:0.0116


100%|██████████| 1346/1346 [21:40<00:00,  1.04it/s]


Epoch:4, Loss:0.0122


100%|██████████| 1346/1346 [22:25<00:00,  1.00it/s]


Epoch:5, Loss:0.0109


  4%|▍         | 51/1346 [00:42<18:28,  1.17it/s]

: 

In [None]:
print(f"Train Loss: {sum(train_loss)/len(train_loss)}")

Train Loss: 0.03406316414475441


In [None]:
decoded = []
test_loss = []

with torch.no_grad():
    for inputs, outputs in tqdm(test_dataloader):
        # Forward pass through the encoder
        # encoded_representation = autoencoder.encoder(inputs)
        inputs = inputs.to(device)
        outputs = outputs.to(device)
        # Forward pass through the decoder
        reconstructed_output = model.decoder(inputs)

        # Reconstruction loss (optional, depending on your use case)
        reconstruction_loss = criterion(reconstructed_output, outputs)

        # Print or use the reconstructed output as needed
        decoded.append(VoxDataset.to_mesh_points(reconstructed_output))

        test_loss.append(reconstruction_loss)

100%|██████████| 200/200 [00:03<00:00, 59.40it/s]


In [None]:
print(f"Test Loss: {sum(test_loss)/len(test_loss)}")

Test Loss: 0.01122639887034893


In [None]:
import open3d as o3d
import numpy as np

for index, i in enumerate(decoded):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(i)
    # o3d.visualization.draw_geometries([pcd])
    o3d.io.write_point_cloud(f"./Test/data{index}.ply", pcd)


In [None]:
# Test Loss: 1486.4837646484375