In [None]:
import csv
import itertools
import pathlib
import shutil
from math import ceil, sqrt
from pathlib import Path
from pprint import pprint
from typing import Callable, List, Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import rootutils
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from skimage.metrics import structural_similarity
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from torchmetrics import (
    Accuracy,
    ConfusionMatrix,
    F1Score,
    MaxMetric,
    MeanMetric,
    Precision,
    Recall,
)
from torchvision.io import read_image
from tqdm.notebook import tqdm

# from IPython.display import set_matplotlib_formats
# set_matplotlib_formats('png')

root = rootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)

from src.data.components.dataset import (
    FetalBrainPlanesDataset,
    USVideosDataset,
    USVideosFrameDataset,
    VideoQualityDataset,
)
from src.data.components.transforms import (
    Affine,
    HorizontalFlip,
    LabelEncoder,
    RandomPercentCrop,
    VerticalFlip,
)
from src.data.utils.utils import show_numpy_images, show_pytorch_images
from src.models.fetal_module import FetalLitModule
from src.models.quality_module import QualityLitModule

data_dir = root / "data"
root

In [None]:
# model_log_dir = root / "logs" / "train" / "multiruns" / "2023-11-07_21-10-40" / "0"  # frosty-forest-2691


model_log_dir = root / "logs" / "train" / "runs" / "2025-05-04_22-41-24"  # civilized-droid-3054
# model_log_dir = root / "logs" / "train" / "runs" / "2025-05-05_09-31-12"  # carbonite-ewok-3069
# model_log_dir = root / "logs" / "train" / "runs" / "2025-05-04_14-12-45"  # tusken-transport-3040

# model_log_dir = root / "logs" / "train" / "runs" / "2025-05-05_02-41-36"  # tusken-ewok-3060 best test/acc

# model_log_dir = root / "logs" / "train" / "runs" / "2024-06-26_07-45-41"  # neat-aardvark-2941
# model_log_dir = root / "logs" / "train" / "runs" / "2024-06-22_03-38-18"  # fine-lion-2828
# model_log_dir = root / "logs" / "train" / "runs" / "2024-06-21_13-37-18"  # fresh-grass-2813
# model_log_dir = root / "logs" / "train" / "runs" / "2024-06-23_18-16-01"  # prime-butterfly-2873
# model_log_dir = root / "logs" / "train" / "runs" / "2024-06-24_12-02-44"  # lilac-frost-2893

# model_log_dir = root / "logs" / "train" / "runs" / "2024-06-22_21-48-38"  # glowing-sea-2849
# model_log_dir = root / "logs" / "train" / "runs" / "2024-06-21_19-18-09"  # swift-butterfly-2819
# model_log_dir = root / "logs" / "train" / "runs" / "2024-06-20_07-04-24"  # sunny-flower-2780

# model_log_dir = root / "logs" / "train" / "runs" / ""  #

checkpoint = sorted(model_log_dir.glob("checkpoints/epoch_*.ckpt"))[-1]
model = FetalLitModule.load_from_checkpoint(str(checkpoint))
# disable randomness, dropout, etc...
model.eval()

model.hparams.net_spec.name

# Plot Videos Probabilities

In [None]:
def init_counts():
    counts = {}
    for min_probability in min_probabilities:
        count = {}
        for label in label_names:
            count[label] = 0
        counts[min_probability] = count
    return counts


def label_videos():
    selected_path = video_dataset_dir / "images"
    videos = list(selected_path.iterdir())
    for i, frames_path in enumerate(tqdm(videos, desc="Label videos")):
        label_video(frames_path)
        # break


def label_video(frames_path: Path):
    frames_paths = list(frames_path.iterdir())
    epochs = ceil(len(frames_paths) / batch_size)
    for i in range(epochs):
        frames = frames_paths[(i * batch_size) : ((i + 1) * batch_size)]
        label_frames(frames)
        # break


# def label_frames(frames):
#     with torch.no_grad():
#         frames = get_frames_tensor(frames)
#         frames = frames.to(model.device)

