# ContentLearner

## Usage

In [None]:
from src.models import BackBone
from src.learners import ContentLearner
from src.datasets import ImageNetVICReg
from torch.utils.data import DataLoader

bs = 8
ds = ImageNetVICReg("./data/ImageNet-2012/")
dl = DataLoader(ds, bs)

backbone = BackBone("ConvNeXt")
ln = ContentLearner(backbone)
ln.set_devices([0, 1])

for batch in dl:
    out = ln(batch)
    for k, v in out.items():
        print(f"{k}: {v.shape}")
    break

## Validate

In [None]:
import torch
from src.models import BackBone
from src.learners import ContentLearner
from src.constants import content_img_wh
from mt_pipe.src.test.external.util import validate_nested_obj

bs = 8

backbone = BackBone("ConvNeXt").cuda(0)
ln = ContentLearner(backbone)
ln.set_devices([0, 1])

mock_batch = {
    "view1": torch.Tensor(bs, 3, *content_img_wh[::-1]),
    "view2": torch.Tensor(bs, 3, *content_img_wh[::-1]),
}

out = ln(mock_batch)
expected_out_conf = {
    "X_one": {"shape": (bs, 768), "dtype": "torch.float32"},
    "X_two": {"shape": (bs, 768), "dtype": "torch.float32"},
    "Y_one": {"shape": (bs, 8192), "dtype": "torch.float32"},
    "Y_two": {"shape": (bs, 8192), "dtype": "torch.float32"},
}

valid, msg = validate_nested_obj(out, expected_out_conf)
assert valid, msg

# FlowLearner

## Usage

In [None]:
from src.learners import FlowLearner
from src.datasets import MPISintel
from torch.utils.data import DataLoader

bs = 8
ds = MPISintel("./data/MPISintel/", img_wh=[832, 256])
dl = DataLoader(ds, bs)
ln = FlowLearner({"target": "src.models.encoders.PWCEncoder"})
ln.set_devices([0, 1])

batch = next(iter(dl))
info = ln(batch)
for k, v in info.items():
    print(f"{k}: {type(v)}")

## Validate

In [None]:
import torch
from src.models import BackBone
from src.learners import FlowLearner
from mt_pipe.src.test.external.util import validate_nested_obj

bs = 8
img_hw = [64,128]

backbone = BackBone("ConvNeXt").cuda(1)
ln = FlowLearner(backbone)
ln.set_devices([1, 1])

mock_batch = {
    "img1": torch.Tensor(bs, 3, *img_hw),
    "img2": torch.Tensor(bs, 3, *img_hw),
}
expected_out_conf = {
    "flow_pred": {"shape": (8, 2, 64, 128), "dtype": "torch.float32"},
    "feature_pyramid_one": {
        "emb": {"shape": (8, 768), "dtype": "torch.float32"},
        "l6": {"shape": (8, 768, 1, 2), "dtype": "torch.float32"},
        "l5": {"shape": (8, 384, 2, 4), "dtype": "torch.float32"},
        "l4": {"shape": (8, 192, 4, 8), "dtype": "torch.float32"},
        "l3": {"shape": (8, 96, 8, 16), "dtype": "torch.float32"},
        "l2": {"shape": (8, 48, 16, 32), "dtype": "torch.float32"},
        "l1": {"shape": (8, 48, 32, 64), "dtype": "torch.float32"},
    },
    "feature_pyramid_two": {
        "emb": {"shape": (8, 768), "dtype": "torch.float32"},
        "l6": {"shape": (8, 768, 1, 2), "dtype": "torch.float32"},
        "l5": {"shape": (8, 384, 2, 4), "dtype": "torch.float32"},
        "l4": {"shape": (8, 192, 4, 8), "dtype": "torch.float32"},
        "l3": {"shape": (8, 96, 8, 16), "dtype": "torch.float32"},
        "l2": {"shape": (8, 48, 16, 32), "dtype": "torch.float32"},
        "l1": {"shape": (8, 48, 32, 64), "dtype": "torch.float32"},
    },
    "optical_flows": [
        {"shape": (8, 2, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 2, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 2, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 2, 8, 16), "dtype": "torch.float32"},
    ],
    "optical_flows_rev": [
        {"shape": (8, 2, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 2, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 2, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 2, 8, 16), "dtype": "torch.float32"},
    ],
    "img1_valid_masks": [
        {"shape": (8, 1, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 1, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 1, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 1, 8, 16), "dtype": "torch.float32"},
    ],
    "img2_valid_masks": [
        {"shape": (8, 1, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 1, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 1, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 1, 8, 16), "dtype": "torch.float32"},
    ],
    "fwd_flow_diff_pyramid": [
        {"shape": (8, 2, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 2, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 2, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 2, 8, 16), "dtype": "torch.float32"},
    ],
    "bwd_flow_diff_pyramid": [
        {"shape": (8, 2, 64, 128), "dtype": "torch.float32"},
        {"shape": (8, 2, 32, 64), "dtype": "torch.float32"},
        {"shape": (8, 2, 16, 32), "dtype": "torch.float32"},
        {"shape": (8, 2, 8, 16), "dtype": "torch.float32"},
    ],
}

