In [None]:
from MyDataset import CifarDataModule

from art.project import ArtProject
from art.checks import CheckResultExists, CheckScoreExists, CheckScoreLessThan, CheckScoreGreaterThan
from art.steps import EvaluateBaseline, OverfitOneBatch, Overfit#, TransferLearning
from torchmetrics import Accuracy, Precision, Recall
import torch.nn as nn
from lightning.pytorch.callbacks import EarlyStopping

from MyDataset import CifarDataModule
from checks import CheckClassImagesExist, CheckLenClassNamesEqualToNumClasses
from steps import DataAnalysis
import math

In [None]:
%load_ext autoreload
%autoreload 2
from lightning import seed_everything
seed_everything(23)

# Proper data analysis

In [None]:
project = ArtProject("Cifar100", CifarDataModule(batch_size=32))
project.add_step(DataAnalysis(), [
    CheckResultExists("number_of_classes"),
    CheckResultExists("class_names"),
    CheckResultExists("number_of_examples_in_each_class"),
    CheckResultExists("img_dimensions"),
    CheckClassImagesExist(),
    CheckLenClassNamesEqualToNumClasses()])
project.run_all()


In [None]:
from torchmetrics import Accuracy
import torch.nn as nn
from art.steps import EvaluateBaseline

NUM_CLASSES = project.get_step(0).get_latest_run()["number_of_classes"]
accuracy_metric, ce_loss = Accuracy(task="multiclass", num_classes = NUM_CLASSES), nn.CrossEntropyLoss()
project.register_metrics([accuracy_metric, ce_loss])

In [None]:
from art.metrics import SkippedMetric
from models.baselines import MlBaseline, HeuristicBaseline, AlreadyExistingResNet20Baseline
from art.checks import CheckScoreExists
baselines = [HeuristicBaseline, MlBaseline, AlreadyExistingResNet20Baseline]
for baseline in baselines:
    project.add_step(
        step = EvaluateBaseline(baseline), 
        checks = [CheckScoreExists(metric=accuracy_metric)],
        skipped_metrics=[SkippedMetric(metric=ce_loss)]
    )


In [None]:
project.run_all()

In [None]:
from models.ResNet import ResNet18
from art.steps import CheckLossOnInit
from art.checks import CheckScoreCloseTo
from torch.cuda import is_available

EXPECTED_LOSS = -math.log(1/NUM_CLASSES)
print(EXPECTED_LOSS)

In [None]:
project.add_step(
        CheckLossOnInit(ResNet18),
        [CheckScoreCloseTo(metric=ce_loss,
                           value=EXPECTED_LOSS, rel_tol=0.1)]
    )

project.run_all()