#         results = [model(t(frames))[1] for t in transforms]
#         logits = torch.mean(torch.stack(results, dim=1), dim=1)

#         y_hats = F.softmax(logits, dim=1)
#         preds = torch.argmax(logits, dim=1)
#         count_labels(y_hats, preds)
#         sample_classes(frames, y_hats, preds)


def label_frames(frames):
    with torch.no_grad():
        frames = get_frames_tensor(frames)
        frames = frames.to(model.device)

        results = [F.softmax(model(t(frames))[1], dim=1) for t in transforms]
        y_hats = torch.mean(torch.stack(results, dim=1), dim=1)

        preds = torch.argmax(y_hats, dim=1)
        count_labels(y_hats, preds)
        sample_classes(frames, y_hats, preds)


# def label_frames(frames):
#     with torch.no_grad():
#         frames = get_frames_tensor(frames)

#         y_hats_models = []
#         for model in models:
#             frames = frames.to(model.device)
#             results = [model(t(frames))[1] for t in transforms]
#             logits = torch.mean(torch.stack(results, dim=1), dim=1)

#             y_hats = F.softmax(logits, dim=1)
#             y_hats_models.append(y_hats)

#         y_hats = torch.mean(torch.stack(y_hats_models, dim=1), dim=1)
#         preds = torch.argmax(y_hats, dim=1)
#         count_labels(y_hats, preds)
#         sample_classes(frames, y_hats, preds)


def get_frames_tensor(frame_paths):
    frames = []
    for frame_path in frame_paths:
        frame = cv2.imread(str(frame_path))
        frame = PIL.Image.fromarray(frame)
        frame = TF.to_tensor(frame)
        frames.append(frame)
    return torch.stack(frames, dim=0)


def count_labels(y_hats, preds):
    for y_hat, pred in zip(y_hats, preds):
        for min_probability in min_probabilities:
            if y_hat[pred] > min_probability:
                counts[min_probability][label_names[pred]] += 1


def init_samples():
    counts = {}
    for label in label_names:
        count = {}
        for min_probability in min_probabilities:
            count[min_probability] = []
        counts[label] = count
    return counts


def sample_classes(frames, y_hats, preds):
    for frame, y_hat, pred in zip(frames, y_hats, preds):
        for min_probability, max_probability in zip(min_probabilities, min_probabilities[1:] + [1.0]):
            if y_hat[pred] > min_probability and y_hat[pred] <= max_probability:
                samples[label_names[pred]][min_probability].append(frame.to("cpu"))


video_dataset_dir = data_dir / "US_VIDEOS_ssim_0.7"
label_names = FetalBrainPlanesDataset.labels
min_probabilities = [0.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.96, 0.97, 0.98, 0.99, 0.995]
counts = init_counts()
samples = init_samples()

horizontal_flips = [False]
vertical_flips = [False]
rotate_degrees = [0]
translates = [(0.0, 0.0)]
scales = [1.0]

transforms = [
    torch.nn.Sequential(
        T.Grayscale(),
        T.Resize((165, 240), antialias=False),
        HorizontalFlip(flip=horizontal_flip),
        VerticalFlip(flip=vertical_flip),
        Affine(degrees=rotate_degree, translate=translate, scale=scale),
        T.ConvertImageDtype(torch.float32),
    )
    for horizontal_flip, vertical_flip, rotate_degree, translate, scale in itertools.product(
        horizontal_flips,
        vertical_flips,
        rotate_degrees,
        translates,
        scales,
    )
]
batch_size = 32

# model.to("cpu")
# model.to("cuda")

label_videos()

In [None]:
def plot_videos_probabilities(counts):
    with plt.style.context("seaborn-v0_8-muted"):
        fig, ax = plt.subplots(figsize=(15, 8))

    for i, (min_prob, count) in enumerate(counts.items()):
        labels = list(count.keys())
        values = list(count.values())
        ax.bar(labels, values, label=str(min_prob))

    ax.legend()
    ax.set_title("Probabilities on video dataset")


plot_videos_probabilities(counts)
plt.show()
plt.savefig(str(model_log_dir / "plot_videos_probabilities.pdf"), bbox_inches="tight")

