In [None]:
def print_deep_loss_pack(loss_pack, key=""):
    if type(loss_pack)!=dict:
        print(f"{key}: {loss_pack.item() if loss_pack is not None else None}")
    else:
        for k, v in loss_pack.items():
            print_deep_loss_pack(v, f"{key}:{k}" if key !="" else k)

# ContentLoss

## Usage

In [None]:
from src.learners import ContentLearner
from src.models import BackBone
from src.datasets import ImageNetVICReg
from src.losses import ContentLoss
from torch.utils.data import DataLoader

ds = ImageNetVICReg("./data/ImageNet-2012")
dl = DataLoader(ds, 8)

bb = BackBone("ConvNeXt").cuda(0)
ln = ContentLearner(bb)
ln.set_devices([0,1])
loss_fn = ContentLoss()

for batch in dl:
    info = ln(batch)
    loss_pack = loss_fn(info, batch)
    print_deep_loss_pack(loss_pack)
    break

## Validation

In [None]:
from src.losses import ContentLoss
import torch
from mt_pipe.src.test.external.util import validate_nested_obj

loss_fn = ContentLoss()
mock_info = {
    "X_one": torch.Tensor(8, 768),
    "X_two": torch.Tensor(8, 768),
    "Y_one": torch.Tensor(8, 8192),
    "Y_two": torch.Tensor(8, 8192),
}
mock_batch = {
    "view1": torch.Tensor(8, 3, 128, 128),
    "view2": torch.Tensor(8, 3, 128, 128),
}
loss_pack = loss_fn(mock_info, mock_batch)
expected_loss_pack_conf = {
    "tot": {"shape": (), "dtype": "torch.float32"},
    "Content_X": {
        "tot": {"shape": (), "dtype": "torch.float32"},
        "Inv": {"shape": (), "dtype": "torch.float32"},
        "Var": {"shape": (), "dtype": "torch.float32"},
        "Cov": {"shape": (), "dtype": "torch.float32"},
    },
    "Content_Y": {
        "tot": {"shape": (), "dtype": "torch.float32"},
        "Inv": {"shape": (), "dtype": "torch.float32"},
        "Var": {"shape": (), "dtype": "torch.float32"},
        "Cov": {"shape": (), "dtype": "torch.float32"},
    },
}
valid, msg = validate_nested_obj(loss_pack, expected_loss_pack_conf)
assert valid, msg

# FlowLoss

## SSL

### Usage

In [None]:
from src.learners import FlowLearner
from src.models import BackBone
from src.datasets import KITTI
from src.losses import SSLFlowLoss
from torch.utils.data import DataLoader

ds = KITTI("./data/KITTI-2012", img_wh=[128,64])
dl = DataLoader(ds, 8)

bb = BackBone("ConvNeXt").cuda(0)
ln = FlowLearner(bb)
ln.set_devices([0, 1])

loss_fn = SSLFlowLoss(1, loss_weights={"loss_pixel": 9})

for batch in dl:
    info = ln(batch)
    loss_pack = loss_fn(info, batch)
    print_deep_loss_pack(loss_pack)
    break

### Validation

In [None]:
from src.losses import SSLFlowLoss
import torch
from mt_pipe.src.test.external.util import make_random_nested_tens
from mt_pipe.src.test.external.util import validate_nested_obj

loss_fn = SSLFlowLoss()

