In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
import copy
from PIL import Image 

In [None]:
# Avoid OOM errors by setting GPU Memory Consumption Growth
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
import torch

# Set memory growth on all GPUs (if available)
if torch.cuda.is_available():
    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    
    # Print the list of available GPUs
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    
    # Set memory growth on all GPUs
    for i in range(num_gpus):
        torch.cuda.set_per_process_memory_fraction(0.8, i)  # Set 90% of GPU memory as usable
        torch.cuda.set_device(i)
        torch.cuda.empty_cache()  # Clear the GPU cache to avoid OOM
        torch.backends.cudnn.benchmark = True  # Enable cuDNN benchmark for improved performance

# Check for available GPUs
print("Available GPUs:")
for i in range(num_gpus):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)} with {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB total memory")


In [None]:
# Define data directories
test_dir = '/home/ronie/Programs/Projects/2- Deep_Fake/Test'
train_dir = '/home/ronie/Programs/Projects/2- Deep_Fake/Train'
val_dir = '/home/ronie/Programs/Projects/2- Deep_Fake/Validation'


In [None]:
# Define parameters
input_shape = (224, 224)
batch_size = 32
num_classes = 2  # Binary classification
epochs = 5


In [None]:
# Define data transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(input_shape),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(input_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


In [None]:
# Load data
image_datasets = {
    'train': datasets.ImageFolder(train_dir, data_transforms['train']),
    'val': datasets.ImageFolder(val_dir, data_transforms['val']),
    'test': datasets.ImageFolder(test_dir, data_transforms['test'])
}

dataloaders = {
    'train': DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True, num_workers=4),
    'val': DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=False, num_workers=4),
    'test': DataLoader(image_datasets['test'], batch_size=batch_size, shuffle=False, num_workers=4)
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
class_names = image_datasets['train'].classes


In [None]:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out

class DeepfakeDetectionModel(nn.Module):
    def __init__(self, num_blocks, num_classes=2):
        super(DeepfakeDetectionModel, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.residual_layers = self._make_layer(64, num_blocks)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)
    
    def _make_layer(self, out_channels, num_blocks, stride=1):
        layers = []
        for _ in range(num_blocks):
            layers.append(ResidualBlock(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.residual_layers(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

# Parameters
input_shape = (3, 224, 224)  # Example input shape for image data
num_classes = 2  # Binary classification for deepfake detection
num_residual_blocks = 30

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Build and compile the model
model = DeepfakeDetectionModel(num_blocks=num_residual_blocks, num_classes=num_classes).to(device)


# Print the model summary
print(model)

# Freeze all layers except the last one
for name, param in model.named_parameters():
    if 'fc' not in name:
        param.requires_grad = False

# Set up the optimizer and loss function
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
criterion = nn.CrossEntropyLoss()



In [None]:
'''# Define the ResNet model
def build_resnet(num_classes):
    model = models.resnet50(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

model = build_resnet(num_classes).to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
'''

In [None]:
# Callback for tracking max validation accuracy
class MaxValidationAccuracyCallback:
    def __init__(self):
        self.max_val_acc = 0

    def __call__(self, val_acc):
        if val_acc > self.max_val_acc:
            self.max_val_acc = val_acc
        print(f"Max validation accuracy so far is {self.max_val_acc * 100:.2f}")






# Callback for tracking max test accuracy
class MaxTestAccuracyCallback:
    def __init__(self):
        self.max_test_acc = 0

    def __call__(self, test_acc):
        if test_acc > self.max_test_acc:
            self.max_test_acc = test_acc
        print(f"Max test accuracy so far is {self.max_test_acc * 100:.2f}")



max_val_acc_callback = MaxValidationAccuracyCallback()
max_test_acc_callback = MaxTestAccuracyCallback()

In [None]:
# Training loop
def train_model(model, criterion, optimizer, num_epochs=epochs):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    history_acc = []
    history_val_acc = []
    history_loss = []
    history_val_loss = []

    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            

            # Deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

            # Callback for max validation accuracy
            if phase == 'val':
                max_val_acc_callback(epoch_acc.item())
                     
            


            # Store the metrics
            if phase == 'train':
                history_acc.append(epoch_acc.cpu().numpy())
                history_loss.append(epoch_loss)
            else:
                history_val_acc.append(epoch_acc.cpu().numpy())
                history_val_loss.append(epoch_loss)

        print()


    print(f'Best val Acc: {best_acc:.4f}')
    model.load_state_dict(best_model_wts)
    return model, history_acc, history_val_acc, history_loss, history_val_loss


In [None]:
# Fine-tuning the model
for param in model.parameters():
    param.requires_grad = False

for param in model.parameters():
    param.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=0.0001)
model, history_acc, history_val_acc, history_loss, history_val_loss = train_model(model, criterion, optimizer, num_epochs=epochs)


In [None]:
# Evaluate the model on the test set
def evaluate_model(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)  

        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)
    print(f'Test Loss: {epoch_loss:.4f} Test Acc: {epoch_acc:.4f}')

    max_test_acc_callback(epoch_acc.item())

    return epoch_loss, epoch_acc

test_loss, test_acc = evaluate_model(model, dataloaders['test'], criterion)


In [None]:
# Example usage for prediction
def predict_image(image_path, model):
    transform = data_transforms['test']
    img = Image.open(image_path)
    img = transform(img).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        output = model(img)
        _, preds = torch.max(output, 1)
        return 'REAL IMAGE' if preds.item() == 1 else 'FAKE IMAGE'

image_path = '/home/ronie/Programs/Projects/2- Deep_Fake/Test/Fake/fake_5448.jpg'
print(predict_image(image_path, model))


In [None]:
# Plotting training history
# Note: In this example, we haven't stored the history during training as in TensorFlow.
# For simplicity, we'll just demonstrate plotting dummy data here.

plt.plot(history_acc, label='Train Accuracy')
plt.plot(history_val_acc, label='Validation Accuracy')
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.show()


plt.plot(history_loss, label='Train Loss')
plt.plot(history_val_loss, label='Validation Loss')
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.show()


In [None]:
'''import torch
from torchviz import make_dot

# Create a batch of dummy inputs
batch_size = 1  # or any batch size you need
dummy_input = torch.randn(batch_size, 3, 224, 224).to(device)

# Forward pass through the model
output = model(dummy_input)

# Create the graph for visualization
graph = make_dot(output, params=dict(model.named_parameters()))

# Save the graph as a PDF
graph.render("model_architecture", format="pdf")

'''