In [None]:
import csv
import itertools
import pathlib
import shutil
from math import ceil, sqrt
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import pyrootutils
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from skimage.metrics import structural_similarity
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import ConcatDataset, Dataset
from torchvision.io import read_image
from tqdm.notebook import tqdm

root = pyrootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)

from src.data.components.dataset import FetalBrainPlanesDataset, USVideosDataset
from src.data.components.transforms import Affine, HorizontalFlip
from src.data.utils.utils import show_numpy_images, show_pytorch_images
from src.models.fetal_module import FetalLitModule

path = root / "data" / "US_VIDEOS"
root

In [None]:
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([2.0, 1.0, 1.0])

loss = x - y

print(loss)

weight = torch.ones(loss.shape, device=loss.device)
weight = torch.masked_fill(weight, loss > 0, 2)
print(weight)
loss = loss * weight
print(loss)

loss = torch.mul(loss, loss)
print(loss)

torch.mean(loss)

In [None]:
shutil.rmtree(path / "selected", ignore_errors=True)
shutil.rmtree(path / "labeled", ignore_errors=True)

In [None]:
# checkpoint_file = root / "logs" / "train" / "runs" / "2023-02-21_18-51-47"
# checkpoint_file = root / "logs" / "train" / "runs" / "2023-02-24_18-37-02"
checkpoint_file = root / "logs" / "train" / "runs" / "2023-02-25_19-43-37"

checkpoint = sorted(checkpoint_file.glob("checkpoints/epoch_*.ckpt"))[-1]
model = FetalLitModule.load_from_checkpoint(str(checkpoint))
# disable randomness, dropout, etc...
model.eval()

model.hparams.net_spec.name

In [None]:
def label_video(video_path: Path):
    y_hats = []
    all_logits = []
    all_dense_logits = []

    vidcap = cv2.VideoCapture(str(video_path))
    for frame in _frame_iter(vidcap, "Label frames"):
        frame = PIL.Image.fromarray(frame)
        frame = TF.to_tensor(frame)
        frame = frame.to(model.device)
        frames = torch.stack([transform(frame) for transform in transforms])

        with torch.no_grad():
            dense_logits, logits = model(frames)
            y_hat = F.softmax(logits, dim=1)
            y_hats.append(y_hat)
            all_logits.append(logits)
            all_dense_logits.append(dense_logits)

    return torch.stack(all_dense_logits, dim=1), torch.stack(all_logits, dim=1), torch.stack(y_hats, dim=1)


def _frame_iter(capture, description):
    def iterator():
        while capture.grab():
            yield capture.retrieve()[1]

    return tqdm(
        iterator(),
        desc=description,
        total=int(capture.get(cv2.CAP_PROP_FRAME_COUNT)),
        position=0,
        leave=False,
    )


def lable_transform_video(video_path, horizontal_flips, rotate_degrees, translates, scales):
    global transforms

    transforms = [
        T.Compose(
            [
                T.Grayscale(),
                T.Resize((165, 240)),
                HorizontalFlip(flip=horizontal_flip),
                Affine(degrees=rotate_degree, translate=translate, scale=scale),
                T.ConvertImageDtype(torch.float32),
            ]
        )
        for horizontal_flip, rotate_degree, translate, scale in itertools.product(
            horizontal_flips, rotate_degrees, translates, scales
        )
    ]

    return label_video(video_path)


# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
model.to(device)
video_path = sorted((path / "videos" / "test").iterdir())[0]

_, logits, y_hats = lable_transform_video(
    video_path=video_path,
    horizontal_flips=[False, True],
    rotate_degrees=[0, -15, 15],
    translates=[(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)],
    scales=[1.0, 1.2],
)
y_hats.shape

In [None]:
def plot_probabilities(y_hats, ax):
    x = list(range(y_hats.shape[0]))

    for i, label in enumerate(label_def):
        ax.plot(x, y_hats[:, i], label=label)

    ax.legend()


def plot_best_probabilities(y_hats, ax):
    x = torch.zeros((y_hats.shape[1], 0)).tolist()
    y = torch.zeros((y_hats.shape[1], 0)).tolist()

    for i, y_hat in enumerate(y_hats):
        best = torch.argmax(y_hat)
        x[best].append(i)
        y[best].append(y_hat[best])

    for i, label in enumerate(label_def):
        ax.plot(x[i], y[i], "o", markersize=2, label=label)

    ax.legend()


