In [None]:
import os
import random
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from scipy.fftpack import dct
import pickle

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Feature Extraction Functions
def extract_dct_features(image_tensor, block_size=8):
    # Ensure we have a single image (C,H,W)
    if len(image_tensor.shape) == 4:  # If batched, take first image
        image_tensor = image_tensor[0]  # Now shape is [C,H,W]
    
    # Convert to numpy array in HWC format
    image_np = image_tensor.permute(1, 2, 0).cpu().numpy()
    
    # Convert to grayscale
    image_gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
    
    h, w = image_gray.shape
    dct_features = np.zeros((h//block_size, w//block_size, block_size*block_size))
    
    for i in range(0, h, block_size):
        if i + block_size > h:
            continue
        for j in range(0, w, block_size):
            if j + block_size > w:
                continue
            block = image_gray[i:i+block_size, j:j+block_size]
            dct_block = dct(dct(block.T, norm='ortho').T, norm='ortho')
            dct_features[i//block_size, j//block_size, :] = dct_block.flatten()
    
    # Return as tensor with proper dimensions [B, C, H, W]
    dct_tensor = torch.from_numpy(dct_features).float().to(device)
    dct_tensor = dct_tensor.permute(2, 0, 1).unsqueeze(0)  # [1, C, H, W]
    return dct_tensor

def extract_fft_features(image_tensor):
    # Ensure we have a single image (C,H,W)
    if len(image_tensor.shape) == 4:  # If batched, take first image
        image_tensor = image_tensor[0]  # Now shape is [C,H,W]
        
    # Apply FFT to each channel
    image_np = image_tensor.cpu().numpy()
    fft_features = np.zeros_like(image_np)
    
    for c in range(image_np.shape[0]):
        fft_result = np.fft.fft2(image_np[c])
        fft_shifted = np.fft.fftshift(fft_result)
        fft_magnitude = np.log(np.abs(fft_shifted) + 1)
        fft_features[c] = fft_magnitude
    
    # Return as tensor with proper dimensions [B, C, H, W]
    fft_tensor = torch.from_numpy(fft_features).float().to(device)
    fft_tensor = fft_tensor.unsqueeze(0)  # [1, C, H, W]
    return fft_tensor

def extract_entropy_features(image_tensor, window_size=8):
    # Ensure we have a single image (C,H,W)
    if len(image_tensor.shape) == 4:  # If batched, take first image
        image_tensor = image_tensor[0]  # Now shape is [C,H,W]
    
    # Convert to numpy array in HWC format
    image_np = image_tensor.permute(1, 2, 0).cpu().numpy()
    
    # Convert to grayscale
    image_gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
    
    h, w = image_gray.shape
    entropy_features = np.zeros((h//window_size, w//window_size))
    
    for i in range(0, h, window_size):
        if i + window_size > h:
            continue
        for j in range(0, w, window_size):
            if j + window_size > w:
                continue
            window = image_gray[i:i+window_size, j:j+window_size]
            
            # Calculate entropy
            hist = cv2.calcHist([window], [0], None, [256], [0, 256])
            hist = hist / (window_size * window_size)
            non_zero_hist = hist[hist > 0]
            entropy = -np.sum(non_zero_hist * np.log2(non_zero_hist)) if len(non_zero_hist) > 0 else 0
            
            entropy_features[i//window_size, j//window_size] = entropy
    
    # Return as tensor with proper dimensions [B, C, H, W]
    entropy_tensor = torch.from_numpy(entropy_features).float().to(device)
    entropy_tensor = entropy_tensor.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
    return entropy_tensor

# Dataset class for real/fake images
class RealFakeDataset(Dataset):
    def __init__(self, real_dir, fake_dir, transform=None):
        self.real_paths = [os.path.join(real_dir, f) for f in os.listdir(real_dir) if f.endswith('.jpg')]
        self.fake_paths = [os.path.join(fake_dir, f) for f in os.listdir(fake_dir) if f.endswith('.jpg')]
        self.data = [(path, 0) for path in self.real_paths] + [(path, 1) for path in self.fake_paths]
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

# Transform for images
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create full dataset
full_dataset = RealFakeDataset(
    r"/home/jayapal/.cache/kagglehub/datasets/iaddydas/deepfake/versions/1/REAL/REAL",
    r"/home/jayapal/.cache/kagglehub/datasets/iaddydas/deepfake/versions/1/FAKE/FAKE",
    transform=transform
)

# Create subset (10%)
x = 0.1  # Subset percentage
subset_size = int(x * len(full_dataset))
indices = torch.randperm(len(full_dataset))[:subset_size]
subset = Subset(full_dataset, indices)

# Precompute features for subset
print("Precomputing features for subset...")
precomputed_data = []

for idx in tqdm(range(len(subset)), desc="Processing images"):
    img, label = subset[idx]
    img = img.unsqueeze(0).to(device)  # Add batch dimension [1, C, H, W]
    
    # Extract features
    dct_feat = extract_dct_features(img)
    fft_feat = extract_fft_features(img)
    entropy_feat = extract_entropy_features(img)
    
    precomputed_data.append({
        'image': img.cpu(),
        'dct_features': dct_feat.cpu(),
        'fft_features': fft_feat.cpu(),
        'entropy_features': entropy_feat.cpu(),
        'label': label
    })

# Save precomputed dataset
print("Saving precomputed dataset...")
with open('CIFAKE_Precomputed_Dataset.pkl', 'wb') as f:
    pickle.dump(precomputed_data, f)

# Visualize 4 random samples
print("\nVisualizing 4 random samples...")
plt.figure(figsize=(20, 15))
random_indices = random.sample(range(len(precomputed_data)), 4)

for i, idx in enumerate(random_indices):
    sample = precomputed_data[idx]
    
    # Original image
    plt.subplot(4, 4, i*4 + 1)
    img = sample['image'].squeeze(0).permute(1, 2, 0)
    img = img * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
    plt.imshow(img.clamp(0, 1))
    plt.title(f"Original Image\nLabel: {'Fake' if sample['label'] else 'Real'}")
    plt.axis('off')
    
    # DCT features
    plt.subplot(4, 4, i*4 + 2)
    dct_vis = sample['dct_features'].squeeze(0).mean(dim=0)
    plt.imshow(dct_vis, cmap='viridis')
    plt.title('DCT Features')
    plt.axis('off')
    
    # FFT features
    plt.subplot(4, 4, i*4 + 3)
    fft_vis = sample['fft_features'].squeeze(0).mean(dim=0)
    plt.imshow(fft_vis, cmap='viridis')
    plt.title('FFT Features')
    plt.axis('off')
    
    # Entropy features
    plt.subplot(4, 4, i*4 + 4)
    entropy_vis = sample['entropy_features'].squeeze(0).squeeze(0)
    plt.imshow(entropy_vis, cmap='viridis')
    plt.title('Entropy Features')
    plt.axis('off')

plt.tight_layout()
plt.show()

print("Precomputation and visualization complete!") 







import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load precomputed dataset
print("Loading precomputed dataset...")
with open('CIFAKE_Precomputed_Dataset.pkl', 'rb') as f:
    precomputed_data = pickle.load(f)

# Create dataset class for precomputed features
class PrecomputedFeatureDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        return (
            sample['image'],
            sample['dct_features'],
            sample['fft_features'],
            sample['entropy_features'],
            sample['label']
        )

# Create dataset and split
dataset = PrecomputedFeatureDataset(precomputed_data)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

# ResNet-style basic block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

# NCFIM Model
class NCFIMModel(nn.Module):
    def __init__(self):
        super(NCFIMModel, self).__init__()
        
        # Raw image CNN
        self.raw_image_cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        # Feature processors
        self.dct_processor = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.fft_processor = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.entropy_processor = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        # ResNet backbone
        self.stage1 = nn.Sequential(
            ResidualBlock(64, 64),
            ResidualBlock(64, 64)
        )
        
        self.stage2 = nn.Sequential(
            ResidualBlock(64, 128, stride=2),
            ResidualBlock(128, 128)
        )
        
        self.stage3 = nn.Sequential(
            ResidualBlock(128, 256, stride=2),
            ResidualBlock(256, 256),
            ResidualBlock(256, 256)
        )
        
        self.stage4 = nn.Sequential(
            ResidualBlock(256, 512, stride=2),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512)
        )
        
        # Final classification
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 2)
        )
        
    def compute_ncfim(self, feature_list):
        batch_size = feature_list[0].size(0)
        ncfim = torch.zeros(batch_size, 4, 4, device=feature_list[0].device)
        
        for i in range(4):
            for j in range(4):
                if i != j:  # Skip diagonal
                    F_i = feature_list[i]
                    F_j = feature_list[j]
                    squared_diff = (F_i - F_j)**2
                    mean_squared_diff = torch.mean(squared_diff, dim=(1, 2, 3))
                    ncfim[:, i, j] = torch.tanh(mean_squared_diff)
        
        return ncfim
    
    def forward(self, x):
        # x is a tuple of (image, dct_features, fft_features, entropy_features)
        image, dct_features, fft_features, entropy_features = x
        
        # Process each feature stream
        processed_dct = self.dct_processor(dct_features)
        processed_fft = self.fft_processor(fft_features)
        processed_entropy = self.entropy_processor(entropy_features)
        
        # CNN features
        cnn_features = self.raw_image_cnn(image)
        
        # Ensure all features have same shape
        feature_list = [processed_dct, processed_fft, processed_entropy, cnn_features]
        
        # Compute NCFIM
        ncfim = self.compute_ncfim(feature_list)
        
        # Calculate fusion weights
        weights = torch.sum(ncfim, dim=2)
        
        # Fuse features using NCFIM weights
        F_final = torch.zeros_like(processed_dct)
        for i in range(4):
            weight = weights[:, i].view(-1, 1, 1, 1)
            F_final += weight * feature_list[i]
        
        # Pass through ResNet backbone
        x = self.stage1(F_final)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        
        # Global pooling and classification
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x, ncfim, feature_list

# Initialize model, loss function, and optimizer
model = NCFIMModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 40
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

print("Starting training...")
for epoch in range(num_epochs):
    # Training
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False)
    for images, dct_features, fft_features, entropy_features, labels in train_bar:
        # Move data to GPU
        images = images.to(device)
        dct_features = dct_features.to(device)
        fft_features = fft_features.to(device)
        entropy_features = entropy_features.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs, ncfim, feature_list = model((images, dct_features, fft_features, entropy_features))
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        train_bar.set_postfix({
            'loss': running_loss / (train_bar.n + 1),
            'acc': 100. * correct / total
        })
    
    train_loss = running_loss / len(train_loader)
    train_accuracy = 100. * correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    
    # Testing
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        test_bar = tqdm(test_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Test]', leave=False)
        for images, dct_features, fft_features, entropy_features, labels in test_bar:
            # Move data to GPU
            images = images.to(device)
            dct_features = dct_features.to(device)
            fft_features = fft_features.to(device)
            entropy_features = entropy_features.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs, ncfim, feature_list = model((images, dct_features, fft_features, entropy_features))
            loss = criterion(outputs, labels)
            
            # Statistics
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            test_bar.set_postfix({
                'loss': test_loss / (test_bar.n + 1),
                'acc': 100. * correct / total
            })
    
    test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
    
    # Print epoch summary
    print(f'\nEpoch {epoch+1}/{num_epochs}:')
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.2f}%')
    print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_accuracy:.2f}%')
    
    # Visualize NCFIM matrix and feature fusion
    if epoch % 5 == 0:  # Every 5 epochs
        plt.figure(figsize=(15, 5))
        
        # NCFIM matrix
        plt.subplot(1, 3, 1)
        ncfim_vis = ncfim[0].cpu().numpy()
        plt.imshow(ncfim_vis, cmap='coolwarm', vmin=-1, vmax=1)
        plt.colorbar()
        plt.title('NCFIM Matrix')
        plt.xticks([0,1,2,3], ['DCT','FFT','Entropy','CNN'])
        plt.yticks([0,1,2,3], ['DCT','FFT','Entropy','CNN'])
        
        # Feature weights
        plt.subplot(1, 3, 2)
        weights = torch.sum(ncfim, dim=2)[0].cpu().numpy()
        plt.bar(['DCT','FFT','Entropy','CNN'], weights)
        plt.title('Feature Weights')
        plt.ylabel('NCFIM Weight')
        
        # Confusion matrix
        plt.subplot(1, 3, 3)
        cm = confusion_matrix(all_labels, all_preds)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        
        plt.tight_layout()
        plt.show()

