In [None]:
import os
import h5py
import numpy 
import random
import math
import shutil
from tqdm import tqdm
from path import Path
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.offline as pyo

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:50"

# Configuration

In [None]:
SAMPLE_POINTS = 2000
CLASSESS_CNT = 4

# Device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")

# Data Augmentation

In [None]:
class Normalize(object):
    def __call__(self, pointcloud):
        norm_pointcloud = pointcloud - numpy.mean(pointcloud, axis=0)
        norm_pointcloud /= numpy.max(numpy.linalg.norm(norm_pointcloud, axis=1))
        return  norm_pointcloud
    
class RandomNoise(object):
    def __call__(self, pointcloud):
        noise = numpy.random.normal(0, 0.02, (pointcloud.shape))
        noisy_pointcloud = pointcloud + noise
        return  noisy_pointcloud
    
class RandomScale(object):
    def __call__(self, pointcloud):
        s = numpy.random.uniform(0.9, 1.1, 3)
        rot_mat = numpy.array([[s[0], 0, 0],
                            [0, s[1], 0],
                            [0, 0, s[2]]])
        return numpy.matmul(pointcloud, rot_mat)

In [None]:
def default_transforms():
    return transforms.Compose([
        Normalize(),
        transforms.ToTensor()])

def default_transforms_no_normalize():
    return transforms.Compose([transforms.ToTensor()])

def training_transforms():
    return transforms.Compose([
        Normalize(),
        RandomNoise(),
        RandomScale(),
        transforms.ToTensor()])

def training_transforms_no_normalize():
    return transforms.Compose([
        RandomNoise(),
        RandomScale(),
        transforms.ToTensor()])

# Visualize Part Instance Utility Function

In [None]:
import plotly.graph_objects as go
import plotly.io as pio

def visualize_point_cloud(point_cloud, labels, export_svg=False, filename='point_cloud.svg'):
    trace = go.Scatter3d(
        x=point_cloud[:, 0],
        y=point_cloud[:, 1],
        z=point_cloud[:, 2],
        mode='markers',
        marker=dict(size=5, color=labels, colorscale='Viridis', opacity=0.5),
    )

    data = [trace]

    layout = go.Layout(
        scene=dict(
            xaxis=dict(visible=False, showbackground=False),
            yaxis=dict(visible=False, showbackground=False),
            zaxis=dict(visible=False, showbackground=False),
        ),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)'
    )

    fig = go.Figure(data=data, layout=layout)
    
    if export_svg:
        pio.write_image(fig, filename, format='svg')
    else:
        pio.show(fig)

# Example usage
# visualize_point_cloud(point_cloud, labels, export_svg=True, filename='point_cloud.svg')


# Dataset

In [None]:
class GenerativeJawDataset(Dataset):
    def __init__(self, npy_file_path="/kaggle/input/generative-jaw/generative.npy", transform=training_transforms()):
        self.data = numpy.load(npy_file_path)
        self.transform = transform

    def __len__(self):
        return (self.data).shape[0]

    def __getitem__(self, idx):
        itemdata = self.data[idx]
        pointcloud = numpy.column_stack((itemdata[:,0], itemdata[:,1], itemdata[:,2]))
        label = itemdata[:,3] 
        d = 1  # d = 1, as it has annotations
        pointcloud = self.transform(pointcloud)
        pointcloud = pointcloud[0]
        
        lower = 1
        if idx >= 10: lower = 0
            
        assert pointcloud.shape[0] == SAMPLE_POINTS
        assert pointcloud.shape[1] == 3
        
        del itemdata
        torch.cuda.empty_cache()

        return pointcloud.type(torch.FloatTensor), torch.tensor(label).type(torch.LongTensor), torch.tensor(d).type(torch.LongTensor), torch.tensor(lower).type(torch.LongTensor)

In [None]:
class TestJawDataset(Dataset):
    def __init__(self, npy_file_path="/kaggle/input/real-jaw-2000-pointcloud-annotated-classbased/test_jaw.npy", transform=default_transforms()):
        self.data = numpy.load(npy_file_path)
        self.transform = transform

    def __len__(self):
        return 1 # TODO: Edit this later if we have more annotations

    def __getitem__(self, idx):
        itemdata = (self.data).astype(numpy.float32) # TODO: Edit this later if we have more annotations
        pointcloud = itemdata[:, :3]  # Extract x, y, z columns
        label = itemdata[:, 3] 
        d = 1  # d = 1, as it has annotations
        pointcloud = self.transform(pointcloud)
        pointcloud = pointcloud[0]
        
        if pointcloud.shape[0] > SAMPLE_POINTS:
            pointcloud = pointcloud[:2000]
            label = label[:2000]
        
        if pointcloud.shape[0] < SAMPLE_POINTS:
            pointcloud = torch.cat([pointcloud, pointcloud[-1].repeat(SAMPLE_POINTS-pointcloud.shape[0], 1)], dim=0)
            label = torch.cat([label, label[-1].repeat(SAMPLE_POINTS-label.shape[0], 1)], dim=0)
        
        assert pointcloud.shape[0] == SAMPLE_POINTS
        assert pointcloud.shape[1] == 3
        
        del itemdata
        torch.cuda.empty_cache()

        return pointcloud.type(torch.FloatTensor), torch.tensor(label).type(torch.LongTensor), torch.tensor(d).type(torch.LongTensor)

