In [None]:
import pyrootutils

root = pyrootutils.setup_root(
    search_from=".",
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)
root

In [None]:
import pathlib
from pathlib import Path
from typing import Callable, Optional

import cv2
import numpy as np
import PIL
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset
from torchvision.io import read_image
from tqdm import tqdm

from src.models.fetal_module import FetalLitModule

In [None]:
checkpoint_file = str(
    root / "logs" / "train" / "runs" / "2023-01-08_19-53-10" / "checkpoints" / "epoch_019.ckpt"
)
model = FetalLitModule.load_from_checkpoint(checkpoint_file)
# disable randomness, dropout, etc...
model.eval()

In [None]:
transforms = T.Compose(
    [
        T.ToTensor(),
        T.Grayscale(),
        T.ConvertImageDtype(torch.float32),
        T.Resize((55, 80)),
    ]
)

labels = [
    "Other",
    "Maternal cervix",
    "Fetal abdomen",
    "Fetal brain",
    "Fetal femur",
    "Fetal thorax",
]


def label_videos(path: pathlib.Path):
    videos_path = path / "videos"
    images_path = path / "labeled"
    videos = len(list(videos_path.iterdir()))
    for i, video_path in enumerate(videos_path.iterdir()):
        label_video(video_path, images_path, i + 1, videos)


def frame_iter(capture, description):
    def iterator():
        while capture.grab():
            yield capture.retrieve()[1]

    return tqdm(
        iterator(),
        desc=description,
        total=int(capture.get(cv2.CAP_PROP_FRAME_COUNT)),
    )


def label_video(video_path: pathlib.Path, images_path: pathlib.Path, it: int, videos: int):
    if not video_path.exists():
        print(f"path {video_path} not exist")

    vidcap = cv2.VideoCapture(str(video_path))
    for i, frame in enumerate(frame_iter(vidcap, f"label video {it}/{videos}")):
        label = label_frame(frame)
        img_path = images_path / video_path.stem / label / ("frame%d.jpg" % i)
        if not img_path.parent.exists():
            img_path.parent.mkdir(parents=True)
        cv2.imwrite(str(img_path), frame)

    count_images(images_path / video_path.stem)


def label_frame(frame):
    with torch.no_grad():
        frame = PIL.Image.fromarray(frame)
        frame = transforms(frame)
        frame = frame.unsqueeze(0)
        y = model(frame)
        pred = y.max(1).indices[0]
        return labels[pred]


def count_images(images_path: pathlib.Path):
    count = {}
    for label in labels:
        count[label] = 0
    for label_dir in images_path.iterdir():
        count[label_dir.name] = len(list(label_dir.iterdir()))
    print(count)


path = root / "data" / "US_VIDEOS"
label_videos(path)

In [None]:
labels = [
    "Other",
    "Maternal cervix",
    "Fetal abdomen",
    "Fetal brain",
    "Fetal femur",
    "Fetal thorax",
]


def count_all_images(images_path: pathlib.Path):
    count = {}
    for label in labels:
        count[label] = 0
    for video_dir in images_path.iterdir():
        for label_dir in video_dir.iterdir():
            count[label_dir.name] += len(list(label_dir.iterdir()))
    return count


path = root / "data" / "US_VIDEOS" / "labeled"
images = count_all_images(path)
for key, item in images.items():
    print(f"{key}: {item}")

In [None]:
class USVideosDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        max_images: int,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ):
        data_dir = Path(data_dir) / "US_VIDEOS" / "labeled"
        images = self.find_images(data_dir)
        self.items = []
        for key, item in images.items():
            idxs = np.random.permutation(len(item))[:max_images]
            self.items.extend([(str(item[idx]), key) for idx in idxs])

        self.transform = transform
        self.target_transform = target_transform

    @staticmethod
    def find_images(images_path: pathlib.Path):
        images = {}
        for video_dir in images_path.iterdir():
            for label_dir in video_dir.iterdir():
                label = label_dir.name
                if label not in images:
                    images[label] = []
                images[label].extend(label_dir.iterdir())
        return images

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError("list index out of range")

        img_path, label = self.items[idx]
        image = read_image(img_path)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


data_dir = root / "data"
dataset = USVideosDataset(
    data_dir=str(data_dir),
    max_images=1000,
)
print(len(dataset))