# CrossEntropy Loss

## Create MySimpleDataset

In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Sequence, Tuple
import os
import json
from PIL import Image


class MySimpleTeaDataset(Dataset):
    def __init__(
        self,
        root: str,
        wh:Sequence[int]=[224,224]
    ) -> None:
        split="train"

        # find paths
        with open(
            os.path.join(
                root, "annotations", "classification", "splits", f"{split}.txt"
            )
        ) as handler:
            paths = [os.path.join(root, p) for p in handler.read().strip().split("\n")]

        classes = [p.split("/")[-2] for p in paths]
        with open(
            os.path.join(root, "annotations", "classification", "classes.json")
        ) as handler:
            lbl_cls_map = {int(i): c for (i, c) in json.load(handler).items()}
        cls_lbl_map = {c: i for (i, c) in lbl_cls_map.items()}
        cls_lbls = [cls_lbl_map[c] for c in classes]


        trans_lst = [transforms.ToTensor()]
        if wh is not None:
            trans_lst.append(transforms.Resize(wh, antialias=True))
        self.transforms = transforms.Compose(trans_lst)

        self.paths = paths
        self.classes = classes
        self.cls_lbl_map = cls_lbl_map
        self.lbl_cls_map = lbl_cls_map
        self.cls_lbls = cls_lbls

    def __len__(self) -> int:
        return len(self.paths)

    def __getitem__(self, index) -> Tuple[int, torch.Tensor]:
        path = self.paths[index]
        img = Image.open(path).convert("RGB")
        img = self.transforms(img)
        lbl = self.cls_lbls[index]
        return lbl, img

## Train

In [2]:
from torch.nn import CrossEntropyLoss
from torch.nn import Linear
from torch.utils.data import DataLoader, Subset
from torchvision.models import resnet18
from torchvision.models import ResNet18_Weights
from src.util import GradAnalyzer
from torch.optim import Adam
from tqdm import tqdm
import time


device = 2
bs = 16
mock_bc = 4
ep_count = 2

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = Linear(512, 34)
model.to(device)

optimizer = Adam(model.parameters())
ds = Subset(MySimpleTeaDataset("datasets/tea-grade-v2"), range(0, bs * mock_bc))
dl = DataLoader(ds, bs)
loss_fn = CrossEntropyLoss()
grad_analyzer = GradAnalyzer(model, "out/grad-analyzer-demo/cross-entropy")

start = time.time()
for ep in range(ep_count):
    for batch in tqdm(dl, desc=f"EPOCH: {ep}"):
        lbls, imgs = batch
        lbls, imgs = lbls.to(device), imgs.to(device)
        out = model(imgs)
        loss = loss_fn(out, lbls)

        optimizer.zero_grad()
        loss.backward()
        grad_analyzer.batch_step()
        optimizer.step()

    grad_analyzer.epoch_step(ep)
end = time.time()

print(f"Time taken for CrossEntropy Loss: {end-start}")

EPOCH: 0: 100%|██████████| 4/4 [00:00<00:00,  4.60it/s]
EPOCH: 1: 100%|██████████| 4/4 [00:00<00:00,  4.37it/s]


Time taken for CrossEntropy Loss: 38.63070821762085


# MMCR

In [4]:
from src.learners import BackBoneLearner
from src.datasets import ClassDataset
from src.datasets.collate_fns import Augmentor
from torch.utils.data import DataLoader, Subset
from src.datasets.collate_fns import aug_collate_fn
from src.losses import MMCRLoss
from src.util import Logger
from functools import partial
from torch.optim import Adam
from tqdm import tqdm
import time
from src.constants import img_wh

logger = Logger(1)
device = 2
bs = 32
mock_bc = 4
ep_count = 2

model = BackBoneLearner(enc_name="ResNet18", enc_params={"drop_out":0.2}, logger=logger, devices=[device])
optimizer = Adam(model.parameters())
ds = Subset(ClassDataset("datasets/tea-grade-v2", resize_wh=[300, 300]), range(0, bs * mock_bc))
aug = Augmentor(img_wh)
dl = DataLoader(ds, bs, collate_fn=partial(aug_collate_fn, aug=aug, aug_count=6))
loss_fn = MMCRLoss(device=0, lamb=0.01)
logger.init_plotter("out/grad-analyzer-demo/mmcr", model)

start = time.time()
for ep in range(ep_count):
    for batch in tqdm(dl, desc=str(ep)):
        lbls, imgs = batch["lbl"], batch["img"]
        out = model(batch)
        loss = loss_fn(out, batch)["tot"]

        optimizer.zero_grad()
        loss.backward()
        logger.batch_step(analyze_grads=True)
        optimizer.step()

    logger.step(ep, analyze_grad=True)
end = time.time()

print(f"Time taken for MMCR Loss: {end-start}")

0: 100%|██████████| 4/4 [00:02<00:00,  1.70it/s]
1: 100%|██████████| 4/4 [00:02<00:00,  1.88it/s]


Time taken for MMCR Loss: 36.26873970031738
