In [None]:
from math import ceil, sqrt
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rootutils
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import ConcatDataset
from tqdm import tqdm

root = rootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)

from src.data.components.dataset import FetalBrainPlanesDataset
from src.data.components.transforms import RandomPercentCrop
from src.data.utils.utils import show_pytorch_images

database_dir = root / "data"

In [None]:
dataset = ConcatDataset(
    [
        FetalBrainPlanesDataset(data_dir=database_dir, train=True),
        # FetalBrainPlanesDataset(data_dir=database_dir, train=False),
    ]
)

In [None]:
show_pytorch_images(
    [dataset[i] for i in np.random.permutation(len(dataset))[:25]],
    tick_labels=True,
)

In [None]:
df = pd.DataFrame(data=iter(tqdm(dataset)))

In [None]:
df_shape = pd.DataFrame(data=[(i.shape[0], i.shape[1], i.shape[2]) for i in df[0]])
df_shape

In [None]:
df_shape[0].value_counts()

In [None]:
mean = df_shape.mean()
median = df_shape.median()

axs = df_shape.hist(column=[1, 2], bins=100, figsize=(20, 5))

axs[0][0].axvline(mean[1], color="r", linestyle="dashed", linewidth=2)  # 572.415
axs[0][1].axvline(mean[2], color="r", linestyle="dashed", linewidth=2)  # 661.0
axs[0][0].axvline(median[1], color="b", linestyle="dashed", linewidth=2)  # 857.255
axs[0][1].axvline(median[2], color="b", linestyle="dashed", linewidth=2)  # 959.0
axs[0][0].legend([f"mean {mean[1]}", f"median {median[1]}"], loc="upper left")
axs[0][1].legend([f"mean {mean[2]}", f"median {median[2]}"], loc="upper left")

In [None]:
scale = median[2] / median[1]


def print_resolution(width):
    height = width / scale
    print(f"{height:.2f} / {width}")


print_resolution(80)  # 55 / 80
print_resolution(100)  # 70 / 100
print_resolution(150)  # 100 / 150
print_resolution(240)  # 165 / 240

print_resolution(300)  # 205 / 300
print_resolution(400)  # 275 / 400
print_resolution(500)  # 345 / 500
print_resolution(600)  # 415 / 600

In [None]:
def get_mean_std(dataset):
    # var[X] = E[X**2] - E[X]**2
    (
        channels_sum,
        channels_sqrd_sum,
    ) = (
        0,
        0,
    )

    for data, _ in tqdm(dataset):
        channels_sum += torch.mean(data, dim=[1, 2])
        channels_sqrd_sum += torch.mean(data**2, dim=[1, 2])

    mean = channels_sum / len(dataset)
    std = (channels_sqrd_sum / len(dataset) - mean**2) ** 0.5

    return mean, std


def find_mean_std(height, width):
    train = FetalBrainPlanesDataset(
        data_dir=database_dir,
        train=True,
        transform=torch.nn.Sequential(
            T.Grayscale(),
            T.Resize((height, width)),
            T.ConvertImageDtype(torch.float32),
        ),
    )

    mean, std = get_mean_std(train)
    print(f"For {height} / {width} mean is {mean.item():.2f} std is {std.item():.2f}")


# print(mean.item())  # 0.16958117485046387, 0.16958120197261892
# print(std.item())  # 0.1906554251909256,  0.19065533816416103
find_mean_std(55, 80)  # 0.17 / 0.19
find_mean_std(165, 240)  # 0.17 / 0.19
find_mean_std(415, 600)  # 0.17 / 0.19

In [None]:
train = FetalBrainPlanesDataset(
    data_dir=database_dir,
    train=True,
    transform=torch.nn.Sequential(
        T.Grayscale(),
        # RandomPercentCrop(max_percent=20),
        T.Resize((165, 240), antialias=False),
        # T.RandomHorizontalFlip(p=0.5),
        # T.RandomAffine(degrees=15, fill=255),
        # T.RandomAffine(degrees=0, translate=(0.1, 0.1), fill=255),
        # T.RandomAffine(degrees=0, scale=(1.0, 1.2), fill=255),
        # T.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(1.0, 1.2), fill=255),
        T.ConvertImageDtype(torch.float32),
        # T.Normalize(mean=0.17, std=0.19),
    ),
)

