In [1]:
from dataset import CifarDataModule

from art.project import ArtProject
from art.checks import CheckResultExists, CheckScoreExists, CheckScoreLessThan, CheckScoreGreaterThan
from art.steps import EvaluateBaseline, OverfitOneBatch, Overfit#, TransferLearning
from torchmetrics import Accuracy, Precision, Recall
import torch.nn as nn
from lightning.pytorch.callbacks import EarlyStopping

from dataset import CifarDataModule
from checks import CheckClassImagesExist, CheckLenClassNamesEqualToNumClasses
from steps import DataAnalysis

In [2]:
%load_ext autoreload
%autoreload 2
from lightning import seed_everything
seed_everything(23)

Global seed set to 23


23

# Proper data analysis

In [3]:
project = ArtProject("Cifar100", CifarDataModule(batch_size=32))
project.add_step(DataAnalysis(), [
    CheckResultExists("number_of_classes"),
    CheckResultExists("class_names"),
    CheckResultExists("number_of_examples_in_each_class"),
    CheckResultExists("img_dimensions"),
    CheckClassImagesExist(),
    CheckLenClassNamesEqualToNumClasses()])
project.run_all()


Summary: 
Step: Data analysis, Model: , Passed: False. Results:



In [4]:
from torchmetrics import Accuracy
import torch.nn as nn
from art.steps import EvaluateBaseline

NUM_CLASSES = project.get_step(0).get_latest_run()["number_of_classes"]
accuracy_metric, ce_loss = Accuracy(task="multiclass", num_classes = NUM_CLASSES), nn.CrossEntropyLoss()
project.register_metrics([accuracy_metric, ce_loss])

In [5]:
from art.metrics import SkippedMetric
from models.baselines import MlBaseline, HeuristicBaseline, AlreadyExistingSolutionBaseline
from art.checks import CheckScoreExists
# baselines = [HeuristicBaseline, MlBaseline, AlreadyExistingSolutionBaseline]
# for baseline in baselines:
#     project.add_step(
#         step = EvaluateBaseline(baseline), 
#         checks = [CheckScoreExists(metric=accuracy_metric)],
#         skipped_metrics=[SkippedMetric(metric=ce_loss)]
#     )
project.add_step(
    step = EvaluateBaseline(MlBaseline), 
    checks = [CheckScoreExists(metric=accuracy_metric)],
    skipped_metrics=[SkippedMetric(metric=ce_loss)]
)


In [6]:
project.run_all()

Summary: 
Step: Data analysis, Model: , Passed: False. Results:

