In [None]:
import csv
import itertools
import pathlib
import shutil
from math import ceil, sqrt
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import rootutils
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from skimage.metrics import structural_similarity
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from torchvision.io import read_image
from tqdm.notebook import tqdm

# from IPython.display import set_matplotlib_formats
# set_matplotlib_formats('png')

root = rootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)

from src.data.components.dataset import (
    FetalBrainPlanesDataset,
    USVideosDataset,
    USVideosFrameDataset,
    VideoQualityDataset,
)
from src.data.components.transforms import Affine, HorizontalFlip
from src.data.utils.utils import show_numpy_images, show_pytorch_images
from src.models.fetal_module import FetalLitModule
from src.models.quality_module import QualityLitModule

path = root / "data"
root

In [None]:
checkpoint_file = root / "logs" / "train" / "runs" / "2023-03-07_05-55-06"

checkpoint = sorted(checkpoint_file.glob("checkpoints/epoch_*.ckpt"))[-1]
rnn = QualityLitModule.load_from_checkpoint(str(checkpoint))
# disable randomness, dropout, etc...
rnn.eval()
print("ok")

In [None]:
dataset = VideoQualityDataset(
    data_dir=path,
    dataset_name="US_VIDEOS_tran_0500",
    train=True,
    seq_len=0,
    # seq_len=128,
    # seq_step=None,
    reverse=False,
    transform=False,
    normalize=False,
)
# torch.Size([472])
# torch.Size([329])

# for i in range(len(dataset)):
#     print(f"{dataset.clips.From[i]} {dataset.clips.To[i]}")

print(len(dataset))
# for i in range(len(dataset)):
#     print(dataset[i][1].shape)

In [None]:
print(len(dataset))
for i in range(len(dataset)):
    print(dataset[i][1].shape)

In [None]:
label_names = FetalBrainPlanesDataset.labels

for idx, video in enumerate(dataset):
    x, y, p = video
    print(idx, dataset.clips.Video[idx], y[0].item(), p[0].item())
    print(len(y))
    print(len(p))
    for frame_idx in range(len(x)):
        if frame_idx >= 50:
            break
        print(
            idx,
            dataset.clips.Video[idx],
            frame_idx,
            y[frame_idx].item(),
            p[frame_idx].item(),
            label_names[p[frame_idx]],
        )

    break

In [None]:
label_names = FetalBrainPlanesDataset.labels

dataset_train = VideoQualityDataset(
    data_dir=path,
    dataset_name="US_VIDEOS_tran_0500",
    train=True,
    seq_len=0,
)
dataset_test = VideoQualityDataset(
    data_dir=path,
    dataset_name="US_VIDEOS_tran_0500",
    train=False,
    seq_len=0,
)

with open("data/quality.csv", "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["Video_idx", "Video_name", "Frame_idx", "Quality", "Class_idx", "Class_name"])

    for idx, video in enumerate(dataset_train):
        x, y, p = video
        for frame_idx in range(len(x)):
            writer.writerow(
                [
                    idx,
                    dataset_train.clips.Video[idx],
                    frame_idx,
                    y[frame_idx].item(),
                    p[frame_idx].item(),
                    label_names[p[frame_idx]],
                ]
            )

    for idx, video in enumerate(dataset_test):
        x, y, p = video
        for frame_idx in range(len(x)):
            writer.writerow(
                [
                    idx + len(dataset_train),
                    dataset_train.clips.Video[idx],
                    frame_idx,
                    y[frame_idx].item(),
                    p[frame_idx].item(),
                    label_names[p[frame_idx]],
                ]
            )

In [None]:
dl = DataLoader(
    dataset=dataset,
    batch_size=1,
    num_workers=0,
    pin_memory=True,
    shuffle=False,
)

In [None]:
output = torch.tensor(
    (
        (
            (1, 2, 3, 4),
            (2, 3, 4, 5),
            (3, 4, 5, 6),
        ),
        (
            (4, 5, 6, 7),
            (5, 6, 7, 8),
            (6, 7, 8, 9),
        ),
    )
)
print(output.shape)

batch_size, seq_len, hidden_size = output.shape
output = output.contiguous().view(-1, hidden_size)
print(output.shape)

fn = torch.tensor(((1, 1, 1, 1),))
print(fn.T.shape)
output = torch.mm(output, fn.T)
print(output.shape)

output = output.contiguous().view(batch_size, seq_len)
print(output.shape)

output

In [None]:
lin = torch.nn.Linear(4, 2)

x = torch.tensor(((1, 1, 1, 1),), dtype=torch.float32)
print(x.shape)
print(lin(x).shape)

x = torch.tensor(
    (
        (
            (1, 2, 3, 4),
            (2, 3, 4, 5),
            (3, 4, 5, 6),
        ),
        (
            (4, 5, 6, 7),
            (5, 6, 7, 8),
            (6, 7, 8, 9),
        ),
    ),
    dtype=torch.float32,
)
print(x.shape)
# y1 = lin(x)
# print(y1.shape)
x[:, -1, :]

# batch_size, seq_len, hidden_size = x.shape
# y2 = y2.contiguous().view(-1, hidden_size)
# y2 = lin(y2)
# y2 = y2.contiguous().view(batch_size, seq_len, 2)
# print(y2.shape)

# print(y1)
# print(y2)