show_pytorch_images([train[i] for i in np.random.permutation(len(train))][:49]).show()

In [None]:
train = FetalBrainPlanesDataset(
    data_dir=database_dir,
    train=True,
    transform=torch.nn.Sequential(
        T.Grayscale(),
        T.Resize((165, 240)),
        T.AutoAugment(T.AutoAugmentPolicy.IMAGENET),
        # T.RandAugment(),
        # T.TrivialAugmentWide(),
        # T.AugMix(),
        T.ConvertImageDtype(torch.float32),
    ),
)

show_pytorch_images([train[i] for i in np.random.permutation(len(train))][:49]).show()

In [None]:
def group_split_label(
    dataset: pd.DataFrame, test_size: float, groups, random_state: int = None
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    splitter = GroupShuffleSplit(test_size=test_size, n_splits=1, random_state=random_state)
    split = splitter.split(dataset, groups=groups)
    train_idx, test_idx = next(split)
    return dataset.iloc[train_idx], dataset.iloc[test_idx]


def get_similarity(train, test, test_size):
    similarity = 0
    train_count = train.value_counts(sort=False).sort_index()
    test_count = test.value_counts(sort=False).sort_index()

    if train_count.index.tolist() != test_count.index.tolist():
        return -1

    for a, b in zip(train_count, test_count):
        similarity += (a * test_size - b * (1 - test_size)) ** 2
    return similarity**0.5


def plt_value_counts(ax, dataset, tile=None):
    counts = dataset.value_counts(sort=False).sort_index()
    counts.plot(kind="bar", ax=ax)
    if tile:
        ax.set_title(tile)


def plt_group_split(dataset: pd.DataFrame, test_size: float, random_states: List[int], top_states: int = None):
    splits = []
    for random_state in tqdm(random_states):
        tr, val = group_split_label(
            dataset,
            test_size=test_size,
            groups=dataset["Patient_num"],
            random_state=random_state,
        )

        similarity = get_similarity(tr.Brain_plane, val.Brain_plane, test_size)
        if similarity >= 0:
            splits.append((similarity, tr, val, random_state))

    splits.sort(key=lambda e: (e[0], e[3]))
    nrows = top_states if top_states else len(splits)

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=2,
        sharex="all",
        squeeze=False,
        figsize=(20, 3 * nrows),
    )
    fig.suptitle(f"Test size {test_size}")
    for i, (similarity, tr, val, random_state) in enumerate(splits[:nrows]):
        plt_value_counts(axes[i, 0], tr.Brain_plane, tile=f"Seed {random_state}")
        plt_value_counts(axes[i, 1], val.Brain_plane, tile=f"Similarity {similarity}")

    plt.show()
    print([random_state for (similarity, tr, val, random_state) in splits[:nrows]])


plt_group_split(
    train.img_labels,
    test_size=0.2,
    random_states=list(range(10000)),
    top_states=10,
)  # 564, 3097, 1683, 4951, 5724, 8910, 9486, 7023, 5907, 9759
# plt_group_split(
#     train.img_labels,
#     test_size=0.15,
#     random_states=list(range(10000)),
#     top_states=10,
# )  # 943, 9787, 4935, 6588, 6893, 697, 6347, 5785, 4, 7765
# plt_group_split(
#     train.img_labels,
#     test_size=0.1,
#     random_states=list(range(10000)),
#     top_states=10,
# )  # 2251, 3084, 9456, 8902, 1208, 9959, 2696, 2086, 4063, 9126

In [None]:
len(train.img_labels[train.img_labels.Brain_plane == "Other"])