# ImageNet


## VICReg


### Usage

In [None]:
from src.datasets import ImageNetVICReg
from src.constants import content_img_wh
import matplotlib.pyplot as plt

ds = ImageNetVICReg("./data/ImageNet-2012/", img_wh=content_img_wh, download=True)
sample = ds[0]
img1, img2 = sample["view1"], sample["view2"]
print(f"Dataset length: {len(ds)}")
print(f"Image shape: {img1.shape}")
img1, img2 = img1.numpy(), img2.numpy()
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(img1.transpose(1, 2, 0))
ax[1].imshow(img2.transpose(1, 2, 0))
plt.show()

### Validation

In [None]:
from src.datasets import ImageNetVICReg
from tqdm import tqdm
from src.constants import content_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "view1": {
        "shape": (3, *content_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "view2": {
        "shape": (3, *content_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
}

for split in ["val", "train"]:
    ds = ImageNetVICReg("./data/ImageNet-2012/", split, img_wh=content_img_wh)
    for sample in tqdm(ds, desc=split):
        valid, msg = validate_nested_obj(sample, expected_sample_conf)
        assert valid, msg

## Classify


### Usage

In [None]:
from src.datasets import ImageNetClassify
import matplotlib.pyplot as plt
from src.constants import content_img_wh

ds = ImageNetClassify("./data/ImageNet-2012/", img_wh=content_img_wh, download=True)
label_map = ds.label2txtlable_map

sample = ds[0]
img, lbl = sample["img"], sample["lbl"]
print(f"Label: {label_map[lbl]}")
print(f"Dataset length: {len(ds)}")
print(f"Image shape: {img.shape}")
plt.imshow(img.numpy().transpose(1, 2, 0))
plt.show()

### Validation

In [None]:
from src.datasets import ImageNetClassify
from tqdm import tqdm
from src.constants import content_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "img": {
        "shape": (3, *content_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "lbl": {"dtype": "int", "min": 0, "max": 999},
}

for split in ["val", "train"]:
    ds = ImageNetClassify("./data/ImageNet-2012/", split, img_wh=content_img_wh)
    for sample in tqdm(ds, desc=split):
        valid, msg = validate_nested_obj(sample, expected_sample_conf)
        assert valid, msg

# KITTI


## Image Pair

### Usage

In [None]:
from src.datasets import KITTI
from src.constants import flow_img_wh
import matplotlib.pyplot as plt

ds_1 = KITTI("./data/KITTI-2012", img_wh=flow_img_wh, download="2012")
ds_2 = KITTI("./data/KITTI-2012-multiview", img_wh=flow_img_wh, download="2012_m")
ds_3 = KITTI("./data/KITTI-2015", img_wh=flow_img_wh, download="2015")
ds_4 = KITTI("./data/KITTI-2015-multiview", img_wh=flow_img_wh, download="2015_m")
ds_5 = KITTI("./data/KITTI", img_wh=flow_img_wh, download="2011")

print(f"KITTI 2012 length: {len(ds_1)}")
print(f"KITTI 2012 Multiview length: {len(ds_2)}")
print(f"KITTI 2015 length: {len(ds_3)}")
print(f"KITTI 2015 Multiview length: {len(ds_4)}")
print(f"KITTI Raw length: {len(ds_5)}")

sample = ds_1[0]
img1, img2 = sample["img1"], sample["img2"]
print(f"Image shape: {tuple(img1.shape)}")

fig, ax = plt.subplots(1, 2)
ax[0].imshow(img1.numpy().transpose(1, 2, 0))
ax[1].imshow(img2.numpy().transpose(1, 2, 0))
plt.show()

### Validation

In [None]:
from tqdm import tqdm
from src.datasets import KITTI
from src.constants import flow_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "img1": {"shape": (3, *flow_img_wh[::-1]), "min": 0, "max": 1, "dtype": "torch.float32"},
    "img2": {"shape": (3, *flow_img_wh[::-1]), "min": 0, "max": 1, "dtype": "torch.float32"},
}

for split in ["val", "train"]:
    for root in [
        "./data/KITTI-2012",
        "./data/KITTI-2012-multiview",
        "./data/KITTI-2015",
        "./data/KITTI-2015-multiview",
        "./data/KITTI",
    ]:
        ds = KITTI(root, split, img_wh=flow_img_wh)
        for sample in tqdm(ds, desc=f"{split} {root}"):
            valid, msg = validate_nested_obj(sample, expected_sample_conf)
            assert valid, msg

## Image Pair and Calibration Matrix

### Usage

In [None]:
from src.datasets import KITTIWithCalibration
import matplotlib.pyplot as plt
from src.constants import flow_img_wh

ds_1 = KITTIWithCalibration("./data/KITTI-2012", img_wh=flow_img_wh, download="2012")
ds_2 = KITTIWithCalibration("./data/KITTI", img_wh=flow_img_wh, download="2011")
sample = ds_1[0]
img1, img2, K, K_inv = sample["img1"], sample["img2"], sample["K"], sample["K_inv"]
print(f"Image shape: {tuple(img1.shape)}")
print(f"Caliberation matrix shape: {K.shape}")
print(f"KITTI Raw Dataset length: {len(ds_1)}")
print(f"KITTI 2012 Dataset length: {len(ds_2)}")

fig, ax = plt.subplots(1, 2)
ax[0].imshow(img1.numpy().transpose(1, 2, 0))
ax[1].imshow(img2.numpy().transpose(1, 2, 0))
plt.show()

### Validation

In [None]:
from tqdm import tqdm
from src.datasets import KITTIWithCalibration
from src.constants import flow_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "img1": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "img2": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "K": {
        "shape": (1, 3, 3),
        "dtype": "torch.float32",
    },  # TODO: is there a min and max for K and K_inv?
    "K_inv": {"shape": (1, 3, 3), "dtype": "torch.float32"},
}

# TODO: Fix the following warning. Occurs when using the kitti_raw
# TODO: error while iterating
#   /home/avishka/anaconda3/envs/wm/lib/python3.10/site-packages/numpy/linalg/linalg.py:562: RuntimeWarning: overflow encountered in cast
#   return wrap(ainv.astype(result_t, copy=False))
for split in ["val", "train"]:
    for root in ["./data/KITTI", "./data/KITTI-2012"]:
        ds = KITTIWithCalibration(root, split, img_wh=flow_img_wh)
        for sample in tqdm(ds, desc=f"{split} {root}"):
            valid, msg = validate_nested_obj(sample, expected_sample_conf)
            assert valid, msg

## Image with DepthMap

### Usage

In [None]:
from src.datasets import KITTIWithDepth
import matplotlib.pyplot as plt
import numpy as np
from src.constants import flow_img_wh

ds = KITTIWithDepth("./data/KITTI-2012/", img_wh=flow_img_wh, download="2012")

sample = ds[100]
img, depth_map = sample["img"], sample["depth_map"]
print(f"Image shape: {tuple(img.shape)}")
print(f"Flowmap shape: {depth_map.shape}")
print(f"KITTI 2012 Dataset length: {len(ds)}")

fig, ax = plt.subplots(1, 2, figsize=(9, 3))
ax[0].imshow(img.numpy().transpose(1, 2, 0))
ax[1].imshow(depth_map.numpy().transpose(1, 2, 0))
plt.show()

### Validation

In [None]:
from tqdm import tqdm
from src.datasets import KITTIWithDepth
from src.constants import flow_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "img": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "depth_map": {
        "shape": (1, *flow_img_wh[::-1]),
        "dtype": "torch.float32",
        "min": 0
    }, 
}

for split in ["train"]:
    for root in [
        "./data/KITTI-2012",
        "./data/KITTI-2015",
        "./data/KITTI",
    ]:
        ds = KITTIWithDepth(root, split, img_wh=flow_img_wh)
        for sample in tqdm(ds, desc=f"{split} {root}"):
            valid, msg = validate_nested_obj(sample, expected_sample_conf)
            assert valid, msg

# Flow Datasets

## KITTI

### Usage

In [None]:
from src.datasets import KITTIWithFlow
from src.constants import flow_img_wh
from src.util.visualize import plot_warp

ds_1 = KITTIWithFlow("./data/KITTI-2012", img_wh=flow_img_wh, download="2012")
ds_2 = KITTIWithFlow("./data/KITTI-2015", img_wh=flow_img_wh, download="2015")
sample = ds_1[0]
img1, img2, flow_map, valid_mask, occ_mask = (
    sample["img1"],
    sample["img2"],
    sample["flow_gt"],
    sample["valid"],
    sample["occ_gt"],
)
print(f"Image shape: {tuple(img1.shape)}")
print(f"Flowmap shape: {flow_map.shape}")
print(f"KITTI 2012 Dataset length: {len(ds_1)}")
print(f"KITTI 2015 Dataset length: {len(ds_2)}")
print("valid_mask shape : ", valid_mask.shape)
print("occ_mask shape :", occ_mask.shape)

plot_warp(sample)

### Validation

In [None]:
from tqdm import tqdm
from src.datasets import KITTIWithFlow
from src.constants import flow_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "img1": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "img2": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "flow_gt": {  # TODO: is there a min and max for flow maps?
        "shape": (2, *flow_img_wh[::-1]),
        "dtype": "torch.float32",
    },
    "valid": {
        "shape": flow_img_wh[::-1],
        "unique": [0, 1],
        "dtype": "numpy.float32",
    },
    "occ_gt": {
        "shape": flow_img_wh[::-1],
        "unique": [0, 1],
        "dtype": "numpy.float32",
    },
}

