In [None]:
# vanilla architecture
# network architecture
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimplePointNet(nn.Module):
    def __init__(self, num_classes):
        super(SimplePointNet, self).__init__()
        self.num_classes = num_classes

        # Shared MLP for feature extraction
        self.mlp1 = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )

        # Fully connected layers for per-point classification
        self.mlp2 = nn.Sequential(
            nn.Conv1d(1024, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Conv1d(256, num_classes, 1)  # Output per point
        )

    def forward(self, x):
        # Input shape: (batch_size, 3, num_points)
        x = self.mlp1(x)  # Shape: (batch_size, 1024, num_points)
        x = self.mlp2(x)  # Shape: (batch_size, num_classes, num_points)

        return x  # Shape: (batch_size, num_classes, num_points)



In [None]:
# simple architecture TNet

class TNet(nn.Module):
    def __init__(self, k=3):
        super(TNet, self).__init__()
        self.k = k
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.bn2 = nn.BatchNorm1d(128)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn3 = nn.BatchNorm1d(1024)

        self.fc1 = nn.Linear(1024, 512)
        self.bn4 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn5 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, k * k)

        self.fc3.weight.data.zero_()
        self.fc3.bias.data.copy_(torch.eye(k).view(-1))

    def forward(self, x):
        B = x.size(0)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2)[0]
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        return x.view(B, self.k, self.k)

class SimplePointNet(nn.Module):
    def __init__(self, num_classes):
        super(SimplePointNet, self).__init__()
        self.num_classes = num_classes

        # Input T-Net
        self.input_transform = TNet(k=3)

        # Shared MLP for feature extraction
        self.mlp1 = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )

        # Fully connected layers for per-point classification
        self.mlp2 = nn.Sequential(
            nn.Conv1d(1024, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Conv1d(256, num_classes, 1)  # Output per point
        )

    def forward(self, x):
        # Input shape: (batch_size, 3, num_points)
        B, _, N = x.shape

        # Apply input transform
        trans = self.input_transform(x)  # (B, 3, 3)
        x = torch.bmm(trans, x)          # Apply to point cloud

        x = self.mlp1(x)                 # (B, 1024, N)
        x = self.mlp2(x)                 # (B, num_classes, N)
        return x


In [1]:
# adding global feature to each point
# adding dropout after concatenating global feature to each point
class PointNet(nn.Module):
    def __init__(self, num_classes):
        super(PointNet, self).__init__()
        self.num_classes = num_classes

        # Shared MLP for feature extraction (split to capture intermediate features)
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.bn1 = nn.BatchNorm1d(64)

        self.conv2 = nn.Conv1d(64, 128, 1)
        self.bn2 = nn.BatchNorm1d(128)

        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn3 = nn.BatchNorm1d(1024)

        # MLP after concatenating global feature to each point (64 + 1024 = 1088)
        self.mlp2 = nn.Sequential(
            nn.Conv1d(1088, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),  # Increased dropout
            nn.Conv1d(512, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv1d(256, self.num_classes, 1)
        )

    def forward(self, x):
        # Input shape: (B, 3, N)
        B, _, N = x.shape

        x = F.relu(self.bn1(self.conv1(x)))  # (B, 64, N)
        point_feat = x                      # Save 64-dim per-point features

        x = F.relu(self.bn2(self.conv2(x)))  # (B, 128, N)
        x = F.relu(self.bn3(self.conv3(x)))  # (B, 1024, N)

        global_feat = torch.max(x, 2, keepdim=True)[0]  # (B, 1024, 1)
        global_feat = global_feat.repeat(1, 1, N)       # (B, 1024, N)

        x = torch.cat([point_feat, global_feat], 1)     # (B, 1088, N)
        x = F.dropout(x, p=0.3, training=self.training) # Extra dropout after concat
        x = self.mlp2(x)                                # (B, num_classes, N)
        return x


In [None]:
# TNet with previous architecture
class TNet(nn.Module):
    def __init__(self, k=3):
        super(TNet, self).__init__()
        self.k = k
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.bn2 = nn.BatchNorm1d(128)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn3 = nn.BatchNorm1d(1024)

        self.fc1 = nn.Linear(1024, 512)
        self.bn4 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn5 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, k * k)

        self.fc3.weight.data.zero_()
        self.fc3.bias.data.copy_(torch.eye(k).view(-1))

    def forward(self, x):
        B = x.size(0)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2)[0]
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        return x.view(B, self.k, self.k)

class PointNet(nn.Module):
    def __init__(self, num_classes):
        super(PointNet, self).__init__()
        self.num_classes = num_classes

        # Input T-Net
        self.input_transform = TNet(k=3)

        # Shared MLP for feature extraction (split to capture intermediate features)
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.bn1 = nn.BatchNorm1d(64)

        self.conv2 = nn.Conv1d(64, 128, 1)
        self.bn2 = nn.BatchNorm1d(128)

        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn3 = nn.BatchNorm1d(1024)

        # MLP after concatenating global feature to each point (64 + 1024 = 1088)
        self.mlp2 = nn.Sequential(
            nn.Conv1d(1088, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Conv1d(512, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv1d(256, self.num_classes, 1)
        )

    def forward(self, x):
        # Input shape: (B, 3, N)
        B, _, N = x.shape

        # Apply input transform
        trans = self.input_transform(x)                  # (B, 3, 3)
        x = torch.bmm(trans, x)                          # Apply transformation

        x = F.relu(self.bn1(self.conv1(x)))              # (B, 64, N)
        point_feat = x                                   # Save 64-dim per-point features

        x = F.relu(self.bn2(self.conv2(x)))              # (B, 128, N)
        x = F.relu(self.bn3(self.conv3(x)))              # (B, 1024, N)

        global_feat = torch.max(x, 2, keepdim=True)[0]   # (B, 1024, 1)
        global_feat = global_feat.repeat(1, 1, N)        # (B, 1024, N)

        x = torch.cat([point_feat, global_feat], 1)      # (B, 1088, N)
        x = F.dropout(x, p=0.3, training=self.training)  # Extra dropout after concat
        x = self.mlp2(x)                                 # (B, num_classes, N)
        return x
