In [1]:
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
from tqdm.notebook import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
class DeepFake(nn.Module):
    def __init__(self):
        super().__init__()
        self.inception = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
        self.fc1 = nn.Linear(1000, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p = 0.2)
        self.fc2 = nn.Linear(256, 1)
        
    def forward(self, x : torch.Tensor):
        output, aux = self.inception(x)
        output = self.fc1(output)
#         print('fc1 shape', output.shape)
        output = self.relu(output)
        output = self.dropout(output)
        output = self.fc2(output)
#         print('fc2 shape', output.shape)
        output = torch.sigmoid(output)
#         print('final shape', output)
        return output

In [4]:
from collections import defaultdict

def def_val():
  return 0

outer_path = '/kaggle/input/celeb-df-ga/processed-dataset-ga/data/'
all_labels = defaultdict(def_val)
subfolders = ['Celeb-real','Celeb-synthesis','YouTube-real']
for subfolder in subfolders:
  for file in os.listdir('/kaggle/input/celeb-df-ga/processed-dataset-ga/data/' + subfolder):
    if subfolder=='Celeb-real' or subfolder=='YouTube-real':
      all_labels[outer_path + subfolder + '/' + file] = 0
    else:
      all_labels[outer_path + subfolder + '/' + file] = 1

In [5]:
all_video_ids = []
all_video_ids.extend(os.listdir('/kaggle/input/celeb-df-ga/processed-dataset-ga/data/Celeb-real'))
all_video_ids.extend(os.listdir('/kaggle/input/celeb-df-ga/processed-dataset-ga/data/Celeb-synthesis'))
all_video_ids.extend(os.listdir('/kaggle/input/celeb-df-ga/processed-dataset-ga/data/YouTube-real'))

test_df = pd.read_csv('/kaggle/input/test-files/test.txt', sep=' ', header=None, names=['label','video_id'])

test_video_ids = test_df['video_id']
test_video_ids = [outer_path + id[:-4] for id in test_video_ids]

train_video_ids = list(set(all_video_ids).difference(set(test_video_ids)))
print(len(train_video_ids),len(test_video_ids),len(all_video_ids))

1203 100 1203


In [6]:
train_video_ids = list(set(all_labels.keys()) - set(test_video_ids))

In [7]:
preprocess = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [8]:
X_train_paths = []
for video in train_video_ids:
    frames = os.listdir(video)
    X_train_paths.extend([str(video) + '/' + i for i in frames])
    
y_train = []
for frames in X_train_paths:
    video_id = frames.split('/')
    video_path = '/'.join(video_id[:-1])
    y_train.append(all_labels[video_path])
# y_train = y_train[:1000]
y_train = torch.FloatTensor(y_train)
y_train = y_train.unsqueeze(1)

# X_train_paths = X_train_paths[:1000]


X_test_lens = []
X_test_paths = []
for video in test_video_ids:
    frames = os.listdir(video)
    X_test_lens.append(len(frames))
    X_test_paths.extend([str(video) + '/' + i for i in frames])
# X_test = []
# for path in X_test_paths[:302]:
#     input_image = Image.open(path)
#     input_tensor = preprocess(input_image)
#     X_test.append(input_tensor)
# X_test = torch.stack(X_test)
# print(X_test_lens)

y_test = []
for frames in X_test_paths:
    video_id = frames.split('/')
    video_path = '/'.join(video_id[:-1])
    y_test.append(all_labels[video_path])
y_test = torch.FloatTensor(y_test)
y_test = y_test.unsqueeze(1)

print(sum(X_test_lens[:10]))

250


In [9]:
y_train = y_train.to(device)
# X_test = X_test.to(device)
y_test = y_test.to(device)

In [10]:
class DFDataloader(Dataset):
  """
  This is a custom dataset class. It can get more complex than this, but simplified so you can understand what's happening here without
  getting bogged down by the preprocessing
  """
  def __init__(self, X_train_paths, y):
    self.y = y
    self.X_train_paths = X_train_paths
    if len(self.X_train_paths) != len(self.y):
      print(len(self.X_train_paths), len(self.y))
      raise Exception("The length of X does not match the length of Y")

  def __len__(self):
    return len(self.y)

  def __getitem__(self, index):
    # note that this isn't randomly selecting. It's a simple get a single item that represents an x and y
    input_image = Image.open(self.X_train_paths[index])
    input_tensor = preprocess(input_image)
    _x = input_tensor
    _y = self.y[index]
    return _x, _y

In [11]:
model = DeepFake()
model.to(device)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

DeepFake(
  (inception): Inception3(
    (Conv2d_1a_3x3): BasicConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2a_3x3): BasicConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2b_3x3): BasicConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (Conv2d_3b_1x1): BasicConv2d(
      (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_4a_3x3): Basi

In [12]:
torch.sum(y_test)

tensor(1550., device='cuda:0')

In [13]:
to_train = ['fc2.bias', 'fc2.weight', 'fc1.bias', 'fc1.weight']
for name, param in model.named_parameters():
    param.requires_grad = True if name in to_train else False

In [14]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr = 0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

epochs = 30
for epoch in range(epochs):
    model.train()
    losses = []
    train_dataloader = iter(DataLoader(DFDataloader(X_train_paths, y_train), batch_size=256, shuffle=True))    
    for X_train_dl, y_train_dl in tqdm(train_dataloader):
        X_train_dl = X_train_dl.to(device)
        y_train_dl = y_train_dl.to(device)
        y_pred = model(X_train_dl)
        loss = loss_fn(y_pred, y_train_dl)
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    avg_loss = sum(losses)/len(losses)
    print('Avg Loss: ', avg_loss)
    
    calculated_mean_scores = []
    actual_scores = []
    with torch.no_grad():
        test_dataloader = iter(DataLoader(DFDataloader(X_test_paths, y_test), batch_size=25, shuffle=False))
        for X_test_dl, y_test_dl in test_dataloader:
            X_test_dl = X_test_dl.to(device)
            y_test_dl = y_test_dl.to(device)
            y_pred_test = model(X_test_dl)
#             print(y_test_dl)
            loss_test = loss_fn(y_pred_test, y_test_dl)
            video_pred_scores = y_pred_test.squeeze(1)
            calculated_mean_scores.append(torch.mean(video_pred_scores).item())
            actual_scores.append(y_test_dl[0].squeeze().item())
#         print(calculated_mean_scores)
#         print(actual_scores)
        calculated_mean_scores = [1 if score>=0.5 else 0 for score in calculated_mean_scores]
        acc = torch.sum(torch.eq(torch.FloatTensor(calculated_mean_scores),torch.FloatTensor(actual_scores))).item()/len(actual_scores)
        print('accuracy:', acc)
        scheduler.step(acc)

  0%|          | 0/108 [00:00<?, ?it/s]

Avg Loss:  0.5585768962347949


KeyboardInterrupt: 