In [None]:
import csv
import itertools
import pathlib
import shutil
from math import ceil, sqrt
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import pyrootutils
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from skimage.metrics import structural_similarity
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from torchvision.io import read_image
from tqdm.notebook import tqdm

root = pyrootutils.setup_root(search_from=".", indicator=".project-root", pythonpath=True)

from src.data.components.dataset import FetalBrainPlanesDataset, USVideosDataset
from src.data.components.transforms import Affine, HorizontalFlip
from src.data.utils.utils import group_split, show_numpy_images, show_pytorch_images
from src.models.fetal_module import FetalLitModule

data_dir = root / "data"
root

In [None]:
class VideoQualityDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        dataset_name: str = "US_VIDEOS",
        train: bool = True,
        window_size: int = 32,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ):
        self.data_dir = Path(data_dir) / dataset_name / "data" / ("train" if train else "test")
        self.window_size = window_size
        self.clips = self.load_clips()
        self.transform = transform
        self.target_transform = target_transform

    def load_clips(self):
        clips = []
        for video in sorted(self.data_dir.iterdir()):
            _, quality = torch.load(video)
            window_size = self.window_size or len(quality)
            for i in range(len(quality) - window_size + 1):
                clips.append((video.name, i, i + window_size))

        return pd.DataFrame(clips, columns=["Video", "From", "To"])

    def __len__(self):
        return len(self.clips)

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError("list index out of range")

        video = self.data_dir / self.clips.Video[idx]
        logits, quality = torch.load(video)

        from_idx = self.clips.From[idx]
        to_idx = self.clips.To[idx]
        x = logits[from_idx:to_idx]
        y = quality[from_idx:to_idx]

        if self.transform:
            x = self.transform(x)
        if self.target_transform:
            y = self.target_transform(y)

        return x, y


dataset = VideoQualityDataset(
    data_dir=data_dir,
    dataset_name="US_VIDEOS",
    window_size=0,
    train=False,
)
print(len(dataset))

x, y = dataset[0]
print(x.shape, x.dtype)
print(y.shape, y.dtype)

In [None]:
for i in tqdm(range(len(dataset))):
    dataset[i]

In [None]:
data_train, data_val = group_split(
    dataset=dataset,
    test_size=0.1,
    groups=dataset.clips.Video,
    random_state=42,
)
print(len(data_train))
print(len(data_val))

In [None]:
dataloader = DataLoader(
    dataset=dataset,
    batch_size=2,
)

for batch in dataloader:
    x, y = batch
    print(x.shape)
    print(y.shape)
    break

In [None]:
y = torch.tensor([[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], [[3.0, 3.0, 3.0], [4.0, 4.0, 4.0]]])
print(y.shape)
print(y)

y = y.reshape(-1, 3)
print(y.shape)
print(y)

lin = torch.nn.Linear(3, 1, bias=False)
lin.weight.data.fill_(1)
y = lin(y)
print(y)

y.reshape(2, 2)