In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import open3d as o3d
import json


from torch.utils.data import ConcatDataset, DataLoader, Subset
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

from Helpers.data import SDDataset, PointCloudDataset

if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f'Using: {device}')

Using: cuda


In [26]:
torch.save(trainset, '../Data/sd_dataset_3072_train_easier')
torch.save(testset, '../Data/sd_dataset_3072_test_easier')

In [28]:
trainset = torch.load('../Data/sd_dataset_3072_train_easier.pt', weights_only= False)
testset = torch.load('../Data/sd_dataset_3072_test_easier.pt', weights_only= False)
# trainset = SDDataset('../Data/ModelNet40', '../Data/sampled_points_easier', device, 3072, 'train', object_classes= None)
# testset = SDDataset('../Data/ModelNet40', '../Data/sampled_points_easier', device, 3072, 'test', object_classes= None)

batch_size = 32

train_loader = DataLoader(trainset, batch_size= batch_size, shuffle= True)
test_loader = DataLoader(testset, batch_size= batch_size, shuffle= True)


In [33]:
def train_model(key, model, num_epochs, train_loader, val_loader):

    optimizer = optim.AdamW(model.parameters(), lr=0.0001)

    min_val_loss = np.inf

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):

        # Train one epoch
        train_loss = 0 
        
        count = 0
        for data in tqdm(train_loader):
            
            latent_rep = data['latent_rep'].to(device)
            xyz = data['xyz'].to(device)
            sd = (data['sd'] > 0).float().to(device)

            pred = model(latent_rep, xyz)
            pred = pred.squeeze()

            optimizer.zero_grad()
            
            loss = F.binary_cross_entropy_with_logits(pred, sd)

            loss.backward()

            optimizer.step()

            train_loss += loss.item()

            if count % 50000 == 0:
                print(f'{count}: {loss.item()}')

            count +=1

        train_loss /= len(train_loader)

        # Calculate validation loss

        val_loss = 0 

        for data in val_loader:

            latent_rep = data['latent_rep'].to(device)
            xyz = data['xyz'].to(device)
            # sd = data['sd'].to(device)
            sd = (data['sd'] >  0).float().to(device)

            with torch.no_grad():
                pred = model(latent_rep, xyz)
                pred = pred.squeeze()

                loss = F.binary_cross_entropy_with_logits(pred, sd)
                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(f'\nEpoch {epoch+1} \t Train Loss: {train_loss:.5f} \t Val Loss: {val_loss:.5f}')

        # Save best model
        if val_loss < min_val_loss:
            print(f'Val Loss Decreased({min_val_loss:.6f} ---> {val_loss:.6f}) \t Saving The Model')
            min_val_loss = val_loss

            torch.save(model.state_dict(), f'./trained_sdf_models/{key}')

        train_losses.append(train_loss)
        val_losses.append(val_loss)

    return train_losses, val_losses


In [None]:
model = SD_Model_Simple(512)
model = model.to(device)

num_epochs = 3
key = "mlp_512_simple"
train_losses, val_losses = train_model(key, model, num_epochs , train_loader, test_loader)

results = {
    'train_losses': train_losses,
    'val_losses' : val_losses
           }

with open(f'./{key}_results.json', 'w') as f: 
    json.dumps(f, results)


In [None]:
model = SD_Model_Small(512)
model = model.to(device)

num_epochs = 3
key = "mlp_512_small"
train_losses, val_losses = train_model(key, model, num_epochs , train_loader, test_loader)

results = {
    'train_losses': train_losses,
    'val_losses' : val_losses
           }

with open(f'./{key}_results.json', 'w') as f: 
    json.dumps(f, results)