In [4]:
# ============================================================
# DEEPFAKE DETECTION - LOCAL DESKTOP/VS CODE VERSION
# ============================================================
# This version is adapted for running on your local machine with GPU
# No Google Drive mounting needed!

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from PIL import Image
import numpy as np
import open_clip
import copy
import sys
from tqdm.auto import tqdm
from datetime import datetime

sys.stdout.flush()

# ---------------- LOCAL PATHS CONFIG ----------------
# CHANGE THESE PATHS TO YOUR LOCAL DIRECTORIES
BASE_DIR = r"C:\Users\Admin\deepfake"


train_dirs = [
    (os.path.join(BASE_DIR, "train/real"), None, 0),
    (os.path.join(BASE_DIR, "train/fake"), None, 1)
]
valid_dirs = [
    (os.path.join(BASE_DIR, "valid/real"), None, 0),
    (os.path.join(BASE_DIR, "valid/fake"), None, 1)
]
test_image_dir = os.path.join(BASE_DIR, "test/images")
test_video_dir = os.path.join(BASE_DIR, "test/videos")
models_dir = os.path.join(BASE_DIR, "models")

# Create models directory if it doesn't exist
os.makedirs(models_dir, exist_ok=True)

# ---------------- TRAINING CONFIG ----------------
sequence_length = 4
batch_size = 8  # Increase if you have good GPU (GTX 1080 or better)
num_epochs = 25
learning_rate = 2e-5
patience = 4
min_delta = 0.005

# GPU Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"🔥 GPU detected: {torch.cuda.get_device_name(0)}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"   CUDA Version: {torch.version.cuda}")
else:
    print("⚠️ No GPU detected! Training will be VERY slow on CPU.")
    print("   Consider using Google Colab if you don't have a GPU.")

# ---------------- PATH VALIDATION ----------------
print("\n📁 Checking directories...")
required_dirs = {
    "Train Real": train_dirs[0][0],
    "Train Fake": train_dirs[1][0],
    "Valid Real": valid_dirs[0][0],
    "Valid Fake": valid_dirs[1][0],
}

all_exist = True
for name, path in required_dirs.items():
    exists = os.path.exists(path)
    status = "✅" if exists else "❌"
    print(f"   {status} {name}: {path}")
    if exists:
        file_count = len([f for f in os.listdir(path) if f.lower().endswith(('.jpg', '.png', '.jpeg'))])
        print(f"      └─ {file_count} images")
    else:
        all_exist = False

if not all_exist:
    print("\n❌ ERROR: Some required directories are missing!")
    print(f"\nPlease create the directory structure at: {BASE_DIR}")
    print("   Deepfake/")
    print("   ├── train/real/")
    print("   ├── train/fake/")
    print("   ├── valid/real/")
    print("   └── valid/fake/")
    sys.exit(1)

# ---------------- LOAD CLIP MODEL ----------------
print("\n📦 Loading CLIP model...")
try:
    clip_model_base, _, preprocess = open_clip.create_model_and_transforms(
        "ViT-B-16", pretrained="openai", device=device
    )
    print("✅ CLIP model loaded successfully")
except Exception as e:
    print(f"❌ Error loading CLIP: {e}")
    print("Run: pip install open_clip_torch")
    sys.exit(1)

# ---------------- TRANSFORMS ----------------
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3),
    preprocess.transforms[-2],
    preprocess.transforms[-1]
])
test_transform = preprocess

