In [None]:
import os
from collections.abc import Callable, Iterable, Iterator, Sequence
from math import ceil, sqrt
from typing import List, Literal, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rootutils
import torch
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as TF
from sklearn.model_selection import GroupShuffleSplit
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import ConcatDataset, Dataset
from torchvision.io import read_image
from tqdm import tqdm

root = rootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)

from src.data.components.dataset import FetalBrainPlanesDataset
from src.data.components.transforms import RandomPercentCrop
from src.data.utils.utils import show_pytorch_images

database_dir = root / "data"

In [None]:
img_labels = pd.read_csv(f"{database_dir}/FETAL_PLANES/FETAL_PLANES_DB_data.csv", sep=";")
img_labels
# 12400 rows × 7 columns

In [None]:
# img_labels["9"]=[row["Brain_plane"] if row["Plane"] == "Fetal brain" else row["Plane"] for row in img_labels]

train_labels = img_labels[img_labels["Train "] == 1]
all_classes = train_labels["Plane"] + "." + train_labels["Brain_plane"]

counts = all_classes.value_counts(sort=False).sort_index()

classes = np.unique(all_classes)
class_weight = compute_class_weight(class_weight="balanced", classes=classes, y=all_classes)

pd.DataFrame({"Count": counts[classes], "Weight": class_weight})


# Fetal brain.Other	77	10.287157
# Fetal brain.Trans-cerebellum	375	2.112296
# Fetal brain.Trans-thalamic	873	0.907344
# Fetal brain.Trans-ventricular	295	2.685122

# Fetal abdomen.Not A Brain	353	2.243941
# Fetal femur.Not A Brain	516	1.535099
# Fetal thorax.Not A Brain	1058	0.748687
# Maternal cervix.Not A Brain	981	0.807453
# Other.Not A Brain	2601	0.304541
# 1.12

In [None]:
brain_plane = img_labels["Brain_plane"]
counts = brain_plane.value_counts(sort=False).sort_index()

classes = np.unique(brain_plane)
class_weight = compute_class_weight(class_weight="balanced", classes=classes, y=brain_plane)

pd.DataFrame({"Count": counts[classes], "Weight": class_weight})

In [None]:
# 143 * 0.2 = 115 | 28
# 115 * 0.2 =  92 | 23 - 64% 16% 20%

# (140 - a) * b = c
# 140 * 0.2 = a 28
# 140 * 0.1 = c 14
# b = c / (140 - a) = 0.125

brain_plane = img_labels[img_labels["Train "] == 1]["Brain_plane"]
counts = brain_plane.value_counts(sort=False).sort_index()

classes = np.unique(brain_plane)
class_weight = compute_class_weight(class_weight="balanced", classes=classes, y=brain_plane)

pd.DataFrame({"Count": counts[classes], "Weight": class_weight})

In [None]:
img_labels = img_labels[img_labels["Plane"] == "Fetal brain"]
img_labels = img_labels[["Image_name", "Patient_num", "Brain_plane", "Train "]]
img_labels = img_labels.reset_index(drop=True)

In [None]:
fig = plt.figure(figsize=(6, 4))
brain_plane = img_labels["Brain_plane"]
counts = brain_plane.value_counts(sort=False).sort_index()
counts.plot(kind="bar")
plt.show()

In [None]:
fig, axes = plt.subplots(
    nrows=1,
    ncols=2,
    squeeze=False,
    sharex="all",
    sharey="all",
    figsize=(12, 4),
)

brain_plane = img_labels[img_labels["Train "] == 1]["Brain_plane"]
counts = brain_plane.value_counts(sort=False).sort_index()
counts.plot(kind="bar", ax=axes[0, 0])
axes[0, 0].set_title("Train")

brain_plane = img_labels[img_labels["Train "] == 0]["Brain_plane"]
counts = brain_plane.value_counts(sort=False).sort_index()
counts.plot(kind="bar", ax=axes[0, 1])
axes[0, 1].set_title("Test")

plt.show()

In [None]:
# plt_group_split(
#     img_labels,
#     test_size=0.1,  # 0.125 - 10%  0.25 - 20%
#     random_states=list(range(10000)),
#     top_states=10,
# )

