In [1]:
import os
import urllib.request
import zipfile

# URL of the dataset
url = "http://modelnet.cs.princeton.edu/ModelNet40.zip"

# Directory where the dataset will be downloaded
data_dir = "/kaggle/working/ModelNet40.zip"  # Use a path within Kaggle's writable area

# Download the dataset
print("Downloading dataset...")
urllib.request.urlretrieve(url, data_dir)
print("Download completed.")

# Extract the dataset
print("Extracting dataset...")
with zipfile.ZipFile(data_dir, 'r') as zip_ref:
    zip_ref.extractall(os.path.dirname(data_dir))
print("Extraction completed.")

# Path to the extracted dataset
DATA_DIR = os.path.join(os.path.dirname(data_dir), "ModelNet40")
print(f"Dataset extracted to {DATA_DIR}")


Downloading dataset...
Download completed.
Extracting dataset...
Extraction completed.
Dataset extracted to /kaggle/working/ModelNet40


In [2]:
# Libraries
!pip install trimesh
import torch
import torch.utils.data as data
import glob
import numpy as np
from torch.utils.data import Dataset, DataLoader
import trimesh
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# Parameter setup
NUM_POINTS = 1024
NUM_CLASSES = 40  # ModelNet40
BATCH_SIZE = 64
poly_degree = 4  # Polynomial degree of Jacobi Polynomial
FEATURE = 6  # (x,y,z) + (nx,ny,nz)
ALPHA = 1.0  # \alpha in Jacobi Polynomial
BETA = 1.0  # \beta in Jacobi Polynomial
SCALE = 3.0  # To control the size of tensor A in the manuscript
MAX_EPOCHS = 300
SSL_EPOCHS = 300  # Epochs for self-supervised pre-training
direction = '/kaggle/working/ModelNet40'

# Data Augmentation Functions
def random_rotation(point_cloud):
    """Apply random rotation to the point cloud."""
    theta = np.random.uniform(0, 2 * np.pi)
    rotation_matrix = np.array([
        [np.cos(theta), -np.sin(theta), 0],
        [np.sin(theta), np.cos(theta), 0],
        [0, 0, 1]
    ])
    rotated_points = np.dot(point_cloud[:, :3], rotation_matrix)
    return np.concatenate([rotated_points, point_cloud[:, 3:]], axis=1)

def add_noise(point_cloud, sigma=0.01):
    """Add Gaussian noise to the point cloud."""
    noise = np.random.normal(0, sigma, point_cloud.shape)
    return point_cloud + noise

def jitter_point_cloud(point_cloud, sigma=0.01, clip=0.05):
    """Add jitter to the point cloud."""
    noise = np.clip(sigma * np.random.randn(*point_cloud.shape), -clip, clip)
    return point_cloud + noise

def augment_point_cloud(point_cloud):
    """Apply a series of augmentations to the point cloud."""
    point_cloud = random_rotation(point_cloud)
    point_cloud = add_noise(point_cloud)
    point_cloud = jitter_point_cloud(point_cloud)
    return point_cloud



Collecting trimesh
  Downloading trimesh-4.6.6-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.6.6-py3-none-any.whl (709 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m709.3/709.3 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.6.6


In [3]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z1, z2):
        """Compute contrastive loss for positive pairs (z1, z2)."""
        batch_size = z1.size(0)

        # Normalize embeddings
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)

        # Compute similarity matrix
        similarity_matrix = torch.exp(torch.mm(z1, z2.T) / self.temperature)

        # Mask to remove self-similarity
        mask = torch.eye(batch_size, device=z1.device, dtype=torch.bool)
        positives = similarity_matrix[mask].view(batch_size, -1)

        # Compute negatives
        negatives = similarity_matrix[~mask].view(batch_size, -1)

        # Compute contrastive loss
        loss = -torch.log(positives / (positives + torch.sum(negatives, dim=1, keepdim=True)))
        return loss.mean()

