In [2]:
import os
import random
import torch
import torchvision
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split as TTS

In [13]:
device = "cpu"

class frame():
  def __init__(self, imgpath, label, category, transform):
    self.imgpath = imgpath
    self.label = label
    self.cat = category

    self.X = transform(Image.open(imgpath).convert('RGB'))#.to(device)

    if self.cat == "fake": self.y = 1
    else: self.y = 0

  def describe(self):
    return self.imgpath, self.label, self.cat

  def values(self):
    return self.X, self.y

In [14]:
SIZE = 224

train_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(10),
    torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    torchvision.transforms.Resize((SIZE, SIZE), antialias = True)
    ]) #, torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((SIZE, SIZE), antialias = True)
    ])

def load(root, label, category, ftype, train_transform, test_transform, test_size = 0.2, sample = None):
  '''
  loads data as frame objects and train test split using the sampling and transform rules provided.
  '''
  path = os.path.join(root, label, category, ftype)
  train_frames = []
  test_frames = []

  dir = os.listdir(path)

  if sample: sample_size = sample
  else: sample_size = len(dir)

  indices = random.sample(range(len(dir)), sample_size)

  train_indices, test_indices = TTS(indices, test_size = test_size)

  for i in tqdm(train_indices):
    image_name = dir[i]
    new_frame = frame(imgpath = os.path.join(path, image_name),
                      label = label,
                      category = category,
                      transform = train_transform)
    train_frames.append(new_frame)

  for i in tqdm(test_indices):
    image_name = dir[i]
    new_frame = frame(imgpath = os.path.join(path, image_name),
                      label = label,
                      category = category,
                      transform = test_transform)
    test_frames.append(new_frame)
  return train_frames, test_frames

In [15]:
ROOT_PATH = "C:/Users/chuag/OneDrive - Nanyang Technological University/Desktop/BCG 4.2/FYP/code/data/FF++"

# DeepFakeDetection dataset
DFD_train, DFD_test = load(ROOT_PATH,"fake","DeepFakeDetection","faces",train_transform,test_transform,0.2,1000)
DF_train, DF_test = load(ROOT_PATH,"fake","Deepfakes","faces",train_transform,test_transform,0.2,1000)
F2F_train, F2F_test = load(ROOT_PATH,"fake","Face2Face","faces",train_transform,test_transform,0.2,1000)
SHIFTER_train, SHIFTER_test = load(ROOT_PATH,"fake","FaceShifter","faces",train_transform,test_transform,0.2,1000)
SWAP_train, SWAP_test = load(ROOT_PATH,"fake","FaceSwap","faces",train_transform,test_transform,0.2,1000)
NT_train, NT_test =  load(ROOT_PATH,"fake","NeuralTextures","faces",train_transform,test_transform,0.2,1000)
YT_train, YT_test = load(ROOT_PATH,"real","youtube","faces",train_transform,test_transform,0.2,2000)
ACTORS_train, ACTORS_test = load(ROOT_PATH,"real","actors","faces",train_transform,test_transform,0.2,1000)

training = DFD_train + DF_train + F2F_train + SHIFTER_train + SWAP_train + NT_train + YT_train + ACTORS_train
testing = DFD_test + DF_test + F2F_test + SHIFTER_test + SWAP_test + NT_test + YT_test + ACTORS_test

  0%|          | 0/800 [00:00<?, ?it/s]

100%|██████████| 800/800 [00:54<00:00, 14.55it/s]
100%|██████████| 200/200 [00:02<00:00, 70.80it/s]
100%|██████████| 800/800 [00:56<00:00, 14.27it/s]
100%|██████████| 200/200 [00:03<00:00, 59.35it/s]
100%|██████████| 800/800 [01:01<00:00, 12.98it/s]
100%|██████████| 200/200 [00:03<00:00, 53.23it/s]
100%|██████████| 800/800 [01:09<00:00, 11.49it/s]
100%|██████████| 200/200 [00:03<00:00, 56.82it/s]
100%|██████████| 800/800 [01:15<00:00, 10.63it/s]
100%|██████████| 200/200 [00:03<00:00, 57.30it/s]
100%|██████████| 800/800 [01:06<00:00, 12.06it/s]
100%|██████████| 200/200 [00:05<00:00, 39.67it/s]
100%|██████████| 1600/1600 [02:19<00:00, 11.43it/s]
100%|██████████| 400/400 [00:07<00:00, 53.06it/s]
100%|██████████| 800/800 [01:14<00:00, 10.76it/s]
100%|██████████| 200/200 [00:03<00:00, 54.26it/s]


In [16]:
random.shuffle(training)
random.shuffle(testing)