# 0.2: [435, 3078, 3462, 9261, 9018, 1386, 1216, 8400, 157, 1631]
# 0.1: [1631, 9018, 5423, 2091, 7735, 9828, 2526, 3683, 6712, 1849]

# 0.1: [8190, 2749, 6106, 8394, 9592, 1585, 990, 2520, 8838, 6802]

In [None]:
from src.data.utils import group_split


def get_class_num(dataset: Dataset) -> Sequence[int]:
    classes = torch.tensor([FetalBrainPlanesDataset.labels.index(dataset[i, 1]) for i in range(len(dataset))])
    classes_indices = [torch.nonzero(classes == class_id).flatten() for class_id in torch.arange(5)]
    classes_num_samples = [len(indices) for indices in classes_indices]
    return classes_num_samples


data_all = FetalBrainPlanesDataset(
    data_dir=database_dir,
    subset="all",
)
data_train, data_test = group_split(
    dataset=data_all,
    test_size=0.1,
    groups=data_all.img_labels["Patient_num"],
    random_state=6106,  # 0.1: [8190, 2749, 6106, 8394, 9592, 1585, 990, 2520, 8838, 6802]
)

print(len(data_all))
print(get_class_num(data_all))

print(len(data_test), len(data_test) / len(data_all))
print(get_class_num(data_test))

print(len(data_train), len(data_train) / len(data_all))
print(get_class_num(data_train))

In [None]:
train_idx = data_all.img_labels.iloc[data_train.indices].reset_index(drop=True)
# train_idx = train_idx[train_idx["Brain_plane"] != "Not A Brain"]
train_idx = train_idx.reset_index(drop=True)
train_idx

In [None]:
# plt_group_split(
#     train_idx,
#     test_size=0.114,  # 0.125 - 10%  0.25 - 20%
#     random_states=list(range(10000)),
#     top_states=10,
# )

# 0.125: [1106, 6652, 9111, 4965, 4161, 98, 1985, 9598, 3151, 8322]

In [None]:
data_all = FetalBrainPlanesDataset(
    data_dir=database_dir,
    subset="all",
)
data_train, data_test = group_split(
    dataset=data_all,
    test_size=0.1,
    groups=data_all.img_labels["Patient_num"],
    random_state=6106,  # 0.1: [8190, 2749, 6106, 8394, 9592, 1585, 990, 2520, 8838, 6802]
)

data_train, data_val = group_split(
    dataset=data_train,
    test_size=0.114,
    groups=data_all.img_labels.iloc[data_train.indices].reset_index(drop=True)["Patient_num"],
    random_state=6277,  # [7539, 6277, 2613, 6652, 2769, 653, 1084, 3368, 9111, 9101]
)

print(len(data_test), len(data_test) / len(data_all))
print(get_class_num(data_test))

print(len(data_val), len(data_val) / len(data_all))
print(get_class_num(data_val))

print(len(data_train), len(data_train) / len(data_all))
print(get_class_num(data_train))

In [None]:
class FetalBrainPlanesDataset(Dataset):
    labels = [
        "Trans-ventricular",
        "Trans-thalamic",
        "Trans-cerebellum",
        "Other",
        "Not A Brain",
    ]

    def __init__(
        self,
        data_dir: str,
        data_name: str = "FETAL_PLANES",
        train: bool = True,
        transform: Callable | None = None,
        target_transform: Callable | None = None,
    ):
        self.dataset_dir = f"{data_dir}/{data_name}"
        self.img_labels = self.load_img_labels(train)
        self.img_dir = f"{self.dataset_dir}/Images"
        self.transform = transform
        self.target_transform = target_transform

    def load_img_labels(self, train: bool):
        img_labels = pd.read_csv(f"{self.dataset_dir}/FETAL_PLANES_DB_data.csv", sep=";")
        img_labels = img_labels[img_labels["Train "] == (1 if train else 0)]
        img_labels = img_labels[["Image_name", "Patient_num", "Brain_plane"]]
        return img_labels.reset_index(drop=True)

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        if isinstance(idx, tuple):
            idx, sub_idx = idx
            if sub_idx == 0:
                return self.get_image(idx)
            elif sub_idx == 1:
                return self.get_label(idx)

        return self.get_image(idx), self.get_label(idx)

    def get_image(self, idx):
        if isinstance(idx, torch.Tensor):
            idx = idx.item()

        img_path = os.path.join(self.img_dir, self.img_labels.Image_name[idx] + ".png")
        image = read_image(img_path)
        if image.shape[0] == 4:
            image = image[:3, :, :]

        if self.transform:
            image = self.transform(image)
        return image

    def get_label(self, idx):
        if isinstance(idx, torch.Tensor):
            idx = idx.item()

        label = self.img_labels.Brain_plane[idx]

        if self.target_transform:
            label = self.target_transform(label)
        return label

