In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool
from torch_geometric.data import Data, DataLoader
import cv2
import numpy as np
import os
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import librosa
import warnings
import random
import glob
import json
from torch.optim.lr_scheduler import CosineAnnealingLR

warnings.filterwarnings('ignore')

class AudioVisualFeatureExtractor:
    """Extract features from audio and visual modalities - FIXED VERSION"""
    
    def __init__(self, device='cuda'):
        self.device = device
        self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
        self.scaler_video = StandardScaler()
        self.scaler_audio = StandardScaler()
        self.is_fitted = False
        
    def extract_video_features(self, video_path, max_frames=30):
        """Extract facial and optical flow features from video - FIXED"""
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Warning: Could not open video {video_path}")
            return np.zeros((max_frames, 1024))
        
        facial_features = []
        flow_features = []
        prev_frame = None
        frame_count = 0
        
        while frame_count < max_frames:
            ret, frame = cap.read()
            if not ret:
                break
                
            frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            
            faces = self.face_cascade.detectMultiScale(frame_gray, 1.1, 4)
            if len(faces) > 0:
                x, y, w, h = faces[0]
                face_roi = frame_gray[y:y+h, x:x+w]
                if face_roi.size > 0:
                    face_resized = cv2.resize(face_roi, (32, 32))
                    face_features = face_resized.flatten()
                    if len(face_features) < 512:
                        face_features = np.pad(face_features, (0, 512 - len(face_features)), 'constant')
                    else:
                        face_features = face_features[:512]
                    facial_features.append(face_features)
                else:
                    facial_features.append(np.zeros(512))
            else:
                facial_features.append(np.zeros(512))
            
            if prev_frame is not None:
                flow = cv2.calcOpticalFlowFarneback(prev_frame, frame_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
                flow_flat = flow.flatten()
                if len(flow_flat) >= 512:
                    flow_features.append(flow_flat[:512])
                else:
                    padded_flow = np.pad(flow_flat, (0, 512 - len(flow_flat)), 'constant')
                    flow_features.append(padded_flow)
            else:
                flow_features.append(np.zeros(512))
                
            prev_frame = frame_gray
            frame_count += 1
        
        cap.release()
        
        if len(facial_features) == 0:
            facial_features = [np.zeros(512)]
        if len(flow_features) == 0:
            flow_features = [np.zeros(512)]
            
        facial_features = np.array(facial_features)
        flow_features = np.array(flow_features)
        
        if len(facial_features) < max_frames:
            pad_shape = ((0, max_frames - len(facial_features)), (0, 0))
            facial_features = np.pad(facial_features, pad_shape, 'constant')
            flow_features = np.pad(flow_features, pad_shape, 'constant')
        else:
            facial_features = facial_features[:max_frames]
            flow_features = flow_features[:max_frames]
        
        video_features = np.concatenate([facial_features, flow_features], axis=1)
        return video_features
    
    def extract_audio_features(self, audio_path, max_length=5):
        """Extract audio features - FIXED with debugging"""
        try:
            if not os.path.exists(audio_path):
                print(f"File not found: {audio_path}")
                return np.zeros(28)
            
            if audio_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
                temp_audio = os.path.join(os.environ.get('TEMP', os.getcwd()), "temp_audio.wav")
                if os.name == 'nt':
                    null_device = 'NUL'
                else:
                    null_device = '/dev/null'
                cmd = f'ffmpeg -i "{audio_path}" -vn -acodec pcm_s16le -ar 16000 -ac 1 "{temp_audio}" -y'
                result = os.system(f'{cmd} > {null_device} 2>&1')
                if result != 0:
                    print(f"ffmpeg failed with exit code {result} for {audio_path}. Command: {cmd}")
                    return np.zeros(28)
                if not os.path.exists(temp_audio):
                    print(f"Temp file {temp_audio} not created for {audio_path}")
                    return np.zeros(28)
                y, sr = librosa.load(temp_audio, sr=16000, duration=max_length)
                os.remove(temp_audio)
            else:
                y, sr = librosa.load(audio_path, sr=16000, duration=max_length)
            
            if len(y) < sr * 0.5:
                y = np.pad(y, (0, int(sr * 0.5) - len(y)), 'constant')
            
            if random.random() < 0.3:
                noise = np.random.normal(0, 0.005, len(y))
                y = y + noise
            
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
            spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr)
            spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            
            features = np.hstack([
                np.mean(mfccs, axis=1),
                np.mean(spectral_centroids),
                np.mean(spectral_rolloff),
                np.mean(chroma, axis=1)
            ])
            
            return features
        except Exception as e:
            print(f"Error processing audio {audio_path}: {e}")
            return np.zeros(28)
    
    def fit_scalers(self, video_audio_pairs):
        """Fit scalers on entire dataset - FIXED"""
        all_video_features = []
        all_audio_features = []
        
        for i, (video_path, audio_path) in enumerate(video_audio_pairs[:100]):
            try:
                video_feat = self.extract_video_features(video_path)
                audio_feat = self.extract_audio_features(audio_path)
                all_video_features.append(video_feat.reshape(-1, video_feat.shape[-1]))
                all_audio_features.append(audio_feat.reshape(1, -1))
            except Exception as e:
                continue
        
        if all_video_features and all_audio_features:
            all_video_features = np.vstack(all_video_features)
            all_audio_features = np.vstack(all_audio_features)
            self.scaler_video.fit(all_video_features)
            self.scaler_audio.fit(all_audio_features)
            self.is_fitted = True
            print("✅ Scalers fitted successfully")
        else:
            print("⚠️ Could not fit scalers - using identity scaling")
            self.is_fitted = False

