In [1]:
from colorcloud.chen2020mvlidarnet import MVLidarNet, SemanticSegmentationTask
from colorcloud.UFGsim2024infufg import SemanticSegmentationSimLDM, ProjectionSimVizTransform
from torch.nn import CrossEntropyLoss
import lightning as L
import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import torch
import numpy as np

In [2]:
data = SemanticSegmentationSimLDM(eval_batch_size=4, train_batch_size=4)
data.setup('fit')
epoch_steps = len(data.train_dataloader())

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [4]:
model = MVLidarNet(in_channels=4, n_classes=13).to(device)

In [5]:
loss_fn = torch.nn.CrossEntropyLoss(reduction='none')

In [6]:
n_epochs = 25
learner = SemanticSegmentationTask(
    model,
    CrossEntropyLoss(reduction='none'),
    data.viz_tfm,
    total_steps=n_epochs*epoch_steps
)

In [None]:
model_name = "UFGSim-MVLidarNet"
timestamp = datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
experiment_name = f'{model_name}_{timestamp}'
wandb_logger = WandbLogger(project="colorcloud", name=experiment_name, log_model="all")
wandb_logger.watch(learner.model, log="all")

In [None]:
# train model
trainer = L.Trainer(max_epochs=n_epochs, logger=wandb_logger)
trainer.fit(learner, data)
trainer.save_checkpoint("ufgsim_mvlidarnet_1.ckpt", weights_only=True)

In [None]:
wandb.finish()