In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
import time
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score, confusion_matrix

# Import the dataset utilities from your file

from utils.dataset_new import load_dataset_and_create_loaders


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class MobileNetV3DeepfakeDetector(nn.Module):
    def __init__(self, pretrained=True, freeze_features=False):
        super(MobileNetV3DeepfakeDetector, self).__init__()
        
        # Load pre-trained MobileNetV3 Large model
        if pretrained:
            self.model = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
        else:
            self.model = mobilenet_v3_large()
        
        # Freeze feature extraction layers if specified
        if freeze_features:
            for param in self.model.features.parameters():
                param.requires_grad = False
        
        # Replace the classifier with a custom one for binary classification
        in_features = self.model.classifier[0].in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.Hardswish(),
            nn.Dropout(p=0.2),
            nn.Linear(128, 64),
            nn.Hardswish(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 1)  # Binary classification (fake or real)
        )
        
    def forward(self, x):
        return self.model(x)


In [3]:
class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=5, min_delta=0, verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
        self.verbose = verbose
        
    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f'Early stopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print('Early stopping triggered')
        return self.early_stop

In [4]:
from sklearn.metrics import roc_curve


def train_one_epoch(model, loader, criterion, optimizer, device):
    """Train the model for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Use tqdm for progress bar
    progress_bar = tqdm(loader, desc="Training", leave=False)
    
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs).squeeze()
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Update statistics
        running_loss += loss.item() * inputs.size(0)
        
        # Calculate accuracy
        predicted = (outputs > 0.5).float()
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })
    
    # Calculate epoch statistics
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc

def validate(model, loader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels)
            
            # Update statistics
            running_loss += loss.item() * inputs.size(0)
            
            # Calculate accuracy
            predicted = (outputs > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            
            # Store predictions and labels for metrics
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate epoch statistics
    val_loss = running_loss / len(loader.dataset)
    val_acc = 100 * correct / total
    
    return val_loss, val_acc, all_preds, all_labels

def test(model, loader, criterion, device):
    """Test the model and calculate various metrics"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Testing"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels)
            
            # Update statistics
            running_loss += loss.item() * inputs.size(0)
            
            # Calculate accuracy
            predicted = (outputs > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            
            # Store predictions and labels for metrics
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    test_loss = running_loss / len(loader.dataset)
    test_acc = 100 * correct / total
    
    # Convert to numpy arrays for sklearn
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Calculate ROC curve and AUC
    fpr, tpr, _ = roc_curve(all_labels, all_preds)
    roc_auc = auc(fpr, tpr)
    
    # Calculate Precision-Recall curve and AP
    precision, recall, _ = precision_recall_curve(all_labels, all_preds)
    ap = average_precision_score(all_labels, all_preds)
    
    # Calculate confusion matrix
    binary_preds = (all_preds > 0.5).astype(int)
    conf_matrix = confusion_matrix(all_labels, binary_preds)
    
    results = {
        'test_loss': test_loss,
        'test_acc': test_acc,
        'roc_auc': roc_auc,
        'ap': ap,
        'fpr': fpr,
        'tpr': tpr,
        'precision': precision,
        'recall': recall,
        'confusion_matrix': conf_matrix
    }
    
    return results

def plot_metrics(results, save_path=None):
    """Plot ROC curve, PR curve, and confusion matrix"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot ROC curve
    axes[0].plot(results['fpr'], results['tpr'], color='blue', lw=2,
                label=f'ROC curve (AUC = {results["roc_auc"]:.2f})')
    axes[0].plot([0, 1], [0, 1], color='gray', linestyle='--')
    axes[0].set_xlim([0.0, 1.0])
    axes[0].set_ylim([0.0, 1.05])
    axes[0].set_xlabel('False Positive Rate')
    axes[0].set_ylabel('True Positive Rate')
    axes[0].set_title('Receiver Operating Characteristic')
    axes[0].legend(loc="lower right")
    
    # Plot Precision-Recall curve
    axes[1].plot(results['recall'], results['precision'], color='green', lw=2,
                label=f'PR curve (AP = {results["ap"]:.2f})')
    axes[1].set_xlim([0.0, 1.0])
    axes[1].set_ylim([0.0, 1.05])
    axes[1].set_xlabel('Recall')
    axes[1].set_ylabel('Precision')
    axes[1].set_title('Precision-Recall Curve')
    axes[1].legend(loc="lower left")
    
    # Plot confusion matrix
    cm = results['confusion_matrix']
    im = axes[2].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    axes[2].set_title('Confusion Matrix')
    tick_marks = np.arange(2)
    axes[2].set_xticks(tick_marks)
    axes[2].set_yticks(tick_marks)
    axes[2].set_xticklabels(['Real', 'Fake'])
    axes[2].set_yticklabels(['Real', 'Fake'])
    
    # Add text annotations to confusion matrix
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            axes[2].text(j, i, format(cm[i, j], 'd'),
                        ha="center", va="center",
                        color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        print(f"Metrics plot saved to {save_path}")
    
    plt.show()

In [6]:
# Set device


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load dataset (use max_samples=5000 for quick development, remove for full dataset)
dataset_name = "xingjunm/WildDeepfake"

# Using the convenience function:
# train_loader, val_loader, test_loader = load_dataset_and_create_loaders(
#     dataset_name,
#     streaming=True,
#     max_samples=1000,  # Optional limit
#     batch_size=16
# )
# # Create MobileNetV3-Large model
# model = MobileNetDeepfakeDetector(pretrained=True, freeze_backbone=False)

Using device: cuda


In [7]:
# # Train the model
# trained_model = train_model(
#     model,
#     train_loader,
#     val_loader,
#     num_epochs=20,
#     learning_rate=0.0001,
#     device=device,
#     mixed_precision=True  # Use mixed precision for faster training
# )

# # Save the final model
# torch.save(trained_model.state_dict(), 'mobilenet_deepfake_final.pt')

# # Evaluate on test set
# model.eval()
# test_correct = 0
# test_total = 0

# with torch.no_grad():
#     for images, labels in test_loader:
#         images, labels = images.to(device), labels.to(device)
        
#         outputs = model(images)
#         predicted = (torch.sigmoid(outputs.squeeze()) > 0.5).float()
#         test_total += labels.size(0)
#         test_correct += (predicted == labels).sum().item()

# test_acc = 100 * test_correct / test_total
# print(f'Test Accuracy: {test_acc:.2f}%')

In [8]:
pretrained = True
streaming = True
max_samples = 1000  # For quick development, remove for full dataset
batch_size = 16
num_workers = 4
learning_rate = 0.0001
weight_decay = 1e-4
epochs = 20
patience = 5
freeze_features = False
seed = 42
save_path = 'models/mobilenet_deepfake.pth'

# Set seeds for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)

# Determine device (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load dataset and create loaders
print(f"Loading dataset: {dataset_name}")
train_loader, val_loader, test_loader = load_dataset_and_create_loaders(
    dataset_name,
    streaming=streaming,
    max_samples=max_samples,
    batch_size=batch_size,
    num_workers=num_workers,
    seed=seed
)

Using device: cuda
Loading dataset: xingjunm/WildDeepfake
Loading streaming dataset: xingjunm/WildDeepfake
Processed 100 examples: 80 train, 10 val, 10 test
Processed 200 examples: 160 train, 20 val, 20 test
Processed 300 examples: 240 train, 30 val, 30 test
Processed 400 examples: 320 train, 40 val, 40 test
Processed 500 examples: 400 train, 50 val, 50 test
Processed 600 examples: 480 train, 60 val, 60 test
Processed 700 examples: 560 train, 70 val, 70 test
Processed 800 examples: 640 train, 80 val, 80 test
Processed 900 examples: 720 train, 90 val, 90 test
Processed 1000 examples: 800 train, 100 val, 100 test
Finished processing. Total: 1000
Train: 800, Val: 100, Test: 100


In [None]:
# Create model
print("Initializing MobileNetV3 Large model...")
model = MobileNetV3DeepfakeDetector(pretrained=pretrained, freeze_features=freeze_features)
model = model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Define learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Initialize early stopping
early_stopping = EarlyStopping(patience=patience, verbose=True)

# Initialize history dictionary
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

# Initialize best validation loss
best_val_loss = float('inf')

# Training loop
print(f"Starting training for {epochs} epochs...")
start_time = time.time()

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    
    # Train for one epoch
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    # Update history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), save_path)
        print(f"Model saved to {save_path}")
    
    # Check early stopping
    if early_stopping(val_loss):
        print(f"Early stopping triggered after {epoch+1} epochs")
        break

# Print training time
end_time = time.time()
print(f"Training completed in {(end_time - start_time) / 60:.2f} minutes")

# Load best model for testing
print(f"Loading best model from {save_path}")
model.load_state_dict(torch.load(save_path))

# Test the model
print("Testing the model...")
test_results = test(model, test_loader, criterion, device)

# Print test results
print(f"Test Loss: {test_results['test_loss']:.4f}")
print(f"Test Accuracy: {test_results['test_acc']:.2f}%")
print(f"ROC AUC: {test_results['roc_auc']:.4f}")
print(f"Average Precision: {test_results['ap']:.4f}")



Initializing MobileNetV3 Large model...


NameError: name 'MobileNetV3DeepfakeDetector' is not defined

In [None]:
# Plot metrics
plot_metrics(test_results, save_path=save_path.replace('.pth', '_metrics.png'))

# Plot training history
plt.figure(figsize=(12, 5))

# Plot training and validation loss
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Plot training and validation accuracy
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['val_acc'], label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()
plt.savefig(save_path.replace('.pth', '_history.png'))
plt.show()