for split in ["train"]:
    for root in [
        "./data/KITTI-2012",
        "./data/KITTI-2015",
    ]:
        ds = KITTIWithFlow(root, split, img_wh=flow_img_wh)
        for sample in tqdm(ds, desc=f"{split} {root}"):
            valid, msg = validate_nested_obj(sample, expected_sample_conf)
            assert valid, msg

## HD1K


### Usage

In [None]:
from src.datasets import HD1K
from src.constants import flow_img_wh
from src.util.visualize import plot_warp

ds = HD1K("./data/HD1K/", "train", img_wh=flow_img_wh, download=True)
sample = ds[0]
img1, img2, flow, occs = sample["img1"], sample["img2"], sample["flow_gt"], sample["occ_gt"]

print(f"Image1 shape: {tuple(img1.shape)}")
print(f"Image2 shape: {tuple(img2.shape)}")
print(f"Flow shape: {tuple(flow.shape)}")
print(f"Occlusions shape: {tuple(occs.shape)}")
print(f"Dataset length: {len(ds)}")

plot_warp(sample)

### Validation

In [None]:
from tqdm import tqdm
from src.datasets import HD1K
from src.constants import flow_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "img1": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "img2": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "flow_gt": {  # TODO: is there a min and max for flow maps?
        "shape": (2, *flow_img_wh[::-1]),
        "dtype": "torch.float32",
    },
    "valid": {
        "shape": flow_img_wh[::-1],
        "unique": [0, 1],
        "dtype": "numpy.float32",
    },
    "occ_gt": {
        "shape": flow_img_wh[::-1],
        "unique": [0, 1],
        "dtype": "numpy.float32",
    },
}