out = ln(mock_batch)
valid, msg = validate_nested_obj(out, expected_out_conf)
assert valid, msg

# DepthLearner

## Usage

In [None]:
# from src.learners import DepthLearner
# from src.datasets.kitti import KITTIWithDepth
# from torch.utils.data import DataLoader
# from src.models import BackBone

# bs = 4
# ds = KITTIWithDepth("./data/KITTI-2012/")
# dl = DataLoader(ds, bs)

# backbone = BackBone("ConvNeXt").cuda(0)
# ln = DepthLearner(backbone)
# ln.set_devices([0, 1])

# for batch in dl:
#     out = ln(batch)
#     for k, v in out.items():
#         print(f"{k}: {v.shape}")
#     break

## Validate

In [None]:
# import torch
# from src.learners import DepthLearner
# from src.models import BackBone
# from src.constants import flow_img_wh
# from mt_pipe.src.test.external.util import validate_nested_obj


# bs = 4

# backbone = BackBone("ConvNeXt").cuda(0)
# ln = DepthLearner(backbone)
# ln.set_devices([0, 1])

# mock_batch = {
#     "img": torch.Tensor(4, 3, *flow_img_wh[::-1]),
#     "depth_map": torch.Tensor(4, 1, *flow_img_wh[::-1]),
# }
# out = ln(mock_batch)

# expected_out_conf = {
#     "l2": {
#         "shape": (bs, 1, *[int(d / 2**5) for d in flow_img_wh[::-1]]),
#         "dtype": "torch.float32",
#     },
#     "l3": {
#         "shape": (bs, 1, *[int(d / 2**4) for d in flow_img_wh[::-1]]),
#         "dtype": "torch.float32",
#     },
#     "l4": {
#         "shape": (bs, 1, *[int(d / 2**3) for d in flow_img_wh[::-1]]),
#         "dtype": "torch.float32",
#     },
#     "l5": {
#         "shape": (bs, 1, *[int(d / 2**2) for d in flow_img_wh[::-1]]),
#         "dtype": "torch.float32",
#     },
#     "l6": {
#         "shape": (bs, 1, *[int(d / 2**1) for d in flow_img_wh[::-1]]),
#         "dtype": "torch.float32",
#     },
#     "l7": {"shape": (bs, 1, *flow_img_wh[::-1]), "dtype": "torch.float32"},
#     "pred": {"shape": (bs, 1, *flow_img_wh[::-1]), "dtype": "torch.float32"},
# }
# valid, msg = validate_nested_obj(out, expected_out_conf)
# assert valid, msg

# All Learners

## Usage

In [None]:
from mt_pipe.src.util.learner_mux import LearnerMux
from src.datasets import KITTIWithCalibration, ImageNetVICReg
from torch.utils.data import DataLoader

ds1 = KITTIWithCalibration("./data/KITTI-2012", img_wh=[128,64])
dl1 = DataLoader(ds1, 4)
ds2 = ImageNetVICReg("./data/ImageNet-2012", img_wh=[128,64])
dl2 = DataLoader(ds2, 4)

ln = LearnerMux(
    chldrn={
        "flow_learner": {
            "target": "src.learners.FlowLearner",
            "out_map": {"flow_path": "flow"},
        },
        # "depth_learner": {
        #     "target": "src.learners.DepthLearner",
        #     "in_map": {
        #         "depth_path_1": {"img1": "img"},
        #         "depth_path_2": {"img2": "img"},
        #     },
        #     "out_map": {
        #         "depth_path_1": "depth1",
        #         "depth_path_2": "depth2",
        #     },
        # },
        "content_learner": {
            "target": "src.learners.ContentLearner",
            "out_map": {"content_path": "content"},
        },
    },
    encoder={
        "target": "src.models.backbone.BackBone",
        "params": {
            "enc_name": "ConvNeXt",
        },
    },
)
ln.set_devices([0, 1])

