In [None]:
pip install --upgrade fsspec

In [None]:
import os
import torch
import pandas as pd
import random
import shutil
import numpy as np
from datasets import load_dataset, load_metric
from skimage import io
from torch.utils.data import Dataset
from PIL import ImageDraw, ImageFont, Image
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer

# Dataset building (skip if already done)

In [None]:
path = "/home/jovyan/hfactory_magic_folders/tooling_for_the_data_scientist/deepfakes_detection/images"

In [None]:
train_df = pd.read_csv('/home/jovyan/project/deepfakes-detection-3-mousketeers/data/train.csv')
test_df = pd.read_csv('/home/jovyan/project/deepfakes-detection-3-mousketeers/data/test.csv')

train_df['complete_image_id']= train_df.apply(lambda x: x.image_id +'.jpg', axis=1) 
test_df['complete_image_id']= test_df.apply(lambda x: x.image_id +'.jpg', axis=1) 

In [None]:
destination_train_fake = '/home/jovyan/project/deepfakes-detection-3-mousketeers/folder_normalized/train/fake/'
destination_train_not_fake = '/home/jovyan/project/deepfakes-detection-3-mousketeers/folder_normalized/train/not_fake/'

for file in os.listdir(path):
    file_path = path + '/' + file
    if file in list(train_df.complete_image_id):
        row = train_df[train_df['complete_image_id']==file]
        is_deep_fake = int(row['label'])
        if is_deep_fake == 1:
            shutil.copyfile(file_path, destination_train_fake + file, follow_symlinks = True)
        else:
            shutil.copyfile(file_path, destination_train_not_fake + file, follow_symlinks = True)

In [None]:
destination_test = '/home/jovyan/project/deepfakes-detection-3-mousketeers/folder_normalized/test/'
for file in os.listdir(path):
    file_path = path + '/' + file
    if file in list(test_df.complete_image_id):
        shutil.copyfile(file_path, destination_test + file, follow_symlinks = True)

# Preprocessing / Training

In [None]:
dataset = load_dataset("imagefolder", data_dir="/home/jovyan/project/deepfakes-detection-3-mousketeers/folder_normalized/train", split="train")
ds = dataset.train_test_split(test_size=0.2)

### Viewing few examples if needed

In [None]:
def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['label'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = ds['train'].filter(lambda ex: ex['label'] == label_id).shuffle(seed).select(range(examples_per_class))

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)

### Model training

In [None]:
labels = ds["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [None]:
from transformers import AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

In [None]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])

In [None]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [None]:
ds = ds.with_transform(transforms)

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=4,
    fp16=False,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=feature_extractor,
)

trainer.train()