In [None]:
dataset = ConcatDataset(
    [
        FetalBrainPlanesDataset(data_dir=database_dir, subset="train"),
        # FetalBrainPlanesDataset(data_dir=database_dir, subset="test"),
    ]
)

In [None]:
show_pytorch_images(
    [dataset[i] for i in np.random.permutation(len(dataset))[:25]],
    tick_labels=True,
)

In [None]:
df = pd.DataFrame(data=iter(tqdm(dataset)))

In [None]:
df_shape = pd.DataFrame(data=[(i.shape[0], i.shape[1], i.shape[2]) for i in df[0]])
df_shape

In [None]:
df_shape[0].value_counts()

In [None]:
mean = df_shape.mean()
median = df_shape.median()

axs = df_shape.hist(column=[1, 2], bins=100, figsize=(20, 5))

axs[0][0].axvline(mean[1], color="r", linestyle="dashed", linewidth=2)  # 572.415
axs[0][1].axvline(mean[2], color="r", linestyle="dashed", linewidth=2)  # 661.0
axs[0][0].axvline(median[1], color="b", linestyle="dashed", linewidth=2)  # 857.255
axs[0][1].axvline(median[2], color="b", linestyle="dashed", linewidth=2)  # 959.0
axs[0][0].legend([f"mean {mean[1]}", f"median {median[1]}"], loc="upper left")
axs[0][1].legend([f"mean {mean[2]}", f"median {median[2]}"], loc="upper left")

In [None]:
scale = median[2] / median[1]


def print_resolution(width):
    height = width / scale
    print(f"{height:.2f} / {width}")


print_resolution(80)  # 55 / 80
print_resolution(100)  # 70 / 100
print_resolution(150)  # 100 / 150
print_resolution(240)  # 165 / 240

print_resolution(300)  # 205 / 300
print_resolution(400)  # 275 / 400
print_resolution(500)  # 345 / 500
print_resolution(600)  # 415 / 600

In [None]:
def get_mean_std(dataset):
    # var[X] = E[X**2] - E[X]**2
    (
        channels_sum,
        channels_sqrd_sum,
    ) = (
        0,
        0,
    )

    for data, _ in tqdm(dataset):
        channels_sum += torch.mean(data, dim=[1, 2])
        channels_sqrd_sum += torch.mean(data**2, dim=[1, 2])

    mean = channels_sum / len(dataset)
    std = (channels_sqrd_sum / len(dataset) - mean**2) ** 0.5

    return mean, std


def find_mean_std(height, width):
    train = FetalBrainPlanesDataset(
        data_dir=database_dir,
        subset="train",
        transform=torch.nn.Sequential(
            T.Grayscale(),
            T.Resize((height, width)),
            T.ConvertImageDtype(torch.float32),
        ),
    )

    mean, std = get_mean_std(train)
    print(f"For {height} / {width} mean is {mean.item():.2f} std is {std.item():.2f}")


# print(mean.item())  # 0.16958117485046387, 0.16958120197261892
# print(std.item())  # 0.1906554251909256,  0.19065533816416103
find_mean_std(55, 80)  # 0.17 / 0.19
find_mean_std(165, 240)  # 0.17 / 0.19
find_mean_std(415, 600)  # 0.17 / 0.19