for split in ["val", "train"]:
    ds = HD1K("./data/HD1K/", split, img_wh=flow_img_wh)
    for sample in tqdm(ds, desc=split):
        valid, msg = validate_nested_obj(sample, expected_sample_conf)
        assert valid, msg

## MPISintel


### Usage

In [None]:
from src.datasets import MPISintel
from src.util.visualize import plot_warp
import numpy as np
from src.constants import flow_img_wh

dataset = MPISintel("./data/MPISintel/", "train", flow_img_wh)
idx = np.random.randint(0, len(dataset))
print(f"Index: {idx}")
sample = dataset[idx]
plot_warp(sample)

### Validation

In [None]:
from src.datasets import MPISintel
from tqdm import tqdm
from src.constants import flow_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "img1": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "img2": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "flow_gt": {
        "shape": (2, *flow_img_wh[::-1]),
        "min": -max(flow_img_wh),
        "max": max(flow_img_wh),
        "dtype": "torch.float32",
    },
    "occ_gt": {
        "shape": flow_img_wh[::-1],
        "unique": [0, 1],
        "dtype": "numpy.float32",
    },
    "valid": {
        "shape": flow_img_wh[::-1],
        "unique": [0, 1],
        "dtype": "numpy.float32",
    },
}

for split in ["train"]:
    ds = MPISintel("./data/MPISintel/", split, img_wh=flow_img_wh)
    for sample in tqdm(ds, desc=split):
        valid, msg = validate_nested_obj(sample, expected_sample_conf)
        assert valid, msg

