In [1]:
import torch
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F
from ignite.metrics import Accuracy
from tqdm.notebook import tqdm
import json
import logging

---
## 数据部分

In [2]:
class ShapeNet(Dataset):
    def __init__(self, root, split, npoints=1024):
        super(ShapeNet, self).__init__()
        self.npoints = npoints
        self.idx_to_class = {}
        dir_to_idx = {}

        with open(os.path.join(root, 'synsetoffset2category.txt'), 'r') as f:
            for i, line in enumerate(f):
                line = line.strip().split()
                self.idx_to_class[i] = line[0]
                dir_to_idx[line[1]] = i
        
        self.files = []
        self.object_labels = []
        with open(os.path.join(root, 'train_test_split', f'shuffled_{split}_file_list.json'), 'r') as f:
            temp = json.load(f)   # type(temp) = list

        if split == 'train':
            with open(os.path.join(root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
                temp += json.load(f)

        for x in temp:
            x = x.split('/')
            self.files.append(os.path.join(root, x[1], x[2]+'.txt'))
            self.object_labels.append(dir_to_idx[x[1]])
            
    def __len__(self):
        return len(self.files)
    
    def pcd_norm(self, points):
        mean = points.mean(axis=0)
        points = points - mean
        max_dis = (np.sqrt((np.square(points)).sum(axis=1))).max()
        points = points / max_dis

        return points

    def __getitem__(self, index):
        file = self.files[index]
        temp = self.object_labels[index]

        points = np.genfromtxt(file, dtype=np.float32)
        choice = np.random.choice(len(points), self.npoints)
        points = points[choice]

        points[:, 0:3] = self.pcd_norm(points[:, 0:3])

        part_label = points[:, -1].astype(np.int64)
        points = points[:, 0:6]

        object_label = np.zeros(16, dtype=np.float32)
        object_label[temp] = 1

        return points, object_label, part_label

In [3]:
train_dataset = ShapeNet('shapenetcore_partanno_segmentation_benchmark_v0_normal', split='train', npoints=2048)
test_dataset = ShapeNet('shapenetcore_partanno_segmentation_benchmark_v0_normal', split='test', npoints=2048)

In [4]:
print(train_dataset.idx_to_class)
print(len(train_dataset))
print(len(test_dataset))

{0: 'Airplane', 1: 'Bag', 2: 'Cap', 3: 'Car', 4: 'Chair', 5: 'Earphone', 6: 'Guitar', 7: 'Knife', 8: 'Lamp', 9: 'Laptop', 10: 'Motorbike', 11: 'Mug', 12: 'Pistol', 13: 'Rocket', 14: 'Skateboard', 15: 'Table'}
14007
2874


In [5]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=8)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8)

In [6]:
a, b, c = next(iter(train_dataloader))
print(a.shape)
print(b.shape)
print(c.shape)

torch.Size([16, 2048, 6])
torch.Size([16, 16])
torch.Size([16, 2048])


---
## 模型部分

In [7]:
class PointNet(nn.Module):
    def __init__(self, class_num):
        super(PointNet, self).__init__()
        self.mlp_1 = nn.Sequential(nn.Conv1d(6, 64, kernel_size=1),
                                    nn.BatchNorm1d(64),
                                    nn.ReLU(inplace=True))
        self.mlp_2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1),
                                    nn.BatchNorm1d(128),
                                    nn.ReLU(inplace=True))
        self.mlp_3 = nn.Sequential(nn.Conv1d(128, 128, kernel_size=1),
                                    nn.BatchNorm1d(128),
                                    nn.ReLU(inplace=True))
        self.mlp_4 = nn.Sequential(nn.Conv1d(128, 512, kernel_size=1),
                                    nn.BatchNorm1d(512),
                                    nn.ReLU(inplace=True))
        self.mlp_5 = nn.Sequential(nn.Conv1d(512, 2048, kernel_size=1),
                                    nn.BatchNorm1d(2048))
        self.mlp_6 = nn.Sequential(nn.Conv1d(4944, 256, kernel_size=1),
                                    nn.BatchNorm1d(256),
                                    nn.ReLU(inplace=True))
        self.mlp_7 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1),
                                    nn.BatchNorm1d(256),
                                    nn.ReLU(inplace=True))
        self.mlp_8 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1),
                                    nn.BatchNorm1d(128),
                                    nn.ReLU(inplace=True))
        self.mlp_9 = nn.Sequential(nn.Conv1d(128, class_num, kernel_size=1))
        
    def forward(self, points, object_labels):
        """
        points.shape = (b, n, 6)
        object_labels = (b, 16)
        """
        _, n, _ = points.shape
        points = points.permute(0, 2, 1)
        object_labels = object_labels.unsqueeze(dim=2).repeat(1, 1, n)

        out1 = self.mlp_1(points)
        out2 = self.mlp_2(out1)
        out3 = self.mlp_3(out2)
        out4 = self.mlp_4(out3)
        out5 = self.mlp_5(out4)

        global_feature, _ = out5.max(dim=2, keepdim=True)
        global_feature = global_feature.repeat(1, 1, n)

        final_feature = torch.cat((out1, out2, out3, out4, out5, global_feature, object_labels), dim=1)
        y = self.mlp_9(self.mlp_8(self.mlp_7(self.mlp_6(final_feature))))

        return y