for batch1, batch2 in zip(dl1, dl2):
    batch = {**batch1, **batch2}
    out = ln(batch)
    break

# ClassLearner

## Usage

In [None]:
from src.datasets import ImageNetClassify
from torch.utils.data import DataLoader
from src.learners import ClassLearner
from src.models import BackBone

bs = 32
ds = ImageNetClassify("./data/ImageNet-2012/")
dl = DataLoader(ds, bs)

pre_loaded_backbone = BackBone("ConvNeXt").cuda(0)
backbone_conf = {"target": "src.models.BackBone", "params": {"enc_name": "ConvNeXt"}}

for backbone in [pre_loaded_backbone, backbone_conf]:  # backbone can be any of these
    ln = ClassLearner(encoder=backbone, n_classes=ds.n_classes)
    ln.set_devices([0, 1])

    # optionally freeze the encoder and load a checkpoint
    for param in ln.encoder.parameters():
        param.requires_grad = False
    ln.load_ckeckpoint("models/mock.ckpt")

    for batch in dl:
        out = ln(batch)
        for k, v in out.items():
            print(f"{k}: {v.shape}")
        break

## Validation

In [None]:
import torch
from src.learners import ClassLearner
from src.models import BackBone
from mt_pipe.src.test.external.util import validate_nested_obj


mock_batch = {"img": torch.Tensor(32, 3, 128, 128), "lbl": torch.Tensor(32)}
expected_out_conf = {"logits": {"shape":(32, 1000), "dtype":"torch.float32"}}


pre_loaded_backbone = BackBone("ConvNeXt").cuda(0)
backbone_conf = {"target": "src.models.BackBone", "params": {"enc_name": "ConvNeXt"}}

for backbone in [pre_loaded_backbone, backbone_conf]:  # backbone can be any of these
    ln = ClassLearner(encoder=backbone, n_classes=1000)
    ln.set_devices([0, 1])
    ln.load_ckeckpoint("models/mock.ckpt")  # optionally load a checkpoint
    out = ln(mock_batch)
    valid, msg = validate_nested_obj(out, expected_out_conf)
    assert valid, msg

# SegmentLearner

## Usage

In [None]:
from src.datasets import COCOSegment
from torch.utils.data import DataLoader
from src.learners import SegmentLearner
from src.models import BackBone

bs = 8
ds = COCOSegment("./data/COCO-2017/")
dl = DataLoader(ds, bs)

pre_loaded_backbone = BackBone("ConvNeXt").cuda(0)
backbone_conf = {"target": "src.models.BackBone", "params": {"enc_name": "ConvNeXt"}}

for backbone in [pre_loaded_backbone, backbone_conf]:  # backbone can be any of these
    ln = SegmentLearner(encoder=backbone, n_classes=len(ds.classes))
    ln.set_devices([0, 1])
    ln.load_ckeckpoint("models/mock.ckpt")  # optionally load a checkpoint
    for batch in dl:
        out = ln(batch)
        for k, v in out.items():
            print(f"{k}: {v.shape}")
        break

## Validation

In [None]:
import torch
from src.learners import SegmentLearner
from src.models import BackBone
from mt_pipe.src.test.external.util import validate_nested_obj

mock_batch = {
    "img": torch.Tensor(5, 3, 480, 640),
    "seg": torch.Tensor(5, 80, 480, 640),
}
expected_out_conf = {"seg": {"shape":(5, 80, 480, 640), "dtype":"torch.float32"}}

pre_loaded_backbone = BackBone("ConvNeXt").cuda(0)
backbone_conf = {"target": "src.models.BackBone", "params": {"enc_name": "ConvNeXt"}}

for backbone in [pre_loaded_backbone, backbone_conf]:  # backbone can be any of these
    ln = SegmentLearner(encoder=backbone, n_classes=80)
    ln.set_devices([0, 1])
    ln.load_ckeckpoint("models/mock.ckpt")  # optionally load a checkpoint
    out = ln(mock_batch)
    valid, msg = validate_nested_obj(out, expected_out_conf)
    assert valid, msg