In [4]:
import pandas as pd
import torch
from torch.utils.data import Dataset

import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from pytorchvideo.data.encoded_video import EncodedVideo
from transformers import BertTokenizer
from torchvision.transforms import Compose, Lambda, Resize, Normalize, ColorJitter

from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
)

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample
)

In [13]:
class DeepFakeDataset(Dataset):
    
    def __init__(self, video_path_file, text_csv_file, text_transforms=None, video_transforms=None, num_frames=8, sampling_rate=8, frames_per_second=30):
        self.video_annotation = pd.read_csv(video_path_file)
        self.text_df = pd.read_csv(text_csv_file)
        self.text_transforms = text_transforms
        self.video_transforms = video_transforms
        self.num_frames = num_frames

    def __len__(self):
        return len(self.video_annotation)

    def __getitem__(self, index):

        video_path = self.video_annotation.iloc[index]['video_path']
        label = self.video_annotation.iloc[index]['label']
        text = self.text_df.iloc[index]['text']

        try:
            # Load video using PyTorchVideo
            video = EncodedVideo.from_path(video_path)

            # Get video duration and calculate the step size for frame sampling
            duration = video.duration
            step = duration / self.num_frames
            print(f'Video length: {duration}')

            # Sample frames at regular intervals
            video_data = []
            for i in range(self.num_frames):
                start_sec = i * step
                end_sec = start_sec + step
                clip = video.get_clip(start_sec=start_sec, end_sec=end_sec)
                print(f'clip shape: {clip['video'].shape}')
                transformed_clip = self.video_transforms(clip['video'])
                print(f'Transformed Clip: {transformed_clip.shape}')
                video_data.append(transformed_clip)

            # Stack the sampled frames
            video_data = torch.stack(video_data)
            # print(video_data)
            print(f'Video_data shape: {video_data.shape}')
            print('--------------------------------------------\n')
 

        except Exception as e:
            print(f'Error Processing video {video_path}: {e}')
            # print(f'Clip Duration: {clip_duration}')
            # print(f'Video Duration: {video.duration}')

        # Apply text transforms
        if self.text_transforms:
            text_data = self.text_transforms(text)
        else:
            text_data = text

        return {
            'video': video_data,
            'text': text_data,
            'label': torch.tensor(label, dtype=torch.long)
        }

In [14]:
# Create instances of text and video transforms
text_transforms = Compose([
    BertTokenizer.from_pretrained('bert-base-uncased'),
    # Add more text transformations as needed
])

side_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
crop_size = 256
num_frames = 30


video_transforms = Compose(
        [
            UniformTemporalSubsample(1),
            Lambda(lambda x: x/255.0),
            NormalizeVideo(mean, std),
            ShortSideScale(size=side_size),
            CenterCropVideo(crop_size=(crop_size, crop_size))
        ]
    )

# Create an instance of the dataset
video_path_file = '../annotations/video_train_path.csv'
text_csv_file = '../annotations/text_train.csv'

dataset = DeepFakeDataset(
                            video_path_file=video_path_file,
                            text_csv_file=text_csv_file,
                            text_transforms=text_transforms,
                            video_transforms=video_transforms,
                            num_frames=num_frames
                        )

In [None]:
# Create a dataloader
batch_size = 8
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)