In [92]:
from importlib import reload

import daml_stage
import maite.protocols.image_classification as ic
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.models import ResNet50_Weights, resnet50
from utils import collect_metrics, collect_report_consumables, load_models_and_datasets, run_stages

reload(daml_stage)
from daml_stage import DamlStage

# Configure Pipeline Stages


### Panel Inputs


In [93]:
# ['CenterNet V2', 'visdrone-yolo']
model_str = "CenterNet V2"

# ['dev_train', 'dev_val', 'dev_test', 'op_train', 'op_val', 'op_test']
base_dataset_split = "dev_train"

# ['dev_train', 'dev_val', 'dev_test', 'op_train', 'op_val', 'op_test']
target_dataset_split = "dev_val"

# ['Accuracy', 'mAP']
metric = "Accuracy"

# Float
performance = 0.92

# ['Base', 'Target', 'Both']
linting = "Both"

# ['Base', 'Target', 'Both']
bias_detection = "Both"

# ['Base', 'Target', 'Both']
# feasibility = "Target"
feasibility = "Both"

# ['Base', 'Target', 'Both']
sufficiency = "Both"

### Backend Script


Load model


In [94]:
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
model = resnet50(weights)
isinstance(model, ic.Model)

comparison_weights = ResNet50_Weights.IMAGENET1K_V1
comparison_model = resnet50(comparison_weights)
isinstance(comparison_model, ic.Model)

True

Load datasets


In [95]:
class MaiteMNIST(Dataset):
    def __init__(self, train=True, transforms=None):
        self.dataset = MNIST("../data/", train=train, transform=transforms, download=True)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        imgs, labels = self.dataset[idx]
        return imgs, labels, {}


t = transforms.Compose([transforms.Grayscale(num_output_channels=3), preprocess])
dev_dataset = MaiteMNIST(True, t)
op_dataset = MaiteMNIST(False, t)
print(f"Dev Dataset is MAITE compliant: {isinstance(dev_dataset, ic.Dataset)}")
print(f"Op Dataset is MAITE compliant: {isinstance(op_dataset, ic.Dataset)}")

Dev Dataset is MAITE compliant: True
Op Dataset is MAITE compliant: True


Create Stage


In [96]:
ds = DamlStage(feasibility_opt=feasibility, bias_opt=bias_detection, linting_opt=linting, sufficiency_opt=sufficiency)
stages = [ds]

# Pipeline


In [97]:
load_models_and_datasets(
    dev_dataset=dev_dataset,
    op_dataset=op_dataset,
    model=model,
    comparison_model=comparison_model,
    target_performance=performance,
    stages=stages,
)

In [98]:
run_stages(stages=stages)

Cache miss


In [99]:
collect_metrics(stages=stages)

Returning metrics


{'dev_train': {'ber': 0.27, 'ber_lower': 0.147005976119332, 'bias': 0.5},
 'op_val': {'ber': 0.33, 'ber_lower': 0.18375981682120068, 'bias': 0.5}}

In [100]:
collect_report_consumables(stages=stages)

Returning Gradient parameters


{'dev_train': {'ber': 0.27, 'ber_lower': 0.147005976119332, 'bias': 0.5},
 'op_val': {'ber': 0.33, 'ber_lower': 0.18375981682120068, 'bias': 0.5}}