In [None]:
train = FetalBrainPlanesDataset(
    data_dir=database_dir,
    subset="train",
    transform=torch.nn.Sequential(
        T.Grayscale(),
        # RandomPercentCrop(max_percent=20),
        T.Resize((165, 240), antialias=False),
        # T.RandomHorizontalFlip(p=0.5),
        # T.RandomAffine(degrees=15, fill=255),
        # T.RandomAffine(degrees=0, translate=(0.1, 0.1), fill=255),
        # T.RandomAffine(degrees=0, scale=(1.0, 1.2), fill=255),
        # T.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(1.0, 1.2), fill=255),
        T.ConvertImageDtype(torch.float32),
        # T.Normalize(mean=0.17, std=0.19),
    ),
)

show_pytorch_images([train[i] for i in np.random.permutation(len(train))][:49]).show()

In [None]:
train = FetalBrainPlanesDataset(
    data_dir=database_dir,
    subset="train",
    transform=torch.nn.Sequential(
        T.Grayscale(),
        T.Resize((165, 240)),
        T.AutoAugment(T.AutoAugmentPolicy.IMAGENET),
        # T.RandAugment(),
        # T.TrivialAugmentWide(),
        # T.AugMix(),
        T.ConvertImageDtype(torch.float32),
    ),
)

show_pytorch_images([train[i] for i in np.random.permutation(len(train))][:49]).show()

In [None]:
def group_split_label(
    dataset: pd.DataFrame, test_size: float, groups, random_state: int = None
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    splitter = GroupShuffleSplit(test_size=test_size, n_splits=1, random_state=random_state)
    split = splitter.split(dataset, groups=groups)
    train_idx, test_idx = next(split)
    return dataset.iloc[train_idx], dataset.iloc[test_idx]


def get_similarity(train, test, test_size):
    similarity = 0
    train_count = train.value_counts(sort=False).sort_index()
    test_count = test.value_counts(sort=False).sort_index()

    if train_count.index.tolist() != test_count.index.tolist():
        return -1

    for a, b in zip(train_count, test_count):
        similarity += (a * test_size - b * (1 - test_size)) ** 2
    return similarity**0.5


def plt_value_counts(ax, dataset, tile=None):
    counts = dataset.value_counts(sort=False).sort_index()
    counts.plot(kind="bar", ax=ax)
    if tile:
        ax.set_title(tile)


def plt_group_split(dataset: pd.DataFrame, test_size: float, random_states: List[int], top_states: int = None):
    splits = []
    for random_state in tqdm(random_states):
        tr, val = group_split_label(
            dataset,
            test_size=test_size,
            groups=dataset["Patient_num"],
            random_state=random_state,
        )

        similarity = get_similarity(tr.Brain_plane, val.Brain_plane, test_size)
        if similarity >= 0:
            splits.append((similarity, tr, val, random_state))

    splits.sort(key=lambda e: (e[0], e[3]))
    nrows = top_states if top_states else len(splits)

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=2,
        sharex="all",
        squeeze=False,
        figsize=(20, 3 * nrows),
    )
    fig.suptitle(f"Test size {test_size}")
    for i, (similarity, tr, val, random_state) in enumerate(splits[:nrows]):
        plt_value_counts(axes[i, 0], tr.Brain_plane, tile=f"Seed {random_state}")
        plt_value_counts(axes[i, 1], val.Brain_plane, tile=f"Similarity {similarity}")

    plt.show()
    print([random_state for (similarity, tr, val, random_state) in splits[:nrows]])


plt_group_split(
    train.img_labels,
    test_size=0.2,
    random_states=list(range(10000)),
    top_states=10,
)  # 564, 3097, 1683, 4951, 5724, 8910, 9486, 7023, 5907, 9759
# plt_group_split(
#     train.img_labels,
#     test_size=0.15,
#     random_states=list(range(10000)),
#     top_states=10,
# )  # 943, 9787, 4935, 6588, 6893, 697, 6347, 5785, 4, 7765
# plt_group_split(
#     train.img_labels,
#     test_size=0.1,
#     random_states=list(range(10000)),
#     top_states=10,
# )  # 2251, 3084, 9456, 8902, 1208, 9959, 2696, 2086, 4063, 9126

In [None]:
len(train.img_labels[train.img_labels.Brain_plane == "Other"])

## Image augmentation