In [72]:
from colorcloud.chen2020mvlidarnet import MVLidarNet, SemanticSegmentationTask
from colorcloud.UFGsim2024infufg import SemanticSegmentationSimLDM, ProjectionSimVizTransform, UFGSimDataset
from colorcloud.UFGsim2024infufg import ProjectionToTensorTransformSim, UFGSimDataset, SphericalProjection, ProjectionSimTransform
from torch.nn import CrossEntropyLoss
import lightning as L
import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
from torchvision.transforms import v2
import numpy as np
import matplotlib.pyplot as plt
import torch
import time
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim import AdamW
from torchmetrics.classification import Accuracy
from torchmetrics.segmentation import MeanIoU

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Lightning training

In [2]:
data = SemanticSegmentationSimLDM(eval_batch_size=4, train_batch_size=4)
data.setup('fit')
epoch_steps = len(data.train_dataloader())

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [4]:
model = MVLidarNet(in_channels=4, n_classes=13).to(device)

In [5]:
loss_fn = torch.nn.CrossEntropyLoss(reduction='none')

In [6]:
n_epochs = 25
learner = SemanticSegmentationTask(
    model,
    CrossEntropyLoss(reduction='none'),
    data.viz_tfm,
    total_steps=n_epochs*epoch_steps
)

In [None]:
model_name = "UFGSim-MVLidarNet"
timestamp = datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
experiment_name = f'{model_name}_{timestamp}'
wandb_logger = WandbLogger(project="colorcloud", name=experiment_name, log_model="all")
wandb_logger.watch(learner.model, log="all")

In [None]:
# train model
trainer = L.Trainer(max_epochs=n_epochs, logger=wandb_logger)
trainer.fit(learner, data)
trainer.save_checkpoint("ufgsim_mvlidarnet_1.ckpt", weights_only=True)

In [None]:
wandb.finish()

# Pytorch Training

In [82]:
proj = SphericalProjection(fov_up_deg=15., fov_down_deg=-15., W=440, H=16)
tfms = v2.Compose([
    ProjectionSimTransform(proj),
    ProjectionToTensorTransformSim(),
])

data_path = '/workspace/data'
train_dataset = UFGSimDataset(data_path=data_path, split='train', transform=tfms)
val_dataset = UFGSimDataset(data_path=data_path, split='valid', transform=tfms)

batch_size = 4

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [83]:
n_epochs = 20

model = MVLidarNet(in_channels=4, n_classes=13).to(device)
loss_func = torch.nn.CrossEntropyLoss(reduction='none')
opt = AdamW(model.parameters(), lr=5e-4, eps=1e-5)

accuracy = Accuracy(task="multiclass", num_classes=model.n_classes).to(device)
accuracy_dict = {"train": [], "val": []}

miou = MeanIoU(num_classes=model.n_classes).to(device)
miou_dict = {"train": [], "val": []}

train_steps = len(train_loader) // batch_size
test_steps = len(val_loader) // batch_size
H = {"train_loss": [], "test_loss": []} # store loss history

In [None]:
start_time = time.time()

for epochs in tqdm(range(n_epochs)):

    model.train()

    total_train_loss = 0
    total_test_loss = 0

    for batch in train_loader:
        train_item = {key: value.to(device) for key, value in batch.items()}
        img = train_item['frame']
        label = train_item['label']
        mask = train_item['mask']

        label[~mask] = 0

        pred = model(img)
        train_loss = loss_func(pred, label)
        train_loss = train_loss[mask]
        train_loss = train_loss.mean()

        pred_f = torch.permute(pred, (0, 2, 3, 1)) # N,C,H,W -> N,H,W,C
        pred_f = torch.flatten(pred_f, 0, -2)      # N,H,W,C -> N*H*W,C
        mask_f = torch.flatten(mask)               # N,H,W   -> N*H*W
        pred_m = pred_f[mask_f, :]
        label_m = label[mask]
        current_train_acc = accuracy(pred_m, label_m)
        accuracy_dict["train"].append(current_train_acc)

        pred_labels = torch.argmax(pred, dim=1).to(device)
        mask_miou = (label != 0)
        pred_labels[~mask] = 0
        current_train_miou = miou(pred_labels, label)
        miou_dict["train"].append(current_train_miou)
        
        opt.zero_grad()
        train_loss.backward()
        opt.step()

        total_train_loss += train_loss


    with torch.no_grad():

        model.eval()

        for batch in val_loader:
            test_item = {key: value.to(device) for key, value in batch.items()}
            img = test_item['frame']
            label = test_item['label']
            mask = test_item['mask']
    
            label[~mask] = 0
    
            pred = model(img)
            test_loss = loss_func(pred, label)
            test_loss = test_loss[mask]
            test_loss = test_loss.mean()

            pred_f = torch.permute(pred, (0, 2, 3, 1)) # N,C,H,W -> N,H,W,C
            pred_f = torch.flatten(pred_f, 0, -2)      # N,H,W,C -> N*H*W,C
            mask_f = torch.flatten(mask)               # N,H,W   -> N*H*W
            pred_m = pred_f[mask_f, :]
            label_m = label[mask]
            current_test_acc = accuracy(pred_m, label_m)
            accuracy_dict["val"].append(current_test_acc)

            pred_labels = torch.argmax(pred, dim=1).to(device)
            mask_miou = (label != 0)
            pred_labels[~mask] = 0
            current_test_miou = miou(pred_labels, label)
            miou_dict["val"].append(current_test_miou)
        
            total_test_loss += test_loss

    avg_train_loss = total_train_loss / train_steps
    avg_test_loss = total_test_loss / test_steps

    # Store loss history for graphical visualization
    H["train_loss"].append(avg_train_loss.cpu().detach().numpy())
    H["test_loss"].append(avg_test_loss.cpu().detach().numpy())

    print("CURRENT EPOCH: {}/{}".format(epochs + 1, n_epochs))
    print("Train loss: {:.10f}, Test loss {:.4f}".format(avg_train_loss, avg_test_loss))

end_time = time.time()
print("Training took {:.2f}s".format(end_time - start_time))
print("Accuracy: {:.4f} on training and {:.4f} on testing".format(current_train_acc, current_test_acc))
print("Mean IOU: {:.4f} on training and {:.4f} on testing".format(current_train_miou, current_test_miou))

In [None]:
# Save model
model_name = "ufgsim_mvlidar_torch.pt"
torch.save(model, model_name)

## Plotting

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.show()

In [None]:
# Accuracy
plt.style.use("ggplot")
plt.figure()

train_accuracy = [x.cpu().numpy() for x in accuracy_dict["train"]]
val_accuracy = [x.cpu().numpy() for x in accuracy_dict["val"]]

plt.plot(train_accuracy, label="train_accuracy")
plt.plot(val_accuracy, label="val_accuracy")
plt.title("Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Accuracy")
plt.legend(loc="lower left")
plt.grid()
plt.show()

In [None]:
# Mean IoU
train_miou = [x.cpu().numpy() for x in miou_dict["train"]]
val_miou = [x.cpu().numpy() for x in miou_dict["val"]]

plt.plot(train_miou, label="train_miou")
plt.plot(val_miou, label="val_miou")
plt.title("Mean IOU")
plt.xlabel("Epoch #")
plt.ylabel("miou")
plt.legend(loc="lower left")
plt.grid()
plt.show()