In [None]:
import torch
import torch.utils.data.dataloader
import numpy as np
import tqdm
import matplotlib.pyplot as plt

import pcs.models.pointconv
import pcs.dataset

In [None]:
BATCH_SIZE = 100
NUM_CLASSES = 8

In [None]:
# TODO: more transforms, normalization, etc.
# TODO: aggregate more neighborhoods
# TODO: data augmentation
# TODO: stratification
dataset = pcs.dataset.SemSegDataset(
    data_dir="./data/aggregated/bild/",
    point_transforms=(lambda x: torch.tensor(x.T, dtype=torch.float32), ),
    label_transforms=(
        lambda x: torch.tensor(x-1, dtype=torch.long),
    )
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [None]:
EPOCH_LENGTH = len(dataset) // BATCH_SIZE + bool(len(dataset) % BATCH_SIZE)

In [None]:
model = pcs.models.pointconv.PointConvNet(features=4, classes=NUM_CLASSES)
loss_fn = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
model.train()

In [None]:
def train_one_epoch() -> float:
    loss_agg = []
    pbar = tqdm.tqdm(dataloader, total=EPOCH_LENGTH)
    for data, labels in pbar:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        loss_agg.append(loss.item())
        pbar.set_description(f"Loss: {loss.item()}")
    mean_loss = np.array(loss_agg).mean()
    print(f"Mean loss: {mean_loss}")
    return mean_loss

losses = []
# Note: PoC, only 10 epochs
# Training takes long time nonetheless
for epoch in range(10):
    print(f"Epoch {epoch}")
    epoch_loss = train_one_epoch()
    losses.append(epoch_loss)

In [None]:
# TODO: verify correctness
plt.plot(losses)