# Backbone
## Usage

In [None]:
from src.learners import BackBoneLearner
from src.datasets import ClassDataset
from src.datasets.collate_fns import Augmentor
from torch.utils.data import DataLoader
from src.datasets.collate_fns import aug_collate_fn
from src.losses import MMCRLoss
from src.util import Logger
from functools import partial

logger = Logger(1)
device = 2

model = BackBoneLearner(
    enc_name="ResNet50", enc_params={"drop_out": 0.2}, logger=logger, devices=[device]
)
ds = ClassDataset("datasets/tea-grade-v2", resize_wh=[300, 300])
aug = Augmentor([224, 224])
dl = DataLoader(ds, 16, collate_fn=partial(aug_collate_fn, aug=aug, aug_count=8))
loss_fn = MMCRLoss(logger=logger, device=device, lamb=0.01)

for batch in dl:
    lbls, imgs = batch["lbl"], batch["img"]
    out = model(batch)
    loss = loss_fn(out, batch, False)
    print("lbls.shape", lbls.shape)
    print("imgs.shape", imgs.shape)
    print('out["embs"].shape', out["embs"].shape)
    print("loss", loss)

    break

# ClassifierLearner
## Usage

In [None]:
from src.learners import ClassifierLearner
from src.datasets import ClassDataset
from src.datasets.collate_fns import Augmentor
from torch.utils.data import DataLoader
from src.datasets.collate_fns import aug_collate_fn
from src.losses import MMCRLoss, CrossEntropyLoss
from src.util import Logger
from functools import partial

logger = Logger(1)
device = 2

model = ClassifierLearner(
    devices=[device, device],
    logger=logger,
    head_trail=[34],
    enc_name="ResNet50",
    enc_params={"drop_out": 0.2},
    dec_params={"drop_out": 0.2, "use_batch_norm":True},
)
ds = ClassDataset("datasets/tea-grade-v2", resize_wh=[300, 300])
aug = Augmentor([224, 224])
dl = DataLoader(ds, 16, collate_fn=partial(aug_collate_fn, aug=aug, aug_count=8))
loss_fn1 = MMCRLoss(device=device, logger=logger, lamb=0.01)
loss_fn2 = CrossEntropyLoss(device=device, logger=logger, has_aug_ax=True)

for batch in dl:
    lbls, imgs = batch["lbl"], batch["img"]
    out = model(batch)
    loss1 = loss_fn1(out, batch, False)
    loss2 = loss_fn2(out, batch, False)
    loss = loss1 + loss2
    print("lbls.shape", lbls.shape)
    print("imgs.shape", imgs.shape)
    print('out["embs"].shape', out["embs"].shape)
    print("loss", loss)

    break