In [1]:
import torch
import torch.utils.data.dataloader
import numpy as np
import tqdm
import matplotlib.pyplot as plt

import pcs.models.pointconv_simple
import pcs.dataset

In [2]:
BATCH_SIZE = 2 
NUM_CLASSES = 8

In [3]:
# TODO: more transforms, normalization, etc.
# TODO: aggregate more neighborhoods
# TODO: data augmentation
# TODO: stratification
dataset = pcs.dataset.SemSegDataset(
    data_dir="./data/aggregated/bild/",
    point_transforms=(lambda x: torch.tensor(x.T, dtype=torch.float32), ),
    label_transforms=(
        lambda x: torch.tensor(x-1, dtype=torch.long),
    )
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [4]:
EPOCH_LENGTH = len(dataset) // BATCH_SIZE + bool(len(dataset) % BATCH_SIZE)

In [5]:
model = pcs.models.pointconv_simple.PointConvNet(features=4, classes=NUM_CLASSES)
loss_fn = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
model.train()

PointConvNet(
  (sa1): FeatureEncoder(
    (mlp_convs): ModuleList(
      (0): Conv2d(7, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (mlp_bns): ModuleList(
      (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (weightnet): WeightNet(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 8, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0-1): 2 x BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (linear): Linear(in_features=1024, out_features=64, bias=True)
    (bn_linear): BatchNorm1d(64, eps=1

In [None]:
def train_one_epoch() -> float:
    loss_agg = []
    pbar = tqdm.tqdm(dataloader, total=EPOCH_LENGTH)
    for data, labels in pbar:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        loss_agg.append(loss.item())
        pbar.set_description(f"Loss: {loss.item()}")
    mean_loss = np.array(loss_agg).mean()
    print(f"Mean loss: {mean_loss}")
    return mean_loss

losses = []
# Note: PoC, only 10 epochs
# Training takes long time nonetheless
for epoch in range(10):
    print(f"Epoch {epoch}")
    epoch_loss = train_one_epoch()
    losses.append(epoch_loss)

In [None]:
# TODO: verify correctness
plt.plot(losses)