In [1]:
from sys import getsizeof

from common.data.amigos.transform import text_transform, audio_transform, video_transform, eeg_transform
from sampler import do_sample, TimedMoviepyAMIGOSSampler, TimedFastAMIGOSSampler
from common.data.video.utils import extract_frames
import cv2

In [None]:
%%time
%%capture --no-display
# ~ 4 minutes for my current subset of data.
path = "../../../resources/AMIGOS/"
sampler = TimedFastAMIGOSSampler(path + "sampled/", 5)
do_sample(path, path + "sampled/", sampler, 100, True)

#### With ProcessPoolExecutor
CPU times: user 6.16 s, sys: 4.12 s, total: 10.3 s <br>
Wall time: 4min 24s

#### With ThreadPoolExecutor
CPU times: user 6.08 s, sys: 2.99 s, total: 9.07 s <br>
Wall time: 4min 17s

In [None]:
%%time
%%capture --no-display
path = "../../../resources/AMIGOS/"
sampler = TimedMoviepyAMIGOSSampler(path + "sampled/", 5)
do_sample(path, path + "sampled/", sampler, 100)

In [None]:
%%time
%%capture --no-display
# Does not work too well.
path = "../../../resources/AMIGOS/"
sampler = TimedMoviepyAMIGOSSampler(path + "sampled/", 5)
do_sample(path, path + "sampled/", sampler, 100, True)

In [None]:
# CPU times: user 17.7 s, sys: 10.5 s, total: 28.1 s
# Wall time: 9min 54s

# Transforms
## Video Transform

In [None]:
t = video_transform(fps_map=(15, 8), means=None)
path = "../../../resources/AMIGOS/sampled/P01_9/0.mp4"
frames = extract_frames(cv2.VideoCapture(path))

In [None]:
res = t(frames, train=False, return_both=False)
res[0].shape

In [None]:

import torchvision.transforms.v2 as v2

v2.ToPILImage()(res[22]).show()

## Text Transform

In [None]:
t = text_transform()
text = "This is some bullshit"
t(text, train=False, return_both=False)

## Audio Transform

In [None]:
import torchaudio

path = "../../../resources/AMIGOS/sampled/P01_9/0.wav"
wavelength, sampling_rate = torchaudio.load(path)
t = audio_transform((sampling_rate, 10000))

res = t(wavelength, train=False, return_both=False)

In [None]:
sampling_rate

In [None]:
res.shape

In [None]:
from matplotlib import pyplot as plt
import torch


def plot_waveform(waveform, sample_rate):
    # Provided by pytorch documentation
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c + 1}")
    figure.suptitle("waveform")

In [None]:
plot_waveform(res, 10000)

In [None]:
plot_waveform(wavelength[0].unsqueeze(0), sampling_rate)

In [None]:
wavelength[0].unsqueeze(0).shape

## Test Dataset

In [2]:
from common.data.amigos.dataset import AMIGOSDataset

spec_file = "../../../resources/AMIGOS/sampled/AMIGOS_sampled.csv"
ds = AMIGOSDataset(
    spec_file,
    True,
    video_transform(),
    audio_transform((24000, 12000)),
    text_transform(),
    eeg_transform()
)



In [3]:
record = ds[1]

In [9]:
len(record.video)

77

In [None]:
p = ds[11]

In [None]:
d = ds[-1]

In [17]:
getsizeof(p.video[0])

80

In [19]:
p.video.shape

torch.Size([77, 3, 224, 224])

In [21]:
ds[12]

DatasetRecord(eeg=tensor([[[-4.3739e+01, -1.3664e+00, -3.9712e+00,  ...,  1.2572e+01,
           2.0011e+01,  5.5726e+05],
         [-3.3928e+01,  1.3070e+00, -3.3443e+00,  ...,  1.3200e+01,
           2.0995e+01,  5.5693e+05],
         [-4.5414e+01,  2.1131e-01, -8.5257e-01,  ...,  1.1024e+01,
           1.8541e+01,  5.5703e+05],
         ...,
         [-2.9891e+00,  1.7217e+00, -2.8688e-01,  ...,  9.5457e+00,
           2.4462e+01,  5.5398e+05],
         [-7.1921e+00, -9.6282e-01, -1.3444e-01,  ...,  7.1169e+00,
           2.1349e+01,  5.5340e+05],
         [-1.3094e+01, -3.8069e+00,  3.4489e+00,  ...,  4.9116e-01,
           1.9384e+01,  5.5284e+05]]], dtype=torch.float64), video=[Image([[[-1.6211, -1.6478, -1.6519,  ..., -1.5969, -1.6146, -1.5996],
        [-1.6370, -1.6596, -1.6691,  ..., -1.6174, -1.6205, -1.5808],
        [-1.6243, -1.6448, -1.6553,  ..., -1.6407, -1.6015, -1.5913],
        ...,
        [-1.6328, -1.6464, -1.6547,  ..., -1.6267, -1.6278, -1.6619],
        [-1.64