In [1]:
from dataset import CifarDataModule

from art.project import ArtProject
from art.checks import CheckResultExists, CheckScoreExists, CheckScoreLessThan, CheckScoreGreaterThan
from art.steps import EvaluateBaseline, OverfitOneBatch, Overfit#, TransferLearning
from torchmetrics import Accuracy, Precision, Recall
import torch.nn as nn
from lightning.pytorch.callbacks import EarlyStopping

from dataset import CifarDataModule
from checks import CheckClassImagesExist, CheckLenClassNamesEqualToNumClasses
from steps import DataAnalysis
import math

In [2]:
%load_ext autoreload
%autoreload 2
from lightning import seed_everything
seed_everything(23)

Global seed set to 23


23

# Proper data analysis

In [3]:
project = ArtProject("Cifar100", CifarDataModule(batch_size=32))
project.add_step(DataAnalysis(), [
    CheckResultExists("number_of_classes"),
    CheckResultExists("class_names"),
    CheckResultExists("number_of_examples_in_each_class"),
    CheckResultExists("img_dimensions"),
    CheckClassImagesExist(),
    CheckLenClassNamesEqualToNumClasses()])
project.run_all()


Summary: 
Step: Data analysis, Model: , Passed: True. Results:



In [4]:
from torchmetrics import Accuracy
import torch.nn as nn
from art.steps import EvaluateBaseline

NUM_CLASSES = project.get_step(0).get_latest_run()["number_of_classes"]
accuracy_metric, ce_loss = Accuracy(task="multiclass", num_classes = NUM_CLASSES), nn.CrossEntropyLoss()
project.register_metrics([accuracy_metric, ce_loss])

In [5]:
from art.metrics import SkippedMetric
from models.baselines import MlBaseline, HeuristicBaseline, AlreadyExistingResNet20Baseline
from art.checks import CheckScoreExists
baselines = [HeuristicBaseline, MlBaseline, AlreadyExistingResNet20Baseline]
for baseline in baselines:
    project.add_step(
        step = EvaluateBaseline(baseline), 
        checks = [CheckScoreExists(metric=accuracy_metric)],
        skipped_metrics=[SkippedMetric(metric=ce_loss)]
    )


In [6]:
project.run_all()

Summary: 
Step: Data analysis, Model: , Passed: True. Results:

Step: Evaluate Baseline, Model: HeuristicBaseline, Passed: True. Results:
	MulticlassAccuracy-validate: 0.012299999594688416
Step: Evaluate Baseline, Model: MlBaseline, Passed: True. Results:
	MulticlassAccuracy-validate: 0.15649999678134918
Step: Evaluate Baseline, Model: AlreadyExistingResNet20Baseline, Passed: True. Results:
	MulticlassAccuracy-validate: 0.6883000135421753


In [16]:
from models.ResNet import ResNet18
from art.steps import CheckLossOnInit
from art.checks import CheckScoreCloseTo
from torch.cuda import is_available
project.to("cuda" if is_available() else "cpu")

EXPECTED_LOSS = -math.log(1/NUM_CLASSES)
print(EXPECTED_LOSS)

4.605170185988091


In [17]:
project.add_step(
        CheckLossOnInit(ResNet18),
        [CheckScoreCloseTo(metric=ce_loss,
                           value=EXPECTED_LOSS, rel_tol=0.1)]
    )

project.run_all()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Calculating loss on init
Validating model ResNet18


Using cache found in C:\Users\matri/.cache\torch\hub\facebookresearch_semi-supervised-ImageNet1K-models_master


Validation: 0it [00:00, ?it/s]

Summary: 
Step: Data analysis, Model: , Passed: True. Results:

Step: Evaluate Baseline, Model: HeuristicBaseline, Passed: True. Results:
	MulticlassAccuracy-validate: 0.012299999594688416
Step: Evaluate Baseline, Model: MlBaseline, Passed: True. Results:
	MulticlassAccuracy-validate: 0.15649999678134918
Step: Evaluate Baseline, Model: AlreadyExistingResNet20Baseline, Passed: True. Results:
	MulticlassAccuracy-validate: 0.6883000135421753
Step: Check Loss On Init, Model: ResNet18, Passed: True. Results:
	MulticlassAccuracy-validate: 0.008419999852776527
	CrossEntropyLoss-validate: 5.054237365722656
Step: Check Loss On Init, Model: ResNet18, Passed: True. Results:
	MulticlassAccuracy-validate: 0.008419999852776527
	CrossEntropyLoss-validate: 5.054237365722656
Step: Check Loss On Init, Model: ResNet18, Passed: True. Results:
	MulticlassAccuracy-validate: 0.008419999852776527
	CrossEntropyLoss-validate: 5.054237365722656