## FlyingThings


### Usage

In [None]:
from src.datasets import FlyingThings
import matplotlib.pyplot as plt
from src.constants import flow_img_wh

ds = FlyingThings(
    "./data/FlyingThings/", "train", img_wh=flow_img_wh, download=True
)
sample = ds[0]
img1, img2 = sample["img1"], sample["img2"]

print(f"Image1 shape: {tuple(img1.shape)}")
print(f"Image2 shape: {tuple(img2.shape)}")
print(f"Dataset length: {len(ds)}")

fig, ax = plt.subplots(1, 2)
ax[0].imshow(img1.numpy().transpose(1, 2, 0))
ax[1].imshow(img2.numpy().transpose(1, 2, 0))
plt.show()

### Validation

In [None]:
from tqdm import tqdm
from src.datasets import FlyingThings
from src.constants import flow_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "img1": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "img2": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
}

for split in ["val", "train"]:
    ds = FlyingThings("./data/FlyingThings/", split, img_wh=flow_img_wh)
    for sample in tqdm(ds, desc=split):
        valid, msg = validate_nested_obj(sample, expected_sample_conf)
        assert valid, msg

## FlyingChairs


### Usage

In [None]:
from src.datasets import FlyingChairs
from src.util.visualize import plot_warp
import numpy as np
from src.constants import flow_img_wh

dataset = FlyingChairs("./data/FlyingChairs/", "train", None)
idx = np.random.randint(0, len(dataset))
print(f"Index: {idx}")
sample = dataset[idx]
img1, img2, flow, occs = (
    sample["img1"],
    sample["img2"],
    sample["flow_gt"],
    sample["occ_gt"],
)

print(f"Image1 shape: {tuple(img1.shape)}")
print(f"Image2 shape: {tuple(img2.shape)}")
print(f"Flow shape: {tuple(flow.shape)}")
print(f"Occlusions: {tuple(occs.shape)}")

plot_warp(sample)

### Validation

In [None]:
from tqdm import tqdm
from src.datasets import FlyingChairs
from src.constants import flow_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

expected_sample_conf = {
    "img1": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "img2": {
        "shape": (3, *flow_img_wh[::-1]),
        "min": 0,
        "max": 1,
        "dtype": "torch.float32",
    },
    "flow_gt": {
        "shape": (2, *flow_img_wh[::-1]),
        "min": -max(flow_img_wh),
        "max": max(flow_img_wh),
        "dtype": "torch.float32",
    },
    "occ_gt": None,
    "valid": None,
}

for split in ["val", "train"]:
    ds = FlyingChairs("./data/FlyingChairs/", split, img_wh=flow_img_wh)
    for sample in tqdm(ds, desc=split):
        valid, msg = validate_nested_obj(sample, expected_sample_conf)
        assert valid, msg

# ConcatSet

## Usage

In [None]:
from mt_pipe.src.datasets import ConcatSet
import matplotlib.pyplot as plt

ds = ConcatSet(
    root=["./data/FlyingChairs", "./data/FlyingThings"],
    conf=[
        {"target": "src.datasets.FlyingChairs", "reps": 1},
        {"target": "src.datasets.FlyingThings", "reps": 1},
    ],
)

sample = ds[0]
img1, img2 = sample["img1"], sample["img2"]

print(f"Image1 shape: {tuple(img1.shape)}")
print(f"Image2 shape: {tuple(img2.shape)}")
print(f"Dataset length: {len(ds)}")

