In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import time
import matplotlib.pyplot as plt 

class SeparableConv2d(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False):
        super(SeparableConv2d, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

class Block(nn.Module):
    
    def __init__(self, in_channels, out_channels, reps, stride=1, start_with_relu=True, grow_first=True):
        super(Block, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.sepconv1 = SeparableConv2d(in_channels, out_channels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.skip = None
        if out_channels != in_channels or stride != 1:
            self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False)
            self.skipbn = nn.BatchNorm2d(out_channels)
        
        self.rep = nn.Sequential()
        for i in range(reps - 1):
            self.rep.add_module('relu' + str(i+1), nn.ReLU(inplace=True))
            self.rep.add_module('sepconv' + str(i+1), SeparableConv2d(out_channels, out_channels, 3, stride=1, padding=1))
            self.rep.add_module('bn' + str(i+1), nn.BatchNorm2d(out_channels))
        
        self.sepconv2 = SeparableConv2d(out_channels, out_channels, 3, stride=stride, padding=1)

    def forward(self, inp):
        x = inp
        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x = self.sepconv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.rep(x)

        x = self.sepconv2(x)
        x = self.bn1(x)
        
        x += skip
        return x

class XceptionNet(nn.Module):
   
    def __init__(self, num_classes=1000):
        super(XceptionNet, self).__init__()
        
        
        self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(32, 64, 3, 1, 0, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
        self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
        self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)

        self.middle_flow = nn.Sequential(*[Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) for _ in range(8)])

        # Exit Flow
        self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False)
        
        self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(1536)
        
        self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(2048)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.fc = nn.Linear(2048, num_classes)
        
    def forward(self, x):
       
        x = self.conv1(x); x = self.bn1(x); x = self.relu(x)
        x = self.conv2(x); x = self.bn2(x); x = self.relu(x)
        x = self.block1(x); x = self.block2(x); x = self.block3(x)

        
        x = self.middle_flow(x)

       
        x = self.block12(x)
        x = self.conv3(x); x = self.bn3(x); x = self.relu(x)
        x = self.conv4(x); x = self.bn4(x); x = self.relu(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



class FaceForensicsDataset(Dataset):
    
    def __init__(self, data_root, transform=None):
        self.data_root = data_root
        self.transform = transform
        self.image_paths = []
        self.labels = []  # 0 pour Original (Réel), 1 pour Deepfake (Faux)

        for label_idx, folder_name in enumerate(['Original', 'Deepfake']):
            current_dir = os.path.join(data_root, folder_name)
            if os.path.exists(current_dir):
                for filename in os.listdir(current_dir):
                    if filename.endswith(('.jpg', '.png', '.jpeg')):
                        self.image_paths.append(os.path.join(current_dir, filename))
                        self.labels.append(label_idx)
        
        if not self.image_paths:
            raise FileNotFoundError(f"Aucune image trouvée dans {data_root}/Original ou {data_root}/Deepfake. Vérifiez le chemin.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, label



def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    """ Entraîne et valide le modèle en enregistrant les métriques. """
    best_acc = 0.0
    # Dictionnaires pour stocker l'historique des métriques
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    
    for epoch in range(num_epochs):
        start_time = time.time()
        print(f'\n--- Epoch {epoch+1}/{num_epochs} ---')

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  
                data_loader = train_loader
            else:
                model.eval()   
                data_loader = val_loader

            running_loss = 0.0
            all_labels = []
            all_preds = []

            for inputs, labels in tqdm(data_loader, desc=f'Phase {phase}'):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

            epoch_loss = running_loss / len(data_loader.dataset)
            epoch_acc = accuracy_score(all_labels, all_preds)

            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc)
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc)

            print(f'{phase.capitalize()} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), 'best_xception_deepfake_detector.pth')
                print(f"Meilleur modèle sauvegardé avec précision: {best_acc:.4f}")
        
        end_time = time.time()
        print(f"Temps écoulé pour l'époque {epoch+1}: {(end_time - start_time):.2f} secondes")
    
    return model, history # Retourne le modèle et l'historique

def evaluate_model(model, test_loader, device):
    
    model.eval()
    all_labels = []
    all_preds = []

    print("\n--- Évaluation des performances du modèle ---")
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Évaluation'):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='binary', pos_label=1) # Deepfake = 1
    recall = recall_score(all_labels, all_preds, average='binary', pos_label=1)
    f1 = f1_score(all_labels, all_preds, average='binary', pos_label=1)

    print("\n--- Métriques de Performance Finales ---")
    print(f"Accuracy (Précision globale): {accuracy:.4f}")
    print(f"Precision (Deepfake): {precision:.4f}")
    print(f"Recall (Sensibilité Deepfake): {recall:.4f}")
    print(f"F1 Score (Deepfake): {f1:.4f}")
    
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}



def plot_metrics(history):
    
    epochs = range(1, len(history['train_loss']) + 1)

    plt.figure(figsize=(12, 5))

    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b-', label='Perte d\'entraînement')
    plt.plot(epochs, history['val_loss'], 'r-', label='Perte de validation')
    plt.title('Évolution de la Perte (Loss) par Époque')
    plt.xlabel('Époque')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['train_acc'], 'b-', label='Précision d\'entraînement')
    plt.plot(epochs, history['val_acc'], 'r-', label='Précision de validation')
    plt.title('Évolution de la Précision (Accuracy) par Époque')
    plt.xlabel('Époque')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()



if __name__ == '__main__':
    
   
    DATA_ROOT = 'path/to/your/FaceForensics_Extracted_Dataset' # REMPLACEZ CE CHEMIN
    BATCH_SIZE = 16 
    NUM_EPOCHS = 15 
    LEARNING_RATE = 0.0001
    INPUT_SIZE = 299 
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Utilisation du périphérique : **{device}**")

    
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
            transforms.RandomRotation(10),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 
        ]),
        'test': transforms.Compose([
            transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ]),
    }

    
    print("\nChargement des données...")
    try:
        full_dataset = FaceForensicsDataset(DATA_ROOT, transform=data_transforms['train'])
    except FileNotFoundError as e:
        print(f"Erreur de chargement: {e}")
        print("Veuillez vérifier que le chemin DATA_ROOT pointe vers le dossier contenant 'Original' et 'Deepfake'.")
        exit()

    train_size = int(0.8 * len(full_dataset))
    val_size = int(0.1 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size, test_size]
    )
    
    val_dataset.dataset.transform = data_transforms['test']
    test_dataset.dataset.transform = data_transforms['test']

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    print(f"Taille Train: {len(train_dataset)}, Taille Val: {len(val_dataset)}, Taille Test: {len(test_dataset)}")

    model = XceptionNet(num_classes=1000) 
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2) 
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    
    trained_model, training_history = train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=NUM_EPOCHS)

   
    print("\n--- Affichage des graphiques d'évolution ---")
    plot_metrics(training_history)
    # 

    
    try:
        trained_model.load_state_dict(torch.load('best_xception_deepfake_detector.pth'))
        print("\nChargement du meilleur modèle sauvegardé pour l'évaluation finale.")
    except FileNotFoundError:
        print("\nAvertissement: Le meilleur modèle n'a pas été trouvé. Utilisation de l'état final du modèle.")

    final_metrics = evaluate_model(trained_model, test_loader, device)

ModuleNotFoundError: No module named 'torchvision'