In [4]:
# PointNet-KAN with SSL Support
# PointNet-KAN with SSL Support
class PointNetKANSSL(nn.Module):
    def __init__(self, input_channels, output_channels, scaling=SCALE):
        super(PointNetKANSSL, self).__init__()
        self.jacobikan5 = KANshared(input_channels, int(1024 * scaling), poly_degree)
        self.jacobikan6 = KAN(int(1024 * scaling), output_channels, poly_degree)
        self.bn5 = nn.BatchNorm1d(int(1024 * scaling))

        # Projection head for SSL
        self.projection_head = nn.Sequential(
            nn.Linear(output_channels, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, x, return_projection=False):
        x = self.jacobikan5(x)
        x = self.bn5(x)
        global_feature = F.adaptive_max_pool1d(x, output_size=1).squeeze(-1)
        x = self.jacobikan6(global_feature)

        if return_projection:
            return self.projection_head(x)
        return x

# Dataset with SSL Support
class PointCloudDataset(Dataset):
    def __init__(self, points, labels, ssl=False):
        self.points = points
        self.labels = labels
        self.ssl = ssl  # Flag for self-supervised learning
        self.normalize()

    def normalize(self):
        for i in range(self.points.shape[0]):
            spatial_coords = self.points[i, :, :3]
            normals = self.points[i, :, 3:]
            centroid = spatial_coords.mean(axis=0, keepdims=True)
            spatial_coords -= centroid
            furthest_distance = torch.max(torch.sqrt(torch.sum(spatial_coords ** 2, axis=1, keepdims=True)))
            spatial_coords /= furthest_distance
            self.points[i] = torch.cat((spatial_coords, normals), dim=1)

    def __len__(self):
        return len(self.points)

    def __getitem__(self, idx):
        point = self.points[idx]
        label = self.labels[idx]

        if self.ssl:
            # Return two augmented views for SSL
            point1 = augment_point_cloud(point.cpu().numpy())
            point2 = augment_point_cloud(point.cpu().numpy())
            return torch.tensor(point1, dtype=torch.float32), torch.tensor(point2, dtype=torch.float32), label
        else:
            return point, label

In [5]:
# KANshared and KAN Classes
class KANshared(nn.Module):
    def __init__(self, input_dim, output_dim, degree, a=ALPHA, b=BETA):
        super(KANshared, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.a = a
        self.b = b
        self.degree = degree

        self.jacobi_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
        nn.init.normal_(self.jacobi_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))

    def forward(self, x):
        batch_size, input_dim, num_points = x.shape
        x = x.permute(0, 2, 1).contiguous() 
        x = torch.tanh(x) 

        jacobi = torch.ones(batch_size, num_points, self.input_dim, self.degree + 1, device=x.device)

        if self.degree > 0:
            jacobi[:, :, :, 1] = ((self.a - self.b) + (self.a + self.b + 2) * x) / 2

        for i in range(2, self.degree + 1):
            A = (2*i + self.a + self.b - 1)*(2*i + self.a + self.b)/((2*i) * (i + self.a + self.b))
            B = (2*i + self.a + self.b - 1)*(self.a**2 - self.b**2)/((2*i)*(i + self.a + self.b)*(2*i+self.a+self.b-2))
            C = -2*(i + self.a -1)*(i + self.b -1)*(2*i + self.a + self.b)/((2*i)*(i + self.a + self.b)*(2*i + self.a + self.b -2))
            jacobi[:, :, :, i] = (A*x + B)*jacobi[:, :, :, i-1].clone() + C*jacobi[:, :, :, i-2].clone()

        jacobi = jacobi.permute(0, 2, 3, 1)  
        y = torch.einsum('bids,iod->bos', jacobi, self.jacobi_coeffs) 
        return y

class KAN(nn.Module):
    def __init__(self, input_dim, output_dim, degree, a=ALPHA, b=BETA):
        super(KAN, self).__init__()
        self.inputdim = input_dim
        self.outdim   = output_dim
        self.a        = a
        self.b        = b
        self.degree   = degree

        self.jacobi_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
        nn.init.normal_(self.jacobi_coeffs, mean=0.0, std=1/(input_dim * (degree + 1)))

    def forward(self, x):
        x = torch.reshape(x, (-1, self.inputdim)) 
        x = torch.tanh(x)
        
        jacobi = torch.ones(x.shape[0], self.inputdim, self.degree + 1, device=x.device)
        if self.degree > 0:
            jacobi[:, :, 1] = ((self.a - self.b) + (self.a + self.b + 2) * x) / 2

        for i in range(2, self.degree + 1):
            A = (2*i + self.a + self.b - 1)*(2*i + self.a + self.b)/((2*i) * (i + self.a + self.b))
            B = (2*i + self.a + self.b - 1)*(self.a**2 - self.b**2)/((2*i)*(i + self.a + self.b)*(2*i+self.a+self.b-2))
            C = -2*(i + self.a -1)*(i + self.b -1)*(2*i + self.a + self.b)/((2*i)*(i + self.a + self.b)*(2*i + self.a + self.b -2))
            jacobi[:, :, i] = (A*x + B)*jacobi[:, :, i-1].clone() + C*jacobi[:, :, i-2].clone()

        y = torch.einsum('bid,iod->bo', jacobi, self.jacobi_coeffs) 
        y = y.view(-1, self.outdim)
        return y

In [6]:
# Parse Dataset
def parse_dataset(num_points=NUM_POINTS):
    train_points_with_normals = []
    train_labels = []
    test_points_with_normals = []
    test_labels = []
    class_map = {}

    DATA_DIR = direction
    folders = glob.glob(os.path.join(DATA_DIR, "*"))

    for i, folder in enumerate(folders):
        print("processing class: {}".format(os.path.basename(folder)))
        class_map[i] = folder.split("/")[-1]
        train_files = glob.glob(os.path.join(folder, "train/*"))
        test_files = glob.glob(os.path.join(folder, "test/*"))

        for f in train_files:
            mesh = trimesh.load(f)
            points, face_indices = mesh.sample(num_points, return_index=True)
            normals = mesh.face_normals[face_indices]
            points_with_normals = np.concatenate([points, normals], axis=1)
            train_points_with_normals.append(points_with_normals)
            train_labels.append(i)

        for f in test_files:
            mesh = trimesh.load(f)
            points, face_indices = mesh.sample(num_points, return_index=True)
            normals = mesh.face_normals[face_indices]
            points_with_normals = np.concatenate([points, normals], axis=1)
            test_points_with_normals.append(points_with_normals)
            test_labels.append(i)

    train_points = torch.tensor(np.array(train_points_with_normals), dtype=torch.float32)
    test_points = torch.tensor(np.array(test_points_with_normals), dtype=torch.float32)
    train_labels = torch.tensor(np.array(train_labels), dtype=torch.long)
    test_labels = torch.tensor(np.array(test_labels), dtype=torch.long)

    return train_points, test_points, train_labels, test_labels, class_map

# Load Data
train_points, test_points, train_labels, test_labels, CLASS_MAP = parse_dataset(NUM_POINTS)

# Create Datasets and DataLoaders
train_dataset_ssl = PointCloudDataset(train_points, train_labels, ssl=True)
train_loader_ssl = DataLoader(train_dataset_ssl, batch_size=BATCH_SIZE, shuffle=True)

train_dataset = PointCloudDataset(train_points, train_labels)
test_dataset = PointCloudDataset(test_points, test_labels)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


# Training Loop for SSL with Fixed Loss
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PointNetKANSSL(input_channels=FEATURE, output_channels=NUM_CLASSES, scaling=SCALE).to(device)
contrastive_loss = ContrastiveLoss(temperature=0.5)
optimizer_ssl = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_ssl, T_max=SSL_EPOCHS)

