In [2]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import open3d as o3d
import os
import torch
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
import sys
from pathlib import Path
import time

sys.path.append(str(Path.cwd().parent))

from Helpers.data import PointCloudDataset

if torch.cuda.is_available():
    device = "cuda"

elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f'Using: {device}')

Using: cpu


In [3]:
train_dataset = PointCloudDataset("../ModelNet40", 5000, 'train')
train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = False)

9843


KeyboardInterrupt: 

In [28]:
class MLPEncoder(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.fc1 = nn.Linear(config.input_dim, config.hidden_dim1)
        self.fc2 = nn.Linear(config.hidden_dim1, config.hidden_dim2)
        self.fc3 = nn.Linear(config.hidden_dim2, config.latent_dim)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        x = self.fc3(x)

        return x

class MLPDecoder(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.fc1 = nn.Linear(config.latent_dim, config.hidden_dim2)
        self.fc2 = nn.Linear(config.hidden_dim2, config.hidden_dim1)
        self.fc3 = nn.Linear(config.hidden_dim1, config.input_dim)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        x = self.fc3(x)

        return x


class Autoencoder(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.encoder = MLPEncoder(config)
        self.decoder = MLPDecoder(config)

    def forward(self,x):
        latent_rep = self.encoder(x)
        out = self.decoder(latent_rep)
        return out


# class Autoencoder(nn.Module):

#     def __init__(self, config):
#         super().__init__()

#         self.fc = nn.Linear(config.input_dim, config.input_dim)

#     def forward(self,x):
#         y = self.fc(x)
#         return x + (.00001 * y)
    

class PointCloudAutoencoder(nn.Module):
    def __init__(self, config):
        super(PointCloudAutoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(config.input_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Linear(config.hidden_dim, config.latent_dim)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(config.latent_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Linear(config.hidden_dim, config.input_dim)
        )

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed


In [30]:
point_cloud_size = 5000 


# @dataclass 
# class MLPAutoEncoderConfig:
#     input_dim = point_cloud_size * 3
#     hidden_dim1 = 3072
#     hidden_dim2 = 1048
#     latent_dim = 512

# @dataclass 
# class MLPAutoEncoderConfig:
#     input_dim = point_cloud_size * 3
#     hidden_dim1 = point_cloud_size * 3
#     hidden_dim2 = point_cloud_size * 3
#     latent_dim = point_cloud_size * 3

@dataclass 
class MLPAutoEncoderConfig:
    input_dim = point_cloud_size * 3
    hidden_dim = 2048
    latent_dim = 512


config = MLPAutoEncoderConfig()

# model = Autoencoder(config).to(device)
model = PointCloudAutoencoder(config).to(device)

optim = torch.optim.AdamW(model.parameters(), lr= 1e-4)

epochs = 100


report_rate = 600

s= time.time()

for epoch in range(epochs):

    running_loss = 0 

    batch_count = 0 

    for i, data in enumerate(train_loader):

        x = data['points']
        x = x.view(x.shape[0], -1).to(device)
        
        optim.zero_grad()

        pred = model(x)

        loss = F.mse_loss(pred, x)

        loss.backward()

        optim.step()

        running_loss += loss.item()

        batch_count +=1

    if epoch % 1 == 0:
        print(f'Epoch {epoch:<3} Epoch Loss: {running_loss / batch_count}')

        # if i % report_rate == report_rate - 1:
        #     print(f'Batch {i:<3} Running Loss: {running_loss / report_rate}')
        #     running_loss = 0

print(time.time() - s)
    

Epoch 0   Epoch Loss: 2.3480648486699917e+19
Epoch 1   Epoch Loss: 2.495094256946952e+20
Epoch 2   Epoch Loss: 2.1314188879035213e+20
Epoch 3   Epoch Loss: 7.413677845910648e+19
Epoch 4   Epoch Loss: 2.2948326830167376e+19
Epoch 5   Epoch Loss: 2.0969642647735427e+19
Epoch 6   Epoch Loss: 1.798323757158154e+19
Epoch 7   Epoch Loss: 2.57058508589087e+19
Epoch 8   Epoch Loss: 7.426546715074205e+18
Epoch 9   Epoch Loss: 9.568644576636707e+18
Epoch 10  Epoch Loss: 1.0529830274308594e+19
Epoch 11  Epoch Loss: 1.2385103378147912e+19
Epoch 12  Epoch Loss: 8.210757999839613e+18
Epoch 13  Epoch Loss: 6.844340003392166e+18
Epoch 14  Epoch Loss: 1.3062039865639199e+19
Epoch 15  Epoch Loss: 2.0046769900805444e+19
Epoch 16  Epoch Loss: 1.7305529961227743e+19
Epoch 17  Epoch Loss: 5.178540398903014e+18
Epoch 18  Epoch Loss: 6.932961992673825e+18
Epoch 19  Epoch Loss: 7.116343701280102e+18
Epoch 20  Epoch Loss: 9.438354456650093e+18
Epoch 21  Epoch Loss: 1.1172786477994236e+19
Epoch 22  Epoch Loss: 7

KeyboardInterrupt: 

In [None]:
x.view(x.shape[0], -1).shape

torch.Size([16, 15000])