In [8]:
def train_loop(dataloader, model, loss_fn, metric_fn, optimizer, device, cur_epoch, total_epoch, show_gap, interval):
    model.train()
    if cur_epoch % show_gap == 0:
        pbar = tqdm(dataloader, desc=f'Epoch {cur_epoch}/{total_epoch}', unit='batch')
    else:
        pbar = dataloader

    for i, (x1, x2, y) in enumerate(pbar):
        x1 = x1.to(device)
        x2 = x2.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        y_pred = model(x1, x2)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()

        metric_fn.reset()
        metric_fn.update((y_pred, y))
        acc = metric_fn.compute()

        if cur_epoch % show_gap == 0 and i % interval == 0:
            pbar.set_postfix_str(f'loss={loss:.4f}, acc={acc:.4f}')

In [9]:
best_miou = 0
best_epoch = 0

object_to_part = {0: [0, 1, 2, 3], 1: [4, 5], 2: [6, 7], 3: [8, 9, 10, 11], 4: [12, 13, 14, 15], 5: [16, 17, 18], 
                    6: [19, 20, 21], 7: [22, 23], 8: [24, 25, 26, 27], 9: [28, 29], 10: [30, 31, 32, 33, 34, 35],
                    11: [36, 37], 12: [38, 39, 40], 13: [41, 42, 43], 14: [44, 45, 46], 15: [47, 48, 49]}

def test_loop(dataloader, model, loss_fn, device, cur_epoch, path, show_gap, log_dir):
    model.eval()
    steps = len(dataloader)
    idx_to_class = dataloader.dataset.idx_to_class
    loss = 0
    object_mious = [[] for _ in range(16)]
    logging.basicConfig(filename=log_dir, format='%(message)s', level=logging.INFO)
    
    with torch.no_grad():
        for x1, x2, y in dataloader:
            x1 = x1.to(device)
            x2 = x2.to(device)
            y = y.to(device)
            y_pred = model(x1, x2)
            loss += loss_fn(y_pred, y)

            y_pred = y_pred.permute(0, 2, 1)
            y_pred = F.softmax(y_pred, dim=-1)
            for i in range(len(y_pred)):
                cur_object_label = x2[i].argmax().item()
                cur_y_pred = y_pred[i, :, object_to_part[cur_object_label]].argmax(dim=-1)
                cur_y_pred += object_to_part[cur_object_label][0]
                cur_y = y[i]
                
                temp = []
                for part_class in object_to_part[cur_object_label]:
                    if (torch.sum(cur_y == part_class) == 0 and torch.sum(cur_y_pred == part_class) == 0):
                        temp.append(1)
                    else:
                        intersection = torch.sum((cur_y == part_class) & (cur_y_pred == part_class)).item()
                        union = torch.sum((cur_y == part_class) | (cur_y_pred == part_class)).item()
                        temp.append(intersection / union)
                object_mious[cur_object_label].append(np.mean(temp))
    
    class_mious = [np.mean(object_mious[i]) for i in range(16)]
    all_mious = [y for x in object_mious for y in x]
    miou = np.mean(all_mious)
            
    loss = loss / steps

    global best_miou, best_epoch
    if miou >= best_miou:
        torch.save(model.state_dict(), path)
        best_miou = miou
        best_epoch = cur_epoch

    if cur_epoch % show_gap == 0:
        logging.info(f'Epoch {cur_epoch}\n')
        for i in range(16):
            logging.info(f'{idx_to_class[i]}:   {class_mious[i]:.4f}')
        logging.info(f'test_loss={loss:.4f}, test_miou={miou:.4f}')
        logging.info('-------------------------------------------------------')
        print(f'test_loss={loss:.4f}, test_miou={miou:.4f}')

---
## 开始训练

In [10]:
device = 'cuda:0'

pointnet = PointNet(50).to(device)
loss_fn = nn.CrossEntropyLoss()
metric_acc = Accuracy(device=device)
optimizer = torch.optim.Adam(pointnet.parameters(), lr=0.001, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 250)

In [None]:
epochs = 250
show_gap = 1
save_path = 'pointnet_partseg.pth'
for i in range(epochs):
    train_loop(train_dataloader, pointnet, loss_fn, metric_acc, optimizer, device, i, epochs, show_gap, 1)
    test_loop(test_dataloader, pointnet, loss_fn, device, i, save_path, show_gap, 'pointnet_partseg.log')
    lr_scheduler.step()

In [None]:
print(best_miou, best_epoch)