# CrossEntropy

In [1]:
import torch
from src.losses import CrossEntropyLoss
from src.util import Logger
import numpy as np

device = 2
logger = Logger(1)

loss_fn_ce = CrossEntropyLoss(device=device, logger=logger)
loss_fn_ce_aug = CrossEntropyLoss(device=device, logger=logger, has_aug_ax=True)

B, K, C, W, H = 16, 8, 3, 224, 224
out_D = 20
mock_out_ce = {"logits": torch.Tensor(B, out_D)}
mock_batch_ce = {
    "lbl": torch.Tensor(np.random.randint(0, out_D, B)).to(int),
    "img": torch.Tensor(B, C, W, H),
}
mock_out_ce_aug = {"logits": torch.Tensor(B, K, out_D)}
mock_batch_ce_aug = {
    "lbl": torch.Tensor(np.random.randint(0, out_D, (B, K))).to(int),
    "img": torch.Tensor(B, K, C, W, H),
}

loss_ce = loss_fn_ce(mock_out_ce, mock_batch_ce, plot=False)
loss_ce_aug = loss_fn_ce_aug(mock_out_ce_aug, mock_batch_ce_aug, plot=False)

assert loss_ce.shape == torch.Size([])
assert loss_ce_aug.shape == torch.Size([])
assert loss_ce.dtype == torch.float
assert loss_ce_aug.dtype == torch.float
assert loss_ce.cpu().detach().item() >= 0 or torch.isnan(loss_ce)
assert loss_ce_aug.cpu().detach().item() >= 0 or torch.isnan(loss_ce_aug)

# MMCR

In [3]:
import torch
from src.losses import MMCRLoss
from src.util import Logger

device = 2
logger = Logger(1)

loss_fn = MMCRLoss(logger=logger, device=device, lamb=0.01)
mock_batch = {"lbl": torch.Tensor(16, 8), "img": torch.Tensor(16, 8, 3, 224, 224)}
mock_out = {"embs": torch.Tensor(16, 8, 768)}
loss = loss_fn(mock_out, mock_batch, plot=False)

assert loss.shape == torch.Size([])
assert loss.dtype == torch.float

# ConcatLoss

In [5]:
from omegaconf import OmegaConf
from src.util import Logger
import torch
import numpy as np
from src.losses import ConcatLoss, MMCRLoss, CrossEntropyLoss

device = 2
logger = Logger(1)

loss_fn_aug = MMCRLoss(logger=logger, device=device, lamb=0.01)
loss_fn_ce = CrossEntropyLoss(device=device, logger=logger, has_aug_ax=True)

conf1 = OmegaConf.create(
    {
        "aug": {"target": "src.losses.MMCRLoss", "params": {"lamb": 0.01}},
        "cross_entropy": {
            "target": "src.losses.CrossEntropyLoss",
            "params": {"has_aug_ax": True},
        },
    }
)
conf2 = {
    "aug": loss_fn_aug,
    "cross_entropy": loss_fn_ce,
}

for conf in [conf1, conf2]:
    conc_loss_fn = ConcatLoss(device, logger, conf)

    B, K, C, W, H = 16, 8, 3, 224, 224
    emb_D, out_D = 768, 20
    mock_batch = {
        "lbl": torch.Tensor(np.random.randint(0, out_D, (B, K))).to(int),
        "img": torch.Tensor(B, K, C, H, W),
    }
    mock_out = {"embs": torch.Tensor(B, K, emb_D), "logits": torch.Tensor(B, K, out_D)}
    loss = conc_loss_fn(mock_out, mock_batch, False)
    assert loss.dtype == torch.float
    assert loss.shape == torch.Size([])