In [None]:
import os
import h5py
import numpy 
import random
import math
import shutil
from tqdm import tqdm
from path import Path
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.offline as pyo

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Configuration

In [None]:
SAMPLE_POINTS = 2000
CLASSESS_CNT = 4

# Device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")

# Data Augmentation

In [None]:
class Normalize(object):
    def __call__(self, pointcloud):
        norm_pointcloud = pointcloud - numpy.mean(pointcloud, axis=0)
        norm_pointcloud /= numpy.max(numpy.linalg.norm(norm_pointcloud, axis=1))
        return  norm_pointcloud
    
class RandomNoise(object):
    def __call__(self, pointcloud):
        noise = numpy.random.normal(0, 0.02, (pointcloud.shape))
        noisy_pointcloud = pointcloud + noise
        return  noisy_pointcloud
    
class RandomScale(object):
    def __call__(self, pointcloud):
        s = numpy.random.uniform(0.9, 1.1, 3)
        rot_mat = numpy.array([[s[0], 0, 0],
                            [0, s[1], 0],
                            [0, 0, s[2]]])
        return numpy.matmul(pointcloud, rot_mat)

In [None]:
def default_transforms():
    return transforms.Compose([
        Normalize(),
        transforms.ToTensor()])

def default_transforms_no_normalize():
    return transforms.Compose([transforms.ToTensor()])

def training_transforms():
    return transforms.Compose([
        Normalize(),
        RandomNoise(),
        RandomScale(),
        transforms.ToTensor()])

def training_transforms_no_normalize():
    return transforms.Compose([
        RandomNoise(),
        RandomScale(),
        transforms.ToTensor()])

# Visualize Part Instance Utility Function

In [None]:
def visualize_point_cloud(point_cloud, labels, generative=False):
    # Define a colormap for labels
    colormap = plt.get_cmap("tab10")  # You can choose any other colormap

    trace = go.Scatter3d(
        x=point_cloud[:, 0],
        y=point_cloud[:, 1],
        z=point_cloud[:, 2],
        mode='markers',
        marker=dict(size=5, color=labels, colorscale='Viridis', opacity=0.5),
    )

    data = [trace]

    if True:
        layout = go.Layout(
            scene=dict(
                xaxis=dict(title='X'),
                yaxis=dict(title='Y'),
                zaxis=dict(title='Z'),
            )
        )
    else:
        layout = go.Layout(
            scene=dict(
                xaxis=dict(title='X', range=(0,80)),
                yaxis=dict(title='Y', range=(0,100)),
                zaxis=dict(title='Z', range=(0,80)),
            )
        )
        
    fig = go.Figure(data=data, layout=layout)
    pyo.init_notebook_mode(connected=True)
    pyo.iplot(fig)

# Dataset

In [None]:
class GenerativeJawDataset(Dataset):
    def __init__(self, npy_file_path="/kaggle/input/generative-jaw/generative.npy", transform=training_transforms()):
        self.data = numpy.load(npy_file_path)
        self.transform = transform

    def __len__(self):
        return (self.data).shape[0]

    def __getitem__(self, idx):
        itemdata = self.data[idx]
        pointcloud = numpy.column_stack((itemdata[:,0], itemdata[:,1], itemdata[:,2]))
        label = itemdata[:,3] 
        d = 1  # d = 1, as it has annotations
        pointcloud = self.transform(pointcloud)
        pointcloud = pointcloud[0]
        
        lower = 1
        if idx >= 10: lower = 0
            
        assert pointcloud.shape[0] == SAMPLE_POINTS
        assert pointcloud.shape[1] == 3
        
        del itemdata
        torch.cuda.empty_cache()

        return pointcloud.type(torch.FloatTensor), torch.tensor(label).type(torch.LongTensor), torch.tensor(d).type(torch.LongTensor), torch.tensor(lower).type(torch.LongTensor)

