In [None]:
# import warnings
# warnings.filterwarnings("ignore")

import gc 
import numpy as np
import pandas as pd
import itertools
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    confusion_matrix,
    classification_report,
    f1_score
)

from imblearn.over_sampling import RandomOverSampler
import accelerate
import evaluate
from datasets import Dataset, Image, ClassLabel
from transformers import (
    TrainingArguments,
    Trainer,
    ViTImageProcessor,
    ViTForImageClassification,
    DefaultDataCollator
)

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomRotation,
    RandomResizedCrop,
    RandomHorizontalFlip,
    RandomAdjustSharpness,
    Resize,
    ToTensor
)

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [5]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
image_dict = {}

from pathlib import Path
from tqdm import tqdm

file_names = {}
labels = {}


for file in sorted(Path("/home/uppercase/Workspace/Projects/Deepfake_Detection/dataset")):
    label  = str(file).split("/")[-2]
    labels.append(label)
    file_names.append(str(file))

print(len(file_names), len(labels))

df = pd.DataFrame.from_dict({"image" : file_names, "label": labels})

In [None]:
df.head()

In [None]:
df['label'].unique()

In [None]:
y = df[['label']]

df.drop(['label'], axis=1)
ros = RandomOverSampler(random_state=83)

df,y_resampled = ros.fit_resample(df, y)

del y
df['label'] = y_resampled

gc.collect()

print(df.shape)


In [None]:
df.head()

In [None]:
dataset = Dataset.from_pandas(df).cast_column("image", Image())

In [None]:
dataset[0]['image']

In [None]:
labels_subset = labels[:5]
print(labels_subset)

In [None]:
labels_list = ['Real', 'Fake']

label2id, id2label  = dict(), dict()

for i, label in enumerate(labels_list):
    label2id[label] = i
    id2label[i] = label

print(label2id)
print(id2label)

In [None]:
Class_labels = ClassLabel(num_classes=len(label2id), names=labels_list)

def map_label2id(example):
    example['label'] = label2id[example['label']]
    return example

dataset = dataset.map(map_label2id, batched=True)

dataset = dataset.cast_column("label", Class_labels)

dataset = dataset.train_test_split(test_size=0.4,shufflfe = True, stratify_by_column="label")

train_data = dataset['train']
test_data = dataset['test']

In [None]:
model = "dima806/deepfake_vs_real_image_detection"

processor = ViTImageProcessor.from_pretrained(model)

image_mean , image_std = processor.image_mean, processor.image_std
size = processor.size['height']
print(image_mean, image_std, size)

normalize = Normalize(mean=image_mean, std=image_std)

_train_transforms = Compose([
    Resize((size, size)),
    RandomRotation(90),
    RandomAdjustSharpness(2),
    ToTensor(),
    normalize])


_val_transforms = Compose([
    Resize((size, size)),
    ToTensor(),
    normalize])

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

In [None]:
train_data.with_transform(train_transforms)
test_data.with_transform(val_transforms)

In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example['pixel_values'] for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {'pixel_values': pixel_values, 'labels': labels}

In [None]:
model = ViTForImageClassification.from_pretrained(model_str, num_labels=len(labels_list))

model.config.label2id = label2id
model.config.id2label = id2label

print(model.num_parameters(only_trainable=True) / 1e6)

In [None]:
accuracy = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    predictions = eval_pred,predictions
    labels = eval_pred.label

    predicted_labels = np.argmax(predictions, axis=1)
    acc_score = accuracy_score(predictions = predicted_labels, references=labels)['accuracy']

    return {
        "accuracy":acc_score
    }

In [None]:
metric_name = "accuracy"
model_name  = "deepfake_detection"

num_train_epochs = 5

args = TrainingArguments(
    output_dir=f"{model_name}",
    logging_dir= './logs',
    evaluation_strategy = "epoch",
    learning_rate=1e-6
    save_strategy = "epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=8,
    num_train_epochs=num_train_epochs,
    weight_decay=0.02,
    warmup_steps=50,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    save_total_limit=1,
    report_to=["none"]
)


In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
outputs = trainer.predict(test_data)
print(outputs.metrics)

In [None]:
y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(-1)


def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues, figsize=(10, 8)):
    plt.figure(figsize=figsize)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)
    fmt = '.0f'
    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')

    plt.tight_layout()
    plt.show()

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro')

print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")

if len(labels_list) <= 150:
    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, labels_list, figsize=(8, 6))

print()
print("Classification report:")
print()
print(classification_report(y_true, y_pred, target_names=labels_list, digits=4))
