<a href="https://colab.research.google.com/github/antonbaumann/MIMO-Unet/blob/main/MIMO_U_Net_NYUv2_depth.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Download data and dependencies

In [None]:
%load_ext tensorboard

In [None]:
import sys

# download NYU Depth Dataset V2
!wget -c https://www.dropbox.com/s/qtab28cauzalqi7/depth_data.tar.gz?dl=1 -O depth_data.tar.gz
!mkdir data && tar -xzvf depth_data.tar.gz -C data

# clone MIMO U-Net repository
!rm -r MIMO-Unet; git clone https://github.com/antonbaumann/MIMO-Unet.git

# add repository to PATH
sys.path.append('/content/MIMO-Unet/')

# install MIMO U-Net dependencies
!pip install -r MIMO-Unet/requirements.txt

In [None]:
from typing import List
from datetime import datetime

import lightning.pytorch as pl
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger

from mimo.models.mimo_unet import MimoUnetModel
from mimo.tasks.depth.nyuv2_datamodule import NYUv2DepthDataModule
from mimo.tasks.depth.callbacks import OutputMonitor, WandbMetricsDefiner

In [None]:
def default_callbacks(validation: bool = True) -> List[pl.Callback]:
    callbacks = [
        OutputMonitor(),
        ModelCheckpoint(save_last=True),
    ]
    if validation:
        callbacks_validation = [
            ModelCheckpoint(
                monitor="val_loss",
                save_top_k=1,
                filename="epoch-{epoch}-step-{step}-valloss-{val_loss:.8f}-mae-{metric_val/mae_epoch:.8f}",
                auto_insert_metric_name=False,
            ),
        ]
        callbacks += callbacks_validation
    return callbacks

# Example: Pytorch Lightning
Initialize `datamodule`, `model`, `logger` and `trainer`

In [None]:
pl.seed_everything(1)

dm = NYUv2DepthDataModule(
    dataset_dir='data',
    batch_size=32,
    num_workers=3,
    pin_memory=True,
    normalize=True,
)

model = MimoUnetModel(
    in_channels=3,
    out_channels=2,
    num_subnetworks=2,
    filter_base_count=21,
    center_dropout_rate=0.0,
    final_dropout_rate=0.0,
    encoder_dropout_rate=0.0,
    core_dropout_rate=0.0,
    decoder_dropout_rate=0.0,
    loss_buffer_size=10,
    loss_buffer_temperature=0.3,
    input_repetition_probability=0.0,
    batch_repetitions=1,
    loss='laplace_nll',
    weight_decay=0.0,
    learning_rate=1e-3,
    seed=1,
)

tensorboard_logger = TensorBoardLogger(
    save_dir='/content/logs',
)

trainer = pl.Trainer(
    callbacks=default_callbacks(),
    accelerator='gpu',
    devices=1,
    precision="16-mixed",
    max_epochs=15,
    default_root_dir='/content/runs',
    log_every_n_steps=200,
    logger=tensorboard_logger,
)

In [None]:
%tensorboard --logdir /content/logs

In [None]:
trainer.started_at = str(datetime.now().isoformat(timespec="seconds"))
trainer.fit(model, dm)

# Example: raw pytorch example
You can build your own training loop as shown below if you do not want to use `lightning`

In [None]:
import torch

from mimo.models.mimo_components.model import MimoUNet
from mimo.models.mimo_components.loss_buffer import LossBuffer
from mimo.models.utils import apply_input_transform
from mimo.datasets.nyuv2 import NYUv2DepthDataset
from mimo.losses import LaplaceNLL

In [None]:
out_channels = 1
num_subnetworks = 2
input_repetition_probability = 0
batch_repetitions = 1
device = torch.device("cuda")

model = MimoUNet(
    in_channels=3,
    out_channels=out_channels * 2,
    num_subnetworks=num_subnetworks,
    filter_base_count=21,
)

model = model.to(device)

loss_buffer = LossBuffer(
    subnetworks=num_subnetworks,
    temperature=0.3,
    buffer_size=10,
)

criterion = LaplaceNLL()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

data_train = NYUv2DepthDataset(
    dataset_path='data/depth_train.h5',
    normalize=True,
    shuffle_on_load=True,
)

train_loader = torch.utils.data.DataLoader(
    data_train,
    batch_size=32,
    shuffle=True,
)

data_test = NYUv2DepthDataset(
    dataset_path='data/depth_test.h5',
    normalize=True,
    shuffle_on_load=False,
)

test_loader = torch.utils.data.DataLoader(
    data_test,
    batch_size=32,
    shuffle=False,
)

In [None]:
# train for one epoch
for batch in train_loader:
  image, label = batch["image"].to(device), batch["label"].to(device)

  image_transformed, label_transformed, mask_transformed = apply_input_transform(
    image,
    label,
    mask=None,
    num_subnetworks=num_subnetworks,
    input_repetition_probability=input_repetition_probability,
    batch_repetitions=batch_repetitions,
  )

  out = model(image_transformed)
  p1 = out[:, :, :out_channels, ...]
  p2 = out[:, :, out_channels:, ...]

  y_pred = criterion.mode(p1, p2)

  raw_loss = criterion.forward(p1, p2, label_transformed, reduce_mean=False, mask=None)
  loss = raw_loss.mean(dim=(0, 2, 3, 4))
  weights = loss_buffer.get_weights().to(loss.device)
  loss_buffer.add(loss.detach())

  weighted_loss = (loss * weights).mean()

  weighted_loss.backward()
  optimizer.step()

  print(f'loss={loss.mean().cpu().detach()}')