In [None]:
def show_samples(label, probability, images):
    sample_images = samples[label][probability]
    idxs = torch.randint(0, len(sample_images), (images,))
    sample_images = [sample_images[idx] for idx in idxs]

    n = ceil(sqrt(len(sample_images)))
    figsize = 16
    scale = 165 / 230
    fig, axes = plt.subplots(ncols=n, nrows=n, squeeze=False, figsize=(figsize, figsize * scale))

    for i in range(n):
        for j in range(n):
            if i * n + j >= len(sample_images):
                continue

            img = sample_images[i * n + j]

            if img is None:
                continue

            img = img.detach()
            img = TF.to_pil_image(img)
            img = TF.to_grayscale(img)
            axes[i, j].imshow(np.asarray(img), cmap="gray")

    title = f"Probability {probability}"
    fig.suptitle(title, fontsize=16)

    for i in range(n):
        for j in range(n):
            axes[i, j].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    fig.tight_layout(h_pad=0.1, w_pad=0.1)

    return fig


# min_probabilities = [0.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.96, 0.97, 0.98, 0.99, 0.995]
show_samples("Other", 0.5, 9)
plt.show()

# Test-Time Augmentation

In [None]:
label_names = FetalBrainPlanesDataset.labels

test = FetalBrainPlanesDataset(
    data_dir=data_dir,
    subset="test",
    transform=torch.nn.Sequential(
        T.Grayscale(),
        T.Resize((165, 240), antialias=False),
    ),
    target_transform=LabelEncoder(labels=label_names),
)

In [None]:
def evaluate():
    epochs = ceil(len(test) / batch_size)

    logits_batches = []
    y_batches = []

    for epoch_idx in tqdm(range(epochs), desc="Evaluate"):
        batch = [test[i] for i in range(epoch_idx * batch_size, (epoch_idx + 1) * batch_size) if i < len(test)]
        x = torch.stack([x for x, _ in batch])
        y = torch.stack([y for _, y in batch])

        x = x.to(model.device)
        logits_batch = [evaluate_x(x, tran) for tran in tqdm(transforms, desc="Transforms", leave=False)]
        logits_batch = torch.stack(logits_batch, dim=1)

        logits_batches.append(logits_batch)
        y_batches.append(y)

    return torch.cat(logits_batches, dim=0), torch.cat(y_batches, dim=0)


def evaluate_x(x, transform):
    with torch.no_grad():
        x = transform(x)
        _, logits = model(x)
        return logits.cpu()


# horizontal_flips = [False]
horizontal_flips = [False, True]

# vertical_flips = [False]
vertical_flips = [False, True]

# rotate_degrees = [0]
rotate_degrees = [0, -5, 5]
# rotate_degrees = [0, -5, -10, 5, 10]
# rotate_degrees = [0, -5, -10, -15, 5, 10, 15]

translates = [(0.0, 0.0)]
# translates = [(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)]
# translates = list(itertools.product([0.0, 0.1, -0.1], [0.0, 0.1, -0.1]))

# scales = [1.0]
scales = [1.0, 1.10]
# scales = [1.0, 1.05, 1.10]
# scales = [1.0, 1.05, 1.10, 1.15]
# scales = [1.0, 1.05, 1.10, 1.15, 1.20]
# scales = [1.0, 1.05, 1.10, 1.15, 1.20, 1.25]
# scales = [1.0, 1.05, 1.10, 1.15, 1.20, 1.25, 1.30]

params = [
    {
        "horizontal_flip": horizontal_flip,
        "vertical_flips": vertical_flips,
        "rotate_degree": rotate_degree,
        "translate": translate,
        "scale": scale,
    }
    for horizontal_flip, vertical_flips, rotate_degree, translate, scale in itertools.product(
        horizontal_flips,
        vertical_flips,
        rotate_degrees,
        translates,
        scales,
    )
]

