In [None]:
import pyrootutils

root = pyrootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)
root

In [None]:
from math import ceil, sqrt
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import ConcatDataset
from tqdm import tqdm

from src.data.components.dataset import FetalPlanesDataset
from src.data.components.transforms import RandomPercentCrop

database_dir = root / "data"

In [None]:
dataset = ConcatDataset(
    [
        FetalPlanesDataset(data_dir=database_dir, train=True),
        # FetalPlanesDataset(data_dir=database_dir, train=False),
    ]
)

In [None]:
def show(imgs, tick_labels: bool = True):
    n = ceil(sqrt(len(imgs)))

    fig, axes = plt.subplots(ncols=n, nrows=n, squeeze=False, figsize=(20, 15))

    for i in range(n):
        for j in range(n):
            if i * n + j >= len(imgs):
                continue

            img, label = imgs[i * n + j]
            img = img.detach()
            img = F.to_pil_image(img)
            img = F.to_grayscale(img)
            axes[i, j].imshow(np.asarray(img), cmap="gray")
            axes[i, j].set_xlabel(label)
            if not tick_labels:
                axes[i, j].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    fig.tight_layout()
    plt.show()


show([dataset[i] for i in np.random.permutation(len(dataset))[:25]])

In [None]:
df = pd.DataFrame(data=iter(tqdm(dataset)))

In [None]:
df_shape = pd.DataFrame(data=[(i.shape[0], i.shape[1], i.shape[2]) for i in df[0]])
df_shape

In [None]:
df_shape[0].value_counts()

In [None]:
mean = df_shape.mean()
median = df_shape.median()

axs = df_shape.hist(column=[1, 2], bins=100, figsize=(20, 5))

axs[0][0].axvline(mean[1], color="r", linestyle="dashed", linewidth=2)  # 572.415
axs[0][1].axvline(mean[2], color="r", linestyle="dashed", linewidth=2)  # 661.0
axs[0][0].axvline(median[1], color="b", linestyle="dashed", linewidth=2)  # 857.255
axs[0][1].axvline(median[2], color="b", linestyle="dashed", linewidth=2)  # 959.0
axs[0][0].legend([f"mean {mean[1]}", f"median {median[1]}"], loc="upper left")
axs[0][1].legend([f"mean {mean[2]}", f"median {median[2]}"], loc="upper left")

In [None]:
scale = median[2] / median[1]
80 / scale

In [None]:
def get_mean_std(dataset):
    # var[X] = E[X**2] - E[X]**2
    channels_sum, channels_sqrd_sum, = (
        0,
        0,
    )

    for data, _ in tqdm(dataset):
        channels_sum += torch.mean(data, dim=[1, 2])
        channels_sqrd_sum += torch.mean(data**2, dim=[1, 2])

    mean = channels_sum / len(dataset)
    std = (channels_sqrd_sum / len(dataset) - mean**2) ** 0.5

    return mean, std


train = FetalPlanesDataset(
    data_dir=database_dir,
    train=True,
    transform=torch.nn.Sequential(
        T.Grayscale(),
        T.ConvertImageDtype(torch.float32),
        T.Resize((165, 240)),
    ),
)

mean, std = get_mean_std(train)
print(mean.item())  # 0.16958117485046387, 0.16958120197261892
print(std.item())  # 0.1906554251909256,  0.19065533816416103

In [None]:
# train = FetalPlanesDataset(
#     data_dir=database_dir,
#     train=True,
#     transform=torch.nn.Sequential(
#         T.Grayscale(),
#         T.ConvertImageDtype(torch.float32),
#         RandomPercentCrop(max_percent=20),
#         T.RandomHorizontalFlip(p=0.5),
#         # T.RandomAffine(degrees=15, fill=1),
#         T.RandomAffine(degrees=0, translate=(0.1, 0.1), fill=1),
#         # T.RandomAffine(degrees=0, scale=(1.0, 1.2), fill=1),
#         # T.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(1.0, 1.2)),
#         T.Resize((170, 250)),  # (55, 80), (70, 100), (100, 150), (170, 250)
#         # Normalize(mean=mean, std=std)
#     ),
# )

train = FetalPlanesDataset(
    data_dir=database_dir,
    train=True,
    transform=torch.nn.Sequential(
        T.Grayscale(),
        T.AutoAugment(T.AutoAugmentPolicy.IMAGENET),
        T.ConvertImageDtype(torch.float32),
        T.Resize((170, 250)),
    ),
)

show([train[i] for i in np.random.permutation(len(train))[:49]], tick_labels=False)

In [None]:
def group_split_label(
    dataset: pd.DataFrame, test_size: float, groups, random_state: int = None
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    splitter = GroupShuffleSplit(test_size=test_size, n_splits=1, random_state=random_state)
    split = splitter.split(dataset, groups=groups)
    train_idx, test_idx = next(split)
    return dataset.iloc[train_idx], dataset.iloc[test_idx]


def get_similarity(train, test, test_size):
    similarity = 0
    train_count = train.value_counts(sort=False).sort_index()
    test_count = test.value_counts(sort=False).sort_index()

    if train_count.index.tolist() != test_count.index.tolist():
        return -1

    for (a, b) in zip(train_count, test_count):
        similarity += (a * test_size - b * (1 - test_size)) ** 2
    return similarity**0.5


def plt_value_counts(ax, dataset, tile=None):
    counts = dataset.value_counts(sort=False).sort_index()
    counts.plot(kind="bar", ax=ax)
    if tile:
        ax.set_title(tile)


def plt_group_split(
    dataset: pd.DataFrame, test_size: float, random_states: List[int], top_states: int = None
):
    splits = []
    for random_state in tqdm(random_states):
        tr, val = group_split_label(
            dataset,
            test_size=test_size,
            groups=dataset["Patient_num"],
            random_state=random_state,
        )

        similarity = get_similarity(tr.Plane, val.Plane, test_size)
        if similarity >= 0:
            splits.append((similarity, tr, val, random_state))

    splits.sort(key=lambda e: (e[0], e[3]))
    nrows = top_states if top_states else len(splits)

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=2,
        sharex="all",
        squeeze=False,
        figsize=(20, 3 * nrows),
    )
    fig.suptitle(f"Test size {test_size}")
    for (i, (similarity, tr, val, random_state)) in enumerate(splits[:nrows]):
        plt_value_counts(axes[i, 0], tr.Plane, tile=f"Seed {random_state}")
        plt_value_counts(axes[i, 1], val.Plane, tile=f"Similarity {similarity}")

    plt.show()
    print([random_state for (similarity, tr, val, random_state) in splits[:nrows]])


plt_group_split(
    train.img_labels,
    test_size=0.2,
    random_states=list(range(1000)),
    top_states=3,
)  # 9521, 3397, 4078, 6127, 1434, 4424, 8613, 8823, 9185, 3218
plt_group_split(
    train.img_labels,
    test_size=0.15,
    random_states=list(range(1000)),
    top_states=3,
)  # 34, 1208, 2971, 9081, 8517, 8176, 640, 3679, 5951, 8733
plt_group_split(
    train.img_labels,
    test_size=0.1,
    random_states=list(range(1000)),
    top_states=3,
)  # 8166