In [None]:
class TestJawDataset(Dataset):
    def __init__(self, npy_file_path="/kaggle/input/real-jaw-2000-pointcloud-annotated-classbased/test_jaw.npy", transform=default_transforms()):
        self.data = numpy.load(npy_file_path)
        self.transform = transform

    def __len__(self):
        return 1 # TODO: Edit this later if we have more annotations

    def __getitem__(self, idx):
        itemdata = (self.data).astype(numpy.float32) # TODO: Edit this later if we have more annotations
        pointcloud = itemdata[:, :3]  # Extract x, y, z columns
        label = itemdata[:, 3] 
        d = 1  # d = 1, as it has annotations
        pointcloud = self.transform(pointcloud)
        pointcloud = pointcloud[0]
        
        if pointcloud.shape[0] > SAMPLE_POINTS:
            pointcloud = pointcloud[:2000]
            label = label[:2000]
        
        if pointcloud.shape[0] < SAMPLE_POINTS:
            pointcloud = torch.cat([pointcloud, pointcloud[-1].repeat(SAMPLE_POINTS-pointcloud.shape[0], 1)], dim=0)
            label = torch.cat([label, label[-1].repeat(SAMPLE_POINTS-label.shape[0], 1)], dim=0)
        
        assert pointcloud.shape[0] == SAMPLE_POINTS
        assert pointcloud.shape[1] == 3
        
        del itemdata
        torch.cuda.empty_cache()

        return pointcloud.type(torch.FloatTensor), torch.tensor(label).type(torch.LongTensor), torch.tensor(d).type(torch.LongTensor)

In [None]:
class RealJawDataset(Dataset):
    def __init__(self, data_dir="/kaggle/input/real-jaw-2000-pointcloud/RealJaw2000", transform=training_transforms()):
        self.data_dir = data_dir
        self.file_list = sorted(os.listdir(data_dir)[:20])
        self.transform = transform

    def __len__(self):
        return len(self.file_list)
    
    def load_xyz_file(self, file_path):
        with open(file_path, 'r') as f:
            data = numpy.loadtxt(f, delimiter=' ')
        return data

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.file_list[idx])

        data = self.load_xyz_file(file_path)

        # Load XYZ data from the file (assuming comma-delimited)
        pointcloud = numpy.column_stack((data[:,0], data[:,1], data[:,2])).astype(numpy.float32)

        # Create a fixed label (doesnt matter) and d (0)
        label = torch.zeros(SAMPLE_POINTS, dtype=torch.int64)
        d = torch.tensor(0, dtype=torch.int64)
        
        lower = (idx%2 == 0)
        
        pointcloud = self.transform(pointcloud)
        pointcloud = pointcloud[0]
        
        if pointcloud.shape[0] > SAMPLE_POINTS:
            pointcloud = pointcloud[:2000]
        
        if pointcloud.shape[0] < SAMPLE_POINTS :
            pointcloud = torch.cat([pointcloud, pointcloud[-1].repeat(SAMPLE_POINTS-pointcloud.shape[0], 1)], dim=0)
        
        assert pointcloud.shape[0] == SAMPLE_POINTS
        assert pointcloud.shape[1] == 3
        
        del file_path
        del data
        torch.cuda.empty_cache()

        return pointcloud.type(torch.FloatTensor),label, d,torch.tensor(lower).type(torch.LongTensor)

In [None]:
class TrainingDataset(Dataset):
    def __init__(self):
        self.generative_dataset = GenerativeJawDataset()
        self.real_dataset = RealJawDataset()

    def __len__(self):
        return (len(self.generative_dataset) + len(self.real_dataset))//2

    def __getitem__(self, idx):
        try:
            if idx % 2 == 0:
                return self.generative_dataset[idx // 2], self.real_dataset[idx]
            else:
                return self.generative_dataset[(idx - 1) // 2 + 10], self.real_dataset[idx]
        except IndexError as e:
            # Handle the IndexError
            print(f"An IndexError occurred when idx={idx}")
            return 0, 0
        

In [None]:
generative_dataset = GenerativeJawDataset()
test_dataset = TestJawDataset()
real_dataset = RealJawDataset()
training_dataset = TrainingDataset()

In [None]:
len(training_dataset)

In [None]:
for i in range(len(training_dataset)):
    gen, re = training_dataset[i]
    print("generative " + str("lower" if gen[3] == 1 else "upper") + " + real " + str("lower" if re[3] == 1 else "upper"))

# Visualize Real Data

In [None]:
data = real_dataset[1]
visualize_point_cloud(data[0], data[1])

# Visualize Real Annotated Test Data

In [None]:
data = test_dataset[0]

visualize_point_cloud(data[0], data[1])

# Visualize Generative Jaw Data With Annotations

In [None]:
data = generative_dataset[16]
visualize_point_cloud(data[0], data[1], generative=True)

# Model: PointNet

In [None]:
!git clone https://github.com/yanx27/Pointnet_Pointnet2_pytorch.git

In [None]:
class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()

def GRL(x):
    return GradReverse.apply(x)

In [None]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F


class STN3d(nn.Module):
    def __init__(self, channel):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
            batchsize, 1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
            batchsize, 1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

def feature_transform_reguliarzer(trans):
    d = trans.size()[1]
    I = torch.eye(d)[None, :, :]
    if trans.is_cuda:
        I = I.cuda()
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2)))
    return loss

