In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from colorcloud.UFGsim2024infufg import SemanticSegmentationSimLDM, ProjectionSimVizTransform
from colorcloud.biasutti2019riu import SemanticSegmentationTask, RIUNet
from torch.nn import CrossEntropyLoss
import lightning as L
import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import torch
import numpy as np

In [3]:
data = SemanticSegmentationSimLDM()
data.setup('fit')
epoch_steps = len(data.train_dataloader())

In [3]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [4]:
n_epochs = 25
learner = SemanticSegmentationTask(
    RIUNet(in_channels=4, hidden_channels=(64, 128, 256, 512), n_classes=13),
    CrossEntropyLoss(reduction='none'),
    data.viz_tfm, 
    total_steps=n_epochs*epoch_steps
)

In [None]:
time = datetime.now()
timestamp = str(time.year) + '-' + str(time.month) + '-' + str(time.day) + '_' + str(time.hour) + '-' + str(time.minute) + '-' + str(time.second)
name = 'UFGSim-'+ timestamp
wandb_logger = WandbLogger(project="colorcloud", name=name, log_model="all")
wandb_logger.watch(learner.model, log="all")

In [None]:
# train model
trainer = L.Trainer(max_epochs=n_epochs, logger=wandb_logger)
trainer.fit(learner, data)
trainer.save_checkpoint("ufgsim_riunet_354.ckpt", weights_only=True)

In [None]:
wandb.finish()

inference tryouts

In [4]:
data = SemanticSegmentationSimLDM()
data.setup('predict')
epoch_steps = len(data.predict_dataloader())
n_epochs = 25
model = RIUNet(in_channels=4, hidden_channels=(64, 128, 256, 512), n_classes=13)
loss_fn =  CrossEntropyLoss(reduction='none')
viz_tfm = data.viz_tfm
total_steps = n_epochs*epoch_steps

loaded_model = SemanticSegmentationTask.load_from_checkpoint("ufgsim_riunet_354.ckpt", model=model, loss_fn=loss_fn, viz_tfm=viz_tfm, total_steps=total_steps)

In [5]:
datas = SemanticSegmentationSimLDM()
datas.setup('predict')
epoch_steps = len(datas.predict_dataloader())

In [6]:
frame, _, _ = datas.ds_predict[2]
frames = np.array(frame)
frames.shape

(4, 16, 440)

In [17]:
loaded_model.to(device)
loaded_model.eval()
with torch.no_grad():
    frame = frame.cpu()
    frame = frame.numpy()
    frame = np.expand_dims(frame, 0)
    frame = torch.from_numpy(frame).to(device)
    y_hat = loaded_model(frame).squeeze()
    argmax = torch.argmax(y_hat, dim=0)
    pred = np.array(argmax.cpu())