transforms = [
    torch.nn.Sequential(
        HorizontalFlip(flip=param["horizontal_flip"]),
        VerticalFlip(flip=param["vertical_flips"]),
        Affine(degrees=param["rotate_degree"], translate=param["translate"], scale=param["scale"]),
        T.ConvertImageDtype(torch.float32),
    )
    for param in params
]

batch_size = 32 * 6

# model.to("cpu")
model.to("cuda")

y_hats, target = evaluate()

print(y_hats.shape)
print(target.shape)
# ~20:50

In [None]:
def load_model(model_log_dir: Path, tta_transforms: dict = None):
    checkpoint = sorted(model_log_dir.glob("checkpoints/epoch_*.ckpt"))[-1]
    model = FetalLitModule.load_from_checkpoint(str(checkpoint))
    # disable randomness, dropout, etc...
    model.eval()

    if tta_transforms is not None:
        model.tta_transforms = FetalLitModule.create_transforms(tta_transforms)

    return model


tta_transforms = {
    "horizontal_flips": [False, True],
    "vertical_flips": [False, True],
    "rotate_degrees": [0, -5, 5],
    "translates": [[0.0, 0.0]],
    "scales": [1.0, 1.10],
}

models = [
    load_model(model_log_dir, tta_transforms)
    for model_log_dir in [
        # seed 5724
        root / "logs" / "train" / "runs" / "2025-05-04_22-41-24",  # civilized-droid-3054
        root / "logs" / "train" / "runs" / "2025-05-05_09-31-12",  # carbonite-ewok-3069
        root / "logs" / "train" / "runs" / "2025-05-04_14-12-45",  # tusken-transport-3040
        # seed 8910,
        root / "logs" / "train" / "runs" / "2025-05-19_00-12-17",  # dainty-totem-3234
        # seed 9759
        root / "logs" / "train" / "runs" / "2025-05-19_11-58-46",  # lunar-grass-3253
    ]
]

len(models)

In [None]:
def multimodel_evaluate():
    epochs = ceil(len(test) / batch_size)

    y_hat_batches = []
    y_batches = []

    for batch in tqdm(test_data_loader, desc="Evaluate"):
        x, y = batch

        y_hat_batch = []
        for model in models:
            x = x.to(model.device)

            with torch.no_grad():
                _, y_hat = model.forward_tta(x)

            y_hat_batch.append(y_hat.cpu())

        y_hat_batch = torch.stack(y_hat_batch, dim=1)
        y_hat_batches.append(y_hat_batch)
        y_batches.append(y)

    return torch.cat(y_hat_batches, dim=0), torch.cat(y_batches, dim=0)


test = FetalBrainPlanesDataset(
    data_dir=data_dir,
    subset="test",
    transform=torch.nn.Sequential(
        T.Grayscale(),
        T.Resize((165, 240), antialias=False),
        T.ConvertImageDtype(torch.float32),
    ),
    target_transform=LabelEncoder(labels=label_names),
)

test_data_loader = DataLoader(
    dataset=test,
    batch_size=32,
    num_workers=8,
    pin_memory=True,
    shuffle=False,
)

y_hats, target = multimodel_evaluate()
print(y_hats.shape)
print(target.shape)

# 0 - 4:20
# 8 - 4min

In [None]:
def accuracy(y_hats, target):
    if y_hats.size(1) == 0:
        return 0.0

    y_hat = torch.mean(F.softmax(y_hats, dim=2), dim=1)
    # y_hat = torch.mean(y_hats, dim=1)
    pred = torch.argmax(y_hat, dim=1)

    acc = Accuracy(task="multiclass", num_classes=len(label_names), average="macro")
    return acc(pred, target).item()


print(accuracy(y_hats, target))  # 0.8330667018890381, 0.834805965423584
print(accuracy(y_hats[:, [0]], target))
# print(accuracy(y_hats[:, [0,1,2]], target))
# print(accuracy(y_hats[:, [0,3,4]], target))

In [None]:
def precision(y_hats, target):
    if y_hats.size(1) == 0:
        return 0.0

    # y_hat = torch.mean(F.softmax(y_hats, dim=2), dim=1)
    y_hat = torch.mean(y_hats, dim=1)
    pred = torch.argmax(y_hat, dim=1)

    prec = Precision(task="multiclass", num_classes=len(label_names), average="macro")
    return prec(pred, target).item()


