# Deep Fake Detection. ResNet18

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import models
import torchvision.transforms as transforms

from torch.optim import lr_scheduler

from torch.utils.data import random_split, DataLoader
from dataset_handlers.resnet.resnet_feature_dataset import FeatureDataset

from sklearn.model_selection import KFold

import os
from PIL import Image
from matplotlib import pyplot as plt

El modelo resnet18 es un modelo de red neuronal convolucional que fue entrenado para clasificar imágenes en 1000 clases. En este caso, se utilizará el modelo pre-entrenado para clasificar imágenes en 2 clases: real y fake. Se trata de una arquitectura de red residual, que permite entrenar redes más profundas sin que se produzca el problema del desvanecimiento del gradiente. Para más información, consultar el artículo [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).

## Transfer Learning

In [74]:
classifier = nn.Sequential(
    nn.Linear(512, 128),
    nn.Tanh(),
    nn.Dropout(0.5),
    nn.BatchNorm1d(128),
    nn.Linear(128, 64),
    nn.Tanh(),
    nn.Dropout(0.5),
    nn.BatchNorm1d(64),
    nn.Linear(64, 2)
)

In [75]:
dataset = FeatureDataset(
    root_path='data/real_and_fake_restnet',
    transform=None
    )

In [76]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [78]:
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)

### Entrenamiento del modelo

In [79]:
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classifier.to(device)

Sequential(
  (0): Linear(in_features=512, out_features=128, bias=True)
  (1): Tanh()
  (2): Dropout(p=0.5, inplace=False)
  (3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): Linear(in_features=128, out_features=64, bias=True)
  (5): Tanh()
  (6): Dropout(p=0.5, inplace=False)
  (7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): Linear(in_features=64, out_features=2, bias=True)
)

In [80]:
num_epochs = 20

l1_factor = 0.0001
l2_factor = 0.001

k_folds = 1

In [81]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(classifier.parameters(), lr=0.0001, weight_decay=l2_factor)

In [82]:
def accuracy(data_loader, model):
    acc = 0
    for i, (image, label) in enumerate(data_loader, 1):
        image = image.to(device)
        label = label.to(device)

        output = classifier(image)
        acc += (torch.argmax(output, dim=1) == label).sum().item()

    return acc / len(data_loader.dataset)

In [83]:
def train_model(model, critereon, optimizer, train_loader, test_loader, num_epochs, l1_factor):
    acc_training_set = []
    acc_val_set = []

    for epoch in range(num_epochs):
        for i, (image, label) in enumerate(train_loader, 1):
            image = image.to(device)
            label = label.to(device)

            output = model(image)
            loss = critereon(output, label)

            l1_regularization = torch.tensor(0., requires_grad=False)
            for param in model.parameters():
                l1_regularization += torch.norm(param, 1)

            l1_regularization.requires_grad_(True)
            loss += l1_factor * l1_regularization

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 10 == 0:
                acc_training_set.append(accuracy(train_loader, model))
                acc_val_set.append(accuracy(test_loader, model))

                print('Epoch: {:2.0f}/{}, Batch: {:2.0f}, Loss: {:.6f}, Acc (train): {:.6f}, Acc (val): {:.6f}'
                    .format(epoch+1, num_epochs, i, loss.item(), acc_training_set[-1], acc_val_set[-1]))

In [84]:
train_model(classifier, criterion, optimizer, train_loader, test_loader, num_epochs, l1_factor)

Epoch:  1/20, Batch: 10, Loss: 0.878209, Acc (train): 0.492034, Acc (val): 0.486553
Epoch:  1/20, Batch: 20, Loss: 0.988793, Acc (train): 0.505515, Acc (val): 0.491443
Epoch:  2/20, Batch: 10, Loss: 0.996351, Acc (train): 0.478554, Acc (val): 0.491443
Epoch:  2/20, Batch: 20, Loss: 1.014334, Acc (train): 0.496936, Acc (val): 0.488998
Epoch:  3/20, Batch: 10, Loss: 0.928588, Acc (train): 0.506127, Acc (val): 0.491443
Epoch:  3/20, Batch: 20, Loss: 1.018288, Acc (train): 0.477941, Acc (val): 0.479218
Epoch:  4/20, Batch: 10, Loss: 0.975922, Acc (train): 0.516544, Acc (val): 0.530562
Epoch:  4/20, Batch: 20, Loss: 0.893362, Acc (train): 0.494485, Acc (val): 0.476773
Epoch:  5/20, Batch: 10, Loss: 0.919093, Acc (train): 0.479167, Acc (val): 0.501222
Epoch:  5/20, Batch: 20, Loss: 0.957953, Acc (train): 0.517157, Acc (val): 0.498778
Epoch:  6/20, Batch: 10, Loss: 0.952879, Acc (train): 0.506127, Acc (val): 0.501222
Epoch:  6/20, Batch: 20, Loss: 0.954302, Acc (train): 0.501225, Acc (val): 0