def is_stable(y, i):
    min_i = max(0, i - window)
    max_i = min(i + window, len(y) - 1)

    if min_i > i - window or max_i < i + window:
        return False

    for j in range(min_i, max_i + 1):
        if y[j] == 0.0:
            return False

    return True


def plot_filtered_probabilities(y_hats, ax):
    x = torch.arange(0, y_hats.shape[0]).int().repeat(y_hats.shape[1], 1).tolist()
    y = torch.zeros((y_hats.shape[1], y_hats.shape[0])).tolist()

    for i, y_hat in enumerate(y_hats):
        best = torch.argmax(y_hat)
        y[best][i] = y_hat[best]

    for i in range(len(y)):
        to_delete = []

        for j in range(len(y[i])):
            if not is_stable(y[i], j):
                to_delete.append(j)

        to_delete.sort(reverse=True)
        for j in to_delete:
            x[i].pop(j)
            y[i].pop(j)

    for i, label in enumerate(label_def):
        ax.plot(x[i], y[i], "o", markersize=2, label=label)

    ax.legend()


def plot_base_probabilities(y_hats):
    ncols = 3
    nrows = len(y_hats)
    fig, axes = plt.subplots(
        ncols=ncols,
        nrows=nrows,
        sharex=True,
        sharey=True,
        squeeze=False,
        tight_layout=True,
        figsize=(10 * ncols, 5 * nrows),
    )

    for i in range(nrows):
        plot_probabilities(y_hats[i], axes[i, 0])
        plot_best_probabilities(y_hats[i], axes[i, 1])
        plot_filtered_probabilities(y_hats[i], axes[i, 2])

    for ax in axes:
        ax[0].set_xlim(left=0, right=y_hats.shape[1])
        ax[0].set_ylim(bottom=0, top=1)


label_def = FetalBrainPlanesDataset.labels

window = 3
temperature = 1.0
y_hats_ = F.softmax(logits / temperature, dim=2)
plot_base_probabilities(y_hats_[:1].cpu())

# fig, axes = plt.subplots(ncols=1, nrows=1, tight_layout=True, figsize=(10, 5))
# plot_filtered_probabilities(y_hats[0].cpu(), axes)
# axes.set_xlim(left=0, right=y_hats.shape[1])
# axes.set_ylim(bottom=0, top=1)

In [None]:
y_hats_ = torch.tensor(
    [
        [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0], [1.0, 3.0, 2.0], [1.0, 3.0, 2.0]],
        [[1.0, 4.0, 2.0], [2.0, 4.0, 3.0], [3.0, 4.0, 6.0], [1.0, 4.0, 2.0], [1.0, 5.0, 2.0]],
    ]
)
print(y_hats_.shape)

pred = torch.argmax(y_hats_, dim=2)
print(pred.shape)
print(pred)

y_hats_ = y_hats_ * F.one_hot(pred)
# mask = F.one_hot(pred) == 0
# y_hats_ = torch.masked_fill(y_hats_, mask, 0.0)

print(y_hats_)

In [None]:
window = 1
for i in range(pred.shape[0]):
    for j in range(pred.shape[1]):
        min_j = max(0, j - window)
        max_j = min(j + window + 1, pred.shape[1])

        if not torch.all(pred[i, min_j:max_j] == pred[i, j]):
            y_hats_[i, j, pred[i, j]] = 0

y_hats_

In [None]:
window = 1
for i in range(pred.shape[0]):
    for j in range(pred.shape[1]):
        if torch.sum(pred[i, j - window : j + window + 1] == pred[i, j]) < 1 + 2 * window:
            y_hats_[i, j, pred[i, j]] = 0

y_hats_

In [None]:
y_hats_.shape

In [None]:
y_hats__ = torch.mean(y_hats_, dim=0)
print(y_hats__.shape)
print(y_hats__)

plates = y_hats__[:, :2]
other = y_hats__[:, 2:]

quality = torch.sum(plates, dim=1) - torch.sum(other, dim=1)
print(quality)

quality = quality / torch.sum(plates > 0, dim=1)
zaro_mask = torch.eq(quality > 0, False)
torch.masked_fill(quality, zaro_mask, 0.0)