for epoch in range(SSL_EPOCHS):
    model.train()
    running_loss = 0.0

    for points1, points2, _ in train_loader_ssl:
        points1, points2 = points1.to(device), points2.to(device)
        points1 = points1.transpose(1, 2)
        points2 = points2.transpose(1, 2)

        z1 = model(points1, return_projection=True)
        z2 = model(points2, return_projection=True)

        loss = contrastive_loss(z1, z2)

        optimizer_ssl.zero_grad()
        loss.backward()
        optimizer_ssl.step()

        running_loss += loss.item()

    scheduler.step()
    epoch_loss = running_loss / len(train_loader_ssl)
    print(f"SSL Epoch {epoch + 1}/{SSL_EPOCHS}, Loss: {epoch_loss:.4f}, LR: {optimizer_ssl.param_groups[0]['lr']:.6f}")


processing class: bench
processing class: stairs
processing class: xbox
processing class: sink
processing class: guitar
processing class: night_stand
processing class: laptop
processing class: dresser
processing class: glass_box
processing class: range_hood
processing class: mantel
processing class: bed
processing class: keyboard
processing class: cone
processing class: stool
processing class: cup
processing class: vase
processing class: table
processing class: car
processing class: plant
processing class: door
processing class: airplane
processing class: chair
processing class: desk
processing class: wardrobe
processing class: monitor
processing class: radio
processing class: lamp
processing class: toilet
processing class: bathtub
processing class: bookshelf
processing class: piano
processing class: bowl
processing class: tv_stand
processing class: sofa
processing class: curtain


  stacked = np.column_stack(stacked).round().astype(np.int64)