# ---------------- DATASET ----------------
class SequenceDataset(Dataset):
    def __init__(self, datasets, transform=None, seq_len=8, debug=False):
        self.samples = []
        self.transform = transform
        self.seq_len = seq_len
        self.debug = debug
        
        real_count = 0
        fake_count = 0

        for img_dir, _, label in datasets:
            if not os.path.exists(img_dir):
                print(f"⚠️ Missing directory: {img_dir}")
                continue

            img_files = sorted([
                f for f in os.listdir(img_dir)
                if os.path.isfile(os.path.join(img_dir, f)) and f.lower().endswith(('.jpg', '.png', '.jpeg'))
            ])
            
            if len(img_files) < seq_len:
                print(f"⚠️ Not enough images in {img_dir}: {len(img_files)} < {seq_len}")
                continue

            for i in range(0, len(img_files) - seq_len + 1, seq_len):
                seq_paths = [os.path.join(img_dir, img_files[j]) for j in range(i, i + seq_len)]
                if all(os.path.exists(p) for p in seq_paths):
                    self.samples.append((seq_paths, label))
                    if label == 0:
                        real_count += 1
                    else:
                        fake_count += 1

        if self.debug:
            print(f"\n📊 Dataset Distribution:")
            print(f"   Real sequences: {real_count}")
            print(f"   Fake sequences: {fake_count}")
            print(f"   Total sequences: {len(self.samples)}")
            if real_count > 0 and fake_count > 0:
                ratio = max(real_count, fake_count) / min(real_count, fake_count)
                print(f"   Class imbalance ratio: {ratio:.2f}:1")
                if ratio > 3:
                    print(f"   ⚠️ WARNING: Severe class imbalance detected!")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        frame_paths, label = self.samples[idx]
        frames = []
        for p in frame_paths:
            try:
                img = Image.open(p).convert("RGB")
                if self.transform:
                    img = self.transform(img)
                frames.append(img)
            except Exception as e:
                print(f"Error loading {p}: {e}")
                frames.append(torch.zeros(3, 224, 224))
        
        frames = torch.stack(frames)
        return frames, torch.tensor(label, dtype=torch.long)

# ---------------- MODEL ----------------
class CLIP_EfficientNet_LSTM(nn.Module):
    def __init__(self, clip_model, num_classes=2, hidden_dim=512, lstm_layers=2, bidirectional=True, device='cpu'):
        super(CLIP_EfficientNet_LSTM, self).__init__()
        self.device = device

        self.clip = clip_model.visual.to(device)
        for p in self.clip.parameters():
            p.requires_grad = False

        effnet = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
        effnet = effnet.to(device).eval()
        self.efficientnet = nn.Sequential(*list(effnet.children())[:-2])
        self.efficientnet_avgpool = nn.AdaptiveAvgPool2d(1)
        for p in self.efficientnet.parameters():
            p.requires_grad = False

        clip_latent_dim = 512
        effnet_latent_dim = 1536
        fused_dim = clip_latent_dim + effnet_latent_dim

        self.lstm = nn.LSTM(
            input_size=fused_dim,
            hidden_size=hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=0.3
        )
        lstm_out_dim = hidden_dim * (2 if bidirectional else 1)

        self.fc = nn.Sequential(
            nn.LayerNorm(lstm_out_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.Linear(lstm_out_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)

        with torch.no_grad():
            clip_features = self.clip(x)
            effnet_features = self.efficientnet_avgpool(self.efficientnet(x))
            effnet_features = effnet_features.view(B * T, -1)

        clip_features = clip_features.view(B, T, -1)
        effnet_features = effnet_features.view(B, T, -1)
        fused = torch.cat((clip_features, effnet_features), dim=2)

        lstm_out, _ = self.lstm(fused)
        x_mean = torch.mean(lstm_out, dim=1)
        x_max, _ = torch.max(lstm_out, dim=1)
        x = x_mean + x_max
        
        return self.fc(x)

# ---------------- EARLY STOPPING ----------------
class EarlyStopping:
    def __init__(self, patience=4, min_delta=0.005, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
            return False
        
        if self.mode == 'max':
            improved = score > (self.best_score + self.min_delta)
        else:
            improved = score < (self.best_score - self.min_delta)
        
        if improved:
            self.best_score = score
            self.counter = 0
            return False
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return self.early_stop

# ---------------- LOAD DATA ----------------
print("\n📂 Loading datasets...")
train_dataset = SequenceDataset(train_dirs, transform=train_transform, seq_len=sequence_length, debug=True)
valid_dataset = SequenceDataset(valid_dirs, transform=test_transform, seq_len=sequence_length, debug=True)

if len(train_dataset) == 0:
    print("❌ ERROR: No training data found!")
    sys.exit(1)

# Weighted sampling
labels = [label for _, label in train_dataset.samples]
class_counts = [labels.count(0), labels.count(1)]

if class_counts[0] == 0 or class_counts[1] == 0:
    print("❌ ERROR: One class has no samples!")
    sys.exit(1)

total_samples = len(labels)
num_classes = 2
class_weights = [total_samples / (num_classes * count) for count in class_counts]
sample_weights = [class_weights[l] for l in labels]

sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# Adjust num_workers for your system (0-4 is typical for desktop)
num_workers = 2 if os.name == 'nt' else 4  # Windows uses fewer workers
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, 
                          num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=True)

print(f"\n✅ Dataset loaded successfully")
print(f"   Train batches: {len(train_loader)} | Val batches: {len(valid_loader)}")
print(f"   Num workers: {num_workers}")

# ---------------- TRAINING SETUP ----------------
print("\n🏗️ Building model...")
model = CLIP_EfficientNet_LSTM(clip_model_base, device=device).to(device)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=0.1)

optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=learning_rate, 
    weight_decay=1e-4
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=2, verbose=True
)