class GraphConstructor:
    """Construct graphs from audio-visual features - FIXED with enhanced validation"""
    
    def __init__(self, similarity_threshold=0.5):
        self.similarity_threshold = similarity_threshold
    
    def create_graph(self, video_features, audio_features):
        all_features = []
        node_types = []
        
        # Validate and debug video features
        if video_features.size == 0 or not np.any(video_features):
            print("Warning: video_features is empty or all zeros")
            video_features = np.zeros((1, 1024))
        for i, frame_feat in enumerate(video_features):
            if frame_feat.shape[0] != 1024:
                print(f"Warning: Invalid frame feature shape at index {i}: {frame_feat.shape}")
                frame_feat = np.pad(frame_feat, (0, 1024 - frame_feat.shape[0]), 'constant')
            all_features.append(frame_feat)
            node_types.append(0)
        
        # Validate and debug audio features
        if len(audio_features) > 0 and np.any(audio_features):
            if audio_features.shape[0] != 28:
                print(f"Warning: Invalid audio feature shape: {audio_features.shape}")
                audio_features = np.pad(audio_features, (0, 28 - audio_features.shape[0]), 'constant')
            audio_expanded = np.tile(audio_features, (1024 // len(audio_features)) + 1)[:1024]
            all_features.append(audio_expanded)
            node_types.append(1)
        
        if len(all_features) < 2:
            print("Warning: Insufficient features, using dummy data")
            all_features = [np.random.randn(1024) * 0.01, np.random.randn(1024) * 0.01]
            node_types = [0, 1]
        
        all_features = np.array(all_features)
        if all_features.shape[1] != 1024:
            print(f"Warning: all_features has invalid feature dimension: {all_features.shape}")
            all_features = np.pad(all_features, ((0, 0), (0, 1024 - all_features.shape[1])), 'constant')
        print(f"all_features shape: {all_features.shape}")
        n_nodes = len(all_features)
        
        edge_index = []
        edge_attr = []
        
        for i in range(n_nodes):
            for j in range(i + 1, n_nodes):
                similarity = self.compute_similarity(all_features[i], all_features[j])
                if similarity > self.similarity_threshold or node_types[i] != node_types[j]:
                    edge_index.extend([[i, j], [j, i]])
                    edge_attr.extend([similarity, similarity])
        
        if len(edge_index) == 0:
            print("Warning: No edges created, adding default edges")
            for i in range(1, n_nodes):
                edge_index.extend([[0, i], [i, 0]])
                edge_attr.extend([0.3, 0.3])
        
        # Validate edge_index
        edge_index = np.array(edge_index)
        if edge_index.size > 0 and (edge_index.max() >= n_nodes or edge_index.min() < 0):
            print(f"Warning: Invalid edge_index detected. Max: {edge_index.max()}, Nodes: {n_nodes}")
            edge_index = edge_index[edge_index < n_nodes]
        
        x = torch.FloatTensor(all_features)
        edge_index = torch.LongTensor(edge_index).t().contiguous() if edge_index.size > 0 else torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.FloatTensor(edge_attr) if edge_attr else torch.zeros(0)
        node_types = torch.LongTensor(node_types)
        batch = torch.zeros(n_nodes, dtype=torch.long)
        
        print(f"edge_index shape: {edge_index.shape}, x shape: {x.shape}, node_types shape: {node_types.shape}")
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, node_types=node_types, batch=batch)

    def compute_similarity(self, feat1, feat2):
        norm1 = np.linalg.norm(feat1)
        norm2 = np.linalg.norm(feat2)
        if norm1 == 0 or norm2 == 0:
            return 0.0
        return np.dot(feat1, feat2) / (norm1 * norm2)

class MultiModalGNN(nn.Module):
    """Graph Neural Network for multimodal deepfake detection"""
    
    def __init__(self, input_dim=1024, hidden_dim=256, num_classes=2, dropout=0.4):
        super(MultiModalGNN, self).__init__()
        self.node_type_embedding = nn.Embedding(2, 64)
        self.video_proj = nn.Linear(input_dim, hidden_dim)
        self.audio_proj = nn.Linear(input_dim, hidden_dim)
        self.gat_video = GATConv(hidden_dim + 64, hidden_dim, heads=4, dropout=dropout, concat=False)
        self.gat_audio = GATConv(hidden_dim + 64, hidden_dim, heads=4, dropout=dropout, concat=False)
        self.gcn = GCNConv(hidden_dim, hidden_dim // 2)
        self.proj = nn.Linear(hidden_dim // 2, hidden_dim // 2)
        self.cross_attention = nn.MultiheadAttention(hidden_dim // 2, num_heads=4, dropout=dropout, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
        # Validate initial parameters
        for name, param in self.named_parameters():
            if torch.isnan(param).any() or torch.isinf(param).any():
                print(f"Warning: Invalid parameter detected in {name}: {param}")
    
    def forward(self, data):
        try:
            x, edge_index, batch, node_types = data.x, data.edge_index, data.batch, data.node_types
            print(f"Forward pass - x shape: {x.shape}, edge_index shape: {edge_index.shape}, node_types shape: {node_types.shape}")
            
            # Validate node_types
            if not torch.all((node_types >= 0) & (node_types < 2)):
                print(f"Warning: Invalid node_types values: {node_types}")
                node_types = torch.clamp(node_types, 0, 1)
            
            type_emb = self.node_type_embedding(node_types)
            
            video_mask = node_types == 0
            audio_mask = node_types == 1
            
            if video_mask.sum() > x.size(0) or audio_mask.sum() > x.size(0):
                print(f"Warning: Mask size exceeds x size - video_mask: {video_mask.sum()}, audio_mask: {audio_mask.sum()}, x size: {x.size(0)}")
                video_mask = video_mask[:x.size(0)]
                audio_mask = audio_mask[:x.size(0)]
            
            x_projected = torch.zeros(x.size(0), self.video_proj.out_features, device=x.device)
            
            if video_mask.any():
                x_projected[video_mask] = F.relu(self.video_proj(x[video_mask]))
            if audio_mask.any():
                x_projected[audio_mask] = F.relu(self.audio_proj(x[audio_mask]))
            
            x_with_type = torch.cat([x_projected, type_emb], dim=1)
            
            x_gat = torch.zeros(x.size(0), self.gat_video.out_channels, device=x.device)
            
            if video_mask.any():
                x_gat[video_mask] = F.relu(self.gat_video(x_with_type[video_mask], edge_index))
            if audio_mask.any():
                x_gat[audio_mask] = F.relu(self.gat_audio(x_with_type[audio_mask], edge_index))
            
            x = F.relu(self.gcn(x_gat, edge_index))
            x = self.proj(x)
            
            num_graphs = batch.max().item() + 1
            total_nodes = x.size(0)
            expected_feature_dim = 128
            num_nodes_per_graph = (total_nodes + num_graphs - 1) // num_graphs
            
            padding_size = (num_graphs * num_nodes_per_graph) - total_nodes
            if padding_size > 0:
                padding = torch.zeros(padding_size, expected_feature_dim, device=x.device)
                x = torch.cat([x, padding], dim=0)
            
            total_nodes = x.size(0)
            num_nodes = total_nodes // num_graphs
            
            x_reshaped = x.view(num_graphs, num_nodes, expected_feature_dim).transpose(0, 1)
            
            attn_output, _ = self.cross_attention(x_reshaped, x_reshaped, x_reshaped)
            
            x = attn_output.transpose(0, 1).reshape(-1, expected_feature_dim)
            
            if padding_size > 0:
                batch_padded = torch.cat([batch, torch.full((padding_size,), -1, device=batch.device)])
                mask = batch_padded >= 0
                x_masked = x[mask]
                batch_masked = batch_padded[mask]
                graph_repr = torch.cat([global_mean_pool(x_masked, batch_masked), global_max_pool(x_masked, batch_masked)], dim=1)
            else:
                graph_repr = torch.cat([global_mean_pool(x, batch), global_max_pool(x, batch)], dim=1)
            
            return self.classifier(graph_repr)
        except Exception as e:
            print(f"Error in forward pass: {e}")
            raise

class FocalLoss(nn.Module):
    """Focal loss for imbalanced datasets"""
    def __init__(self, gamma=2.0, alpha=0.25):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return loss.mean()

def extract_label_from_entry(entry, video_path):
    """Extract label from metadata entry with multiple strategies"""
    if 'n_fakes' in entry:
        n_fakes = entry['n_fakes']
        if isinstance(n_fakes, int):
            return 1 if n_fakes > 0 else 0
    
    if 'fake_periods' in entry:
        fake_periods = entry['fake_periods']
        if isinstance(fake_periods, list):
            return 1 if len(fake_periods) > 0 else 0
    
    if 'label' in entry:
        label = entry['label']
        if isinstance(label, str):
            return 1 if label.lower() in ['fake', 'deepfake', '1'] else 0
        elif isinstance(label, int):
            return label
    
    if 'is_fake' in entry:
        return 1 if entry['is_fake'] else 0
    
    if 'original' in entry:
        return 0 if entry['original'] else 1
    
    print(f"⚠️ No valid label field found for {video_path}")
    return None

def create_filename_mapping(metadata):
    """Create a mapping of all possible filename variations to metadata entries"""
    filename_map = {}
    
    for entry in metadata:
        if not isinstance(entry, dict) or 'file' not in entry:
            continue
            
        metadata_file = entry['file']
        variations = [
            metadata_file,
            os.path.basename(metadata_file),
            os.path.splitext(os.path.basename(metadata_file))[0],
            metadata_file.replace('\\', '/'),
            metadata_file.replace('/', '\\'),
        ]
        
        for variation in variations:
            if variation not in filename_map:
                filename_map[variation] = []
            filename_map[variation].append(entry)
    
    return filename_map

def find_label_with_mapping(video_path, filename_map, dataset_path):
    """Find label using the filename mapping with multiple strategies"""
    video_rel_path = os.path.relpath(video_path, dataset_path)
    video_name = os.path.basename(video_path)
    video_name_no_ext = os.path.splitext(video_name)[0]
    
    search_keys = [
        video_rel_path,
        video_rel_path.replace('\\', '/'),
        video_rel_path.replace('/', '\\'),
        video_name,
        video_name_no_ext,
    ]
    
    for key in search_keys:
        if key in filename_map:
            entry = filename_map[key][0]
            label = extract_label_from_entry(entry, video_path)
            if label is not None:
                return label
    
    return None

def load_lavdf_dataset_improved(dataset_path, use_subset='train', max_clips=1000):
    """Improved dataset loading with better debugging and matching"""
    metadata_path = os.path.join(dataset_path, 'metadata.json')
    if not os.path.exists(metadata_path):
        print("❌ metadata.json not found!")
        return [], []
    
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
    filename_map = create_filename_mapping(metadata)
    
    subset_folder = os.path.join(dataset_path, use_subset)
    if not os.path.exists(subset_folder):
        print(f"❌ Subset folder '{use_subset}' not found")
        return [], []
    
    video_extensions = ['*.mp4', '*.avi', '*.mov', '*.mkv']
    all_videos = []
    for ext in video_extensions:
        all_videos.extend(glob.glob(os.path.join(subset_folder, '**', ext), recursive=True))
    
    if len(all_videos) > max_clips:
        random.shuffle(all_videos)
        all_videos = all_videos[:max_clips]
    
    video_audio_pairs = []
    labels = []
    matched_count = 0
    unmatched_files = []
    
    for video_path in all_videos:
        label = find_label_with_mapping(video_path, filename_map, dataset_path)
        
        if label is not None:
            video_audio_pairs.append((video_path, video_path))
            labels.append(label)
            matched_count += 1
        else:
            unmatched_files.append(video_path)
    
    print(f"✅ Successfully matched: {matched_count}/{len(all_videos)} files")
    if labels:
        real_count = labels.count(0)
        fake_count = labels.count(1)
        total = real_count + fake_count
        print(f"📊 Label distribution: Real={real_count} ({real_count/total*100:.1f}%), Fake={fake_count} ({fake_count/total*100:.1f}%)")
    
    return video_audio_pairs, labels

def load_all_subsets(dataset_path, max_clips=3000):
    """Load all subsets and combine - targeting 3000 total with fair distribution"""
    all_pairs = []
    all_labels = []
    clips_per_subset = max_clips // 3
    
    for subset in ['train', 'test', 'dev']:
        subset_path = os.path.join(dataset_path, subset)
        if os.path.exists(subset_path):
            pairs, labels = load_lavdf_dataset_improved(dataset_path, use_subset=subset, max_clips=clips_per_subset)
            if pairs and labels:
                all_pairs.extend(pairs)
                all_labels.extend(labels)
    
    if len(all_pairs) > max_clips:
        all_pairs = all_pairs[:max_clips]
        all_labels = all_labels[:max_clips]
    
    return all_pairs, all_labels

class DeepfakeDetector:
    """Main detector class"""
    
    def __init__(self, device='cuda'):
        self.device = device
        self.feature_extractor = AudioVisualFeatureExtractor(device)
        self.graph_constructor = GraphConstructor(similarity_threshold=0.7)
        try:
            self.model = MultiModalGNN()
            # Move to device with validation
            self.model = self.model.to(device)
            print(f"Model successfully moved to {device}")
        except Exception as e:
            print(f"Error initializing model on {device}: {e}")
            raise

    def prepare_data(self, video_audio_pairs, labels, feature_dir=None):
        """Prepare graph data from video-audio pairs or precomputed features"""
        graphs = []
        for i, ((video_path, audio_path), label) in enumerate(zip(video_audio_pairs, labels)):
            try:
                if feature_dir and os.path.exists(os.path.join(feature_dir, f"video_{i}.npy")):
                    video_features, audio_features = self.feature_extractor.load_features(feature_dir, i)
                else:
                    video_features = self.feature_extractor.extract_video_features(video_path)
                    audio_features = self.feature_extractor.extract_audio_features(audio_path)
                
                graph = self.graph_constructor.create_graph(video_features, audio_features)
                graph.y = torch.tensor([label], dtype=torch.long)
                graphs.append(graph)
            except Exception as e:
                print(f"Error processing {video_path}: {e}")
                continue
        return graphs
    
    def train(self, train_graphs, val_graphs, epochs=50, batch_size=4, accum_steps=2):
        """Train the model with gradient accumulation"""
        train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)
        
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=1e-3)
        criterion = FocalLoss(gamma=2.0, alpha=0.25)
        scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
        best_val_acc = 0.0
        
        for epoch in range(epochs):
            self.model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0
            
            for i, batch in enumerate(train_loader):
                try:
                    batch = batch.to(self.device)
                    optimizer.zero_grad()
                    outputs = self.model(batch)
                    loss = criterion(outputs, batch.y) / accum_steps
                    loss.backward()
                    
                    if (i + 1) % accum_steps == 0:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                        optimizer.step()
                        optimizer.zero_grad()
                    
                    train_loss += loss.item() * accum_steps
                    _, predicted = torch.max(outputs.data, 1)
                    train_total += batch.y.size(0)
                    train_correct += (predicted == batch.y).sum().item()
                except Exception as e:
                    print(f"Error in training batch {i}: {e}")
                    continue
            
            train_acc = 100 * train_correct / train_total
            
            self.model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            
            with torch.no_grad():
                for batch in val_loader:
                    try:
                        batch = batch.to(self.device)
                        outputs = self.model(batch)
                        loss = criterion(outputs, batch.y)
                        val_loss += loss.item()
                        _, predicted = torch.max(outputs.data, 1)
                        val_total += batch.y.size(0)
                        val_correct += (predicted == batch.y).sum().item()
                    except Exception as e:
                        print(f"Error in validation batch: {e}")
                        continue
            
            val_acc = 100 * val_correct / val_total
            scheduler.step()
            
            print(f'Epoch [{epoch+1}/{epochs}]')
            print(f'Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%')
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(self.model.state_dict(), 'best_model.pth')
                print(f'💾 New best model saved with accuracy: {val_acc:.2f}%')
            if val_acc >= 80.0:
                print(f"🎯 Target accuracy reached! Final accuracy: {val_acc:.2f}%")
                break
        
        return best_val_acc
    
    def evaluate(self, test_graphs):
        """Evaluate the model"""
        self.model.load_state_dict(torch.load('best_model.pth'))
        self.model.eval()
        test_loader = DataLoader(test_graphs, batch_size=4, shuffle=False)
        
        all_predictions = []
        all_labels = []
        
        with torch.no_grad():
            for batch in test_loader:
                try:
                    batch = batch.to(self.device)
                    outputs = self.model(batch)
                    _, predicted = torch.max(outputs.data, 1)
                    all_predictions.extend(predicted.cpu().numpy())
                    all_labels.extend(batch.y.cpu().numpy())
                except Exception as e:
                    print(f"Error in evaluation batch: {e}")
                    continue
        
        accuracy = accuracy_score(all_labels, all_predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predictions, average='binary')
        
        print(f"📊 Test Results: Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
        return accuracy, precision, recall, f1

def debug_metadata_structure(dataset_path):
    """Debug the metadata structure"""
    metadata_path = os.path.join(dataset_path, 'metadata.json')
    if not os.path.exists(metadata_path):
        print("❌ metadata.json not found!")
        return None
    
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
    if metadata:
        sample_keys = set()
        for entry in metadata[:10]:
            if isinstance(entry, dict):
                sample_keys.update(entry.keys())
        print(f"📊 Metadata entries: {len(metadata)}, Keys: {sample_keys}")
    
    return metadata

def main():
    print("🚀 DEEPFAKE DETECTOR - FIXED VERSION")
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"🖥️ Using device: {device}")
    
    dataset_path = r"C:\archive\LAV-DF"
    feature_dir = r"C:\Users\ARNAV\features"  
    if not os.path.exists(dataset_path):
        print(f"Dataset path does not exist: {dataset_path}")
        return
    if not os.path.exists(feature_dir):
        os.makedirs(feature_dir, exist_ok=True)
    
    print("\n📂 LOADING DATASET...")
    video_audio_pairs, labels = load_all_subsets(dataset_path, max_clips=3000)
    
    if len(video_audio_pairs) == 0:
        print("❌ No valid data loaded. Please check your dataset path and metadata.")
        return
        
    detector = DeepfakeDetector(device=device)
    detector.feature_extractor.fit_scalers(video_audio_pairs)
    real_count = labels.count(0)
    fake_count = labels.count(1)
    print(f"\n📊 FINAL DATASET STATISTICS: Total samples: {len(video_audio_pairs)}, Real: {real_count} ({real_count/(real_count+fake_count)*100:.1f}%), Fake: {fake_count} ({fake_count/(real_count+fake_count)*100:.1f}%)")
    
    if real_count == 0 or fake_count == 0:
        print("❌ Dataset is severely imbalanced. Check metadata parsing.")
        return
    
    print(f"\n⚙️ PRECOMPUTING FEATURES...")
    if not os.path.exists(feature_dir):
        detector.feature_extractor.save_features(video_audio_pairs, feature_dir)
    
    graphs = detector.prepare_data(video_audio_pairs, labels, feature_dir=feature_dir)
    if len(graphs) < 10:
        print(f"⚠️ Only {len(graphs)} graphs created")
        return
    
    train_graphs, temp_graphs = train_test_split(graphs, test_size=0.3, random_state=42)
    val_graphs, test_graphs = train_test_split(temp_graphs, test_size=0.5, random_state=42)
    print(f"Train: {len(train_graphs)}, Val: {len(val_graphs)}, Test: {len(test_graphs)}")
    
    best_acc = detector.train(train_graphs, val_graphs, epochs=50, batch_size=4, accum_steps=2)
    print(f"Best validation accuracy: {best_acc:.2f}%")
    
    if test_graphs:
        accuracy, precision, recall, f1 = detector.evaluate(test_graphs)
        print(f"Final Test Accuracy: {accuracy:.4f}")

if __name__ == "__main__":
    main()

🚀 DEEPFAKE DETECTOR - FIXED VERSION
🖥️ Using device: cuda

📂 LOADING DATASET...
✅ Successfully matched: 1000/1000 files
📊 Label distribution: Real=252 (25.2%), Fake=748 (74.8%)
✅ Successfully matched: 1000/1000 files
📊 Label distribution: Real=273 (27.3%), Fake=727 (72.7%)
✅ Successfully matched: 1000/1000 files
📊 Label distribution: Real=258 (25.8%), Fake=742 (74.2%)
Error initializing model on cuda: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
print('Done')