In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
from colorcloud.cheng2023TransRVNet import TransVRNet
from colorcloud.cheng2023TransRVNet import SemanticSegmentationTask
from colorcloud.cheng2023TransRVNet import TransRVNet_loss
from colorcloud.cheng2023TransRVNet import RandomRotationTransform
from colorcloud.cheng2023TransRVNet import RandomDroppingPointsTransform
from colorcloud.cheng2023TransRVNet import RandomSingInvertingTransform
from colorcloud.UFGsim2024infufg import SemanticSegmentationSimLDM
from colorcloud.UFGsim2024infufg import ProjectionSimVizTransform
from colorcloud.UFGsim2024infufg import UFGSimDataset
from colorcloud.UFGsim2024infufg import UFGSimDataset
from colorcloud.UFGsim2024infufg import SphericalProjection
from colorcloud.UFGsim2024infufg import ProjectionSimTransform
from colorcloud.UFGsim2024infufg import ProjectionToTensorTransformSim

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torchvision
from torchvision.transforms import v2

import numpy as np
import wandb
import time
from datetime import datetime
from tqdm import tqdm

import lightning as L
from lightning.pytorch.loggers import WandbLogger
from torchmetrics.classification import Accuracy
from torchmetrics.segmentation import MeanIoU
from torchmetrics.classification import Dice
from torchmetrics.classification import MulticlassF1Score

## Setup

In [15]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


### Convolutions' Parameters

In [16]:
mrciam_p = {
    "p1": {
        "b1_in": 1,
        "b1_out1": 32,
        "b1_out2": 32,
        "b1_out3": 32,
        "b2_in": 64,
        "b2_out": 64,
        "b3_in": 64,
        "b3_out1": 64,
        "b3_out2": 64,
        "b3_out3": 64,
        "output": 64,
    },
    "p2": {
        "b1_in": 3,
        "b1_out1": 32,
        "b1_out2": 32,
        "b1_out3": 32,
        "b2_in": 64,
        "b2_out": 64,
        "b3_in": 64,
        "b3_out1": 64,
        "b3_out2": 64,
        "b3_out3": 64,
        "output": 64,
    },
    "output_conv": 192
}
encoder_p = {
    "module_1": {
        "in_channels": 192,
        "conv2_in_channels": 128,
        "conv2_out_channels": 256,
        "dilated_conv_out_channels": 256,
        "residual_out_channels": 256
    },
    "module_2": {
        "in_channels": 256,
        "conv2_in_channels": 256,
        "conv2_out_channels": 256,
        "dilated_conv_out_channels": 256,
        "residual_out_channels": 256
    }
}

decoder_p = {
    "in_channels": 264,
    "conv2_in_channels": 128,
    "dilated_conv_in_channels": 128,
    "dilated_conv_out_channels": 64,
    "output": 32
}

p_bntm = {
    "window_size": (4,4),
    "embed_dim": int(encoder_p["module_2"]["residual_out_channels"]/8)
}

## Pytorch Training

In [17]:
proj = SphericalProjection(fov_up_deg=15., fov_down_deg=-15., W=440, H=16)

aug_tfms = v2.Compose([
    RandomDroppingPointsTransform(),
    RandomRotationTransform(),
    RandomSingInvertingTransform(),
])

# tfms = v2.Compose([
#     ProjectionSimTransform(proj),
#     ProjectionToTensorTransformSim(),
# ])

data_path = '/workspace/data'
train_dataset = UFGSimDataset(data_path=data_path, split='train', transform=tfms)
val_dataset = UFGSimDataset(data_path=data_path, split='valid', transform=tfms)

print("Size of train dataset: ", len(train_dataset))
print("Size of val dataset: ", len(val_dataset))

batch_size = 4

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

Size of train dataset:  330
Size of val dataset:  151


In [18]:
n_epochs = 20

model = TransVRNet(mrciam_p, encoder_p, decoder_p, p_bntm).to(device)
loss_fn = TransRVNet_loss(device=device, file_name_yaml='ufg-sim.yaml')

loss_func = torch.nn.CrossEntropyLoss(reduction='none')
opt = AdamW(model.parameters(), lr=5e-4, eps=1e-5)

accuracy = Accuracy(task="multiclass", num_classes=model.n_classes).to(device)
accuracy_dict = {"train": [], "val": []}

miou = MeanIoU(num_classes=model.n_classes).to(device)
miou_dict = {"train": [], "val": []}

dice = Dice(num_classes=model.n_classes).to(device)
dice_dict = {"train": [], "val": []}

mcf1s = MulticlassF1Score(num_classes=model.n_classes, average="macro").to(device)
mcf1s_dict = {"train": [], "val": []}

train_steps = len(train_loader) // batch_size
test_steps = len(val_loader) // batch_size
H = {"train_loss": [], "test_loss": []} # store loss history

In [20]:
tart_time = time.time()