h, w = 64, 128
bs = 8
mock_batch = {
    "img1": torch.Tensor(bs, 3, h, w),
    "img2": torch.Tensor(bs, 3, h, w),
}
mock_info_conf = {
    "feature_pyramid_one": {
        "emb": [8, 768],
        "l6": [8, 768, 1, 2],
        "l5": [8, 384, 2, 4],
        "l4": [8, 192, 4, 8],
        "l3": [8, 96, 8, 16],
        "l2": [8, 48, 16, 32],
        "l1": [8, 48, 32, 64],
    },
    "feature_pyramid_two": {
        "emb": [8, 768],
        "l6": [8, 768, 1, 2],
        "l5": [8, 384, 2, 4],
        "l4": [8, 192, 4, 8],
        "l3": [8, 96, 8, 16],
        "l2": [8, 48, 16, 32],
        "l1": [8, 48, 32, 64],
    },
    "optical_flows": [
        [8, 2, 64, 128],
        [8, 2, 32, 64],
        [8, 2, 16, 32],
        [8, 2, 8, 16],
    ],
    "optical_flows_rev": [
        [8, 2, 64, 128],
        [8, 2, 32, 64],
        [8, 2, 16, 32],
        [8, 2, 8, 16],
    ],
    "img1_valid_masks": [
        [8, 1, 64, 128],
        [8, 1, 32, 64],
        [8, 1, 16, 32],
        [8, 1, 8, 16],
    ],
    "img2_valid_masks": [
        [8, 1, 64, 128],
        [8, 1, 32, 64],
        [8, 1, 16, 32],
        [8, 1, 8, 16],
    ],
    "fwd_flow_diff_pyramid": [
        [8, 2, 64, 128],
        [8, 2, 32, 64],
        [8, 2, 16, 32],
        [8, 2, 8, 16],
    ],
    "bwd_flow_diff_pyramid": [
        [8, 2, 64, 128],
        [8, 2, 32, 64],
        [8, 2, 16, 32],
        [8, 2, 8, 16],
    ],
}
mock_info = make_random_nested_tens(mock_info_conf)
loss_pack = loss_fn(mock_info, mock_batch)
expected_loss_pack_conf = {
    "tot": {"shape": (), "dtype": "torch.float32"},
    "loss_pixel": {"shape": (), "dtype": "torch.float32"},
    "loss_ssim": {"shape": (), "dtype": "torch.float32"},
    "loss_flow_smooth": {"shape": (), "dtype": "torch.float32"},
    "loss_flow_consis": {"shape": (), "dtype": "torch.float32"},
}
valid, msg = validate_nested_obj(loss_pack, expected_loss_pack_conf)
assert valid, msg

## GTFlowLoss

TODO

# DepthLoss

## SSL

### Usage

In [None]:
# from mt_pipe.src.util.learner_mux import LearnerMux
# from src.datasets import KITTIWithCalibration
# from src.losses.depth_loss import DepthLoss
# from torch.utils.data import DataLoader

# loss_fn = DepthLoss(1)

# ds = KITTIWithCalibration("./data/KITTI-2012")
# dl = DataLoader(ds, 4)

# ln = LearnerMux(
#     chldrn={
#         "flow_learner": {
#             "target": "src.learners.flow.FlowLearner",
#             "in_map": {"flow_path": "full"},
#             "out_map": {"flow_path": "flow"},
#         },
#         "depth_learner": {
#             "target": "src.learners.DepthLearner",
#             "in_map": {
#                 "depth_path_1": {"img1": "img"},
#                 "depth_path_2": {"img2": "img"},
#             },
#             "out_map": {
#                 "depth_path_1": "depth1",
#                 "depth_path_2": "depth2",
#             },
#         },
#     },
#     encoder={
#         "target": "src.models.backbone.BackBone",
#         "params": {
#             "enc_name": "ConvNeXt",
#         },
#     },
# )
# ln.set_devices([0, 1])

# for batch in dl:
#     info = ln(batch)
#     loss_pack = loss_fn(info, batch)
#     print_deep_loss_pack(loss_pack)
#     break

### Validation

In [None]:
# import torch
# from src.losses.depth_loss import DepthLoss
# from mt_pipe.src.test.external.util import make_random_nested_tens
# from mt_pipe.src.test.external.util import validate_nested_obj

# loss_fn = DepthLoss(0)