# Plot training curves
plt.figure(figsize=(12, 5))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.title('Training and Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Accuracy plot
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Training and Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()

print("Training complete!")





from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, matthews_corrcoef, cohen_kappa_score, log_loss, balanced_accuracy_score
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm

# Load the saved model
model_path = 'real_fake_fusion_model.pth'
model = MultiDomainFeatureFusionModel().to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

# Function to get all predictions and true labels
def get_all_predictions(model, loader):
    all_preds = []
    all_probs = []
    all_labels = []
    all_inputs = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc='Evaluating'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            
            # Get probabilities and predictions
            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_inputs.append(inputs.cpu())
            all_probs.append(probs.cpu())
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    return torch.cat(all_inputs), torch.cat(all_preds), torch.cat(all_probs), torch.cat(all_labels)

# Get predictions for both datasets
print("Evaluating train dataset:")
train_inputs, train_preds, train_probs, train_labels = get_all_predictions(model, train_loader)
print("Evaluating test dataset:")
test_inputs, test_preds, test_probs, test_labels = get_all_predictions(model, test_loader)

# Convert to numpy for both datasets
train_y_true = train_labels.numpy()
train_y_pred = train_preds.numpy()
train_y_probs = train_probs.numpy()[:, 1]  # Probability of class 1 (fake)

test_y_true = test_labels.numpy()
test_y_pred = test_preds.numpy()
test_y_probs = test_probs.numpy()[:, 1]  # Probability of class 1 (fake)

# 1. Classification Report for both datasets
print("\nClassification Report (Training Set):")
print(classification_report(train_y_true, train_y_pred, target_names=['Real', 'Fake']))

print("\nClassification Report (Test Set):")
print(classification_report(test_y_true, test_y_pred, target_names=['Real', 'Fake']))

# 2. Confusion Matrix using matplotlib instead of seaborn
def plot_confusion_matrix(y_true, y_pred, dataset_name):
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(f'Confusion Matrix - {dataset_name}')
    plt.colorbar()
    
    classes = ['Real', 'Fake']
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    
    # Add text annotations
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                    horizontalalignment="center",
                    color="white" if cm[i, j] > thresh else "black")
    
    plt.xlabel('Predicted label')
    plt.ylabel('True label')
    plt.tight_layout()
    plt.show()