print(precision(y_hats, target))  # 0.7358236312866211, 0.7428303956985474
print(precision(y_hats[:, [0]], target))
# print(precision(y_hats[:, [0,1,2]], target))
# print(precision(y_hats[:, [0,3,4]], target))

In [None]:
def recall(y_hats, target):
    if y_hats.size(1) == 0:
        return 0.0

    y_hat = torch.mean(F.softmax(y_hats, dim=2), dim=1)
    # y_hat = torch.mean(y_hats, dim=1)
    pred = torch.argmax(y_hat, dim=1)

    rec = Recall(task="multiclass", num_classes=len(label_names), average="macro")
    return rec(pred, target).item()


print(recall(y_hats, target))  #
print(recall(y_hats[:, [0]], target))
# print(recall(y_hats[:, [0,1,2]], target))
# print(recall(y_hats[:, [0,3,4]], target))

In [None]:
def f1Score(y_hats, target):
    if y_hats.size(1) == 0:
        return 0.0

    y_hat = torch.mean(F.softmax(y_hats, dim=2), dim=1)
    # y_hat = torch.mean(y_hats, dim=1)
    pred = torch.argmax(y_hat, dim=1)

    f1 = F1Score(task="multiclass", num_classes=len(label_names), average="macro")
    return f1(pred, target).item()


print(f1Score(y_hats, target))  # 0.7551986575126648, 0.766685962677002
print(f1Score(y_hats[:, [0]], target))
# print(f1Score(y_hats[:, [0,1,2]], target))
# print(f1Score(y_hats[:, [0,3,4]], target))

In [None]:
def all_stats(y_hats, target):
    y_hat = torch.mean(F.softmax(y_hats, dim=2), dim=1)
    pred = torch.argmax(y_hat, dim=1)

    acc = Accuracy(task="multiclass", num_classes=len(label_names), average="none")
    prec = Precision(task="multiclass", num_classes=len(label_names), average="none")
    rec = Recall(task="multiclass", num_classes=len(label_names), average="none")
    f1 = F1Score(task="multiclass", num_classes=len(label_names), average="none")

    return {
        "Accuracy": acc(pred, target).cpu().numpy(),
        "Precision": prec(pred, target).cpu().numpy(),
        "Recall": rec(pred, target).cpu().numpy(),
        "F1-score": f1(pred, target).cpu().numpy(),
    }


with pd.option_context("display.float_format", "{:0.3f}".format):
    print(pd.DataFrame(all_stats(y_hats, target), index=label_names))

In [None]:
import seaborn as sns


def confusion_matrix(y_hats, target):
    if y_hats.size(1) == 0:
        return 0.0

    y_hat = torch.mean(F.softmax(y_hats, dim=2), dim=1)
    pred = torch.argmax(y_hat, dim=1)

    cm = ConfusionMatrix(task="multiclass", num_classes=len(label_names))
    cm_norm = ConfusionMatrix(task="multiclass", num_classes=len(label_names), normalize="true")

    return cm(pred, target).cpu().numpy(), cm_norm(pred, target).cpu().numpy()


cm, cm_norm = confusion_matrix(y_hats, target)

cm_txt = []
for row, row_norm in zip(cm, cm_norm):
    row_txt = []
    for cell, cell_norm in zip(row, row_norm):
        txt = "{quantity}\n{perc:.2f}%".format(quantity=cell, perc=cell_norm * 100)
        row_txt.append(txt)
    cm_txt.append(row_txt)
cm_txt = np.array(cm_txt)


fig = plt.figure(figsize=(6, 6))

# Using Seaborn heatmap to create the plot
fx = sns.heatmap(
    cm_norm,
    # format 0-1 probability
    # annot=True,
    # fmt=".3f",
    # format - number of classes
    annot=cm,
    fmt=".0f",
    # format - number of classes + % probability
    # annot=cm_txt,
    # fmt="",
    square=True,
    cmap=sns.dark_palette("#69d", as_cmap=True),
    cbar=False,
    xticklabels=label_names,
    # yticklabels=label_names,
    yticklabels=["Trans\nventricular", "Trans\nthalamic", "Trans\ncerebellum", "Other", "Not A Brain"],
    annot_kws={"size": 12},
)

