In [12]:
from colorcloud.UFGsim2024infufg import SemanticSegmentationSimLDM, ProjectionSimVizTransform, UFGSimDataset
from colorcloud.biasutti2019riu import SemanticSegmentationTask, RIUNet
import matplotlib.pyplot as plt
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
import lightning as L
import wandb
import time
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import torch
import numpy as np
from tqdm import tqdm
from torchmetrics.classification import Accuracy

In [3]:
data = SemanticSegmentationSimLDM()
data.setup('fit')
data.setup('test')
data_path = '/workspace/UFGSim/'
train_ds = UFGSimDataset(data_path)
test_ds = UFGSimDataset(data_path, split='test')

In [4]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [8]:
model = RIUNet(in_channels=4, hidden_channels=(64, 128, 256, 512), n_classes=13).to(device)
loss_func = CrossEntropyLoss(reduction='none')
n_epochs = 25
opt = AdamW(model.parameters(), lr=5e-4, eps=1e-5)
train_steps = len(train_ds) // 8
test_steps = len(test_ds) // 16
train_acc = Accuracy(task="multiclass", num_classes=model.n_classes).to(device)
test_acc = Accuracy(task="multiclass", num_classes=model.n_classes).to(device)

H = {"train_loss": [], "test_loss": []}

In [None]:
print("!!! TRAINING !!!")
start_time = time.time()

for epochs in tqdm(range(n_epochs)):

    model.train()

    total_train_loss = 0
    total_test_loss = 0

    for batch in data.train_dataloader():
        train_item = {key: value.to(device) for key, value in batch.items()}
        img = train_item['frame']
        label = train_item['label']
        mask = train_item['mask']

        label[~mask] = 0

        pred = model(img)
        train_loss = loss_func(pred, label)
        train_loss = train_loss[mask]
        train_loss = train_loss.mean()

        pred_f = torch.permute(pred, (0, 2, 3, 1)) # N,C,H,W -> N,H,W,C
        pred_f = torch.flatten(pred_f, 0, -2)      # N,H,W,C -> N*H*W,C
        mask_f = torch.flatten(mask)               # N,H,W   -> N*H*W
        pred_m = pred_f[mask_f, :]
        label_m = label[mask]
        current_train_acc = train_acc(pred_m, label_m)
        
        opt.zero_grad()
        train_loss.backward()
        opt.step()

        total_train_loss += train_loss


    with torch.no_grad():

        model.eval()

        for batch in data.test_dataloader():
            test_item = {key: value.to(device) for key, value in batch.items()}
            img = test_item['frame']
            label = test_item['label']
            mask = test_item['mask']
    
            label[~mask] = 0
    
            pred = model(img)
            test_loss = loss_func(pred, label)
            test_loss = test_loss[mask]
            test_loss = test_loss.mean()

            pred_f = torch.permute(pred, (0, 2, 3, 1)) # N,C,H,W -> N,H,W,C
            pred_f = torch.flatten(pred_f, 0, -2)      # N,H,W,C -> N*H*W,C
            mask_f = torch.flatten(mask)               # N,H,W   -> N*H*W
            pred_m = pred_f[mask_f, :]
            label_m = label[mask]
            current_test_acc = train_acc(pred_m, label_m)
        
            total_test_loss += test_loss

    avg_train_loss = total_train_loss / train_steps
    avg_test_loss = total_test_loss / test_steps

    H["train_loss"].append(avg_train_loss.cpu().detach().numpy())
    H["test_loss"].append(avg_test_loss.cpu().detach().numpy())
    
    print("CURRENT EPOCH: {}/{}".format(epochs + 1, n_epochs))
    print("Train loss: {:.10f}, Test loss {:.4f}".format(avg_train_loss, avg_test_loss))

end_time = time.time()
print("Training took {:.2f}s".format(end_time - start_time))
print("Accuracy: {:.4f} on training and {:.4f} on testing".format(current_train_acc, current_test_acc))

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")

In [14]:
torch.save(model, "ufgsim_riunet_torch.pt")