In [None]:
def calculate_quality(y_hats: Tensor):
    # select highest prediction
    pred = torch.argmax(y_hats, dim=2)
    y_hats = y_hats * F.one_hot(pred, num_classes=y_hats.shape[2])

    # remove predictions that are inconsistent
    for i in range(pred.shape[0]):
        for j in range(pred.shape[1]):
            min_j = max(0, j - window)
            max_j = min(j + window + 1, pred.shape[1])

            if not torch.all(torch.eq(pred[i, min_j:max_j], pred[i, j])):
                y_hats[i, j, pred[i, j]] = 0

    # average of all transformations
    y_hats = torch.mean(y_hats, dim=0)  # ???

    # (sum planes' prediction - sum no planes' prediction) / (number of planes' prediction greater than 0)
    plates = y_hats[:, :3]
    other = y_hats[:, 3:]
    quality = torch.sum(plates, dim=1) - torch.sum(other, dim=1)
    quality = quality / torch.sum(plates > 0, dim=1)
    zaro_mask = torch.eq(quality > 0, False)
    quality.masked_fill_(zaro_mask, 0.0)

    return y_hats, quality


def plot_quality(y_hats: Tensor, quality: Tensor):
    fig, axes = plt.subplots(ncols=1, nrows=3, tight_layout=True, figsize=(10, 15))

    for i, label in enumerate(label_def):
        x, y = extract_nonzero_values(y_hats[:, i])
        axes[0].plot(x, y, "o", markersize=2, label=label)
        axes[0].legend()
    axes[1].plot(range(len(quality)), quality, "o", markersize=2, color="tab:gray")
    axes[2].plot(range(len(quality)), quality, color="tab:gray")

    for ax in axes:
        ax.set_xlim(left=0, right=len(quality))
        ax.set_ylim(bottom=0, top=1)

    return fig


def extract_nonzero_values(y_hats):
    x = []
    y = []
    for i, y_hat in enumerate(y_hats):
        if y_hat > 0:
            x.append(i)
            y.append(y_hat)
    return x, y


window = 3
temperature = 1.0
y_hats_ = F.softmax(logits / temperature, dim=2)
# y_hats_ = y_hats.clone()
label_def = FetalBrainPlanesDataset.labels

y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()

In [None]:
def calculate_quality(y_hats: Tensor):
    # select highest prediction
    pred = torch.argmax(y_hats, dim=2)
    y_hats = y_hats * F.one_hot(pred, num_classes=y_hats.shape[2])

    # remove predictions that are inconsistent
    for i in range(pred.shape[0]):
        for j in range(pred.shape[1]):
            min_j = max(0, j - window)
            max_j = min(j + window + 1, pred.shape[1])

            if not torch.all(torch.eq(pred[i, min_j:max_j], pred[i, j])):
                y_hats[i, j, pred[i, j]] = 0

    # average of all transformations
    y_hats = torch.mean(y_hats, dim=0)

    # (the best prediction - sum of the rest prediction)
    plates = y_hats[:, :3]
    quality = torch.amax(plates, dim=1)
    quality = (quality * 2) - torch.sum(y_hats, dim=1)
    zaro_mask = torch.eq(quality > 0, False)
    quality.masked_fill_(zaro_mask, 0.0)

    return y_hats, quality


window = 3
temperature = 1.0
y_hats_ = F.softmax(logits / temperature, dim=2)
# y_hats_ = y_hats.clone()
label_def = FetalBrainPlanesDataset.labels

y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()

In [None]:
def load_logits(data_path: Path):
    dense = []
    for path in sorted(data_path.iterdir()):
        logits, _, _ = torch.load(path)
        dense.append(logits.view(-1, logits.size(-1)))
    return torch.cat(dense)


def save_std_mean(data_path: Path, logits):
    std_mean = torch.std_mean(logits, unbiased=False, dim=0)
    torch.save(std_mean, f"{data_path}/std_mean.pt")
    print(std_mean[0].shape)


path = root / "data" / "US_VIDEOS_tran" / "data"
dense = load_logits(path / "train")
save_std_mean(path, dense)

# Test transfor operation to quality plot

In [None]:
device = "cuda"
model.to(device)
video_path = sorted((path / "videos" / "test").iterdir())[0]

window = 3
temperature = 1.0