# mock_batch = {
#     "img1": torch.Tensor(4, 3, 64, 128),
#     "img2": torch.Tensor(4, 3, 64, 128),
#     "K": torch.Tensor(4, 1, 3, 3),
#     "K_inv": torch.Tensor(4, 1, 3, 3),
# }
# mock_info_conf = {
#     "flow": {
#         "feature_pyramid_one": {
#             "emb": (4, 768),
#             "l1": (4, 768, 1, 2),
#             "l2": (4, 384, 2, 4),
#             "l3": (4, 192, 4, 8),
#             "l4": (4, 96, 8, 16),
#             "l5": (4, 48, 16, 32),
#             "l6": (4, 48, 32, 64),
#         },
#         "feature_pyramid_two": {
#             "emb": (4, 768),
#             "l1": (4, 768, 1, 2),
#             "l2": (4, 384, 2, 4),
#             "l3": (4, 192, 4, 8),
#             "l4": (4, 96, 8, 16),
#             "l5": (4, 48, 16, 32),
#             "l6": (4, 48, 32, 64),
#         },
#         "warp_feature_pyramid_fwd": {
#             "l2": (4, 384, 2, 4),
#             "l3": (4, 192, 4, 8),
#             "l4": (4, 96, 8, 16),
#             "l5": (4, 48, 16, 32),
#             "l6": (4, 48, 32, 64),
#             "l7": (4, 3, 64, 128),
#         },
#         "warp_feature_pyramid_bwd": {
#             "l2": (4, 384, 2, 4),
#             "l3": (4, 192, 4, 8),
#             "l4": (4, 96, 8, 16),
#             "l5": (4, 48, 16, 32),
#             "l6": (4, 48, 32, 64),
#             "l7": None,
#         },
#         "flow_fwd": {
#             "l2": (4, 2, 2, 4),
#             "l3": (4, 2, 4, 8),
#             "l4": (4, 2, 8, 16),
#             "l5": (4, 2, 16, 32),
#             "l6": (4, 2, 32, 64),
#             "l7": (4, 2, 64, 128),
#         },
#         "flow_bwd": {
#             "l2": (4, 2, 2, 4),
#             "l3": (4, 2, 4, 8),
#             "l4": (4, 2, 8, 16),
#             "l5": (4, 2, 16, 32),
#             "l6": (4, 2, 32, 64),
#             "l7": (4, 2, 64, 128),
#         },
#         "img1_valid_mask": {
#             "l2": (4, 1, 2, 4),
#             "l3": (4, 1, 4, 8),
#             "l4": (4, 1, 8, 16),
#             "l5": (4, 1, 16, 32),
#             "l6": (4, 1, 32, 64),
#             "l7": (4, 1, 64, 128),
#         },
#         "img2_valid_mask": {
#             "l2": (4, 1, 2, 4),
#             "l3": (4, 1, 4, 8),
#             "l4": (4, 1, 8, 16),
#             "l5": (4, 1, 16, 32),
#             "l6": (4, 1, 32, 64),
#             "l7": (4, 1, 64, 128),
#         },
#         "img1_flow_diff_mask": {
#             "l2": (4, 2, 2, 4),
#             "l3": (4, 2, 4, 8),
#             "l4": (4, 2, 8, 16),
#             "l5": (4, 2, 16, 32),
#             "l6": (4, 2, 32, 64),
#             "l7": (4, 2, 64, 128),
#         },
#         "img2_flow_diff_mask": {
#             "l2": (4, 2, 2, 4),
#             "l3": (4, 2, 4, 8),
#             "l4": (4, 2, 8, 16),
#             "l5": (4, 2, 16, 32),
#             "l6": (4, 2, 32, 64),
#             "l7": (4, 2, 64, 128),
#         },
#     },
#     "depth1": {
#         "l2": (4, 1, 2, 4),
#         "l3": (4, 1, 4, 8),
#         "l4": (4, 1, 8, 16),
#         "l5": (4, 1, 16, 32),
#         "l6": (4, 1, 32, 64),
#         "l7": (4, 1, 64, 128),
#         "pred": (4, 1, 64, 128),
#     },
#     "depth2": {
#         "l2": (4, 1, 2, 4),
#         "l3": (4, 1, 4, 8),
#         "l4": (4, 1, 8, 16),
#         "l5": (4, 1, 16, 32),
#         "l6": (4, 1, 32, 64),
#         "l7": (4, 1, 64, 128),
#         "pred": (4, 1, 64, 128),
#     },
# }
# mock_info = make_random_nested_tens(mock_info_conf)
# loss_pack = loss_fn(mock_info, mock_batch)


