In [5]:
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from synthweave.utils.datasets import get_datamodule, get_dataset
from src.pipe import ImagePreprocessor, AudioPreprocessor

In [12]:
ds_kwargs = {
    "data_dir": "./processed_data/DeepSpeak_v1_1",
    "preprocessed": True,
    "sample_mode": "sequence",
}

train = get_dataset("DeepSpeak_v1_1", split="train", **ds_kwargs)
dev = get_dataset("DeepSpeak_v1_1", split="dev", **ds_kwargs)
test = get_dataset("DeepSpeak_v1_1", split="test", **ds_kwargs)

In [3]:
# vid_proc = ImagePreprocessor(window_len=4, step=1, head_pose_dir='../../../models/head_pose')
# aud_proc = AudioPreprocessor(window_len=4, step=1)

# ds_kwargs = {
#     'video_processor': vid_proc, 'audio_processor': aud_proc, 'mode': 'full'
# }

# ds = get_dataset("DeepSpeak_v1_1", split="train", **ds_kwargs)

### DATALOADER

In [9]:
dm = get_datamodule(
    "DeepSpeak_v1_1",
    batch_size=32,
    dataset_kwargs=ds_kwargs,
    sample_mode="single",  # single, sequence
    clip_mode="id",  # 'id', 'idx'
    clip_to=2,  # 'min', int
    clip_selector="random",  # 'first', 'random'
)

# Example: max first 2 single window samples per id in batch

dm.setup("fit")

In [10]:
train_loader = dm.train_dataloader()

In [11]:
sample = next(iter(train_loader))

### EXAMPLE BATCH

In [12]:
sample["video"].shape

torch.Size([32, 3, 112, 112])

In [13]:
sample["audio"].shape

torch.Size([32, 1, 64000])

In [14]:
sample["metadata"]["label"]

tensor([1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1,
        1, 0, 1, 1, 1, 1, 0, 1])

In [15]:
dm.train_dataset.encoders["label"].inverse_transform(sample["metadata"]["label"])

array(['1', '1', '1', '1', '1', '1', '0', '1', '1', '1', '0', '0', '1',
       '1', '1', '1', '0', '1', '1', '1', '1', '0', '1', '1', '1', '0',
       '1', '1', '1', '1', '0', '1'], dtype='<U1')

In [16]:
sample["metadata"]["av"]

tensor([1, 2, 1, 1, 2, 2, 0, 2, 1, 1, 0, 0, 2, 2, 2, 2, 0, 2, 1, 2, 2, 0, 1, 1,
        2, 0, 2, 1, 1, 1, 0, 2])

In [17]:
dm.train_dataset.encoders["av"].inverse_transform(sample["metadata"]["av"])

array(['01', '10', '01', '01', '10', '10', '00', '10', '01', '01', '00',
       '00', '10', '10', '10', '10', '00', '10', '01', '10', '10', '00',
       '01', '01', '10', '00', '10', '01', '01', '01', '00', '10'],
      dtype='<U2')