In [None]:
class RealJawDataset(Dataset):
    def __init__(self, data_dir="/kaggle/input/real-jaw-2000-pointcloud/RealJaw2000", transform=training_transforms()):
        self.data_dir = data_dir
        self.file_list = sorted(os.listdir(data_dir)[:20])
        self.transform = transform

    def __len__(self):
        return len(self.file_list)
    
    def load_xyz_file(self, file_path):
        with open(file_path, 'r') as f:
            data = numpy.loadtxt(f, delimiter=' ')
        return data

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.file_list[idx])

        data = self.load_xyz_file(file_path)

        # Load XYZ data from the file (assuming comma-delimited)
        pointcloud = numpy.column_stack((data[:,0], data[:,1], data[:,2])).astype(numpy.float32)

        # Create a fixed label (doesnt matter) and d (0)
        label = torch.zeros(SAMPLE_POINTS, dtype=torch.int64)
        d = torch.tensor(0, dtype=torch.int64)
        
        lower = idx%2
        
        pointcloud = self.transform(pointcloud)
        pointcloud = pointcloud[0]
        
        if pointcloud.shape[0] > SAMPLE_POINTS:
            pointcloud = pointcloud[:2000]
        
        if pointcloud.shape[0] < SAMPLE_POINTS :
            pointcloud = torch.cat([pointcloud, pointcloud[-1].repeat(SAMPLE_POINTS-pointcloud.shape[0], 1)], dim=0)
        
        assert pointcloud.shape[0] == SAMPLE_POINTS
        assert pointcloud.shape[1] == 3
        
        del file_path
        del data
        torch.cuda.empty_cache()

        return pointcloud.type(torch.FloatTensor),label, d,torch.tensor(lower).type(torch.LongTensor)

In [None]:
class TrainingDataset(Dataset):
    def __init__(self):
        self.generative_dataset = GenerativeJawDataset()
        self.real_dataset = RealJawDataset()

    def __len__(self):
        return len(self.generative_dataset) + len(self.real_dataset)

    def __getitem__(self, idx):
        if idx < len(self.real_dataset):
            return self.real_dataset[idx]
        else:
            return self.generative_dataset[idx-len(self.real_dataset)]

In [None]:
generative_dataset = GenerativeJawDataset()
test_dataset = TestJawDataset()
real_dataset = RealJawDataset()

# Visualize Real Data

In [None]:
data = real_dataset[18]
visualize_point_cloud(data[0], data[1])

# Visualize Real Annotated Test Data

In [None]:
data = test_dataset[0]

visualize_point_cloud(data[0], data[1])

# Visualize Generative Jaw Data With Annotations

In [None]:
data = generative_dataset[16]
visualize_point_cloud(data[0], data[1])

# Model: PointNet

In [None]:
class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()

def GRL(x):
    return GradReverse.apply(x)

In [None]:
!git clone https://github.com/yanx27/Pointnet_Pointnet2_pytorch.git

In [None]:
from Pointnet_Pointnet2_pytorch.models.pointnet2_utils import PointNetSetAbstraction, PointNetFeaturePropagation

In [None]:
sa1C = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=3, mlp=[64, 64, 128], group_all=False)
sa1S = PointNetSetAbstraction(npoint=2000, radius=0.2, nsample=32, in_channel=6, mlp=[64, 64, 128], group_all=False)
sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)

In [None]:
class PointNetPPClassHead(nn.Module):
    def __init__(self,num_class=2, normal_channel=False):
        super(PointNetPPClassHead, self).__init__()
        in_channel = 6 if normal_channel else 3
        self.normal_channel = normal_channel
        self.sa1 = sa1C
        self.sa2 = sa2
        self.sa3 = sa3
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.4)
        self.fc3 = nn.Linear(256, num_class)

    def forward(self, xyz):
        B, _, _ = xyz.shape
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        x = l3_points.view(B, 1024)
        x = GRL(x)
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        x = F.log_softmax(x, -1)


        return x, l3_points