# expected_loss_pack_conf = {
#     "tot": {"shape": (), "dtype": "torch.float32"},
#     "pt_depth_loss": {"shape": (), "dtype": "torch.float32"},
#     "pj_depth_loss": {"shape": (), "dtype": "torch.float32"},
#     "flow_loss": {"shape": (), "dtype": "torch.float32"},
#     "depth_smooth_loss": {"shape": (), "dtype": "torch.float32"},
# }
# tentative_none_mask = {
#     "pt_depth_loss": None,
#     "pj_depth_loss": None,
#     "flow_loss": None,
#     "depth_smooth_loss": None,
# }
# valid, msg = validate_nested_obj(loss_pack, expected_loss_pack_conf, tentative_none_mask)
# assert valid, msg

## Ground Truth

### Usage

In [None]:
# from src.learners import DepthLearner
# from src.datasets import KITTIWithDepth
# from src.losses import GTDepthLoss
# from torch.utils.data import DataLoader

# ds = KITTIWithDepth("./data/KITTI-2015/")
# dl = DataLoader(ds, 8)
# batch = next(iter(dl))

# loss_fn = GTDepthLoss(device=0)
# encoder_conf = {"target": "src.models.BackBone", "params": {"enc_name": "ConvNeXt"}}
# ln = DepthLearner(encoder_conf)
# ln.set_devices([0, 0])
# info = ln(batch)
# loss_pack = loss_fn(info, batch)

# print_deep_loss_pack(loss_pack)

### Validation

In [None]:
# from src.losses import GTDepthLoss
# from mt_pipe.src.test.external.util import validate_nested_obj, make_random_nested_tens

# loss_fn = GTDepthLoss()
# mock_info_conf = {
#     "l2": [8, 1, 2, 4],
#     "l3": [8, 1, 4, 8],
#     "l4": [8, 1, 8, 16],
#     "l5": [8, 1, 16, 32],
#     "l6": [8, 1, 32, 64],
#     "l7": [8, 1, 64, 128],
#     "pred": [8, 1, 64, 128],
# }
# mock_info = make_random_nested_tens(mock_info_conf)
# mock_batch_conf = {"img": [8, 3, 64, 128], "depth_map": [8, 1, 64, 128]}
# mock_batch = make_random_nested_tens(mock_batch_conf)
# loss_pack = loss_fn(mock_info, mock_batch)
# expected_loss_pack_conf = {
#     "tot": {"shape": [], "dtype": "torch.float32"},
#     "L1": {"shape": [], "dtype": "torch.float32"},
#     "Smooth": {"shape": [], "dtype": "torch.float32"},
# }
# valid, msg = validate_nested_obj(loss_pack, expected_loss_pack_conf)
# assert valid, msg

# All Losses

## Usage

In [None]:
from mt_pipe.src.util.learner_mux import LearnerMux
from src.datasets import KITTIWithCalibration, ImageNetVICReg
from torch.utils.data import DataLoader
from mt_pipe.src.losses import ConcatLoss