processing class: flower_pot
processing class: bottle
processing class: tent
processing class: person
SSL Epoch 1/300, Loss: 2.7936, LR: 0.001000
SSL Epoch 2/300, Loss: 2.5263, LR: 0.001000
SSL Epoch 3/300, Loss: 2.4826, LR: 0.001000
SSL Epoch 4/300, Loss: 2.4619, LR: 0.001000
SSL Epoch 5/300, Loss: 2.4735, LR: 0.000999
SSL Epoch 6/300, Loss: 2.4530, LR: 0.000999
SSL Epoch 7/300, Loss: 2.4364, LR: 0.000999
SSL Epoch 8/300, Loss: 2.4675, LR: 0.000998
SSL Epoch 9/300, Loss: 2.4087, LR: 0.000998
SSL Epoch 10/300, Loss: 2.3841, LR: 0.000997
SSL Epoch 11/300, Loss: 2.3951, LR: 0.000997
SSL Epoch 12/300, Loss: 2.3959, LR: 0.000996
SSL Epoch 13/300, Loss: 2.4169, LR: 0.000995
SSL Epoch 14/300, Loss: 2.4026, LR: 0.000995
SSL Epoch 15/300, Loss: 2.3609, LR: 0.000994
SSL Epoch 16/300, Loss: 2.3521, LR: 0.000993
SSL Epoch 17/300, Loss: 2.3517, LR: 0.000992
SSL Epoch 18/300, Loss: 2.3612, LR: 0.000991
SSL Epoch 19/300, Loss: 2.4264, LR: 0.000990
SSL Epoch 20/300, Loss: 2.4240, LR: 0.000989
SSL Epo

In [7]:

# Fine-Tuning for Classification
model.projection_head = nn.Linear(NUM_CLASSES, NUM_CLASSES).to(device)  # Replace projection head
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print("Starting Fine-Tuning for Classification...")
for epoch in range(MAX_EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for points, labels in train_loader:
        points, labels = points.to(device), labels.to(device)
        points = points.transpose(1, 2)

        outputs = model(points)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * points.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_accuracy = 100 * correct / total
    print(f"Epoch {epoch + 1}/{MAX_EPOCHS}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")

# Evaluation
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0

with torch.no_grad():
    for points, labels in test_loader:
        points, labels = points.to(device), labels.to(device)
        points = points.transpose(1, 2)

        outputs = model(points)
        loss = criterion(outputs, labels)

        val_loss += loss.item() * points.size(0)
        _, predicted = torch.max(outputs, 1)
        val_total += labels.size(0)
        val_correct += (predicted == labels).sum().item()

val_loss /= len(test_loader.dataset)
val_accuracy = 100 * val_correct / val_total
print(f"Test Loss: {val_loss:.4f}, Test Accuracy: {val_accuracy:.2f}%")



Starting Fine-Tuning for Classification...
Epoch 1/300, Loss: 152.5357, Accuracy: 28.50%
Epoch 2/300, Loss: 12.4527, Accuracy: 63.13%
Epoch 3/300, Loss: 5.4092, Accuracy: 74.21%
Epoch 4/300, Loss: 3.5992, Accuracy: 78.68%
Epoch 5/300, Loss: 2.7902, Accuracy: 81.60%
Epoch 6/300, Loss: 2.1975, Accuracy: 83.96%
Epoch 7/300, Loss: 1.9012, Accuracy: 84.85%
Epoch 8/300, Loss: 1.6425, Accuracy: 86.05%
Epoch 9/300, Loss: 1.4757, Accuracy: 87.36%
Epoch 10/300, Loss: 1.1452, Accuracy: 88.64%
Epoch 11/300, Loss: 1.0635, Accuracy: 89.33%
Epoch 12/300, Loss: 1.2547, Accuracy: 88.90%
Epoch 13/300, Loss: 1.0798, Accuracy: 89.61%
Epoch 14/300, Loss: 0.8637, Accuracy: 91.18%
Epoch 15/300, Loss: 0.9536, Accuracy: 90.55%
Epoch 16/300, Loss: 0.8049, Accuracy: 91.82%
Epoch 17/300, Loss: 0.6954, Accuracy: 92.25%
Epoch 18/300, Loss: 0.8583, Accuracy: 91.14%
Epoch 19/300, Loss: 0.8304, Accuracy: 91.53%
Epoch 20/300, Loss: 0.7775, Accuracy: 92.17%
Epoch 21/300, Loss: 0.8624, Accuracy: 92.03%
Epoch 22/300, Loss