In [None]:
class PointNetPPSeghead(nn.Module):
    def __init__(self, num_classes=4, normal_channel=False):
        super(PointNetPPSeghead, self).__init__()
        if normal_channel:
            additional_channel = 3
        else:
            additional_channel = 0
        self.normal_channel = normal_channel
        self.sa1 = sa1S
        self.sa2 = sa2
        self.sa3 = sa3
        self.fp3 = PointNetFeaturePropagation(in_channel=1280, mlp=[256, 256])
        self.fp2 = PointNetFeaturePropagation(in_channel=384, mlp=[256, 128])
        self.fp1 = PointNetFeaturePropagation(in_channel=137+additional_channel, mlp=[128, 128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz, cls_label_one_hot):
        # Set Abstraction layers
        B,C,N = xyz.shape
        if self.normal_channel:
            l0_points = xyz
            l0_xyz = xyz[:,:3,:]
        else:
            l0_points = xyz
            l0_xyz = xyz
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        
        l1_xyz, l1_points = GRL(l1_xyz), GRL(l1_points)
        l2_xyz, l2_points = GRL(l2_xyz), GRL(l2_points)
        l3_xyz, l3_points = GRL(l3_xyz), GRL(l3_points)
        
        # Feature Propagation layers
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot.unsqueeze(dim=1).repeat(1,3,1),l0_xyz,l0_points],1), l1_points)
        # FC layers
        feat =  F.relu(self.bn1(self.conv1(l0_points)))
        x = self.drop1(feat)
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)
        return x, l3_points

In [None]:
class PointNetDomainAdaptationModel(nn.Module):
    def __init__(self, m=4, k1=2, k2=2):
        super(PointNetDomainAdaptationModel,self).__init__()

        self.m = m
        self.k1 = k1
        self.k2 =k2
        
        # Backbone
        self.domain_classifier = PointNetPPClassHead(num_class=k1)
        self.lower_upper_classifier = PointNetPPClassHead(num_class=k2)
        self.segmentation_head = PointNetPPSeghead(num_classes=self.m,)

    def forward(self, x, cls_label):
        d_hat, d_trans_feat = self.domain_classifier(x)
        l_hat, l_trans_feat = self.lower_upper_classifier(x)
        seg_hat, seg_trans_feat = self.segmentation_head(x,cls_label)
        return d_hat, d_trans_feat,l_hat, l_trans_feat, seg_hat, seg_trans_feat

In [None]:
# dummpy = torch.randn(40,3,2000)
# dumpyyy = torch.randn(40,2000)
# model = PointNetDomainAdaptationModel()
# model(dummpy,dumpyyy)
# import gc
# del dummpy
# del dumpyyy
# del model
# gc.collect()

# Finally

In [None]:
training_dataset = TrainingDataset()

In [None]:
len(training_dataset)

In [None]:
model = PointNetDomainAdaptationModel(m=CLASSESS_CNT,k1=2, k2=2).to(device)

In [None]:
# Count the number of trainable parameters
num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of trainable parameters:", num_parameters)

# Calculate the model size in megabytes (MB)
model_size_mb = sum(p.numel() for p in model.parameters()) * 4 / (1024 ** 2)  # assuming 4 bytes per parameter
print("Model size:", model_size_mb, "MB")

# Calculate the size of the model's state dictionary in megabytes (MB)
model_state_dict_size_mb = sum(p.numel() for p in model.state_dict().values()) * 4 / (1024 ** 2)
print("Model state dictionary size:", model_state_dict_size_mb, "MB")

In [None]:
class PointNetSegLoss(torch.nn.Module):
    def __init__(self):
        super(PointNetSegLoss, self).__init__()
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, pred, target, trans_feat):
        loss = self.cross_entropy_loss(pred, target)
        return loss

In [None]:
class PointNetClassLoss(torch.nn.Module):
    def __init__(self):
        super(PointNetClassLoss, self).__init__()
        
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, pred, target, trans_feat):
        loss = self.cross_entropy_loss(pred, target)
        return loss

In [None]:
class TotalLoss(nn.Module):
    def __init__(self):
        super(TotalLoss, self).__init__()
        
        # Lossess
        self.seg_loss = PointNetSegLoss()  
        self.class_loss = PointNetClassLoss()

    def forward(self, d_hat, d, d_A_feat,l_hat, l, l_A_feat, seg_hat, seg, seg_A_feat, domain_adapation_factor=0.5):        
        segmentation_loss = self.seg_loss(seg_hat, seg, seg_A_feat)
        domain_classification_loss = self.class_loss(d_hat, d, d_A_feat)
        label_classification_loss = self.class_loss(l_hat, l, l_A_feat)
        total_loss = d*segmentation_loss + domain_classification_loss + label_classification_loss
        total_loss = total_loss.mean()
        return total_loss