loss_fn = ConcatLoss(
    {
        "content": 0,
        "flow": 0,
        # "depth":0,
    },
    conf={
        "content": {
            "target": "src.losses.ContentLoss",
            "branch": "content",
            "params": {
                "loss_weights": {"vc_loss_X": [0.01, 0.04], "vic_loss_Y": [25, 1, 1]}
            },
        },
        "flow": {
            "target": "src.losses.SSLFlowLoss",
            "branch": "flow",
            # "params": {
            #     "loss_weights": {
            #         "cycle_loss": 0.2,
            #         "reconstruction_loss": 1,
            #         "reconstruction_loss_coeffs": [1, 1, 1],
            #         "regression_loss": 1,
            #         "smooth_loss": 75,
            #         "vc_loss": 1,
            #         "vc_loss_coeffs": {
            #             "l1": [0.01, 0.04],
            #             "l2": [0.01, 0.04],
            #             "l3": [0.01, 0.001],
            #             "l4": [0.01, 0],
            #             "l5": [0.001, 0],
            #             "l6": [0.0001, 0],
            #         },
            #     }
            # },
        },
        # "depth": {"target": "src.losses.DepthLoss"},
    },
)

ds1 = KITTIWithCalibration("./data/KITTI-2012", img_wh=[128, 64])
dl1 = DataLoader(ds1, 4)
ds2 = ImageNetVICReg("./data/ImageNet-2012", img_wh=[64, 64])
dl2 = DataLoader(ds2, 4)

ln = LearnerMux(
    chldrn={
        "flow_learner": {
            "target": "src.learners.FlowLearner",
            "in_map": {"flow_path": "full"},
            "out_map": {"flow_path": "flow"},
        },
        # "depth_learner": {
        #     "target": "src.learners.DepthLearner",
        #     "in_map": {
        #         "depth_path_1": {"img1": "img"},
        #         "depth_path_2": {"img2": "img"},
        #     },
        #     "out_map": {
        #         "depth_path_1": "depth1",
        #         "depth_path_2": "depth2",
        #     },
        # },
        "content_learner": {
            "target": "src.learners.ContentLearner",
            "in_map": {"content_path": "full"},
            "out_map": {"content_path": "content"},
        },
    },
    encoder={
        "target": "src.models.backbone.BackBone",
        "params": {
            "enc_name": "ConvNeXt",
        },
    },
)
ln.set_devices([0, 1])

for batch1, batch2 in zip(dl1, dl2):
    batch = {**batch1, **batch2}
    info = ln(batch)
    loss_pack = loss_fn(info, batch)
    print_deep_loss_pack(loss_pack)
    break

# Segmentation Loss

## Usage

In [None]:
from src.learners import SegmentLearner
from src.models import BackBone
from src.datasets import COCOSegment
from src.losses import SegmentationLoss
from torch.utils.data import DataLoader

ds = COCOSegment("./data/COCO-2017")
dl = DataLoader(ds, 8)

bb = BackBone("ConvNeXt").cuda(0)
ln = SegmentLearner(bb, n_classes=80)
ln.set_devices([0,1])
loss_fn = SegmentationLoss()

for batch in dl:
    info = ln(batch)
    loss_pack = loss_fn(info, batch)
    print_deep_loss_pack(loss_pack)
    break

## Validation

In [None]:
from src.losses import SegmentationLoss
import torch
from mt_pipe.src.test.external.util import validate_nested_obj

loss_fn = SegmentationLoss()
mock_info = {
    "seg": torch.Tensor(80, 128, 128)
}
mock_batch = {
    "seg": torch.Tensor(80, 128, 128)
}
loss_pack = loss_fn(mock_info, mock_batch)
expected_loss_pack_conf = {
    "tot": {"shape": (), "dtype": "torch.float32"},
    "Dice": {"shape": (), "dtype": "torch.float32"},
    "BCEWithLogits": {"shape": (), "dtype": "torch.float32"},
}
valid, msg = validate_nested_obj(loss_pack, expected_loss_pack_conf)
assert valid, msg