In [None]:
def plot_quality(y_hats: Tensor, quality: Tensor):
    fig, axes = plt.subplots(ncols=2, nrows=1, tight_layout=True, figsize=(20, 5))

    for i, label in enumerate(label_def):
        x, y = extract_nonzero_values(y_hats[:, i])
        axes[0].plot(x, y, "o", markersize=2, label=label)
        axes[0].legend()
    #     axes[1].plot(range(len(quality)), quality, "o", markersize=2, color="tab:gray")
    axes[1].plot(range(len(quality)), quality, color="tab:gray")

    for ax in axes:
        ax.set_xlim(left=0, right=len(quality))
        ax.set_ylim(bottom=0, top=1)

    return fig

In [None]:
_, logits, y_hats = lable_transform_video(
    video_path=video_path, horizontal_flips=[False], rotate_degrees=[0], translates=[(0.0, 0.0)], scales=[1.0]
)
y_hats.shape

In [None]:
y_hats_ = F.softmax(logits / temperature, dim=2)

plot_base_probabilities(y_hats_[:1].cpu())

In [None]:
y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()

In [None]:
_, logits, y_hats = lable_transform_video(
    video_path=video_path, horizontal_flips=[False, True], rotate_degrees=[0], translates=[(0.0, 0.0)], scales=[1.0]
)

y_hats.shape

In [None]:
y_hats_ = F.softmax(logits / temperature, dim=2)

plot_base_probabilities(y_hats_[:2].cpu())

In [None]:
y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()

In [None]:
_, logits, y_hats = lable_transform_video(
    video_path=video_path, horizontal_flips=[False], rotate_degrees=[0, -15, 15], translates=[(0.0, 0.0)], scales=[1.0]
)
y_hats.shape

In [None]:
y_hats_ = F.softmax(logits / temperature, dim=2)

plot_base_probabilities(y_hats_[:3].cpu())

In [None]:
y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()

In [None]:
_, logits, y_hats = lable_transform_video(
    video_path=video_path,
    horizontal_flips=[False],
    rotate_degrees=[0],
    translates=[(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)],
    scales=[1.0],
)
y_hats.shape

In [None]:
y_hats_ = F.softmax(logits / temperature, dim=2)

plot_base_probabilities(y_hats_[:5].cpu())

In [None]:
y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()

In [None]:
_, logits, y_hats = lable_transform_video(
    video_path=video_path, horizontal_flips=[False], rotate_degrees=[0], translates=[(0.0, 0.0)], scales=[1.0, 1.2]
)
y_hats.shape

In [None]:
y_hats_ = F.softmax(logits / temperature, dim=2)

plot_base_probabilities(y_hats_[:2].cpu())

In [None]:
y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()

In [None]:
_, logits, y_hats = lable_transform_video(
    video_path=video_path,
    horizontal_flips=[False, True],
    rotate_degrees=[0, -15, 15],
    translates=[(0.0, 0.0), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1), (0.1, -0.1)],
    scales=[1.0, 1.2],
)
y_hats.shape

In [None]:
y_hats_ = F.softmax(logits / temperature, dim=2)

plot_base_probabilities(y_hats_[:4].cpu())

In [None]:
y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()

In [None]:
_, logits, y_hats = lable_transform_video(
    video_path=video_path,
    horizontal_flips=[False, True],
    rotate_degrees=[0, -15, 15, -30, 30],
    translates=list(itertools.product([0.0, 0.1, -0.1], [0.0, 0.1, -0.1])),
    scales=[1.0, 1.2],
)
y_hats.shape

In [None]:
y_hats_ = F.softmax(logits / temperature, dim=2)

plot_base_probabilities(y_hats_[:4].cpu())

In [None]:
y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()

In [None]:
_, logits, y_hats = lable_transform_video(
    video_path=video_path,
    horizontal_flips=[False, True],
    rotate_degrees=[0, -15, 15, -30, 30],
    translates=list(itertools.product([0.0, 0.1, -0.1], [0.0, 0.1, -0.1])),
    scales=[1.0, 1.2, 1.5],
)
y_hats.shape

In [None]:
y_hats_ = F.softmax(logits / temperature, dim=2)

plot_base_probabilities(y_hats_[:6].cpu())

In [None]:
y_hats_, quality = calculate_quality(y_hats_)
plot_quality(y_hats_.cpu(), quality.cpu()).show()