In [None]:
def calculate_iou(predictions, targets):
    intersection = (predictions & targets).sum()
    union = (predictions | targets).sum()
    iou = intersection / union
    return iou

In [None]:
import gc

In [None]:
highest_iou = 0.0
def test(dataloader, model, loss_function):
    global highest_iou
    model.eval()
    total_iou = 0.0
    total_batches = len(dataloader)
    
    with torch.no_grad():
        for batch, data in enumerate(tqdm(dataloader)):
            pointcloud, label, d = (data[0].permute(0, 2, 1)).to(device), data[1].to(device), data[2].to(device)

            # Get Scores
            d_hat, d_trans_feat,l_hat, l_trans_feat, seg_hat, seg_trans_feat = model(pointcloud,torch.zeros_like(label).to(device))

            predicted = torch.argmax(seg_hat, dim=2)

            # Calculate the IoU score for this batch
            iou = calculate_iou(predicted, label)

            total_iou += iou.item()
            
            # Clean memory
            del pointcloud
            del label
            del d
            del d_hat
            del d_trans_feat
            del l_hat
            del l_trans_feat
            del seg_hat
            del seg_trans_feat
            del iou
            del predicted
            gc.collect()
            torch.cuda.empty_cache()
            
    average_iou = total_iou / total_batches
    if average_iou >= highest_iou:
        highest_iou = average_iou
        torch.save(model.state_dict(), "/kaggle/working/model_state_highest.pth")
    print("Average IoU: " + str(average_iou))
    print("")
    print("Highest IoU: " + str(highest_iou))

In [None]:
def train(dataloader, model, optimizer, loss_function):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    loss_tot = 0.0
    num = 0
    for batch, data in enumerate(tqdm(dataloader)):
        pointcloud, label, d, l = (data[0].permute(0, 2, 1)).to(device), data[1].to(device), data[2].to(device), data[3].to(device)
        
        # Zeroing the gradients
        optimizer.zero_grad()

        number_of_parts = CLASSESS_CNT
        seg = torch.nn.functional.one_hot(label, number_of_parts).type(torch.FloatTensor).to(device)

        # Get Scores
        d_hat, d_trans_feat,l_hat, l_trans_feat, seg_hat, seg_trans_feat = model(pointcloud, torch.zeros_like(label).to(device))

        # Calculate Loss
        loss = loss_function(d_hat, d, d_trans_feat,l_hat, l, l_trans_feat, seg_hat, seg, seg_trans_feat)

        # Backpropagation
        loss.backward()

        # Update
        optimizer.step()
        
        loss_tot += loss.item()
        num += 1
        
        # Clean memory
        del pointcloud
        del label
        del d
        del l
        del number_of_parts
        del seg
        del d_hat
        del d_trans_feat
        del l_hat
        del l_trans_feat
        del seg_hat
        del seg_trans_feat
        del loss
        gc.collect()
        torch.cuda.empty_cache()
        
#     loss_tot /= num
    print(f'training loss: {(loss_tot):>0.5f}')
    return loss_tot

In [None]:
def visualize_single_test_data(dataloader, model):
    model.eval()
    with torch.no_grad():
        for batch, data in enumerate(tqdm(dataloader)):
            pointcloud, label, d = (data[0].permute(0, 2, 1)).to(device), data[1].long().to(device), data[2].long().to(device)
            # Get Scores
            d_hat, d_trans_feat,l_hat, l_trans_feat, seg_hat, seg_trans_feat = model(pointcloud, label)
            predicted = torch.argmax(seg_hat, dim=2).cpu()
            
            pointcloud_2 = pointcloud.clone().permute(0,2,1).cpu()
            visualize_point_cloud(pointcloud_2[0], predicted[0])
            break

In [None]:
# Training Hyperparameters
epochs = 500
batch_size = 10
learning_rate = 1e-4
momentum=0.9
weight_decay=0.5

# Dataloader
training_data_loader = DataLoader(training_dataset, batch_size, shuffle = True)
testing_dataloader = DataLoader(test_dataset, batch_size, shuffle = False)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Loss 
loss_function = TotalLoss().to(device)

# Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min', factor=0.1, patience=10, cooldown=5, min_lr=1e-4, verbose=True)

# Training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train(training_data_loader, model, optimizer, loss_function)
    test(testing_dataloader, model,loss_function )
    scheduler.step(train_loss)
    torch.save(model.state_dict(), "/kaggle/working/model_state.pth")
print("Done!")


# Visualize Best Predictions

In [None]:
model.load_state_dict(torch.load("/kaggle/working/model_state_highest.pth"))

In [None]:
visualize_single_test_data(testing_dataloader, model)

data = test_dataset[0]
visualize_point_cloud(data[0], data[1])