In [16]:
# Set working directory to the Marigold repo root
%cd /home/boxcat/workspace/labwork/depthmodel/Marigold2

import sys
import os
path = '/home/boxcat/workspace/labwork/depthmodel/Marigold2/script/depth/tumor.ipynb'
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(path), "..", "..")))


import yaml
import torch
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image

import argparse
import logging
import shutil
import torch
from datetime import datetime, timedelta
from omegaconf import OmegaConf
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
from typing import List, Union


from marigold import MarigoldDepthPipeline
# Marigold-specific
from src.util.config_util import (
    find_value_in_omegaconf,
    recursive_load_config,
)
from marigold import MarigoldDepthPipeline
from src.dataset import BaseDepthDataset, DatasetMode, get_dataset
from src.dataset.mixed_sampler import MixedBatchSampler
from src.trainer import get_trainer_cls
from src.util.depth_transform import (
    DepthNormalizerBase,
    get_depth_normalizer,
)
from src.dataset import DatasetMode

home_dir = '/home/boxcat/workspace/labwork/depthmodel/Marigold2/'


/home/boxcat/workspace/labwork/depthmodel/Marigold2


In [18]:
# Update with your actual config path
config_path = "./config/train_marigold_mytumor.yaml"
print(f"Loading config from: {config_path}")

# Load and resolve the full config including nested dataset configs
cfg = recursive_load_config(config_path)
print(OmegaConf.to_yaml(cfg))  # Optional: view full config tree


Loading config from: ./config/train_marigold_mytumor.yaml
logging:
  filename: logging.log
  format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s'
  console_level: 20
  file_level: 10
wandb:
  project: marigold
dataset:
  train:
    batch_size: 2
    name: tumor_dataset
    type: paired_image_depth
    dir: ''
    filenames: data_split/tumor_depth/filenames_train.txt
    depth_format: image
    image_extension: png
    depth_extension: png
    has_invalid_depth: false
    shuffle: true
    drop_last: true
    infinite: true
  val:
  - name: tumor_dataset
    disp_name: mytumor_val
    dir: ''
    filenames: data_split/tumor_depth/filenames_val.txt
    depth_format: image
    image_extension: png
    depth_extension: png
    has_invalid_depth: false
  vis:
  - name: tumor_dataset
    disp_name: mytumor_vis
    dir: ''
    filenames: data_split/tumor_depth/filenames_vis.txt
    resize_to_hw:
    - 480
    - 640
model:
  name: marigold_pipeline
  pretrained_pa

In [25]:
#dirs and args
base_data_dir = './data'

#transforms
depth_transform: DepthNormalizerBase = get_depth_normalizer(cfg_normalizer=cfg.depth_normalization)

In [34]:
# Use get_dataset from src.dataset, which returns an instance of BaseDepthDataset
train_dataset: Union[BaseDepthDataset, List[BaseDepthDataset]] = get_dataset(
        cfg.dataset.train,
        base_data_dir=base_data_dir,
        mode=DatasetMode.TRAIN,
        augmentation_args=cfg.augmentation,
        depth_transform=depth_transform,
    )
print(f"Dataset loaded: {type(train_dataset)}")
print(f"Number of samples: {len(train_dataset)}")


Creating dataset: tumor_dataset (DatasetMode.TRAIN)
Dataset loaded: <class 'src.dataset.mytumor_dataset.MyTumorDataset'>
Number of samples: 1680


In [35]:
# Validation dataset
val_loaders: List[DataLoader] = []
for _val_dict in cfg.dataset.val:
    _val_dataset = get_dataset(
        _val_dict,
        base_data_dir=base_data_dir,
        mode=DatasetMode.EVAL,
    )
    print(f"Validation dataset loaded: {type(_val_dataset)}")
    print(f"Number of validation samples: {len(_val_dataset)}")
    _val_loader = DataLoader(
        dataset=_val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=cfg.dataloader.num_workers,
    )
    val_loaders.append(_val_loader)

Creating dataset: tumor_dataset (DatasetMode.EVAL)
Validation dataset loaded: <class 'src.dataset.mytumor_dataset.MyTumorDataset'>
Number of validation samples: 420


In [36]:
# Basic DataLoader for debugging (not using MixedBatchSampler here)
dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0
)


In [37]:
# Grab a batch
batch = next(iter(dataloader))

batch

{'rgb_int': tensor([[[[117, 119, 120,  ..., 252, 255, 252],
           [123, 116, 121,  ..., 254, 255, 252],
           [121, 122, 115,  ..., 252, 254, 255],
           ...,
           [103, 102, 103,  ..., 148, 148, 145],
           [105, 104, 104,  ..., 147, 143, 144],
           [100, 101, 105,  ..., 145, 146, 144]],
 
          [[253, 255, 255,  ..., 122, 120, 119],
           [255, 252, 255,  ..., 120, 119, 119],
           [255, 255, 252,  ..., 124, 118, 120],
           ...,
           [142, 144, 146,  ...,  99, 102, 102],
           [147, 145, 143,  ..., 103, 100, 101],
           [144, 144, 147,  ..., 102, 106,  97]],
 
          [[104, 109, 105,  ..., 108, 107, 110],
           [105, 107, 109,  ..., 108, 109, 101],
           [109, 107, 108,  ..., 109, 106, 109],
           ...,
           [197, 196, 199,  ..., 199, 195, 197],
           [196, 202, 203,  ..., 199, 198, 199],
           [200, 199, 199,  ..., 200, 196, 198]]],
 
 
         [[[122, 118, 121,  ..., 251, 254, 255]

In [39]:
print(batch.keys())

dict_keys(['rgb_int', 'rgb_norm', 'depth_raw_linear', 'depth_filled_linear', 'valid_mask_raw', 'valid_mask_filled', 'depth_raw_norm', 'depth_filled_norm', 'index', 'rgb_relative_path'])


In [None]:
rgb = batch["rgb"][0]       # [3, H, W]
depth = batch["depth"][0]   # [1, H, W]

# Undo RGB normalization (assuming [-1, 1] → [0, 1])
rgb_img = (rgb + 1) / 2.0
rgb_pil = to_pil_image(rgb_img)

# Plot side-by-side
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(rgb_pil)
axes[0].set_title("RGB Image")
axes[0].axis("off")

axes[1].imshow(depth.squeeze().numpy(), cmap="viridis")
axes[1].set_title("Depth Map")
axes[1].axis("off")

plt.tight_layout()
plt.show()