plot_confusion_matrix(train_y_true, train_y_pred, "Training Set")
plot_confusion_matrix(test_y_true, test_y_pred, "Test Set")

# 3. ROC Curve and AUC for both datasets
def plot_roc_curve(y_true, y_probs, dataset_name):
    fpr, tpr, _ = roc_curve(y_true, y_probs)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, 
             label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic - {dataset_name}')
    plt.legend(loc="lower right")
    plt.show()

plot_roc_curve(train_y_true, train_y_probs, "Training Set")
plot_roc_curve(test_y_true, test_y_probs, "Test Set")

# 4. Precision-Recall Curve for both datasets
def plot_precision_recall(y_true, y_probs, dataset_name):
    precision, recall, _ = precision_recall_curve(y_true, y_probs)
    avg_precision = average_precision_score(y_true, y_probs)
    
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, color='blue', lw=2,
             label=f'Precision-Recall (AP = {avg_precision:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curve - {dataset_name}')
    plt.legend(loc="upper right")
    plt.show()

plot_precision_recall(train_y_true, train_y_probs, "Training Set")
plot_precision_recall(test_y_true, test_y_probs, "Test Set")

# 5. Advanced Metrics Table for both datasets
def calculate_advanced_metrics(y_true, y_pred, y_probs):
    # Calculate ROC curve values for AUC
    fpr, tpr, _ = roc_curve(y_true, y_probs)
    roc_auc_value = auc(fpr, tpr)
    
    metrics = {
        'Accuracy': accuracy_score(y_true, y_pred),
        'Balanced Accuracy': balanced_accuracy_score(y_true, y_pred),
        'Precision': precision_score(y_true, y_pred),
        'Recall': recall_score(y_true, y_pred),
        'F1 Score': f1_score(y_true, y_pred),  # Explicitly using f1_score from sklearn.metrics
        'MCC': matthews_corrcoef(y_true, y_pred),
        "Cohen's Kappa": cohen_kappa_score(y_true, y_pred),
        'Log Loss': log_loss(y_true, y_probs),
        'ROC AUC': roc_auc_value,
        'PR AUC': average_precision_score(y_true, y_probs)
    }
    
    return pd.DataFrame.from_dict(metrics, orient='index', columns=['Value'])

