In [None]:
import pathlib
import shutil
from math import ceil, sqrt
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import pyrootutils
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import ConcatDataset, Dataset
from torchvision.io import read_image
from tqdm import tqdm

root = pyrootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)

from src.models.fetal_module import FetalLitModule

root

In [None]:
checkpoint_files = [
    root
    / "logs"
    / "train"
    / "multiruns"
    / "2023-02-03_15-13-23"
    / "68"
    / "checkpoints"
    / "epoch_015.ckpt",
    root
    / "logs"
    / "train"
    / "multiruns"
    / "2023-02-03_15-13-23"
    / "118"
    / "checkpoints"
    / "epoch_011.ckpt",
    root
    / "logs"
    / "train"
    / "multiruns"
    / "2023-02-03_15-13-23"
    / "171"
    / "checkpoints"
    / "epoch_012.ckpt",
    root
    / "logs"
    / "train"
    / "multiruns"
    / "2023-02-03_15-13-23"
    / "154"
    / "checkpoints"
    / "epoch_010.ckpt",
    root
    / "logs"
    / "train"
    / "multiruns"
    / "2023-02-03_15-13-23"
    / "173"
    / "checkpoints"
    / "epoch_012.ckpt",
]

models = []

for file in checkpoint_files:
    model = FetalLitModule.load_from_checkpoint(str(file))
    # disable randomness, dropout, etc...
    model.eval()
    models.append(model)

In [None]:
path = root / "data" / "US_VIDEOS"
shutil.rmtree(path / "labeled", ignore_errors=True)

In [None]:
transforms = T.Compose(
    [
        T.ToTensor(),
        T.Grayscale(),
        T.Resize((55, 80)),
        T.ConvertImageDtype(torch.float32),
    ]
)

labels = [
    "Other",
    "Maternal cervix",
    "Fetal abdomen",
    "Fetal brain",
    "Fetal femur",
    "Fetal thorax",
]

softmax = torch.nn.Softmax(dim=1)
logsoftmax = torch.nn.LogSoftmax(dim=1)


def label_videos(path: pathlib.Path):
    videos_path = path / "videos"
    images_path = path / "labeled"
    videos = len(list(videos_path.iterdir()))
    for i, video_path in enumerate(videos_path.iterdir()):
        label_video(video_path, images_path, i + 1, videos)


#         break


def frame_iter(capture, description):
    def iterator():
        while capture.grab():
            yield capture.retrieve()[1]

    return tqdm(
        iterator(),
        desc=description,
        total=int(capture.get(cv2.CAP_PROP_FRAME_COUNT)),
    )


def label_video(video_path: pathlib.Path, images_path: pathlib.Path, it: int, videos: int):
    if not video_path.exists():
        print(f"path {video_path} not exist")

    vidcap = cv2.VideoCapture(str(video_path))
    for i, frame in enumerate(frame_iter(vidcap, f"label video {it}/{videos}")):
        label = label_frame(frame)
        img_path = images_path / video_path.stem / label / ("frame%d.jpg" % i)
        if not img_path.parent.exists():
            img_path.parent.mkdir(parents=True)
        cv2.imwrite(str(img_path), frame)

    #         if i + 1 == 36:
    #             break;

    count_images(images_path / video_path.stem)


def label_frame(frame):
    with torch.no_grad():
        frame = PIL.Image.fromarray(frame)
        frame = transforms(frame)
        frame = frame.unsqueeze(0)

        ys = None
        for model in models:
            with torch.no_grad():
                y = model(frame)
                y = softmax(y)
                pred = torch.argmax(y, dim=1).item()
                if ys is None:
                    ys = y
                else:
                    ys += y

        pred = torch.argmax(ys, dim=1).item()
        #         print(f"pred: {pred}, {softmax(ys)[0][3]:.4f}")
        #         print(softmax(ys))
        #         print(pred)

        return labels[pred]


def count_images(images_path: pathlib.Path):
    count = {}
    for label in labels:
        count[label] = 0
    for label_dir in images_path.iterdir():
        count[label_dir.name] = len(list(label_dir.iterdir()))
    print(count)


torch.set_printoptions(precision=2)
path = root / "data" / "US_VIDEOS"
label_videos(path)
torch.set_printoptions(profile="default")

In [None]:
labels = [
    "Other",
    "Maternal cervix",
    "Fetal abdomen",
    "Fetal brain",
    "Fetal femur",
    "Fetal thorax",
]


def count_all_images(images_path: pathlib.Path):
    count = {}
    for label in labels:
        count[label] = 0
    for video_dir in images_path.iterdir():
        for label_dir in video_dir.iterdir():
            count[label_dir.name] += len(list(label_dir.iterdir()))
    return count


path = root / "data" / "US_VIDEOS" / "labeled"
images = count_all_images(path)
for key, item in images.items():
    print(f"{key}: {item}")

In [None]:
class USVideosDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ):
        data_dir = Path(data_dir) / "US_VIDEOS" / "labeled"
        images = self.find_images(data_dir)
        self.items = []
        for key, items in images.items():
            self.items.extend([(str(item), key) for item in items])

        self.transform = transform
        self.target_transform = target_transform

    @staticmethod
    def find_images(images_path: pathlib.Path):
        images = {}
        for video_dir in images_path.iterdir():
            for label_dir in video_dir.iterdir():
                label = label_dir.name
                if label not in images:
                    images[label] = []
                images[label].extend(label_dir.iterdir())
        return images

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError("list index out of range")

        img_path, label = self.items[idx]
        image = read_image(img_path)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


data_dir = root / "data"
dataset = USVideosDataset(
    data_dir=str(data_dir),
)
print(len(dataset))

In [None]:
def show(imgs, tick_labels: bool = True):
    n = ceil(sqrt(len(imgs)))

    fig, axes = plt.subplots(ncols=n, nrows=n, squeeze=False, figsize=(20, 15))

    for i in range(n):
        for j in range(n):
            if i * n + j >= len(imgs):
                continue

            img, label = imgs[i * n + j]
            img = img.detach()
            img = F.to_pil_image(img)
            img = F.to_grayscale(img)
            axes[i, j].imshow(np.asarray(img), cmap="gray")
            axes[i, j].set_xlabel(label)
            if not tick_labels:
                axes[i, j].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    fig.tight_layout()
    plt.show()


show(dataset, tick_labels=False)

In [None]:
dataset = USVideosDataset(
    data_dir=str(root / "data"),
)
print(len(dataset))


def compare_image(imgs):
    hists = []
    for i in range(len(imgs)):
        img_path, label = dataset.items[i]
        img = cv2.imread(img_path)
        hist = cv2.calcHist([img], [0], None, [256], [0, 256])
        hists.append(hist)

    for i in range(len(hists) - 1):
        hist_diff = cv2.compareHist(hists[i], hists[i + 1], cv2.HISTCMP_BHATTACHARYYA)

        template_probability_match = cv2.matchTemplate(
            hists[i], hists[i + 1], cv2.TM_CCOEFF_NORMED
        )[0][0]
        template_diff = 1 - template_probability_match

        c = 0
        # Euclidean Distance between data1 and test
        j = 0
        while j < len(hists[i]) and j < len(hists[i + 1]):
            c += (hists[i][j] - hists[i + 1][j]) ** 2
            j += 1
        c = c ** (1 / 2)

        print(f"hist {hist_diff}, template {template_diff}, L2 {c}")

        # etropia
        # ssi -


compare_image(dataset)
# show(dataset, tick_labels=False)