early_stopping = EarlyStopping(patience=patience, min_delta=min_delta, mode='max')

best_val_acc = 0.0
best_epoch = 0
best_model_wts = copy.deepcopy(model.state_dict())

# Training log file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(models_dir, f"training_log_{timestamp}.txt")

def log_print(message):
    """Print and save to log file"""
    print(message)
    with open(log_file, 'a') as f:
        f.write(message + '\n')

log_print(f"\n🚀 Training Configuration:")
log_print(f"   Device: {device}")
log_print(f"   Max Epochs: {num_epochs}")
log_print(f"   Batch Size: {batch_size}")
log_print(f"   Learning Rate: {learning_rate}")
log_print(f"   Patience: {patience}")
log_print(f"   Log file: {log_file}")
log_print("=" * 80)

# ---------------- TRAINING LOOP ----------------
try:
    for epoch in range(num_epochs):
        log_print(f"\n📍 EPOCH {epoch+1}/{num_epochs}")
        log_print("-" * 80)

        # ============ TRAINING ============
        model.train()
        total_loss, correct, total = 0, 0, 0
        class_correct = [0, 0]
        class_total = [0, 0]

        train_pbar = tqdm(train_loader, desc="Training", leave=False)

        for batch_idx, (frames, labels_) in enumerate(train_pbar):
            frames, labels_ = frames.to(device), labels_.to(device)

            optimizer.zero_grad()
            outputs = model(frames)
            loss = criterion(outputs, labels_)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            total_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels_).sum().item()
            total += labels_.size(0)
            
            for i in range(len(labels_)):
                label = labels_[i].item()
                class_total[label] += 1
                if preds[i] == labels_[i]:
                    class_correct[label] += 1

            current_acc = 100 * correct / total
            train_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{current_acc:.2f}%'
            })

        train_acc = 100 * correct / total
        avg_loss = total_loss / len(train_loader)
        train_real_acc = 100 * class_correct[0] / class_total[0] if class_total[0] > 0 else 0
        train_fake_acc = 100 * class_correct[1] / class_total[1] if class_total[1] > 0 else 0

        # ============ VALIDATION ============
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        val_class_correct = [0, 0]
        val_class_total = [0, 0]
        all_probs = []

        with torch.no_grad():
            for frames, labels_ in tqdm(valid_loader, desc="Validating", leave=False):
                frames, labels_ = frames.to(device), labels_.to(device)
                outputs = model(frames)
                loss = criterion(outputs, labels_)
                
                val_loss += loss.item()
                probs = torch.softmax(outputs, dim=1)
                _, preds = torch.max(outputs, 1)
                
                val_total += labels_.size(0)
                val_correct += (preds == labels_).sum().item()
                
                for i in range(len(labels_)):
                    label = labels_[i].item()
                    val_class_total[label] += 1
                    if preds[i] == labels_[i]:
                        val_class_correct[label] += 1
                
                all_probs.extend(probs.cpu().numpy())

        val_acc = 100 * val_correct / val_total if val_total > 0 else 0
        avg_val_loss = val_loss / len(valid_loader) if len(valid_loader) > 0 else 0
        val_real_acc = 100 * val_class_correct[0] / val_class_total[0] if val_class_total[0] > 0 else 0
        val_fake_acc = 100 * val_class_correct[1] / val_class_total[1] if val_class_total[1] > 0 else 0
        avg_confidence = np.mean([max(p) for p in all_probs]) if all_probs else 0

        scheduler.step(val_acc)
        current_lr = optimizer.param_groups[0]['lr']

        # ============ LOGGING ============
        log_print(f"\n📊 Epoch {epoch+1} Results:")
        log_print(f"   Train Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}%")
        log_print(f"     ├─ Real: {train_real_acc:.2f}% | Fake: {train_fake_acc:.2f}%")
        log_print(f"   Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        log_print(f"     ├─ Real: {val_real_acc:.2f}% | Fake: {val_fake_acc:.2f}%")
        log_print(f"     └─ Avg Confidence: {avg_confidence:.4f}")
        log_print(f"   Learning Rate: {current_lr:.2e}")

        # ============ SAVE BEST MODEL ============
        is_best = val_acc > best_val_acc
        
        if is_best:
            improvement = val_acc - best_val_acc
            best_val_acc = val_acc
            best_epoch = epoch + 1
            best_model_wts = copy.deepcopy(model.state_dict())
            log_print(f"   ✨ NEW BEST! (+{improvement:.2f}%)")
        
        # ============ EARLY STOPPING ============
        if early_stopping(val_acc):
            log_print(f"\n⏹️ EARLY STOPPING TRIGGERED!")
            log_print(f"   No improvement for {patience} epochs")
            log_print(f"   Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch})")
            break
        
        if early_stopping.counter > 0:
            log_print(f"   ⏳ Patience: {early_stopping.counter}/{patience}")

        if avg_confidence < 0.6:
            log_print(f"   ⚠️ WARNING: Low average confidence ({avg_confidence:.2f})")

        log_print("=" * 80)
        sys.stdout.flush()

except KeyboardInterrupt:
    log_print("\n⚠️ Training interrupted by user!")
    log_print("   Saving current best model...")

# ---------------- SAVE MODEL ----------------
log_print("\n💾 Saving best model...")
model.load_state_dict(best_model_wts)
save_path = os.path.join(models_dir, f"clip_effnet_lstm_best_{timestamp}.pth")

torch.save({
    'model_state_dict': model.state_dict(),
    'best_val_acc': best_val_acc,
    'best_epoch': best_epoch,
    'config': {
        'sequence_length': sequence_length,
        'hidden_dim': 512,
        'lstm_layers': 2,
        'bidirectional': True
    },
    'timestamp': timestamp
}, save_path)

log_print(f"✅ Model saved to: {save_path}")
log_print(f"🏆 Best Validation Accuracy: {best_val_acc:.2f}% (Epoch {best_epoch})")

# ---------------- TESTING ----------------
if os.path.exists(test_image_dir):
    log_print(f"\n🧪 Testing on images from: {test_image_dir}")
    model.eval()
    classes = ["Real", "Fake"]
    
    img_files = sorted([f for f in os.listdir(test_image_dir) 
                       if f.lower().endswith(('.jpg', '.jpeg', '.png'))])

    if len(img_files) > 0:
        log_print(f"Found {len(img_files)} test images\n")
        
        for img_file in img_files[:20]:
            try:
                img_path = os.path.join(test_image_dir, img_file)
                img = Image.open(img_path).convert("RGB")
                img_tensor = test_transform(img)
                seq_tensor = img_tensor.unsqueeze(0).repeat(sequence_length, 1, 1, 1).unsqueeze(0).to(device)

                with torch.no_grad():
                    outputs = model(seq_tensor)
                    probs = torch.softmax(outputs, dim=1)
                    _, preds = torch.max(outputs, 1)
                    confidence = probs[0, preds.item()].item()
                    
                    if confidence > 0.75:
                        conf_emoji = "🟢"
                    elif confidence > 0.60:
                        conf_emoji = "🟡"
                    else:
                        conf_emoji = "🔴"
                    
                    result = f"{conf_emoji} {img_file} → {classes[preds.item()]} (conf: {confidence:.4f})"
                    log_print(result)
                    
            except Exception as e:
                log_print(f"Error processing {img_file}: {e}")

log_print("\n✅ Training complete!")
log_print(f"📄 Full log saved to: {log_file}")

SyntaxError: (unicode error) 'unicodeescape' codec can't decode bytes in position 2-3: truncated \UXXXXXXXX escape (2415137539.py, line 25)