In [1]:
# from os import path, scandir
# videos = {f.name: [] for f in scandir(path.join("./datasets/ffpp/real/", f'raw/videos')) if f.is_dir()}
# videos

In [2]:
# initialize accelerator and trackers (if enabled)
from os import makedirs,path,scandir
import pickle
import cv2
import json
from yacs.config import CfgNode as CN
from torch.utils.data import Dataset
from tqdm import tqdm
import logging
import random
import torch
# from src.datasets import FFPP,RPPG
from main import get_config,init_accelerator,set_seed,RPPG



class FFPP(Dataset):
    @staticmethod
    def get_default_config():
        C = CN()
        C.name = 'train'
        C.root_dir = './datasets/ffpp/'
        C.detection_level = 'video'
        C.train_ratio = 0.95
        C.types = ['REAL', 'DF']
        C.compressions = ['raw']
        C.dataset = "FFPP"
        return C

    def __init__(self, config,num_frames,clip_duration, transform=None, accelerator=None, split='train'):
        self.TYPE_DIRS = {
            'REAL': 'real/',
            # 'DFD' : 'data/original_sequences/actors/',
            'DF'  : 'DF/',
            'FS'  : 'FS/',
            'F2F' : 'F2F/',
            'NT'  : 'NT/',
            # 'FSH' : 'data/manipulated_sequences/FaceShifter/',
            # 'DFD-FAKE' : 'data/manipulated_sequences/DeepFakeDetection/',
        }
        self.name = config.name
        self.root = path.expanduser(config.root_dir)
        self.detection_level = config.detection_level
        self.types = config.types
        self.compressions = config.compressions
        self.num_frames = num_frames
        self.clip_duration = clip_duration
        self.split = split
        self.transform = transform

        # available clips per data
        self.video_list = []

        # stacking data clips
        self.stack_video_clips = []

        self._build_video_table(accelerator)
        self._build_video_list(accelerator)
        
    def _build_video_table(self, accelerator):
        self.video_table = {}

        progress_bar = tqdm(self.types, disable=not accelerator.is_local_main_process)
        for df_type in progress_bar:
            self.video_table[df_type] = {}
            for comp in self.compressions:
                video_cache = path.expanduser(f'./.cache/dfd-clip/videos/{df_type}-{comp}.pkl')
                if path.isfile(video_cache):
                    with open(video_cache, 'rb') as f:
                        videos = pickle.load(f)
                    self.video_table[df_type][comp] = videos
                    continue

                # subdir
                subdir = path.join(self.root, self.TYPE_DIRS[df_type], f'{comp}/videos')

                video_metas =  {}

                # video table
                for f in  scandir(subdir):
                    if '.avi' in f.name:
                        cap = cv2.VideoCapture(f.path)
                        fps = int(cap.get(cv2.CAP_PROP_FPS))
                        frames = round(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                        video_metas[f.name[:-4]] ={
                            "fps" : fps,
                            "frames" : frames,
                            "duration": frames/fps,
                            "path": f.path
                        }
                        cap.release()
                
                # description
                progress_bar.set_description(f"{df_type}: {comp}/videos")

                # caching
                if accelerator.is_local_main_process:
                    makedirs(path.dirname(video_cache), exist_ok=True)
                    with open(video_cache, 'wb') as f:
                        pickle.dump(video_metas, f)

                self.video_table[df_type][comp] = video_metas
        
    def _build_video_list(self, accelerator):
        self.video_list = []
        
        with open(path.join(self.root, 'splits', f'{self.split}.json')) as f:
            idxs = json.load(f)
            
        for df_type in self.types:
            for comp in self.compressions:
                adj_idxs = [i for inner in idxs for i in inner] if df_type == 'REAL' else ['_'.join(idx) for idx in idxs] + ['_'.join(reversed(idx)) for idx in idxs]

                for idx in adj_idxs:
                    if idx in self.video_table[df_type][comp]:
                        clips = int(self.video_table[df_type][comp][idx]["duration"]//self.clip_duration)
                        self.video_list.append((df_type, comp, idx, clips))
                    else:
                        accelerator.print(f'Warning: video {path.join(self.root, self.TYPE_DIRS[df_type], comp, "videos", idx)} does not present in the processed dataset.')

        # stacking up the amount of data clips for further usage
        self.stack_video_clips = [0]
        for _,_,_,i in self.video_list:
            self.stack_video_clips.append(self.stack_video_clips[-1] + i)
        self.stack_video_clips.pop(0)


    def __len__(self):
        return self.stack_video_clips[-1]
    
    def __getitem__(self,idx):
        result = self.get_dict(idx)
        return result["frames"],result["label"],result["mask"]


    def get_dict(self,idx):
        while(True):
            try:
                video_idx =  next(i for i,x in enumerate(self.stack_video_clips) if  idx < x)
                df_type, comp, video_name, clips = self.video_list[video_idx]
                video_meta = self.video_table[df_type][comp][video_name]
                video_offset_duration =  (idx - (0 if video_idx == 0 else self.stack_video_clips[video_idx-1]))*self.clip_duration

                # video frame processing
                frames = []
                cap = cv2.VideoCapture(video_meta["path"])
                # - frames per second
                video_sample_freq = video_meta["fps"]
                # - the amount of frames to skip
                video_sample_offset = int(video_sample_freq * video_offset_duration)
                # - the amount of frames for the duration of a clip
                video_clip_samples = int(video_sample_freq * self.clip_duration)
                # - the amount of frames to skip in order to meet the num_frames per clip.(excluding the head & tail frames )
                video_sample_stride = (video_clip_samples-1) / (self.num_frames - 1)
                # - fast forward to the the sampling start.
                cap.set(cv2.CAP_PROP_POS_FRAMES,video_sample_offset)
                # - fetch frames of clip duration
                next_sample_idx = 0
                for sample_idx in range(video_clip_samples):
                    ret, frame = cap.read()
                    if(ret):
                        if(sample_idx == next_sample_idx):
                            frames.append(torch.from_numpy(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB).transpose((2,0,1))))
                            next_sample_idx = int(round(len(frames) * video_sample_stride))
                    else:
                        raise NotImplementedError()
                frames = torch.stack(frames)


                # transformation
                if (self.transform):
                    frames = self.transform(frames)

                # padding and masking missing frames.
                mask = torch.tensor([1.] * len(frames) +
                                    [0.] * (self.num_frames - len(frames)), dtype=torch.bool)
                if len(frames) < self.num_frames:
                    diff = self.num_frames - len(frames)
                    padding = torch.zeros((diff, *frames.shape[1:]),dtype=torch.uint8)
                    frames = torch.concatenate((frames, padding))
                
                return {
                    "frames":frames,
                    "label": 0 if df_type == "REAL" else 1,
                    "mask":mask,
                }
            except Exception as e:
                logging.error(f"Error occur: {e}")
                idx = random.randrange(0,len(self))

c = get_config("./configs/mix.yml")

# x = FFPP(c.data.train[0],c.data.num_frames,c.data.clip_duration,lambda x: x,accelerator)


In [3]:
c.model

CfgNode({'architecture': 'ViT-B/16', 'decode_stride': 2, 'dropout': 0.0, 'out_dim': [180, 2], 'losses': [['log_softmax', 'kl_div'], ['auc_roc']]})

In [16]:
c.trainer.metrics
'options' in c.trainer.metrics[1].types[0]

False

In [None]:
# typ,cmp,idx,_ = x.video_list[5]
# x.video_table[typ][cmp][idx]


In [None]:
# frames,label,mask = x[7]
# (len(frames),label,len(mask))

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# plt.figure(figsize=(30,150))
# plt.imshow(np.stack(frames[:30].numpy().transpose((0,2,3,1)),axis=1).reshape((150,-1,3)))

In [73]:
for i in tqdm(range(len(x))):
    try:
        x[i]
    except Exception as e:
        print(f"Error Occur at {i}:{e}")


100%|██████████| 2059/2059 [05:21<00:00,  6.40it/s]