# labels the title and x, y axis of plot
# fx.set_title("Plotting Confusion Matrix using Seaborn\n\n")
fx.set_xlabel("Predicted Labels", size=16)
fx.set_ylabel("True Labels", size=16)

fx.set_xticklabels(fx.get_xticklabels(), rotation=30, fontsize=12)
fx.set_yticklabels(fx.get_yticklabels(), rotation=0, fontsize=12)

plt.show()
fig.savefig("plots/confusion_matrix.pdf", bbox_inches="tight")

# Delete

In [None]:
model_log_dir = root / "logs" / "train" / "multiruns" / "2023-11-10_08-36-16" / "0"  # graceful-plasma-2705

checkpoint = sorted(model_log_dir.glob("checkpoints/epoch_*.ckpt"))[-1]
model = FetalLitModule.load_from_checkpoint(str(checkpoint))
# disable randomness, dropout, etc...
model.eval()

model.to("cuda")

y_hats, target = evaluate()

print(accuracy(y_hats, target))
print(accuracy(y_hats[:, [0]], target))

# ----

In [None]:
individual_score = pd.DataFrame(params)
individual_score["accuracy"] = [accuracy(y_hats[:, [i]], target) for i in range(y_hats.size(1))]

with pd.option_context("display.max_rows", None, "display.max_columns", None):  # more options can be specified also
    print(individual_score)

In [None]:
def gen(left, n):
    if n == 1:
        return [[v] for v in left]

    rs = gen(left, n - 1)

    rrs = []
    for v in left:
        for r in rs:
            rrs.append(r + [v])
    return rrs


def cumulative_accuracy(row):
    idxs = []
    for i, param in enumerate(params):
        add = True
        for key, value in params[0].items():
            if row[key] == "-" and param[key] != value:
                add = False

        if add:
            idxs.append(i)

    return accuracy(y_hats[:, idxs], target)


cumulative_score = pd.DataFrame(gen(["-", "+"], 5), columns=params[0].keys())
cumulative_score["accuracy"] = [cumulative_accuracy(cumulative_score.loc[i]) for i in range(len(cumulative_score))]
cumulative_score

In [None]:
def aggregate_score(horizontal_flips, vertical_flips, rotate_degrees, translates, scales):
    aggregate_params = [
        {
            "horizontal_flip": horizontal_flip,
            "vertical_flips": vertical_flips,
            "rotate_degree": rotate_degree,
            "translate": translate,
            "scale": scale,
        }
        for horizontal_flip, vertical_flips, rotate_degree, translate, scale in itertools.product(
            horizontal_flips,
            vertical_flips,
            rotate_degrees,
            translates,
            scales,
        )
    ]

    idxs = []
    for i, param in enumerate(params):
        for agr_param in aggregate_params:
            if param == agr_param:
                idxs.append(i)

    return (
        accuracy(y_hats[:, idxs], target),
        precision(y_hats[:, idxs], target),
        recall(y_hats[:, idxs], target),
        f1Score(y_hats[:, idxs], target),
    )


aggregate_score(
    # horizontal_flips=[False],
    horizontal_flips=[False, True],
    # vertical_flips=[False],
    vertical_flips=[False, True],
    # rotate_degrees=[0],
    rotate_degrees=[0, -5, 5],
    # rotate_degrees=[0, -5, -10, 5, 10],
    # rotate_degrees=[0, -5, -10, -15, 5, 10, 15],
    translates=[(0.0, 0.0)],
    # translates=[(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)],
    # translates=list(itertools.product([0.0, 0.1, -0.1], [0.0, 0.1, -0.1])),
    # scales=[1.0],
    scales=[1.0, 1.10],
    # scales=[1.0, 1.05],
    # scales=[1.0, 1.05, 1.10],
    # scales=[1.0, 1.05, 1.10, 1.15],
    # scales=[1.0, 1.05, 1.10, 1.15, 1.20],
    # scales=[1.0, 1.05, 1.10, 1.15, 1.20, 1.25],
    # scales=[1.0, 1.05, 1.10, 1.15, 1.20, 1.25, 1.30],
)

