In [None]:
import csv
import itertools
import pathlib
import shutil
from math import ceil, sqrt
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import pyrootutils
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from skimage.metrics import structural_similarity
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import ConcatDataset, Dataset
from torchvision.io import read_image
from tqdm.notebook import tqdm

root = pyrootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)

from src.data.components.dataset import FetalBrainPlanesDataset, USVideosDataset
from src.data.components.transforms import Affine, HorizontalFlip
from src.data.utils.utils import show_numpy_images, show_pytorch_images
from src.models.fetal_module import FetalLitModule

path = root / "data" / "US_VIDEOS_2"
root

In [None]:
shutil.rmtree(path / "selected", ignore_errors=True)
shutil.rmtree(path / "labeled", ignore_errors=True)

In [None]:
# checkpoint_file = root / "logs" / "train" / "runs" / "2023-02-21_18-51-47"
checkpoint_file = root / "logs" / "train" / "runs" / "2023-02-23_10-41-08"

checkpoint = sorted(checkpoint_file.glob("checkpoints/epoch_*.ckpt"))[-1]
model = FetalLitModule.load_from_checkpoint(str(checkpoint))
# disable randomness, dropout, etc...
model.eval()

model.hparams.net_spec.name

In [None]:
def label_video(video_path: Path):
    y_hats = []
    all_logits = []
    base_dense_logits = []

    vidcap = cv2.VideoCapture(str(video_path))
    for frame in _frame_iter(vidcap, "Label frames"):
        frame = PIL.Image.fromarray(frame)
        frame = TF.to_tensor(frame)
        frame = frame.to(model.device)
        frames = torch.stack([transform(frame) for transform in transforms])

        with torch.no_grad():
            dense_logits, logits = model(frames)
            y_hat = F.softmax(logits, dim=1)
            y_hats.append(y_hat)
            all_logits.append(logits)
            base_dense_logits.append(dense_logits[0])

    return torch.stack(base_dense_logits, dim=0), torch.stack(all_logits, dim=1), torch.stack(y_hats, dim=1)


def _frame_iter(capture, description):
    def iterator():
        while capture.grab():
            yield capture.retrieve()[1]

    return tqdm(
        iterator(),
        desc=description,
        total=int(capture.get(cv2.CAP_PROP_FRAME_COUNT)),
        position=0,
        leave=False,
    )


# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
model.to(device)

horizontal_flips = [False, True]
rotate_degrees = [0, -15, 15]
translates = [(0.0, 0.0)]
scales = [1.0]

transforms = [
    T.Compose(
        [
            T.Grayscale(),
            T.Resize((165, 240)),
            HorizontalFlip(flip=horizontal_flip),
            Affine(degrees=rotate_degree, translate=translate, scale=scale),
            T.ConvertImageDtype(torch.float32),
        ]
    )
    for horizontal_flip, rotate_degree, translate, scale in itertools.product(
        horizontal_flips, rotate_degrees, translates, scales
    )
]

video_path = sorted((path / "videos").iterdir())[0]
_, logits, y_hats = label_video(video_path)

In [None]:
def plot_probabilities(y_hats, ax):
    x = list(range(y_hats.shape[0]))

    for i, label in enumerate(label_def):
        ax.plot(x, y_hats[:, i], label=label)

    ax.legend()


def plot_best_probabilities(y_hats, ax):
    x = torch.zeros((y_hats.shape[1], 0)).tolist()
    y = torch.zeros((y_hats.shape[1], 0)).tolist()

    for i, y_hat in enumerate(y_hats):
        best = torch.argmax(y_hat)
        x[best].append(i)
        y[best].append(y_hat[best])

    for i, label in enumerate(label_def):
        ax.plot(x[i], y[i], "o", markersize=2, label=label)

    ax.legend()


def is_stable(y, i):
    min_i = max(0, i - window)
    max_i = min(i + window, len(y) - 1)

    if min_i > i - window or max_i < i + window:
        return False

    for j in range(min_i, max_i + 1):
        if y[j] == 0.0:
            return False

    return True


def plot_filtered_probabilities(y_hats, ax):
    x = torch.arange(0, y_hats.shape[0]).int().repeat(y_hats.shape[1], 1).tolist()
    y = torch.zeros((y_hats.shape[1], y_hats.shape[0])).tolist()

    for i, y_hat in enumerate(y_hats):
        best = torch.argmax(y_hat)
        y[best][i] = y_hat[best]

    for i in range(len(y)):
        to_delete = []

        for j in range(len(y[i])):
            if not is_stable(y[i], j):
                to_delete.append(j)

        to_delete.sort(reverse=True)
        for j in to_delete:
            x[i].pop(j)
            y[i].pop(j)

    for i, label in enumerate(label_def):
        ax.plot(x[i], y[i], "o", markersize=2, label=label)

    ax.legend()