fig, ax = plt.subplots(2)
ax[0].imshow(img1.numpy().transpose(1, 2, 0))
ax[1].imshow(img2.numpy().transpose(1, 2, 0))
plt.show()

# Parallel DataLoader

## Usage

In [None]:
from src.datasets import ImageNetVICReg, KITTIWithCalibration
from mt_pipe.src.util.data import ParallelDataLoader
from torch.utils.data import DataLoader

ds1 = ImageNetVICReg("./data/ImageNet-2012/")
ds2 = KITTIWithCalibration("./data/KITTI-2012/")
dl1 = DataLoader(ds1, 256)
dl2 = DataLoader(ds2, 8)
dl3 = ParallelDataLoader([dl1, dl2])

print(len(dl1), len(dl2), len(dl3))
for batch in dl3:
    print(batch.keys())
    break

# Segmentation Datasets

## COCO

### Usage

In [None]:
from src.datasets import COCOSegment
from src.constants import content_img_wh
from src.util.visualize import plot_segs

ds = COCOSegment("./data/COCO-2017", img_wh=None, download=True)
classes = ds.classes
sample = ds[0]

print(f"Dataset length: {len(ds)}")
plot_segs(sample, classes)

### Validation

In [None]:
from src.datasets import COCOSegment
from tqdm import tqdm
from mt_pipe.src.test.external.util import validate_nested_obj
from src.constants import content_img_wh


expected_sample_conf = {
    "img":{
        "dtype": "torch.float32",
        "max": 1,
        "min": 0,
        "shape": [3, *content_img_wh[::-1]],
        },
    "seg":{
        "dtype": "torch.float32",
        "shape": content_img_wh[::-1],
        "unique_range": [0, 80],}
}

for split in ["val", "train"]:
    ds = COCOSegment("./data/COCO-2017", split, img_wh=content_img_wh)
    for sample in tqdm(ds, desc=split):
        valid, msg = validate_nested_obj(sample, expected_sample_conf)
        assert valid, msg

## PascalVOC 2012

### Usage

In [None]:
from src.datasets import PascalVOC
from src.util.visualize import plot_segs
from src.constants import flow_img_wh

ds = PascalVOC("./data/PascalVOC-2012", img_wh=flow_img_wh, download=True)
classes = ds.class_names

sample = ds[0]
img, seg = sample["img"], sample["seg"]

print(f"Dataset length: {len(ds)}")
plot_segs(sample, classes)

### Validation

In [None]:
from src.datasets import PascalVOC
from tqdm import tqdm
from mt_pipe.src.test.external.util import validate_nested_obj
from src.constants import flow_img_wh


expected_sample_conf = {
    "img": {
        "dtype": "torch.float32",
        "max": 1,
        "min": 0,
        "shape": [3, *flow_img_wh[::-1]],
    },
    "seg": {
        "dtype": "torch.float32",
        "shape": flow_img_wh[::-1],
        "unique_range": [0, 21],
    },
}

for split in ["val", "train"]:
    ds = PascalVOC("./data/PascalVOC-2012", split, img_wh=flow_img_wh)
    for sample in tqdm(ds, desc=split):
        valid, msg = validate_nested_obj(sample, expected_sample_conf)
        assert valid, msg

## Davis 2017

### Usage

In [None]:
from src.datasets import Davis
from src.util.visualize import plot_segs
from src.constants import flow_img_wh

ds = Davis("./data/Davis", img_wh=flow_img_wh, download=True)
classes = ds.file_contents

sample = ds[0]
img, seg = sample["img"], sample["seg"]

print(f"Dataset length: {len(ds)}")
plot_segs(sample, classes)

### Validation

In [None]:
from src.datasets import Davis
from tqdm import tqdm
from mt_pipe.src.test.external.util import validate_nested_obj
from src.constants import flow_img_wh


expected_sample_conf = {
    "img": {
        "dtype": "torch.float32",
        "max": 1,
        "min": 0,
        "shape": [3, *flow_img_wh[::-1]],
    },
    "seg": {
        "dtype": "torch.float32",
        "shape": flow_img_wh[::-1],
        "unique_range": [0, 60],
    },
}