for epochs in tqdm(range(n_epochs)):

    model.train()

    total_train_loss = 0
    total_test_loss = 0

    for batch in train_loader:
        train_item = {key: value.to(device) for key, value in batch.items()}
        img = train_item['frame']
        label = train_item['label']
        mask = train_item['mask']

        label[~mask] = 0

        pred = model(img)
        print(pred.shape)
        train_loss = loss_func(pred, label)
        train_loss = train_loss[mask]
        train_loss = train_loss.mean()

        pred_f = torch.permute(pred, (0, 2, 3, 1)) # N,C,H,W -> N,H,W,C
        pred_f = torch.flatten(pred_f, 0, -2)      # N,H,W,C -> N*H*W,C
        mask_f = torch.flatten(mask)               # N,H,W   -> N*H*W
        pred_m = pred_f[mask_f, :]
        label_m = label[mask]
        current_train_acc = accuracy(pred_m, label_m)
        accuracy_dict["train"].append(current_train_acc)

        pred_labels = torch.argmax(pred, dim=1).to(device)
        mask_miou = (label != 0)
        pred_labels[~mask] = 0
        current_train_miou = miou(pred_labels, label)
        miou_dict["train"].append(current_train_miou)
        current_train_dice = dice(pred_labels, label)
        dice_dict["train"].append(current_train_dice)
        current_train_mcf1s = mcf1s(pred_labels, label)
        mcf1s_dict["train"].append(current_train_mcf1s)
        
        opt.zero_grad()
        train_loss.backward()
        opt.step()

        total_train_loss += train_loss


    with torch.no_grad():

        model.eval()

        for batch in val_loader:
            test_item = {key: value.to(device) for key, value in batch.items()}
            img = test_item['frame']
            label = test_item['label']
            mask = test_item['mask']
    
            label[~mask] = 0
    
            pred = model(img)
            test_loss = loss_func(pred, label)
            test_loss = test_loss[mask]
            test_loss = test_loss.mean()

            pred_f = torch.permute(pred, (0, 2, 3, 1)) # N,C,H,W -> N,H,W,C
            pred_f = torch.flatten(pred_f, 0, -2)      # N,H,W,C -> N*H*W,C
            mask_f = torch.flatten(mask)               # N,H,W   -> N*H*W
            pred_m = pred_f[mask_f, :]
            label_m = label[mask]
            current_test_acc = accuracy(pred_m, label_m)
            accuracy_dict["val"].append(current_test_acc)

            pred_labels = torch.argmax(pred, dim=1).to(device)
            mask_miou = (label != 0)
            pred_labels[~mask] = 0
            current_test_miou = miou(pred_labels, label)
            miou_dict["val"].append(current_test_miou)
            current_test_dice = dice(pred_labels, label)
            dice_dict["val"].append(current_test_dice)
            current_test_mcf1s = mcf1s(pred_labels, label)
            mcf1s_dict["val"].append(current_test_mcf1s)
        
            total_test_loss += test_loss

    avg_train_loss = total_train_loss / train_steps
    avg_test_loss = total_test_loss / test_steps

    # Store loss history for graphical visualization
    H["train_loss"].append(avg_train_loss.cpu().detach().numpy())
    H["test_loss"].append(avg_test_loss.cpu().detach().numpy())

    print("CURRENT EPOCH: {}/{}".format(epochs + 1, n_epochs))
    print("Train loss: {:.10f}, Test loss {:.4f}".format(avg_train_loss, avg_test_loss))

end_time = time.time()
print("Training took {:.2f}s".format(end_time - start_time))
print("Accuracy: {:.4f} on training and {:.4f} on testing".format(current_train_acc, current_test_acc))
print("Mean IOU: {:.4f} on training and {:.4f} on testing".format(current_train_miou, current_test_miou))
print("Dice: {:.4f} on training and {:.4f} on testing".format(current_train_miou, current_test_miou))
print("F1 Macro: {:.4f} on training and {:.4f} on testing".format(current_train_mcf1s, current_test_mcf1s))

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]


TypeError: TransVRNet.forward() missing 2 required positional arguments: 'x2' and 'x3'

In [None]:
# Save model
model_name = "ufgsim_transRVNet_torch.pt"
torch.save(model, model_name)

## Plotting

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.show()

In [None]:
# Accuracy
plt.style.use("ggplot")
plt.figure()

train_accuracy = [x.cpu().numpy() for x in accuracy_dict["train"]]
val_accuracy = [x.cpu().numpy() for x in accuracy_dict["val"]]

plt.plot(train_accuracy, label="train_accuracy")
plt.plot(val_accuracy, label="val_accuracy")
plt.title("Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Accuracy")
plt.legend(loc="lower right")
plt.grid()
plt.show()

In [None]:
# Mean IoU
train_miou = [x.cpu().numpy() for x in miou_dict["train"]]
val_miou = [x.cpu().numpy() for x in miou_dict["val"]]

plt.plot(train_miou, label="train_miou")
plt.plot(val_miou, label="val_miou")
plt.title("Mean IOU")
plt.xlabel("Epoch #")
plt.ylabel("miou")
plt.legend(loc="lower right")
plt.grid()
plt.show()

In [None]:
# Dice
train_dice = [x.cpu().numpy() for x in dice_dict["train"]]
val_dice = [x.cpu().numpy() for x in dice_dict["val"]]

plt.plot(train_dice, label="train_dice")
plt.plot(val_dice, label="val_dice")
plt.title("Dice")
plt.xlabel("Epoch #")
plt.ylabel("dice score")
plt.legend(loc="lower right")
plt.grid()
plt.show()

In [None]:
# F1 Macro
train_mcf1s = [x.cpu().numpy() for x in mcf1s_dict["train"]]
val_mcf1s = [x.cpu().numpy() for x in mcf1s_dict["val"]]

plt.plot(train_mcf1s, label="train_mcf1s")
plt.plot(val_mcf1s, label="val_mcf1s")
plt.title("F1 Macro")
plt.xlabel("Epoch #")
plt.ylabel("mcf1s score")
plt.legend(loc="lower right")
plt.grid()
plt.show()