In [None]:
aggregate_score(
    horizontal_flips=[False, True],
    vertical_flips=[False, True],
    rotate_degrees=[0, -5, 5],
    translates=[(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)],
    scales=[1.0, 1.05, 1.10],
)

In [None]:
aggregate_score(
    horizontal_flips=[False, True],
    vertical_flips=[False, True],
    rotate_degrees=[0, -5, 5],
    translates=[(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)],
    scales=[1.0, 1.05, 1.10, 1.15, 1.20],
)

In [None]:
aggregate_score(
    horizontal_flips=[False, True],
    vertical_flips=[False, True],
    rotate_degrees=[0, -5, -10, 5, 10],
    translates=[(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)],
    scales=[1.0, 1.05, 1.10],
)

In [None]:
aggregate_score(
    horizontal_flips=[False, True],
    vertical_flips=[False, True],
    rotate_degrees=[0, -5, -10, 5, 10],
    translates=[(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)],
    scales=[1.0, 1.05, 1.10, 1.15, 1.20],
)

In [None]:
def search_aggregate_score(horizontal_flips, vertical_flips, rotate_degrees, translates, scales):
    return [
        {
            "score": aggregate_score(horizontal_flip, vertical_flips, rotate_degree, translate, scale),
            "horizontal_flip": horizontal_flip,
            "vertical_flips": vertical_flips,
            "rotate_degree": rotate_degree,
            "translate": translate,
            "scale": scale,
        }
        for horizontal_flip, vertical_flips, rotate_degree, translate, scale in itertools.product(
            horizontal_flips,
            vertical_flips,
            rotate_degrees,
            translates,
            scales,
        )
    ]


scores = search_aggregate_score(
    horizontal_flips=[[False], [False, True]],
    vertical_flips=[[False], [False, True]],
    rotate_degrees=[[0], [0, -5, 5], [0, -5, -10, 5, 10]],  # [0, -5, -10, -15, 5, 10, 15]
    translates=[[(0.0, 0.0)], [(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)]],
    scales=[[1.0], [1.0, 1.05, 1.10], [1.0, 1.05, 1.10, 1.15, 1.20]],
)

print("Best score")
pprint(max(scores, key=lambda x: x["score"][3]))
print("")
# sorted(scores, key=lambda x: x["score"][0], reverse=True)

In [None]:
def greed_search_aggregate_score(horizontal_flips, vertical_flips, rotate_degrees, translates, scales):
    horizontal_flips = perm(horizontal_flips)
    vertical_flips = perm(vertical_flips)
    rotate_degrees = perm(rotate_degrees)
    translates = perm(translates)
    scales = perm(scales)

    aggregate_params = [
        {
            "score": aggregate_score(horizontal_flip, vertical_flips, rotate_degree, translate, scale),
            "horizontal_flip": horizontal_flip,
            "vertical_flips": vertical_flips,
            "rotate_degree": rotate_degree,
            "translate": translate,
            "scale": scale,
        }
        for horizontal_flip, vertical_flips, rotate_degree, translate, scale in itertools.product(
            horizontal_flips,
            vertical_flips,
            rotate_degrees,
            translates,
            scales,
        )
    ]

    return max(aggregate_params, key=lambda x: x["score"][0])


def perm(values):
    list_of_lists = [list(itertools.combinations(values, i + 1)) for i in range(len(values))]
    return list(itertools.chain(*list_of_lists))


