In [1]:
from pathlib import Path

from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T

current_path = Path.cwd()
dataset_dir = current_path / "modelnet10"

pre_transform = T.Compose([
    T.SamplePoints(1024, remove_faces=True, include_normals=True),
    T.NormalizeScale(),
])

train_dataset = ModelNet(dataset_dir, name="10", train=True, transform=None, pre_transform=pre_transform, pre_filter=None)
test_dataset = ModelNet(dataset_dir, name="10", train=False, transform=None, pre_transform=pre_transform, pre_filter=None)

Downloading http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip
Extracting /workspace/book_writing/actual_note/modelnet10/ModelNet10.zip
Processing...
Done!


In [1]:
print("train_dataset len:", len(train_dataset))
print(train_dataset[0])

NameError: name 'train_dataset' is not defined

In [3]:
print(train_dataset[0].pos.shape)
print(train_dataset[0].pos)

torch.Size([1024, 3])
tensor([[ 0.4117, -0.3452, -0.2691],
        [-0.6268, -0.0072,  0.4082],
        [-0.5076, -0.5131,  0.3200],
        ...,
        [ 0.4676, -0.7442,  0.4082],
        [-0.5076, -0.3066,  0.3075],
        [-0.2543, -0.8059,  0.2235]])


In [4]:
from torch_geometric.data import DataLoader as DataLoader
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
batch = next(iter(dataloader))
print(batch)

DataBatch(pos=[32768, 3], y=[32], batch=[32768], ptr=[33])




In [5]:
from torch_geometric.nn import global_max_pool
import torch.nn as nn

class SymmFunction(nn.Module):
    def __init__(self):
        super(SymmFunction, self).__init__()
        self.shared_mlp = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 512),
        )
        
    def forward(self, batch):
        x = self.shared_mlp(batch.pos)
        x = global_max_pool(x, batch.batch)
        return x

f = SymmFunction()
print(batch)
y = f(batch)
print(y.shape)

DataBatch(pos=[32768, 3], y=[32], batch=[32768], ptr=[33])
torch.Size([32, 512])


In [6]:
class InputTNet(nn.Module):
    def __init__(self):
        super(InputTNet, self).__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.output_mlp = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 9)
        )
        
    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        x = self.output_mlp(x)
        x = x.view(-1, 3, 3)
        id_matrix = torch.eye(3).to(x.device).view(1, 3, 3).repeat(x.shape[0], 1, 1)
        x = id_matrix + x
        return x

In [7]:
class FeatureTNet(nn.Module):
    def __init__(self):
        super(FeatureTNet, self).__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.output_mlp = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 64*64)
        )
        
    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        x = self.output_mlp(x)
        x = x.view(-1, 64, 64)
        id_matrix = torch.eye(64).to(x.device).view(1, 64, 64).repeat(x.shape[0], 1, 1)
        x = id_matrix + x
        return x

In [8]:
class PointNetClassification(nn.Module):
    def __init__(self):
        super(PointNetClassification, self).__init__()
        self.input_tnet = InputTNet()
        self.mlp1 = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
        )
        self.feature_tnet = FeatureTNet()
        self.mlp2 = nn.Sequential(
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.mlp3 = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(256, 10)
        )
        
    def forward(self, batch_data):
        x = batch_data.pos
        
        input_transform = self.input_tnet(x, batch_data.batch)
        transform = input_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 3, 1)).view(-1, 3)
        
        x = self.mlp1(x)
        
        feature_transform = self.feature_tnet(x, batch_data.batch)
        transform = feature_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 64, 1)).view(-1, 64)

        x = self.mlp2(x)        
        x = global_max_pool(x, batch_data.batch)
        x = self.mlp3(x)
        
        return x, input_transform, feature_transform

In [9]:
import torch
from torch.utils.tensorboard import SummaryWriter

num_epoch = 400
batch_size = 32

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = PointNetClassification()
model = model.to(device)

optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=num_epoch // 4, gamma=0.5)