for split in ["val", "train"]:
    ds = Davis("./data/Davis", img_wh=flow_img_wh, download=True)
    for sample in tqdm(ds, desc=split):
        valid, msg = validate_nested_obj(sample, expected_sample_conf)
        assert valid, msg

## Cityscape

### Usage

In [None]:
# from src.datasets import Cityscape
# from src.util.visualize import plot_segs
# from src.constants import flow_img_wh

# ds = Cityscape("./data/Cityscapes", img_wh=flow_img_wh, download=True)
# classes = ds.class_names

# sample = ds[0]
# img, seg = sample["img"], sample["seg"]

# print(f"Dataset length: {len(ds)}")
# plot_segs(sample, classes)

### Validation

In [None]:
# from src.datasets import Cityscape
# from tqdm import tqdm
# from mt_pipe.src.test.external.util import validate_nested_obj
# from src.constants import flow_img_wh


# expected_sample_conf = {
#     "img": {
#         "shape": (3, *flow_img_wh[::-1]),
#         "min": 0,
#         "max": 1,
#         "dtype": "torch.float32",
#     },
#     "seg": {
#         "shape": (flow_img_wh[::-1]),
#         "min": 0,
#         "max": (len(classes)+1),
#         "dtype": "torch.float32",
#     },
# }

# for split in ["val", "train", "test"]:
#     ds = Cityscape("./data/Cityscapes", split, img_wh=flow_img_wh)
#     for sample in tqdm(ds, desc=split):
#         valid, msg = validate_nested_obj(sample, expected_sample_conf)
#         assert valid, msg

## ADE20K

### Usage

In [None]:
# from src.datasets import ADE20k
# import matplotlib.pyplot as plt
# from src.constants import flow_img_wh

# ds = ADE20k("./data/ADE20K-2021", img_wh=flow_img_wh, download=True)
# classes = ds.name_list

# sample = ds[0]
# img, seg = sample["img"], sample["seg"]

# print(f"Dataset length: {len(ds)}")
# print(f"Image shape: {img.shape}, Image min: {img.min()}, Image max: {img.max()}")
# print(f"Segment shape: {seg.shape}, Segment min: {seg.min()}, Segment max: {seg.max()}")

# img, seg = img.numpy(), seg.numpy()
# window_idx = classes.index("windowpane, window")
# podium_idx = classes.index("podium")

# fig, ax = plt.subplots(1, 3, figsize=(12, 4))
# ax[0].imshow(img.transpose(1, 2, 0))
# ax[0].set_title("Original Image")
# ax[1].imshow(seg[window_idx])
# ax[1].set_title("Window")
# ax[2].imshow(seg[podium_idx])
# ax[2].set_title("Podium")
# plt.show()


# ds = ADE20k("./data/ADE20K-2021", "val")
# sample = ds[0]
# img, seg = sample["img"], sample["seg"]

# print(f"Dataset length: {len(ds)}")
# print(f"Image shape: {img.shape}, Image min: {img.min()}, Image max: {img.max()}")
# print(
#     f"Segment shape: {seg.shape}, Segment min: {seg.min()}, Segment max: {seg.max()}"
# )

### Validation

In [None]:
# from src.datasets import ADE20k
# from tqdm import tqdm
# from mt_pipe.src.test.external.util import validate_nested_obj
# from src.constants import flow_img_wh


# expected_sample_conf = {
#     "img": {
#         "shape": (3, *flow_img_wh[::-1]),
#         "min": 0,
#         "max": 1,
#         "dtype": "torch.float32",
#     },
#     "seg": {
#         "shape": (3687, *flow_img_wh[::-1]),
#         "min": 0,
#         "max": 1,
#         "dtype": "torch.float32",
#     },
# }

# for split in ["val", "train"]:
#     ds = ADE20k("./data/ADE20K-2021", split, img_wh=flow_img_wh)
#     for sample in tqdm(ds, desc=split):
#         valid, msg = validate_nested_obj(sample, expected_sample_conf)
#         assert valid, msg