label_def = FetalBrainPlanesDataset.labels

window = 3
temperature = 2.0
y_hats = F.softmax(logits / temperature, dim=2)

ncols = 3
nrows = len(y_hats)
fig, axes = plt.subplots(
    ncols=ncols,
    nrows=nrows,
    sharex=True,
    sharey=True,
    squeeze=False,
    tight_layout=True,
    figsize=(10 * ncols, 5 * nrows),
)

for i, y_hat in enumerate(y_hats):
    plot_probabilities(y_hat, axes[i, 0])
    plot_best_probabilities(y_hat, axes[i, 1])
    plot_filtered_probabilities(y_hat, axes[i, 2])

In [None]:
y_hats.shape

In [None]:
y_hats_ = torch.tensor(
    [
        [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0], [1.0, 3.0, 2.0], [1.0, 3.0, 2.0]],
        [[1.0, 4.0, 2.0], [2.0, 4.0, 3.0], [3.0, 4.0, 5.0], [1.0, 3.0, 2.0], [1.0, 3.0, 2.0]],
    ]
)

pred = torch.argmax(y_hats_, dim=2)
print(pred)

mask = F.one_hot(pred) == 0
y_hats_ = torch.masked_fill(y_hats_, mask, 0.0)
print(y_hats_)

In [None]:
window = 1
for i in range(pred.shape[0]):
    for j in range(pred.shape[1]):
        min_j = max(0, j - window)
        max_j = min(j + window + 1, pred.shape[1])

        if not torch.all(pred[i, min_j:max_j] == pred[i, j]):
            y_hats_[i, j, pred[i, j]] = 0

y_hats_

In [None]:
window = 1
for i in range(pred.shape[0]):
    for j in range(pred.shape[1]):
        if torch.sum(pred[i, j - window : j + window + 1] == pred[i, j]) < 1 + 2 * window:
            y_hats_[i, j, pred[i, j]] = 0

y_hats_

In [None]:
y_hats_.shape

In [None]:
y_hats__ = torch.mean(y_hats_, dim=0)
print(y_hats__.shape)
print(y_hats__)

plates = y_hats__[:, :2]
other = y_hats__[:, 2:]

quality = torch.sum(plates, dim=1) - torch.sum(other, dim=1)
print(quality)

quality = quality / torch.sum(plates > 0, dim=1)
zaro_mask = torch.eq(quality > 0, False)
torch.masked_fill(quality, zaro_mask, 0.0)

In [None]:
temperature = 3.0
y_hats_ = F.softmax(logits / temperature, dim=2)

# print(y_hats_.shape)
# y_hats_ = y_hats.clone()

pred = torch.argmax(y_hats_, dim=2)
print(pred.shape)
pred_mask = F.one_hot(pred, num_classes=y_hats_.shape[2]) == 0
print(pred_mask.shape)
y_hats_.masked_fill_(pred_mask, 0.0)

window = 3
for i in range(pred.shape[0]):
    for j in range(pred.shape[1]):
        min_j = max(0, j - window)
        max_j = min(j + window + 1, pred.shape[1])

        if not torch.all(torch.eq(pred[i, min_j:max_j], pred[i, j])):
            y_hats_[i, j, pred[i, j]] = 0

y_hats_ = torch.mean(y_hats_, dim=0)
plates = y_hats_[:, :3]
other = y_hats_[:, 3:]
quality = torch.sum(plates, dim=1) - torch.sum(other, dim=1)
quality = quality / torch.sum(plates > 0, dim=1)
zaro_mask = torch.eq(quality > 0, False)
quality.masked_fill_(zaro_mask, 0.0)

fig, axes = plt.subplots(ncols=1, nrows=1, squeeze=False, tight_layout=True, figsize=(10, 5))
axes[0, 0].plot(range(len(quality)), quality)
axes[0, 0].set_xlim(left=0, right=len(quality))
axes[0, 0].set_ylim(bottom=0, top=1)
fig.show()

In [None]:
checkpoint_file = root / "logs" / "train" / "runs" / "2023-02-21_18-51-47"
# checkpoint_file = root / "logs" / "train" / "runs" / "2023-02-22_12-35-41"

checkpoint = sorted(checkpoint_file.glob("checkpoints/epoch_*.ckpt"))[-1]
model = FetalLitModule.load_from_checkpoint(str(checkpoint))
# disable randomness, dropout, etc...
model.eval()

model.hparams.net_spec.name