train_advanced_metrics = calculate_advanced_metrics(train_y_true, train_y_pred, train_y_probs)
print("\nAdvanced Metrics (Training Set):")
print(train_advanced_metrics)

test_advanced_metrics = calculate_advanced_metrics(test_y_true, test_y_pred, test_y_probs)
print("\nAdvanced Metrics (Test Set):")
print(test_advanced_metrics)

# 6. Class-wise Metrics for both datasets
def class_wise_metrics(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()
    
    # Handle potential division by zero
    def safe_division(x, y):
        return x / y if y != 0 else 0
    
    metrics = {
        'Class': ['Real', 'Fake'],
        'TP': [tn, tp],  # For Real, TN is like TP
        'FP': [fp, fn],  # For Real, FP is like FN
        'TN': [tp, tn],  # For Real, TP is like TN
        'FN': [fn, fp],  # For Real, FN is like FP
        'Sensitivity': [safe_division(tn, (tn+fp)), safe_division(tp, (tp+fn))],
        'Specificity': [safe_division(tp, (tp+fn)), safe_division(tn, (tn+fp))],
        'PPV': [safe_division(tn, (tn+fn)), safe_division(tp, (tp+fp))],
        'NPV': [safe_division(tp, (tp+fn)), safe_division(tn, (tn+fp))]
    }
    
    return pd.DataFrame(metrics)

train_class_metrics = class_wise_metrics(train_y_true, train_y_pred)
print("\nClass-wise Metrics (Training Set):")
print(train_class_metrics)

test_class_metrics = class_wise_metrics(test_y_true, test_y_pred)
print("\nClass-wise Metrics (Test Set):")
print(test_class_metrics)

# 7. Error Analysis for both datasets
def analyze_errors(model, inputs, preds, labels, dataset_name, num_samples=5):
    # Find misclassified samples
    incorrect_mask = ~preds.eq(labels)
    incorrect_indices = torch.where(incorrect_mask)[0]
    
    if len(incorrect_indices) == 0:
        print(f"No misclassifications found in the {dataset_name}")
        return
    
    # Limit to the requested number of samples
    sample_count = min(num_samples, len(incorrect_indices))
    selected_indices = incorrect_indices[:sample_count]
    
    # Visualize errors
    plt.figure(figsize=(15, 3))
    for i, idx in enumerate(selected_indices):
        plt.subplot(1, sample_count, i+1)
        plt.imshow(inputs[idx].permute(1, 2, 0))
        plt.title(f"True: {'Fake' if labels[idx] else 'Real'}\n"
                 f"Pred: {'Fake' if preds[idx] else 'Real'}")
        plt.axis('off')
    plt.suptitle(f"Error Analysis - {dataset_name}")
    plt.tight_layout()
    plt.show()

print("\nError Analysis (Training Set):")
analyze_errors(model, train_inputs, train_preds, train_labels, "Training Set")

print("\nError Analysis (Test Set):")
analyze_errors(model, test_inputs, test_preds, test_labels, "Test Set")

# 8. Threshold Analysis for both datasets
def threshold_analysis(y_true, y_probs, dataset_name):
    thresholds = np.linspace(0, 1, 101)
    metrics = []
    
    for thresh in thresholds:
        y_pred_thresh = (y_probs >= thresh).astype(int)
        
        # Calculate confusion matrix
        cm = confusion_matrix(y_true, y_pred_thresh)
        tn, fp, fn, tp = cm.ravel()
        
        # Handle division by zero
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
        
        # Calculate F1 score explicitly to avoid potential issues
        f1 = 2 * (precision * tpr) / (precision + tpr) if (precision + tpr) > 0 else 0
        
        metrics.append({
            'Threshold': thresh,
            'TPR': tpr,
            'FPR': fpr,
            'Precision': precision,
            'F1': f1
        })
    
    df = pd.DataFrame(metrics)
    
    # Plot metrics vs threshold
    plt.figure(figsize=(12, 6))
    plt.plot(df['Threshold'], df['TPR'], label='True Positive Rate (Recall)')
    plt.plot(df['Threshold'], df['FPR'], label='False Positive Rate')
    plt.plot(df['Threshold'], df['Precision'], label='Precision')
    plt.plot(df['Threshold'], df['F1'], label='F1 Score')
    
    # Find optimal threshold (max F1)
    if df['F1'].max() > 0:
        optimal_idx = df['F1'].idxmax()
        optimal_thresh = df.loc[optimal_idx, 'Threshold']
        plt.axvline(x=optimal_thresh, color='r', linestyle='--', 
                    label=f'Optimal Threshold: {optimal_thresh:.2f}')
    
    plt.xlabel('Threshold')
    plt.ylabel('Metric Value')
    plt.title(f'Threshold Analysis - {dataset_name}')
    plt.legend()
    plt.grid()
    plt.show()
    
    optimal_threshold = df.loc[df['F1'].idxmax(), 'Threshold'] if df['F1'].max() > 0 else 0.5
    return optimal_threshold

print("\nThreshold Analysis (Training Set):")
train_optimal_threshold = threshold_analysis(train_y_true, train_y_probs, "Training Set")
print(f"Optimal Decision Threshold (Training Set): {train_optimal_threshold:.4f}")

print("\nThreshold Analysis (Test Set):")
test_optimal_threshold = threshold_analysis(test_y_true, test_y_probs, "Test Set")
print(f"Optimal Decision Threshold (Test Set): {test_optimal_threshold:.4f}")

# 9. Prediction Distribution for both datasets
def plot_prediction_distribution(y_true, y_probs, optimal_threshold, dataset_name):
    plt.figure(figsize=(10, 6))
    
    # Create bins for histograms
    bins = np.linspace(0, 1, 21)
    
    # Plot histograms for each class
    real_probs = y_probs[y_true == 0]
    fake_probs = y_probs[y_true == 1]
    
    if len(real_probs) > 0:
        plt.hist(real_probs, bins=bins, alpha=0.5, color='blue', 
                 label='Real', density=True)
    
    if len(fake_probs) > 0:
        plt.hist(fake_probs, bins=bins, alpha=0.5, color='red', 
                 label='Fake', density=True)
    
    plt.axvline(x=0.5, color='black', linestyle='--', label='Default Threshold (0.5)')
    plt.axvline(x=optimal_threshold, color='green', linestyle='--', 
                label=f'Optimal Threshold ({optimal_threshold:.2f})')
    
    plt.xlabel('Predicted Probability of Being Fake')
    plt.ylabel('Density')
    plt.title(f'Prediction Distribution by True Class - {dataset_name}')
    plt.legend()
    plt.show()

plot_prediction_distribution(train_y_true, train_y_probs, train_optimal_threshold, "Training Set")
plot_prediction_distribution(test_y_true, test_y_probs, test_optimal_threshold, "Test Set")

# 10. Compare Training vs Testing Performance
def compare_datasets_performance():
    # Create comparison DataFrame
    comparison = pd.DataFrame({
        'Metric': train_advanced_metrics.index,
        'Training': train_advanced_metrics['Value'],
        'Testing': test_advanced_metrics['Value'],
        'Difference': test_advanced_metrics['Value'] - train_advanced_metrics['Value']
    })
    
    # Plot comparison
    plt.figure(figsize=(12, 8))
    
    metrics_to_plot = ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC']
    comparison_subset = comparison[comparison['Metric'].isin(metrics_to_plot)]
    
    x = np.arange(len(metrics_to_plot))
    width = 0.35
    
    plt.bar(x - width/2, comparison_subset['Training'], width, label='Training')
    plt.bar(x + width/2, comparison_subset['Testing'], width, label='Testing')
    
    plt.xlabel('Metrics')
    plt.ylabel('Score')
    plt.title('Training vs Testing Performance')
    plt.xticks(x, metrics_to_plot)
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add value labels
    for i, v in enumerate(comparison_subset['Training']):
        plt.text(i - width/2, v + 0.01, f'{v:.2f}', ha='center')
    
    for i, v in enumerate(comparison_subset['Testing']):
        plt.text(i + width/2, v + 0.01, f'{v:.2f}', ha='center')
    
    plt.tight_layout()
    plt.show()
    
    return comparison

print("\nTraining vs Testing Performance Comparison:")
performance_comparison = compare_datasets_performance()
print(performance_comparison)

# 11. Overall Model Assessment Summary
def print_model_assessment_summary():
    print("\n==== MODEL ASSESSMENT SUMMARY ====")
    
    # Calculate overfitting metrics
    accuracy_diff = train_advanced_metrics.loc['Accuracy', 'Value'] - test_advanced_metrics.loc['Accuracy', 'Value']
    f1_diff = train_advanced_metrics.loc['F1 Score', 'Value'] - test_advanced_metrics.loc['F1 Score', 'Value']
    
    print(f"Model: MultiDomainFeatureFusionModel")
    print(f"Training Accuracy: {train_advanced_metrics.loc['Accuracy', 'Value']:.4f}")
    print(f"Testing Accuracy: {test_advanced_metrics.loc['Accuracy', 'Value']:.4f}")
    print(f"Accuracy Gap (Train-Test): {accuracy_diff:.4f}")
    
    print(f"\nTraining F1 Score: {train_advanced_metrics.loc['F1 Score', 'Value']:.4f}")
    print(f"Testing F1 Score: {test_advanced_metrics.loc['F1 Score', 'Value']:.4f}")
    print(f"F1 Score Gap (Train-Test): {f1_diff:.4f}")
    
    # Assess overfitting
    if accuracy_diff > 0.05:
        print("\nPotential overfitting detected: The model performs significantly better on training data than test data.")
        print("Consider regularization techniques or collecting more diverse training data.")
    else:
        print("\nNo significant overfitting detected: Model performs similarly on training and test data.")
    
    # Assess model quality based on test metrics
    test_f1 = test_advanced_metrics.loc['F1 Score', 'Value']
    if test_f1 > 0.9:
        model_quality = "Excellent"
    elif test_f1 > 0.8:
        model_quality = "Good"
    elif test_f1 > 0.7:
        model_quality = "Fair"
    else:
        model_quality = "Needs improvement"
    
    print(f"\nOverall Model Quality: {model_quality}")
    print(f"Optimal Threshold (based on F1): {test_optimal_threshold:.4f}")
    
    # Class-specific assessment
    real_sensitivity = test_class_metrics.loc[0, 'Sensitivity']
    fake_sensitivity = test_class_metrics.loc[1, 'Sensitivity']
    
    print(f"\nReal Image Detection Sensitivity: {real_sensitivity:.4f}")
    print(f"Fake Image Detection Sensitivity: {fake_sensitivity:.4f}")
    
    if abs(real_sensitivity - fake_sensitivity) > 0.1:
        imbalance_class = "Real" if real_sensitivity > fake_sensitivity else "Fake"
        print(f"Class imbalance detected: Model performs better on '{imbalance_class}' class.")
        print("Consider balanced training techniques or adjusting the decision threshold.")

print_model_assessment_summary()



scp -P 6868 "C:\Deepfake\Deepfake_Detection.ipynb" harshith@45.112.150.141:/home/harshith/





