# BackBone

## Usage

In [None]:
from src.models import BackBone
from src.datasets import ImageNetClassify
from torch.utils.data import DataLoader

backbone = BackBone("ConvNeXt") # options: ResNet18, ResNet50, TODO: ViT-B

bs = 8
ds = ImageNetClassify("./data/ImageNet-2012/")
dl = DataLoader(ds, bs)

for batch in dl:
    imgs = batch["img"]
    out = backbone(imgs)
    for k, v in out.items():
        print(f"{k}: {v.shape}")
    break

## Validate

In [None]:
from src.models import BackBone
from mt_pipe.src.test.external.util import make_random_nested_tens, validate_nested_obj


bs = 8
w, h = 224, 224
mock_batch_conf = (bs, 3, h, w)

mock_batch = make_random_nested_tens(mock_batch_conf)

backbone = BackBone("ConvNeXt")
out = backbone(mock_batch)

assert hasattr(backbone, "dims")
assert type(backbone.dims) == dict
assert all([type(v) == int for v in backbone.dims.values()])

expected_out_conf = {
    "emb": {"shape": (8, 768), "dtype": "torch.float32"},
    "l6": {"shape": (8, 768, 3, 3), "dtype": "torch.float32"},
    "l5": {"shape": (8, 384, 7, 7), "dtype": "torch.float32"},
    "l4": {"shape": (8, 192, 14, 14), "dtype": "torch.float32"},
    "l3": {"shape": (8, 96, 28, 28), "dtype": "torch.float32"},
    "l2": {"shape": (8, 48, 56, 56), "dtype": "torch.float32"},
    "l1": {"shape": (8, 48, 112, 112), "dtype": "torch.float32"},
}

valid, msg = validate_nested_obj(out, expected_out_conf)
assert valid, msg

# DepthDecoder

## Usage

In [None]:
# TODO: get the TrianFlow model from @muditha
# from src.models.depth_decoder import DepthDecoder
# from src.models import BackBone
# from src.datasets.kitti import KITTIWithDepth
# from torch.utils.data import DataLoader

# ds = KITTIWithDepth("./data/KITTI-2012/")
# dl = DataLoader(ds, 8)

# encoder = BackBone("ConvNeXt")
# decoder = DepthDecoder(encoder.dims)

# for batch in dl:
#     fp = encoder(batch["img"])
#     for k, v in fp.items():
#         print(k, v.shape)
#     pred = decoder(fp)
#     for k, v in pred.items():
#         print(k, v.shape)
#     break

## Validation

In [None]:
# TODO: get the TrianFlow model from @muditha
# import torch
# from mt_pipe.src.test.external.util import validate_nested_obj
# from src.models.depth_decoder import DepthDecoder

# mock_input = {
#     "emb": torch.Tensor(8, 768),
#     "l1": torch.Tensor(8, 768, 1, 2),
#     "l2": torch.Tensor(8, 384, 2, 4),
#     "l3": torch.Tensor(8, 192, 4, 8),
#     "l4": torch.Tensor(8, 96, 8, 16),
#     "l5": torch.Tensor(8, 48, 16, 32),
#     "l6": torch.Tensor(8, 48, 32, 64),
# }
# decoder = DepthDecoder(
#     {"l7": 3, "l6": 48, "l5": 48, "l4": 96, "l3": 192, "l2": 384, "l1": 768},
#     ["l7", "l6", "l5", "l4", "l3", "l2", "l1"],
# )
# out = decoder(mock_input)
# expected_out_conf = {
#     "l2": {"shape":(8, 1, 2, 4), "dtype":"torch.float32"},
#     "l3": {"shape":(8, 1, 4, 8), "dtype":"torch.float32"},
#     "l4": {"shape":(8, 1, 8, 16), "dtype":"torch.float32"},
#     "l5": {"shape":(8, 1, 16, 32), "dtype":"torch.float32"},
#     "l6": {"shape":(8, 1, 32, 64), "dtype":"torch.float32"},
#     "l7": {"shape":(8, 1, 64, 128), "dtype":"torch.float32"},
# }
# valid, msg  = validate_nested_obj(out, expected_out_conf)
# assert valid, msg

# FlowDecoder

## Main

### Usage

In [None]:
import torch
from src.models import BackBone
from src.datasets import KITTI
from torch.utils.data import DataLoader
from src.models.flow_decoder import FlowDecoder
from src.constants import flow_img_wh

device = 0
backbone = BackBone("ConvNeXt").cuda(device)
decoder = FlowDecoder(backbone.dims).cuda(device)

bs = 8
ds = KITTI("./data/KITTI-2015/", img_wh=flow_img_wh)
dl = DataLoader(ds, bs)

for batch in dl:
    batch_one = batch["img1"].cuda(device)
    batch_two = batch["img2"].cuda(device)

    feature_pyramid_one = backbone(batch_one)
    feature_pyramid_two = backbone(batch_two)
    out = decoder(feature_pyramid_one, feature_pyramid_two, flow_img_wh)

    for k, v in out.items():
        if type(v) == torch.Tensor:
            print(f"{k}: {None if v is None else v.shape}")
        else:
            for v2 in v:
                print(f"{k}: {None if v2 is None else v2.shape}")
    break

### Validation

In [None]:
import torch
from src.models.flow_decoder import FlowDecoder
from mt_pipe.src.test.external.util import validate_nested_obj

decoder = FlowDecoder({"l1": 48, "l2": 48, "l3": 96, "l4": 192, "l5": 384, "l6": 768})

feature_pyramid_one = {
    "emb": torch.Tensor(8, 768),
    "l6": torch.Tensor(8, 768, 1, 2),
    "l5": torch.Tensor(8, 384, 2, 4),
    "l4": torch.Tensor(8, 192, 4, 8),
    "l3": torch.Tensor(8, 96, 8, 16),
    "l2": torch.Tensor(8, 48, 16, 32),
    "l1": torch.Tensor(8, 48, 32, 64),
}
feature_pyramid_two = {
    "emb": torch.Tensor(8, 768),
    "l6": torch.Tensor(8, 768, 1, 2),
    "l5": torch.Tensor(8, 384, 2, 4),
    "l4": torch.Tensor(8, 192, 4, 8),
    "l3": torch.Tensor(8, 96, 8, 16),
    "l2": torch.Tensor(8, 48, 16, 32),
    "l1": torch.Tensor(8, 48, 32, 64),
}
out = decoder(feature_pyramid_one, feature_pyramid_two, [128, 64])

expected_out_conf = {
    "optical_flows": [
        {"shape": (8, 2, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 2, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 2, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 2, 8, 16), "dtype": "torch.float32"},
    ],
    "optical_flows_rev": [
        {"shape": (8, 2, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 2, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 2, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 2, 8, 16), "dtype": "torch.float32"},
    ],
    "img1_valid_masks": [
        {"shape": (8, 1, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 1, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 1, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 1, 8, 16), "dtype": "torch.float32"},
    ],
    "img2_valid_masks": [
        {"shape": (8, 1, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 1, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 1, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 1, 8, 16), "dtype": "torch.float32"},
    ],
    "fwd_flow_diff_pyramid": [
        {"shape": (8, 2, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 2, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 2, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 2, 8, 16), "dtype": "torch.float32"},
    ],
    "bwd_flow_diff_pyramid": [
        {"shape": (8, 2, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 2, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 2, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 2, 8, 16), "dtype": "torch.float32"},
    ],
}

valid, msg = validate_nested_obj(out, expected_out_conf)
assert valid, msg