# Airbus Hackatuna

In [22]:
import csv
import numpy as np

from tqdm import tqdm

import h5py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import os
import sys
sys.path.append('../')
sys.path.append('../src')

import config as c
import utils_N as u
import lidar_utils as lu

# Turn off warning
import warnings
warnings.filterwarnings('ignore')

In [None]:
class PointCloudDataset(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path
        self.index = []

        with h5py.File(h5_path, "r") as f:
            for landscape in f.keys():
                for frame in f[landscape].keys():
                    self.index.append((landscape, frame))

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        landscape, frame = self.index[idx]

        with h5py.File(self.h5_path, "r") as f:
            grp = f[landscape][frame]
            points = grp["points"][:]
            labels = grp["labels"][:]

        return (
            torch.tensor(points, dtype=torch.float32),
            torch.tensor(labels, dtype=torch.long)
        )

In [10]:
class TNet(nn.Module):
    def __init__(self, k=3):
        super().__init__()
        self.k = k

        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)

        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batch_size = x.size(0)

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))

        x = torch.max(x, 2)[0]

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        identity = torch.eye(self.k, device=x.device).view(1, self.k * self.k)
        x = x + identity.repeat(batch_size, 1)

        return x.view(-1, self.k, self.k)

In [11]:
class PointNetSeg(nn.Module):
    def __init__(self, num_classes = 4):
        super().__init__()

        self.input_transform = TNet(k=4)
        self.feature_transform = TNet(k=64)

        self.conv1 = nn.Conv1d(4, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.conv3 = nn.Conv1d(64, 64, 1)
        self.conv4 = nn.Conv1d(64, 128, 1)
        self.conv5 = nn.Conv1d(128, 1024, 1)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(1024)

        self.conv6 = nn.Conv1d(1088, 512, 1)
        self.conv7 = nn.Conv1d(512, 256, 1)
        self.conv8 = nn.Conv1d(256, num_classes, 1)

        self.bn6 = nn.BatchNorm1d(512)
        self.bn7 = nn.BatchNorm1d(256)

    def forward(self, x):
        # x: [B, N, 4]
        B, N, _ = x.size()

        x = x.transpose(2, 1)  # [B, 4, N]

        T = self.input_transform(x)
        x = torch.bmm(T, x)

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))

        T_feat = self.feature_transform(x)
        x = torch.bmm(T_feat, x)

        pointfeat = x

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.bn5(self.conv5(x))

        global_feat = torch.max(x, 2, keepdim=True)[0]
        global_feat = global_feat.repeat(1, 1, N)

        #x = torch.cat([x, global_feat], 1)
        x = torch.cat([pointfeat, global_feat], 1)

        x = F.relu(self.bn6(self.conv6(x)))
        x = F.relu(self.bn7(self.conv7(x)))
        x = self.conv8(x)

        return x.transpose(2, 1)  # [B, N, num_classes]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_ds = PointCloudDataset("../datasets/processed/train.h5")
val_ds = PointCloudDataset("../datasets/processed/val.h5")

train_loader = DataLoader(train_ds, batch_size = 8, shuffle = True)
val_loader = DataLoader(val_ds, batch_size = 8, shuffle = False)

model = PointNetSeg(num_classes = 4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [26]:
fields = ['epoch', 'train_loss']
rows = []
epochs = tqdm(range(c.EPOCHS), desc = "Epochs : ")

for epoch in epochs:
    model.train()
    total_loss = 0
    train_loader = tqdm(train_loader, desc = "batch beepobed :")

    for points, labels in train_loader:
        points = points.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(points)
        loss = criterion(logits.permute(0, 2, 1), labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    torch.save(
        model.state_dict(),
        f'../models/PointNetSeg_{device}.pth'
    )
#    rows.append([epoch, total_loss])
    print(f"Epoch {epoch} | Train loss: {total_loss / len(train_loader):.4f}")

#with open(os.path.join(c.OUT_DIR, 'ouputs/CSVs/', 'PointNetSef.csv'), 'w') as csv_file:
#    csv_writer = csv.writer(csv_file)
#    csv_writer.writerow(fields)
#    csv_writer.writerows(rows)

batch beepobed :: 100%|██████████| 75/75 [00:42<00:00,  1.78it/s]
Epochs :  20%|██        | 1/5 [00:42<02:48, 42.21s/it]

Epoch 0 | Train loss: 0.6759


batch beepobed :: 100%|██████████| 75/75 [00:42<00:00,  1.77it/s]
Epochs :  40%|████      | 2/5 [01:24<02:06, 42.31s/it]

Epoch 1 | Train loss: 0.6543


batch beepobed :: 100%|██████████| 75/75 [00:42<00:00,  1.77it/s]
Epochs :  60%|██████    | 3/5 [02:07<01:24, 42.39s/it]

Epoch 2 | Train loss: 0.6199


batch beepobed :: 100%|██████████| 75/75 [00:42<00:00,  1.78it/s]
Epochs :  80%|████████  | 4/5 [02:49<00:42, 42.32s/it]

Epoch 3 | Train loss: 0.6160


batch beepobed :: 100%|██████████| 75/75 [00:43<00:00,  1.74it/s]
Epochs : 100%|██████████| 5/5 [03:32<00:00, 42.47s/it]

Epoch 4 | Train loss: 0.5980





In [27]:
model.eval()

with torch.no_grad():
    correct, total = 0, 0

    for points, labels in val_loader:
        points = points.to(device)
        labels = labels.to(device)

        logits = model(points)
        preds = logits.argmax(dim=-1)

        correct += (preds == labels).sum().item()
        total += labels.numel()

print(f"Validation accuracy: {correct / total:.4f}")

Validation accuracy: 0.7855
