In [None]:
# !pip install catalyst

## Imports

In [1]:
import os

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from catalyst import dl, utils
import catalyst.contrib.nn.schedulers as schedulers


torch.backends.cudnn.benchmark = True
torch.use_deterministic_algorithms(False)

## Dataset

In [2]:
#TODO Аналогичный датасет

class CelebaSpoofDataset(Dataset):
    pass

## Utils

In [2]:
from torchvision.datasets import FakeData


def get_loaders(size, batch_size):
    data_transforms = {
        "train": transforms.Compose(
            [
                #  transforms.ToPILImage(),
                transforms.CenterCrop((size, size)),
                transforms.RandomRotation((-10, 10)),
                transforms.RandomHorizontalFlip(0.2),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        ),
        "valid": transforms.Compose(
            [
                # transforms.ToPILImage(),
                transforms.CenterCrop((size, size)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        ),
    }

    image_datasets = {
        "train": FakeData(num_classes=2, transform=data_transforms["train"]),
        "valid": FakeData(num_classes=2, transform=data_transforms["valid"]),
    }

    # image_datasets = {"train": CelebaSpoofDataset(LOCAL_ROOT,
    #                                           os.path.join("metas",
    #                                                        "intra_test",
    #                                                        "train_label.json"),
    #                                           data_transforms["train"]),
    #               "valid": CelebaSpoofDataset(LOCAL_ROOT,
    #                                           os.path.join("metas",
    #                                                        "intra_test",
    #                                                        "test_label.json"),
    #                                           data_transforms["valid"])}

    loaders = {
        "train": DataLoader(
            image_datasets["train"], batch_size=batch_size, shuffle=True, num_workers=0
        ),
        "valid": DataLoader(
            image_datasets["valid"], batch_size=batch_size, shuffle=False, num_workers=0
        ),
    }
    return loaders


def get_model(num_classes):
    model = models.mobilenet_v3_large()
    model.load_state_dict(torch.load("mobilenet_v3_large-8738ca79.pth"))
    
    in_features = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(in_features, num_classes)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return model.to(device)


def bn_freeze(model):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            if hasattr(module, "weight"):
                module.weight.requires_grad_(False)
            if hasattr(module, "bias"):
                module.bias.requires_grad_(False)
            module.eval()
    return model


## Metrics

In [3]:
from typing import Optional, Dict, Tuple, List
from collections import defaultdict

from catalyst.metrics._classification import StatisticsMetric


class FprTprMetric(StatisticsMetric):
    def __init__(
        self,
        zero_division: int = 0,
        compute_on_call: bool = True,
        prefix: Optional[str] = None,
        suffix: Optional[str] = None,
    ):
        """Init FprTprMetric instance"""
        super().__init__(
            num_classes=2,
            mode="binary",
            compute_on_call=compute_on_call,
            prefix=prefix,
            suffix=suffix,
        )
        self.zero_division = zero_division
        self.reset()

    @staticmethod
    def _convert_metrics_to_kv(fpr_value: float, tpr_value: float) -> Dict[str, float]:
        """
        Convert list of metrics to key-value
        Args:
            FPR: FPR value
            TPR: TPR value
        Returns:
            dict of metrics
        """
        kv_metrics = {
            "FPR": fpr_value,
            "TPR": tpr_value,
        }
        return kv_metrics

    def reset(self) -> None:
        """Reset all the statistics and metrics fields."""
        self.statistics = defaultdict(float)

    def update(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Tuple[float, float, float]:
        """
        Update statistics and return metrics intermediate results
        Args:
            outputs: predicted labels
            targets: target labels
        Returns:
            tuple of intermediate metrics: precision, recall, f1 score
        """
        outputs = outputs.argmax(axis=1)
        tn, fp, fn, tp, _ = super().update(outputs=outputs, targets=targets)
        fpr_value, tpr_value = get_fpr_tpr(
            tp=tp,
            fp=fp,
            fn=fn,
            tn=tn,
            zero_division=self.zero_division,
        )
        return fpr_value, tpr_value

    def update_key_value(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Dict[str, float]:
        """
        Update statistics and return metrics intermediate results
        Args:
            outputs: predicted labels
            targets: target labels
        Returns:
            dict of intermediate metrics
        """
        fpr_value, tpr_value = self.update(outputs=outputs, targets=targets)
        kv_metrics = self._convert_metrics_to_kv(
            fpr_value=fpr_value,
            tpr_value=tpr_value,
        )
        return kv_metrics

    def compute(self) -> Tuple[float, float, float]:
        """
        Compute metrics with accumulated statistics
        Returns:
            tuple of metrics: fpr, tpr
        """
        # @TODO: ddp hotfix, could be done better
        if self._is_ddp:
            for key in self.statistics:
                value: List[float] = all_gather(self.statistics[key])
                value: float = sum(value)
                self.statistics[key] = value

        fpr_value, tpr_value = get_fpr_tpr(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            tn=self.statistics["tn"],
            zero_division=self.zero_division,
        )
        return fpr_value, tpr_value

    def compute_key_value(self) -> Dict[str, float]:
        """
        Compute metrics with all accumulated statistics
        Returns:
            dict of metrics
        """
        fpr_value, tpr_value = self.compute()
        kv_metrics = self._convert_metrics_to_kv(
            fpr_value=fpr_value,
            tpr_value=tpr_value,
        )
        return kv_metrics


def get_fpr_tpr(tp, fp, fn, tn, zero_division):
    fpr_value = fpr(fp=fp, tn=tn, zero_division=zero_division)
    tpr_value = tpr(tp=tp, fn=fn, zero_division=zero_division)

    return fpr_value, tpr_value


def fpr(fp: int, tn: int, zero_division: int):
    if fp == 0 and tn == 0:
        return zero_division
    return fp / (fp + tn)


def tpr(tp: int, fn: int, zero_division: int):
    if tp == 0 and fn == 0:
        return zero_division
    return tp / (tp + fn)


In [4]:
class FprTprCallback(dl.BatchMetricCallback):
    def __init__(
        self,
        input_key: str,
        target_key: str,
        zero_division: int = 0,
        log_on_batch: bool = True,
        prefix: str = None,
        suffix: str = None,
    ):
        """Init."""
        super().__init__(
            metric=FprTprMetric(
                zero_division=zero_division, prefix=prefix, suffix=suffix
            ),
            input_key=input_key,
            target_key=target_key,
            log_on_batch=log_on_batch,
        )


## Model training

In [5]:
LR = 1e-3
EPOCHS = 5
BS = 8
SIZE = 224

loaders = get_loaders(SIZE, BS)
model = get_model(num_classes=2)
model = bn_freeze(model)

optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer, max_lr=LR, steps_per_epoch=len(loaders["train"]), epochs=EPOCHS
)
criterion = nn.CrossEntropyLoss()

runner = dl.SupervisedRunner(
    input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)


In [7]:
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    num_epochs=EPOCHS,
    callbacks=[
        FprTprCallback(input_key="logits", target_key="targets"),
#         dl.ConfusionMatrixCallback(
#             input_key="logits", target_key="targets", num_classes=2
#         ),
        dl.CheckpointCallback(
            use_runner_logdir=True,
            loader_key="valid",
            metric_key="loss",
            minimize=True,
            save_n_best=2,
        ),
        dl.SchedulerCallback(mode="batch"),
    ],
    logdir="./logs",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    load_best_on_end=True,
)


ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir logs/