In [None]:
class PointNetBackbone(nn.Module):
    def __init__(self):
        super(PointNetBackbone, self).__init__()
        self.stn = STN3d(3)
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.fstn = STNkd(k=64)

    def forward(self, x):
        B, D, N = x.size()
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        trans_feat = self.fstn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans_feat)
        x = x.transpose(2, 1)

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        xx = x.view(-1, 1024, 1).repeat(1, 1, N)
        return x, torch.cat([xx, pointfeat], 1), trans, trans_feat


In [None]:
class PointNetSegHead(nn.Module):
    def __init__(self, backbone=None, part_num=4, grl=None):
        super(PointNetSegHead, self).__init__()

        self.feat = backbone
        
        self.head = nn.Sequential(
            nn.Conv1d(1088, 512, kernel_size=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 256, kernel_size=1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Conv1d(256, 128, kernel_size=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, part_num, kernel_size=1)
        )

    def forward(self, combined_feature):
        x = self.head(combined_feature)
        x = x.transpose(2, 1)
        x = F.log_softmax(x, dim=1)
        return x

In [None]:
class PointNetDomainAdaptationModel(nn.Module):
    def __init__(self, m=CLASSESS_CNT):
        super(PointNetDomainAdaptationModel, self).__init__()

        self.m = m

        # Define the Fundamentals
        self.backbone = PointNetBackbone()
        self.segmentation_head = PointNetSegHead(part_num=self.m)

    def forward(self, x1, x2=None):
        x1_global, x_combined, trans, seg_trans_feat = self.backbone(x1)
        seg_hat = self.segmentation_head(x_combined)
        if x2 is not None:
            x2_global, _, _, _ = self.backbone(x2)
            return seg_hat, seg_trans_feat, x1_global, x2_global
        else:
            return seg_hat, trans, seg_trans_feat

# Finally

In [None]:
model = PointNetDomainAdaptationModel(m=CLASSESS_CNT).to(device)

In [None]:
# Count the number of trainable parameters
num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of trainable parameters:", num_parameters)

# Calculate the model size in megabytes (MB)
model_size_mb = sum(p.numel() for p in model.parameters()) * 4 / (1024 ** 2)  # assuming 4 bytes per parameter
print("Model size:", model_size_mb, "MB")

# Calculate the size of the model's state dictionary in megabytes (MB)
model_state_dict_size_mb = sum(p.numel() for p in model.state_dict().values()) * 4 / (1024 ** 2)
print("Model state dictionary size:", model_state_dict_size_mb, "MB")

In [None]:
class PointNetSegLoss(torch.nn.Module):
    def __init__(self, mat_diff_loss_scale=0.001):
        super(PointNetSegLoss, self).__init__()
        self.mat_diff_loss_scale = mat_diff_loss_scale
        
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, pred, target, trans_feat):
        loss = self.cross_entropy_loss(pred, target)
        mat_diff_loss = feature_transform_reguliarzer(trans_feat)
        total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale
        return total_loss

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, target=1):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
        loss_contrastive = torch.mean((1 - target) * torch.pow(euclidean_distance, 2) +
                                      (target) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive

In [None]:
class TotalLoss(nn.Module):
    def __init__(self, mat_diff_loss_scale=0.001):
        super(TotalLoss, self).__init__()
        self.mat_diff_loss_scale = mat_diff_loss_scale
        
        # Lossess
        self.seg_loss = PointNetSegLoss()  
        self.con_loss = ContrastiveLoss()

    def forward(self, seg_hat, seg, seg_trans_feat, x1_global, x2_global):        
        segmentation_loss = self.seg_loss(seg_hat, seg, seg_trans_feat)
        contrastive_loss = self.con_loss(x1_global, x2_global)
        #total_loss = segmentation_loss + contrastive_loss #### TURNING D.A OFF
        total_loss = segmentation_loss
        total_loss = total_loss.mean()
        return total_loss

In [None]:
def calculate_iou(predictions, targets):
    intersection = (predictions & targets).sum()
    union = (predictions | targets).sum()
    iou = intersection / union
    return iou

In [None]:
import gc

In [None]:
highest_iou = 0.0
def test(dataloader, model, loss_function):
    global highest_iou
    model.eval()
    total_iou = 0.0
    total_batches = len(dataloader)
    
    with torch.no_grad():
        for batch, data in enumerate(tqdm(dataloader)):
            pointcloud, label, d = (data[0].permute(0, 2, 1)).to(device), data[1].to(device), data[2].to(device)

            # Get Scores
            seg_hat, seg_trans, seg_trans_feat = model(pointcloud)

            predicted = torch.argmax(seg_hat, dim=2)

            # Calculate the IoU score for this batch
            iou = calculate_iou(predicted, label)

            total_iou += iou.item()
            
            # Clean memory
            del pointcloud
            del label
            del seg_hat
            del seg_trans
            del seg_trans_feat
            del iou
            del predicted
            gc.collect()
            torch.cuda.empty_cache()
            
    average_iou = total_iou / total_batches
    if average_iou >= highest_iou:
        highest_iou = average_iou
        torch.save(model.state_dict(), "/kaggle/working/model_state_highest.pth")
    print("Average IoU: " + str(average_iou))
    print("")
    print("Highest IoU: " + str(highest_iou))

In [None]:
def train(dataloader, model, optimizer, loss_function):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    loss_tot = 0.0
    num = 0
    for batch, data in enumerate(tqdm(dataloader)):
        data0, data1 = data
        pointcloud0, label0, d0, l0 = (data0[0].permute(0, 2, 1)).to(device), data0[1].to(device), data0[2].to(device), data0[3].to(device)
        pointcloud1, label1, d1, l1 = (data1[0].permute(0, 2, 1)).to(device), data1[1].to(device), data1[2].to(device), data1[3].to(device)
    
        # Zeroing the gradients
        optimizer.zero_grad()

        number_of_parts = CLASSESS_CNT
        seg0 = torch.nn.functional.one_hot(label0, number_of_parts).type(torch.FloatTensor).to(device)
        seg1 = torch.nn.functional.one_hot(label1, number_of_parts).type(torch.FloatTensor).to(device)

        # Get Scores
        seg_hat, seg_trans_feat, x1_global, x2_global = model(pointcloud0, pointcloud1)

        # Calculate Loss
        loss = loss_function(seg_hat, seg0, seg_trans_feat, x1_global, x2_global)

        # Backpropagation
        loss.backward()

        # Update
        optimizer.step()
        
        loss_tot += loss.item()
        num += 1
        
        # Clean memory
        del data0
        del data1
        del pointcloud0
        del pointcloud1
        del label0
        del label1
        del d0
        del d1
        del l0
        del l1
        del number_of_parts
        del seg0
        del seg1
        del seg_hat
        del seg_trans_feat
        del x1_global
        del x2_global
        del loss
        gc.collect()
        torch.cuda.empty_cache()
        
    # loss_tot /= num
    print(f'training loss: {(loss_tot):>0.5f}')
    return loss_tot

In [None]:
def visualize_single_test_data(dataloader, model):
    model.eval()
    with torch.no_grad():
        for batch, data in enumerate(tqdm(dataloader)):
            pointcloud, label, d = (data[0].permute(0, 2, 1)).to(device), data[1].long().to(device), data[2].long().to(device)
            # Get Scores
            seg_hat,_,_ = model(pointcloud)
            predicted = torch.argmax(seg_hat, dim=2).cpu()
            
            pointcloud_2 = pointcloud.clone().permute(0,2,1).cpu()
            visualize_point_cloud(pointcloud_2[0], predicted[0])
            break

In [None]:
# Training Hyperparameters
epochs = 500
batch_size = 50
learning_rate = 1e-3
momentum=0.9
weight_decay=0.5

# Dataloader
training_data_loader = DataLoader(training_dataset, batch_size, shuffle = True)
testing_dataloader = DataLoader(test_dataset, batch_size, shuffle = False)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Loss 
loss_function = TotalLoss(mat_diff_loss_scale=0.001).to(device)

# Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-5, eps=1e-08)

# Training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train(training_data_loader, model, optimizer, loss_function)
    test(testing_dataloader, model,loss_function )
    scheduler.step(train_loss)
    torch.save(model.state_dict(), "/kaggle/working/model_state.pth")
print("Done!")


# Visualize Best Predictions

In [None]:
model.load_state_dict(torch.load("/kaggle/working/model_state_highest.pth"))

In [None]:
visualize_single_test_data(testing_dataloader, model)

In [None]:
data = test_dataset[0]
visualize_point_cloud(data[0], data[1])