In [None]:
import csv
import itertools
import pathlib
import shutil
from math import ceil, sqrt
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import rootutils
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from skimage.metrics import structural_similarity
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from torchvision.io import read_image
from tqdm.notebook import tqdm

# from IPython.display import set_matplotlib_formats
# set_matplotlib_formats('png')

root = rootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)

from src.data.components.dataset import (
    FetalBrainPlanesDataset,
    USVideosDataset,
    USVideosFrameDataset,
)
from src.data.components.transforms import Affine, HorizontalFlip
from src.data.utils.utils import show_numpy_images, show_pytorch_images
from src.models.fetal_module import FetalLitModule
from src.models.quality_module import QualityLitModule

path = root / "data"
root

In [None]:
checkpoint_file = root / "logs" / "train" / "runs" / "2023-03-07_05-55-06"

checkpoint = sorted(checkpoint_file.glob("checkpoints/epoch_*.ckpt"))[-1]
rnn = QualityLitModule.load_from_checkpoint(str(checkpoint))
# disable randomness, dropout, etc...
rnn.eval()
print("ok")

In [None]:
class VideoQualityDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        dataset_name: str = "US_VIDEOS",
        train: bool = True,
        seq_len: int = 32,
        seq_step: int = None,
        reverse: bool = False,
        transform: bool = False,
        normalize: bool = False,
        target_transform: Callable | None = None,
        label_transform: Callable | None = None,
    ):
        self.train = train
        self.dataset_dir = Path(data_dir) / dataset_name / "data"
        self.data_dir = self.dataset_dir / ("train" if self.train else "test")
        self.seq_len = seq_len
        self.seq_step = seq_step
        self.reverse = reverse
        self.transform = transform
        self.clips = self.load_clips()
        self.normalize = normalize
        self.std_mean = torch.load(f"{self.dataset_dir}/std_mean.pt")

        self.target_transform = target_transform
        self.label_transform = label_transform

    def load_clips(self):
        clips = []
        print(f"video_paths: {len(list(self.data_dir.iterdir()))}")
        for video_path in sorted(self.data_dir.iterdir()):
            print(f"video_path: {video_path.name}")
            transforms = [transform_path.name for transform_path in sorted(video_path.iterdir())]

            transform_path = sorted(video_path.iterdir())[0]
            logits, quality, _ = torch.load(transform_path)

            seq_len = self.seq_len or len(quality)
            seq_step = self.seq_step or max(1, ceil(seq_len / 2))

            for from_idx in range(0, len(quality) - seq_len + 1, seq_step):
                to_idx = from_idx + seq_len
                clips.append((video_path.name, transforms, from_idx, to_idx, False))
                if self.train and self.reverse:
                    clips.append((video_path.name, transforms, from_idx, to_idx, True))

        return pd.DataFrame(clips, columns=["Video", "Transforms", "From", "To", "Flip"])

    def __len__(self):
        return len(self.clips)

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError(f"list index {idx} out of range")

        if isinstance(idx, torch.Tensor):
            idx = idx.item()

        transforms = self.clips.Transforms[idx]
        transform_idx = torch.randint(0, len(transforms), ()) if (self.train and self.transform) else 0

        video = self.data_dir / self.clips.Video[idx] / transforms[transform_idx]
        logits, quality, preds = torch.load(video)

        from_idx = self.clips.From[idx]
        to_idx = self.clips.To[idx]
        x = logits[from_idx:to_idx]
        y = quality[from_idx:to_idx]
        p = preds[from_idx:to_idx]

        if self.clips.Flip[idx]:
            x = torch.flip(x, dims=[0])
            y = torch.flip(y, dims=[0])
            p = torch.flip(p, dims=[0])

        if self.normalize is not None:
            x = (x - self.std_mean[1]) / self.std_mean[0]

        if self.target_transform:
            y = self.target_transform(y)
        if self.label_transform:
            p = self.label_transform(p)

        return x, y, p


dataset = VideoQualityDataset(
    data_dir=path,
    dataset_name="US_VIDEOS_tran_250_playful-haze-2111",
    train=False,
    seq_len=128,
    #     seq_step=None,
    #     reverse=False,
    #     transform=False,
    #     normalize=False
)
# torch.Size([472])
# torch.Size([329])

for i in range(len(dataset)):
    print(f"{dataset.clips.From[i]} {dataset.clips.To[i]}")

print(len(dataset))
for i in range(len(dataset)):
    print(dataset[i][1].shape)

In [None]:
print(len(dataset))
for i in range(len(dataset)):
    print(dataset[i][1].shape)

In [None]:
dl = DataLoader(
    dataset=dataset,
    batch_size=8,
    num_workers=0,
    pin_memory=True,
    shuffle=False,
)

for d in dl:
    print(len(d))