greed_search_aggregate_score(
    # horizontal_flips=[False],
    horizontal_flips=[False, True],
    # vertical_flips=[False],
    vertical_flips=[False, True],
    # rotate_degrees=[0],
    # rotate_degrees=[0, -5, 5],
    rotate_degrees=[0, -5, -10, 5, 10],
    # rotate_degrees=[0, -5, -10, -15, 5, 10, 15],
    # translates=[(0.0, 0.0)],
    translates=[(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)],
    # translates=list(itertools.product([0.0, 0.1, -0.1], [0.0, 0.1, -0.1])),
    # scales=[1.0],
    # scales=[1.0, 1.05],
    # scales=[1.0, 1.05, 1.10],
    # scales=[1.0, 1.05, 1.10, 1.15],
    scales=[1.0, 1.05, 1.10, 1.15, 1.20],
    # scales=[1.0, 1.05, 1.10, 1.15, 1.20, 1.25],
    # scales=[1.0, 1.05, 1.10, 1.15, 1.20, 1.25, 1.30],
)

# Summary

In [None]:
# neat-aardvark-2941
# fine-lion-2828
# fresh-grass-2813
# prime-butterfly-2873
# lilac-frost-2893

horizontal_flips = [False]
vertical_flips = [False]
rotate_degrees = [0]
translates = [(0.0, 0.0)]
scales = [1.0]
# 0.7860984802246094 0.7706203460693359
# 0.7940822243690491 0.7465180754661560
# 0.7963624000549316 0.7576579451560974
# 0.8050662279129028 0.7383250594139099
# 0.7890244722366333 0.7428253293037415

horizontal_flips = [False, True]
vertical_flips = [False, True]
rotate_degrees = [0, -5, 5]
translates = [(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)]
scales = [1.0, 1.05, 1.10]
# 0.7951812148094177 0.7918230891227720
# 0.8120790719985962 0.7782592773437500
# 0.7909176945686340 0.7713555693626404
# 0.7989473342895508 0.7635297179222107
# 0.7967666387557983 0.7601540088653564

horizontal_flips = [False, True]
vertical_flips = [False, True]
rotate_degrees = [0, -5, 5]
translates = [(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)]
scales = [1.0, 1.05, 1.10, 1.15, 1.20]
# 0.7991837859153748 0.7918857336044310
# 0.8086183667182922 0.7752655148506165
# 0.7911577224731445 0.7767201662063599
# 0.8024842739105225 0.7718890905380249
# 0.8025256395339966 0.7683259844779968

horizontal_flips = [False, True]
vertical_flips = [False, True]
rotate_degrees = [0, -5, -10, 5, 10]
translates = [(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)]
scales = [1.0, 1.05, 1.10]
# 0.7910797595977783 0.7876892685890198
# 0.8088076710700989 0.7740747332572937
# 0.7852406501770020 0.7671951055526733
# 0.8035995960235596 0.7657317519187927
# 0.7966942787170410 0.7604588270187378

horizontal_flips = [False, True]
vertical_flips = [False, True]
rotate_degrees = [0, -5, -10, 5, 10]
translates = [(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)]
scales = [1.0, 1.05, 1.10, 1.15, 1.20]
# 0.7969303131103516 0.7968350648880005
# 0.8065788745880127 0.7735700011253350
# 0.7919420599937439 0.7783745527267456
# 0.8007959127426147 0.7693001031875610
# 0.8012878894805908 0.7667553424835205

In [None]:
# playful-haze-2111
# mix_ra
# model_log_dir = root / "logs" / "train" / "multiruns" / "2023-04-22_14-32-35" / "4"

horizontal_flips = [False, True]
vertical_flips = [False]
rotate_degrees = [0, -5, -10, 5, 10]
translates = [(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)]
scales = [1.0, 1.05, 1.10, 1.15, 1.20]

# 0.7875436544418335
# 0.8037662506103516
# 0.821164608001709


# scales=[1.0, 1.05, 1.10]
# 0.8078333735466003

# scales=[1.0, 1.05, 1.10, 1.15]
# 0.8159346580505371

# scales=[1.0, 1.05, 1.10, 1.15, 1.20]
# 0.821164608001709

# scales=[1.0, 1.05, 1.10, 1.15, 1.20, 1.25]
#

# scales=[1.0, 1.05, 1.10, 1.15, 1.20, 1.25, 1.30]
# 0.8252905011177063