In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from colorcloud.behley2019iccv import SphericalProjection, ProjectionToTensorTransform
from colorcloud.UFGsim2024infufg import UFGSimDataset, ProjectionSimTransform
from colorcloud.biasutti2019riu import RIUNet
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torch

In [3]:
data_path = '../UFGSim'
ds = UFGSimDataset(data_path)

proj = SphericalProjection(fov_up_deg=15., fov_down_deg=-15., W=440, H=16)
tfms = v2.Compose([
    ProjectionSimTransform(proj),
    ProjectionToTensorTransform(),
])
ds.set_transform(tfms)

bs = 3
dl = DataLoader(ds, bs)

In [4]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [5]:
model = RIUNet(in_channels=4, hidden_channels=(64, 128, 256, 512), n_classes=13).to(device)
print(model)

RIUNet(
  (input_norm): BatchNorm2d(4, eps=1e-05, momentum=None, affine=False, track_running_stats=True)
  (backbone): Sequential(
    (enc): Encoder(
      (blocks): ModuleList(
        (0): Block(
          (conv1): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=circular)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
          (relu1): ReLU()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=circular)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
          (relu2): ReLU()
        )
        (1): Block(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=circular)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
          (relu1): ReLU()
          (conv2): Conv2d(128, 128, ke

In [6]:
loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

In [None]:
n_epochs = 50
model.train()
for epoch in range(n_epochs):
    for batch in dl:
        img, label, mask = batch
        img, label, mask = img.to(device), label.to(device), mask.to(device)
        label[~mask] = 0
        
        # Compute prediction error
        pred = model(img)
        loss = loss_fn(pred, label)
        loss = loss[mask].mean()
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if epoch % 5 == 0:
            print(f"loss: {loss.item():>7f}")

loss: 2.999776
loss: 1.864565
loss: 1.284580
loss: 0.928722
loss: 0.754632
loss: 0.662107
loss: 0.623375
loss: 0.624365
loss: 0.621404
loss: 0.617286
loss: 0.643031
loss: 0.628119
loss: 0.655748
loss: 0.665729
loss: 0.686211
loss: 0.638563
loss: 0.522562
loss: 0.461450
loss: 0.448920
loss: 0.421206
loss: 0.438708
loss: 0.440686
loss: 0.456340
loss: 0.474012
loss: 0.450407
loss: 0.411703
loss: 0.388037
loss: 0.437849
loss: 0.468884
loss: 0.387367
loss: 0.354124
loss: 0.308671
loss: 0.307039
loss: 0.297971
loss: 0.312220
loss: 0.298699
loss: 0.324711
loss: 0.341010
loss: 0.391715
loss: 0.358085
loss: 0.330169
loss: 0.336429
loss: 0.368509
loss: 0.346900
loss: 0.332161
loss: 0.331838
loss: 0.334887
loss: 0.381763
loss: 0.371034
loss: 0.378392
loss: 0.420559
loss: 0.903443
loss: 1.272106
loss: 0.972197
loss: 0.757011
loss: 0.612080
loss: 0.492289
loss: 0.440999
loss: 0.352335
loss: 0.269848
loss: 0.194311
loss: 0.146679
loss: 0.122284
loss: 0.100552
loss: 0.077998
loss: 0.064796
loss: 0.05