log_dir = current_path / "log_modelnet10_classification"
log_dir.mkdir(exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

criteria = torch.nn.CrossEntropyLoss()

In [10]:
from tqdm import tqdm

for epoch in tqdm(range(num_epoch)):
    model = model.train()
    
    losses = []
    for batch_data in tqdm(train_dataloader, total=len(train_dataloader)):
        batch_data = batch_data.to(device)
        this_batch_size = batch_data.batch.detach().max() + 1
        
        pred_y, _, feature_transform = model(batch_data)
        true_y = batch_data.y.detach()

        class_loss = criteria(pred_y, true_y)
        accuracy = float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

        id_matrix = torch.eye(feature_transform.shape[1]).to(feature_transform.device).view(1, 64, 64).repeat(feature_transform.shape[0], 1, 1)
        transform_norm = torch.norm(torch.bmm(feature_transform, feature_transform.transpose(1, 2)) - id_matrix, dim=(1, 2))
        reg_loss = transform_norm.mean()

        loss = class_loss + reg_loss * 0.001
        
        losses.append({
            "loss": loss.item(),
            "class_loss": class_loss.item(),
            "reg_loss": reg_loss.item(),
            "accuracy": accuracy,
            "seen": float(this_batch_size)})
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    if (epoch % 10 == 0):
        model_path = log_dir / f"model_{epoch:06}.pth"
        torch.save(model.state_dict(), model_path)
    
    loss = 0
    class_loss = 0
    reg_loss = 0
    accuracy = 0
    seen = 0
    for d in losses:
        seen = seen + d["seen"]
        loss = loss + d["loss"] * d["seen"]
        class_loss = class_loss + d["class_loss"] * d["seen"]
        reg_loss = reg_loss + d["reg_loss"] * d["seen"]
        accuracy = accuracy + d["accuracy"] * d["seen"]
    loss = loss / seen
    class_loss = class_loss / seen
    reg_loss = reg_loss / seen
    accuracy = accuracy / seen
    writer.add_scalar("train_epoch/loss", loss, epoch)
    writer.add_scalar("train_epoch/class_loss", class_loss, epoch)
    writer.add_scalar("train_epoch/reg_loss", reg_loss, epoch)
    writer.add_scalar("train_epoch/accuracy", accuracy, epoch)

    with torch.no_grad():
        model = model.eval()

        losses = []
        for batch_data in tqdm(test_dataloader, total=len(test_dataloader)):
            batch_data = batch_data.to(device)
            this_batch_size = batch_data.batch.detach().max() + 1

            pred_y, _, feature_transform = model(batch_data)
            true_y = batch_data.y.detach()

            class_loss = criteria(pred_y, true_y)
            accuracy =float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

            id_matrix = torch.eye(feature_transform.shape[1]).to(feature_transform.device).view(1, 64, 64).repeat(feature_transform.shape[0], 1, 1)
            transform_norm = torch.norm(torch.bmm(feature_transform, feature_transform.transpose(1, 2)) - id_matrix, dim=(1, 2))
            reg_loss = transform_norm.mean()

            loss = class_loss + reg_loss * 0.001

            losses.append({
                "loss": loss.item(),
                "class_loss": class_loss.item(),
                "reg_loss": reg_loss.item(),
                "accuracy": accuracy,
                "seen": float(this_batch_size)})
            
        loss = 0
        class_loss = 0
        reg_loss = 0
        accuracy = 0
        seen = 0
        for d in losses:
            seen = seen + d["seen"]
            loss = loss + d["loss"] * d["seen"]
            class_loss = class_loss + d["class_loss"] * d["seen"]
            reg_loss = reg_loss + d["reg_loss"] * d["seen"]
            accuracy = accuracy + d["accuracy"] * d["seen"]
        loss = loss / seen
        class_loss = class_loss / seen
        reg_loss = reg_loss / seen
        accuracy = accuracy / seen
        writer.add_scalar("test_epoch/loss", loss, epoch)
        writer.add_scalar("test_epoch/class_loss", class_loss, epoch)
        writer.add_scalar("test_epoch/reg_loss", reg_loss, epoch)
        writer.add_scalar("test_epoch/accuracy", accuracy, epoch)

  0%|          | 0/400 [00:00<?, ?it/s]
  0%|          | 0/125 [00:00<?, ?it/s][A
  1%|          | 1/125 [00:00<01:36,  1.28it/s][A
  3%|▎         | 4/125 [00:00<00:22,  5.40it/s][A
  6%|▌         | 7/125 [00:01<00:12,  9.19it/s][A
  8%|▊         | 10/125 [00:01<00:09, 12.53it/s][A
 10%|█         | 13/125 [00:01<00:07, 15.29it/s][A
 13%|█▎        | 16/125 [00:01<00:06, 17.30it/s][A
 15%|█▌        | 19/125 [00:01<00:05, 18.79it/s][A
 18%|█▊        | 22/125 [00:01<00:05, 20.20it/s][A
 20%|██        | 25/125 [00:01<00:04, 21.34it/s][A
 22%|██▏       | 28/125 [00:01<00:04, 21.61it/s][A
 25%|██▍       | 31/125 [00:02<00:04, 22.22it/s][A
 27%|██▋       | 34/125 [00:02<00:04, 22.54it/s][A
 30%|██▉       | 37/125 [00:02<00:03, 22.50it/s][A
 32%|███▏      | 40/125 [00:02<00:03, 22.30it/s][A
 34%|███▍      | 43/125 [00:02<00:03, 22.08it/s][A
 37%|███▋      | 46/125 [00:02<00:04, 18.78it/s][A
 39%|███▉      | 49/125 [00:02<00:03, 19.85it/s][A
 42%|████▏     | 52/125 [00:03<00:03

KeyboardInterrupt: 