In [None]:
def create_dataset(path: Path):
    videos_path = path / "videos"
    data_path = path / "data"
    plots_path = path / "plots"
    shutil.rmtree(data_path, ignore_errors=True)
    shutil.rmtree(plots_path, ignore_errors=True)

    videos = list(videos_path.iterdir())
    for i, video_path in enumerate(tqdm(videos, desc="Label videos", position=1)):
        dense_logits, y_hats = label_video(video_path)
        y_hats, quality = calculate_quality(y_hats)
        save_processed_video(data_path, video_path.stem, dense_logits, quality)
        save_quality_plot(plots_path, video_path.stem, y_hats, quality)


def label_video(video_path: Path):
    y_hats = []
    base_dense_logits = []

    vidcap = cv2.VideoCapture(str(video_path))
    for frame in frame_iter(vidcap, "Label frames"):
        frame = PIL.Image.fromarray(frame)
        frame = TF.to_tensor(frame)
        frame = frame.to(model.device)
        frames = torch.stack([transform(frame) for transform in transforms])

        with torch.no_grad():
            dense_logits, logits = model(frames)
            y_hat = F.softmax(logits, dim=1)
            y_hats.append(y_hat)
            base_dense_logits.append(dense_logits[0])

    return torch.stack(base_dense_logits, dim=0), torch.stack(y_hats, dim=1)


def frame_iter(capture, description):
    def iterator():
        while capture.grab():
            yield capture.retrieve()[1]

    return tqdm(
        iterator(),
        desc=description,
        total=int(capture.get(cv2.CAP_PROP_FRAME_COUNT)),
        position=0,
        leave=False,
    )


def calculate_quality(y_hats: Tensor):
    # select highest prediction
    pred = torch.argmax(y_hats, dim=2)
    pred_mask = F.one_hot(pred, num_classes=y_hats.shape[2]) == 0
    y_hats.masked_fill_(pred_mask, 0.0)

    # remove predictions that are inconsistent
    window = 3
    for i in range(pred.shape[0]):
        for j in range(pred.shape[1]):
            min_j = max(0, j - window)
            max_j = min(j + window + 1, pred.shape[1])

            if not torch.all(torch.eq(pred[i, min_j:max_j], pred[i, j])):
                y_hats[i, j, pred[i, j]] = 0

    # average of all transformations
    y_hats = torch.mean(y_hats, dim=0)

    # (sum planes' prediction - sum no planes' prediction) / (number of planes' prediction greater than 0)
    plates = y_hats[:, :3]
    other = y_hats[:, 3:]
    quality = torch.sum(plates, dim=1) - torch.sum(other, dim=1)
    quality = quality / torch.sum(plates > 0, dim=1)
    zaro_mask = torch.eq(quality > 0, False)
    quality.masked_fill_(zaro_mask, 0.0)

    return y_hats, quality


def save_processed_video(data_path: Path, video: str, dense_logits: Tensor, quality: Tensor):
    with open(f"{data_path}/{video}.csv", "w") as f:
        writer = csv.writer(f)
        writer.writerows(zip(dense_logits, quality))


def save_quality_plot(plots_path: Path, video: str, y_hats: Tensor, quality: Tensor):
    fig, axes = plt.subplots(ncols=1, nrows=3, sharex=True, sharey=True, tight_layout=True, figsize=(10, 15))

    for i, label in enumerate(label_def):
        x, y = extract_nonzero_values(y_hats[:, i])
        axes[0].plot(x, y, "o", markersize=2, label=label)
    axes[1].plot(range(len(quality)), quality, "o", color="tab:gray")
    axes[2].plot(range(len(quality)), quality, color="tab:gray")

    fig.xlim(left=0, right=len(quality))
    fig.ylim(bottom=0, top=1)
    plt.savefig(f"{plots_path}/{video}.csv")


def extract_nonzero_values(y_hats):
    x = []
    y = []
    for i, y_hat in enumerate(y_hats):
        if y_hat > 0:
            x.append(i)
            y.append(y_hat)
    return x, y


# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
model.to(device)

horizontal_flips = [False, True]
rotate_degrees = [0, -15, 15]
translates = [(0.0, 0.0)]
scales = [1.0]

transforms = [
    T.Compose(
        [
            T.Grayscale(),
            T.Resize((165, 240)),
            HorizontalFlip(flip=horizontal_flip),
            Affine(degrees=rotate_degree, translate=translate, scale=scale),
            T.ConvertImageDtype(torch.float32),
        ]
    )
    for horizontal_flip, rotate_degree, translate, scale in itertools.product(
        horizontal_flips, rotate_degrees, translates, scales
    )
]

y_hats = create_dataset(path)