In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool
from torch_geometric.data import Data, DataLoader, Batch
import cv2
import numpy as np
import os
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import librosa
import warnings
import random
import glob
import json
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

# Enable CUDA if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
    os.environ["TORCH_USE_CUDA_DSA"] = "1"
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""

warnings.filterwarnings('ignore')

def custom_collate(batch):
    return Batch.from_data_list(batch)
    
def validate_edge_index(data: Data) -> Data:
    if data.edge_index.numel() == 0:
        return data
    max_nodes = data.x.size(0)
    mask = (data.edge_index[0] < max_nodes) & (data.edge_index[1] < max_nodes)
    if not mask.all():
        data.edge_index = data.edge_index[:, mask]
        if data.edge_attr is not None and data.edge_attr.size(0) == mask.size(0):
            data.edge_attr = data.edge_attr[mask]
    return data

class AudioVisualFeatureExtractor:
    def __init__(self, device='cpu', max_samples=750):
        self.device = device
        self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
        self.scaler_video = StandardScaler()
        self.scaler_audio = StandardScaler()
        self.is_fitted = False
        self.max_samples = max_samples
        self.audio_feature_dim = 35
        
    def extract_video_features(self, video_path, max_frames=30):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            return np.zeros((max_frames, 1024))
        facial_features = []
        flow_features = []
        prev_frame = None
        frame_count = 0
        while frame_count < max_frames:
            ret, frame = cap.read()
            if not ret:
                break
            frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            faces = self.face_cascade.detectMultiScale(frame_gray, 1.1, 4)
            if len(faces) > 0:
                x, y, w, h = faces[0]
                face_roi = frame_gray[y:y+h, x:x+w]
                if face_roi.size > 0:
                    face_resized = cv2.resize(face_roi, (32, 32))
                    face_features = face_resized.flatten()
                    if len(face_features) < 512:
                        face_features = np.pad(face_features, (0, 512 - len(face_features)), 'constant')
                    else:
                        face_features = face_features[:512]
                    facial_features.append(face_features)
                else:
                    facial_features.append(np.zeros(512))
            else:
                facial_features.append(np.zeros(512))
            if prev_frame is not None:
                flow = cv2.calcOpticalFlowFarneback(prev_frame, frame_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
                flow_flat = flow.flatten()
                if len(flow_flat) >= 512:
                    flow_features.append(flow_flat[:512])
                else:
                    padded_flow = np.pad(flow_flat, (0, 512 - len(flow_flat)), 'constant')
                    flow_features.append(padded_flow)
            else:
                flow_features.append(np.zeros(512))
            prev_frame = frame_gray
            frame_count += 1
        cap.release()
        if len(facial_features) == 0:
            facial_features = [np.zeros(512)]
        if len(flow_features) == 0:
            flow_features = [np.zeros(512)]
        facial_features = np.array(facial_features)
        flow_features = np.array(flow_features)
        if len(facial_features) < max_frames:
            pad_shape = ((0, max_frames - len(facial_features)), (0, 0))
            facial_features = np.pad(facial_features, pad_shape, 'constant')
            flow_features = np.pad(flow_features, pad_shape, 'constant')
        else:
            facial_features = facial_features[:max_frames]
            flow_features = flow_features[:max_frames]
        video_features = np.concatenate([facial_features, flow_features], axis=1)
        if self.is_fitted:
            video_features = self.scaler_video.transform(video_features)
        return video_features
    
    def extract_audio_features(self, audio_path, max_length=5):
        try:
            if not os.path.exists(audio_path):
                return np.zeros(self.audio_feature_dim)
            import subprocess
            ffprobe_cmd = f'ffprobe -i "{audio_path}" -show_streams -select_streams a -loglevel error'
            result = subprocess.run(ffprobe_cmd, capture_output=True, text=True, shell=True)
            if not result.stdout:
                return np.zeros(self.audio_feature_dim)
            if audio_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
                temp_audio = os.path.join(os.environ.get('TEMP', os.getcwd()), "temp_audio.wav")
                if os.name == 'nt':
                    null_device = 'NUL'
                else:
                    null_device = '/dev/null'
                cmd = f'ffmpeg -i "{audio_path}" -vn -acodec pcm_s16le -ar 16000 -ac 1 "{temp_audio}" -y'
                result = os.system(f'{cmd} > {null_device} 2>&1')
                if result != 0:
                    return np.zeros(self.audio_feature_dim)
                if not os.path.exists(temp_audio):
                    return np.zeros(self.audio_feature_dim)
                y, sr = librosa.load(temp_audio, sr=16000, duration=max_length)
                os.remove(temp_audio)
            else:
                y, sr = librosa.load(audio_path, sr=16000, duration=max_length)
            if len(y) == 0 or np.all(y == 0):
                return np.zeros(self.audio_feature_dim)
            if len(y) < sr * 0.5:
                y = np.pad(y, (0, int(sr * 0.5) - len(y)), 'constant')
            if random.random() < 0.3:
                noise = np.random.normal(0, 0.005, len(y))
                y = y + noise
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)
            spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr)
            spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            zcr = librosa.feature.zero_crossing_rate(y)
            if mfccs.shape[1] == 0 or spectral_centroids.shape[1] == 0:
                return np.zeros(self.audio_feature_dim)
            features = np.hstack([
                np.mean(mfccs, axis=1),
                np.mean(spectral_centroids, axis=1),
                np.mean(spectral_rolloff, axis=1),
                np.mean(chroma, axis=1),
                np.mean(zcr, axis=1),
            ])
            expected_size = self.audio_feature_dim
            if features.shape[0] != expected_size:
                features = np.pad(features, (0, expected_size - features.shape[0]), 'constant') if features.shape[0] < expected_size else features[:expected_size]
            if self.is_fitted:
                features = self.scaler_audio.transform(features.reshape(1, -1)).flatten()
            return features
        except Exception:
            return np.zeros(self.audio_feature_dim)
    
    def fit_scalers(self, video_audio_pairs):
        all_video_features = []
        all_audio_features = []
        for i, (video_path, audio_path) in enumerate(video_audio_pairs[:self.max_samples]):
            try:
                video_feat = self.extract_video_features(video_path)
                audio_feat = self.extract_audio_features(audio_path)
                all_video_features.append(video_feat)
                all_audio_features.append(audio_feat)
            except Exception:
                continue
        if all_video_features and all_audio_features:
            all_video_features = np.array(all_video_features)
            all_audio_features = np.array(all_audio_features)
            if all_video_features.shape[0] != self.max_samples or all_audio_features.shape[0] != self.max_samples:
                min_samples = min(all_video_features.shape[0], all_audio_features.shape[0], self.max_samples)
                all_video_features = all_video_features[:min_samples]
                all_audio_features = all_audio_features[:min_samples]
            all_video_features = all_video_features.reshape(-1, all_video_features.shape[-1])
            all_audio_features = all_audio_features.reshape(-1, all_audio_features.shape[-1])
            self.scaler_video.fit(all_video_features)
            self.scaler_audio.fit(all_audio_features)
            self.is_fitted = True
    
    def save_features(self, video_audio_pairs, feature_dir):
        os.makedirs(feature_dir, exist_ok=True)
        saved_count = 0
        for i, (video_path, audio_path) in enumerate(video_audio_pairs[:self.max_samples]):
            try:
                video_features = self.extract_video_features(video_path)
                audio_features = self.extract_audio_features(audio_path)
                if video_features.shape != (30, 1024):
                    video_features = np.zeros((30, 1024))
                if audio_features.shape != (self.audio_feature_dim,):
                    audio_features = np.zeros(self.audio_feature_dim)
                video_file = os.path.join(feature_dir, f"video_{i}.npy")
                audio_file = os.path.join(feature_dir, f"audio_{i}.npy")
                np.save(video_file, video_features)
                np.save(audio_file, audio_features)
                saved_count += 1
            except Exception:
                continue

class GraphConstructor:
    def __init__(self, similarity_threshold=0.7, device='cpu'):
        self.similarity_threshold = similarity_threshold
        self.device = device
    
    def _finalise_edges(self, edge_index_list, edge_attr_list, n_nodes):
        if not edge_index_list:
            return torch.zeros((2, 0), dtype=torch.long), torch.zeros(0)
        ei = np.asarray(edge_index_list, dtype=np.int64)
        ea = np.asarray(edge_attr_list, dtype=np.float32)
        valid = (ei[:, 0] < n_nodes) & (ei[:, 1] < n_nodes)
        ei = ei[valid]
        ea = ea[valid] if ea.shape[0] == valid.size else ea[:valid.sum()]
        if ei.size == 0:
            return torch.zeros((2, 0), dtype=torch.long), torch.zeros(0)
        return torch.from_numpy(ei).t().contiguous(), torch.from_numpy(ea)
        
    def create_graph(self, video_features: np.ndarray, audio_features: np.ndarray) -> Data:
        all_features, node_types = [], []
        for frame in video_features:
            all_features.append(frame)
            node_types.append(0)
        audio_added = False
        if audio_features.size > 0 and np.any(audio_features) and audio_features.shape[0] == 35:
            audio_expanded = np.tile(audio_features, (1024 // audio_features.size + 1))[:1024]
            all_features.append(audio_expanded)
            node_types.append(1)
            audio_added = True
        else:
            all_features.append(np.zeros(1024))
            node_types.append(1)
            audio_added = True
        if len(all_features) < 2:
            all_features.extend([np.zeros(1024), np.zeros(1024)])
            node_types.extend([0, 1])
        all_features = np.asarray(all_features, dtype=np.float32)
        n_nodes = all_features.shape[0]
        edge_index_list, edge_attr_list = [], []
        audio_node_idx = n_nodes - 1 if audio_added else None
        for i in range(n_nodes):
            for j in range(i + 1, n_nodes):
                if node_types[i] != node_types[j]:
                    edge_index_list.extend([[i, j], [j, i]])
                    edge_attr_list.extend([0.5, 0.5])
                else:
                    sim = self._cosine(all_features[i], all_features[j])
                    if sim > self.similarity_threshold:
                        edge_index_list.extend([[i, j], [j, i]])
                        edge_attr_list.extend([sim, sim])
        if audio_added and audio_node_idx is not None:
            audio_edges = [[i, audio_node_idx] for i in range(n_nodes-1) if node_types[i] == 0]
            audio_edges.extend([[audio_node_idx, i] for i in range(n_nodes-1) if node_types[i] == 0])
            edge_index_list.extend(audio_edges)
            edge_attr_list.extend([0.5] * len(audio_edges))
        edge_index, edge_attr = self._finalise_edges(edge_index_list, edge_attr_list, n_nodes)
        if not edge_index_list:
            for j in range(1, min(n_nodes, 2)):
                edge_index_list.extend([[0, j], [j, 0]])
                edge_attr_list.extend([0.0, 0.0])
            edge_index, edge_attr = self._finalise_edges(edge_index_list, edge_attr_list, n_nodes)
        data = Data(
            x=torch.from_numpy(all_features).to(torch.device(self.device)),
            edge_index=edge_index.to(torch.device(self.device)),
            edge_attr=edge_attr.to(torch.device(self.device)),
            node_types=torch.tensor(node_types, dtype=torch.long).to(torch.device(self.device)),
            batch=torch.zeros(n_nodes, dtype=torch.long).to(torch.device(self.device)),
        )
        data = validate_edge_index(data)
        return data

    @staticmethod
    def _cosine(a, b):
        na, nb = np.linalg.norm(a), np.linalg.norm(b)
        return 0.0 if na == 0 or nb == 0 else float(np.dot(a, b) / (na * nb))

class MultiModalGNN(nn.Module):
    def __init__(self, input_dim: int = 1024, hidden_dim: int = 512, num_classes: int = 2, dropout: float = 0.3, device: str = 'cpu'):
        super().__init__()
        self.device = torch.device(device)
        self.node_type_embedding = nn.Embedding(2, 128).to(self.device)
        self.video_proj = nn.Linear(input_dim, hidden_dim).to(self.device)
        self.audio_proj = nn.Linear(input_dim, hidden_dim).to(self.device)
        self.gat_video = GATConv(hidden_dim + 128, hidden_dim, heads=8, dropout=dropout, concat=False).to(self.device)
        self.gat_audio = GATConv(hidden_dim + 128, hidden_dim, heads=8, dropout=dropout, concat=False).to(self.device)
        self.gcn = GCNConv(hidden_dim, hidden_dim // 2).to(self.device)
        self.proj = nn.Linear(hidden_dim // 2, hidden_dim // 2).to(self.device)
        self.attn = nn.MultiheadAttention(hidden_dim // 2, num_heads=8, dropout=dropout, batch_first=True).to(self.device)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        ).to(self.device)
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0, std=0.02)

    def forward(self, data: Data):
        data = data.to(self.device)
        x, ei, batch, node_types = data.x.to(self.device), data.edge_index.to(self.device), data.batch.to(self.device), data.node_types.to(self.device)
        try:
            if node_types.min() < 0 or node_types.max() > 1:
                return torch.zeros((batch.max().item() + 1, 2), device=self.device)
            if hasattr(data, 'y') and (data.y.min() < 0 or data.y.max() > 1):
                return torch.zeros((batch.max().item() + 1, 2), device=self.device)
            if ei.numel() > 0 and ei.max() >= x.size(0):
                mask = (ei[0] < x.size(0)) & (ei[1] < x.size(0))
                ei = ei[:, mask]
                if data.edge_attr is not None and data.edge_attr.size(0) == mask.size(0):
                    data.edge_attr = data.edge_attr[mask]
                data.edge_index = ei
        except Exception:
            return torch.zeros((batch.max().item() + 1, 2), device=self.device)
        type_emb = self.node_type_embedding(node_types)
        video_mask = node_types == 0
        audio_mask = node_types == 1
        x_proj = torch.zeros(x.size(0), self.video_proj.out_features, device=self.device)
        if video_mask.any():
            x_proj[video_mask] = F.relu(self.video_proj(x[video_mask]))
        if audio_mask.any():
            x_proj[audio_mask] = F.relu(self.audio_proj(x[audio_mask]))
        x = torch.cat([x_proj, type_emb], dim=1)
        h = torch.zeros_like(x_proj, device=self.device)
        if video_mask.any():
            video_indices = torch.where(video_mask)[0]
            video_mask_edge = torch.isin(ei[0], video_indices) & torch.isin(ei[1], video_indices)
            video_edge_index = ei[:, video_mask_edge]
            if video_edge_index.numel() > 0:
                video_node_map = torch.zeros(x.size(0), dtype=torch.long, device=self.device)
                video_node_map[video_indices] = torch.arange(video_indices.size(0), device=self.device)
                video_edge_index = video_node_map[video_edge_index]
                try:
                    h[video_mask] = F.relu(self.gat_video(x[video_mask], video_edge_index))
                except Exception:
                    pass
        if audio_mask.any():
            audio_indices = torch.where(audio_mask)[0]
            audio_mask_edge = torch.isin(ei[0], audio_indices) | torch.isin(ei[1], audio_indices)
            audio_edge_index = ei[:, audio_mask_edge]
            if audio_edge_index.numel() > 0:
                audio_node_map = torch.zeros(x.size(0), dtype=torch.long, device=self.device)
                audio_node_map[audio_indices] = torch.arange(audio_indices.size(0), device=self.device)
                audio_edge_index = audio_node_map[audio_edge_index]
                try:
                    h[audio_mask] = F.relu(self.gat_audio(x[audio_mask], audio_edge_index))
                except Exception:
                    pass
            else:
                audio_node = audio_indices[0]
                video_node = torch.where(video_mask)[0][0]
                audio_edge_index = torch.tensor([[audio_node, video_node], [video_node, audio_node]], device=self.device)
                try:
                    h[audio_mask] = F.relu(self.gat_audio(x[audio_mask], audio_edge_index))
                except Exception:
                    pass
        try:
            h = F.relu(self.gcn(h, ei))
        except Exception:
            return torch.zeros((batch.max().item() + 1, 2), device=self.device)
        h = self.proj(h)
        num_graphs = int(batch.max().item() + 1) if batch.numel() > 0 else 1
        nodes_per_graph = int(np.ceil(h.size(0) / num_graphs))
        pad_nodes = nodes_per_graph * num_graphs - h.size(0)
        if pad_nodes:
            h = torch.cat([h, torch.zeros(pad_nodes, h.size(1), device=self.device)])
            node_types = torch.cat([node_types, torch.zeros(pad_nodes, dtype=torch.long, device=self.device)])
            batch = torch.cat([batch, torch.full((pad_nodes,), -1, dtype=torch.long, device=self.device)])
        h = h.view(num_graphs, nodes_per_graph, -1)
        try:
            h, _ = self.attn(h, h, h)
        except Exception:
            return torch.zeros((num_graphs, 2), device=self.device)
        h = h.reshape(-1, h.size(-1))
        valid_mask = batch >= 0
        h_valid = h[valid_mask]
        batch_valid = batch[valid_mask]
        g_repr = torch.cat([
            global_mean_pool(h_valid, batch_valid),
            global_max_pool(h_valid, batch_valid)
        ], dim=1)
        return self.classifier(g_repr)

class FocalLoss(nn.Module):
    def __init__(self, gamma: float = 1.5, alpha: float = None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets):
        ce = F.cross_entropy(inputs, targets, reduction="none", weight=self.alpha)
        pt = torch.exp(-ce)
        return ((1 - pt) ** self.gamma * ce).mean()

def extract_label_from_entry(entry, video_path):
    if 'n_fakes' in entry:
        n_fakes = entry['n_fakes']
        if isinstance(n_fakes, int):
            return 1 if n_fakes > 0 else 0
    if 'fake_periods' in entry:
        fake_periods = entry['fake_periods']
        if isinstance(fake_periods, list):
            return 1 if len(fake_periods) > 0 else 0
    if 'label' in entry:
        label = entry['label']
        if isinstance(label, str):
            return 1 if label.lower() in ['fake', 'deepfake', '1'] else 0
        elif isinstance(label, int):
            return label
    if 'is_fake' in entry:
        return 1 if entry['is_fake'] else 0
    if 'original' in entry:
        return 0 if entry['original'] else 1
    return None

def create_filename_mapping(metadata):
    filename_map = {}
    for entry in metadata:
        if not isinstance(entry, dict) or 'file' not in entry:
            continue
        metadata_file = entry['file']
        variations = [
            metadata_file,
            os.path.basename(metadata_file),
            os.path.splitext(os.path.basename(metadata_file))[0],
            metadata_file.replace('\\', '/'),
            metadata_file.replace('/', '\\'),
        ]
        for variation in variations:
            if variation not in filename_map:
                filename_map[variation] = []
            filename_map[variation].append(entry)
    return filename_map

def find_label_with_mapping(video_path, filename_map, dataset_path):
    video_rel_path = os.path.relpath(video_path, dataset_path)
    video_name = os.path.basename(video_path)
    video_name_no_ext = os.path.splitext(video_name)[0]
    search_keys = [
        video_rel_path,
        video_rel_path.replace('\\', '/'),
        video_rel_path.replace('/', '\\'),
        video_name,
        video_name_no_ext,
    ]
    for key in search_keys:
        if key in filename_map:
            entry = filename_map[key][0]
            label = extract_label_from_entry(entry, video_path)
            if label is not None:
                return label
    return None

def load_lavdf_dataset_improved(dataset_path, use_subset='train', max_clips=750):
    metadata_path = os.path.join(dataset_path, 'metadata.json')
    if not os.path.exists(metadata_path):
        return [], []
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    filename_map = create_filename_mapping(metadata)
    subset_folder = os.path.join(dataset_path, use_subset)
    if not os.path.exists(subset_folder):
        return [], []
    video_extensions = ['*.mp4', '*.avi', '*.mov', '*.mkv']
    all_videos = []
    for ext in video_extensions:
        all_videos.extend(glob.glob(os.path.join(subset_folder, '**', ext), recursive=True))
    if len(all_videos) > max_clips:
        random.shuffle(all_videos)
        all_videos = all_videos[:max_clips]
    video_audio_pairs = []
    labels = []
    matched_count = 0
    for video_path in all_videos:
        label = find_label_with_mapping(video_path, filename_map, dataset_path)
        if label is not None:
            video_audio_pairs.append((video_path, video_path))
            labels.append(label)
            matched_count += 1
    return video_audio_pairs, labels

def load_all_subsets(dataset_path, max_clips=750):
    all_pairs = []
    all_labels = []
    clips_per_subset = max_clips // 3
    for subset in ['train', 'test', 'dev']:
        subset_path = os.path.join(dataset_path, subset)
        if os.path.exists(subset_path):
            pairs, labels = load_lavdf_dataset_improved(dataset_path, use_subset=subset, max_clips=clips_per_subset)
            if pairs and labels:
                all_pairs.extend(pairs)
                all_labels.extend(labels)
    real_pairs = [p for p, l in zip(all_pairs, all_labels) if l == 0]
    real_labels = [0] * len(real_pairs)
    fake_pairs = [p for p, l in zip(all_pairs, all_labels) if l == 1]
    fake_labels = [1] * len(fake_pairs)
    if len(real_pairs) < len(fake_pairs):
        oversample_indices = np.random.choice(len(real_pairs), size=len(fake_pairs) - len(real_pairs), replace=True)
        all_pairs.extend([real_pairs[i] for i in oversample_indices])
        all_labels.extend([0] * len(oversample_indices))
    elif len(fake_pairs) < len(real_pairs):
        oversample_indices = np.random.choice(len(fake_pairs), size=len(real_pairs) - len(fake_pairs), replace=True)
        all_pairs.extend([fake_pairs[i] for i in oversample_indices])
        all_labels.extend([1] * len(oversample_indices))
    if len(all_pairs) > max_clips:
        indices = np.random.choice(len(all_pairs), size=max_clips, replace=False)
        all_pairs = [all_pairs[i] for i in indices]
        all_labels = [all_labels[i] for i in indices]
    return all_pairs, all_labels

class DeepfakeDetector:
    def __init__(self, device: str = "cpu"):
        self.device = device
        self.feature_extractor = AudioVisualFeatureExtractor(device)
        self.graph_constructor = GraphConstructor(similarity_threshold=0.7, device=device)
        self.model = MultiModalGNN(device=device).to(device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, weight_decay=5e-4)
        self.criterion = FocalLoss(gamma=1.5, alpha=None)

    def prepare_data(self, video_audio_pairs, labels, feature_dir=None):
        graphs = []
        for i, ((v_path, a_path), label) in enumerate(zip(video_audio_pairs, labels)):
            video_features = self.feature_extractor.extract_video_features(v_path)
            audio_features = self.feature_extractor.extract_audio_features(a_path)
            g = self.graph_constructor.create_graph(video_features, audio_features)
            g = validate_edge_index(g)
            g.y = torch.tensor([label], dtype=torch.long)
            if g.edge_index.numel() > 0 and g.edge_index.max().item() >= g.x.size(0):
                continue
            graphs.append(g)
        return graphs
        
    def _clean_batch(self, batch):
        with torch.no_grad():
            max_node = batch.x.size(0)
            mask = (batch.edge_index[0] < max_node) & (batch.edge_index[1] < max_node)
            if not mask.all():
                batch.edge_index = batch.edge_index[:, mask]
                if batch.edge_attr is not None and batch.edge_attr.size(0) == mask.size(0):
                    batch.edge_attr = batch.edge_attr[mask]
        return batch
    
    def train(self, train_graphs, val_graphs, epochs: int = 100, batch_size: int = 8, accum_steps: int = 4):
        labels = [g.y.item() for g in train_graphs]
        real_count = labels.count(0)
        fake_count = labels.count(1)
        total = real_count + fake_count
        if real_count > 0 and fake_count > 0:
            weight_real = total / (2 * real_count)
            weight_fake = total / (2 * fake_count)
            self.criterion.alpha = torch.tensor([weight_real, weight_fake], device=self.device)
        else:
            self.criterion.alpha = None
        train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
        val_loader = DataLoader(val_graphs, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
        
        # Learning rate scheduler with warmup
        def lr_lambda(epoch):
            if epoch < 10:  # Warmup for first 10 epochs
                return (epoch + 1) / 10
            return 1.0
        scheduler = CosineAnnealingLR(self.optimizer, T_max=epochs-10)
        warmup_scheduler = LambdaLR(self.optimizer, lr_lambda)
        
        best_val_acc = 0.0
        metrics_file = os.path.join("C:\\Users\\ARNAV", "metrics.txt")
        with open(metrics_file, 'w') as f:
            f.write("Epoch,Train Accuracy,Validation Accuracy\n")
        
        for epoch in range(epochs):
            self.model.train()
            train_loss, train_correct, train_total = 0.0, 0, 0
            for i, batch in enumerate(train_loader):
                try:
                    batch = batch.to(self.device)
                    batch = self._clean_batch(batch)
                    batch.y = batch.y.clamp(0, 1)
                    self.optimizer.zero_grad()
                    outputs = self.model(batch)
                    if outputs.size(0) != batch.y.size(0):
                        continue
                    loss = self.criterion(outputs, batch.y) / accum_steps
                    loss.backward()
                    if (i + 1) % accum_steps == 0:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                        self.optimizer.step()
                        self.optimizer.zero_grad()
                    train_loss += loss.item() * accum_steps
                    train_total += batch.y.size(0)
                    train_correct += (outputs.argmax(dim=1) == batch.y).sum().item()
                except Exception:
                    continue
            train_acc = 100 * train_correct / max(1, train_total)
            self.model.eval()
            val_loss, val_correct, val_total = 0.0, 0, 0
            with torch.no_grad():
                for batch in val_loader:
                    try:
                        batch = batch.to(self.device)
                        batch = self._clean_batch(batch)
                        outputs = self.model(batch)
                        loss = self.criterion(outputs, batch.y)
                        val_loss += loss.item()
                        val_total += batch.y.size(0)
                        val_correct += (outputs.argmax(dim=1) == batch.y).sum().item()
                    except Exception:
                        continue
            val_acc = 100 * val_correct / max(1, val_total)
            print(f"Epoch {epoch+1}/{epochs}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
            with open(metrics_file, 'a') as f:
                f.write(f"{epoch+1},{train_acc:.2f},{val_acc:.2f}\n")
            if epoch < 10:
                warmup_scheduler.step()
            else:
                scheduler.step()
            if val_acc > best_val_acc and val_total > 0:
                best_val_acc = val_acc
                torch.save(self.model.state_dict(), "best_model.pth", _use_new_zipfile_serialization=True)
            if val_acc >= 80.0:
                print(f"Reached validation accuracy of {val_acc:.2f}%, stopping training.")
                break
        return best_val_acc

    def evaluate(self, test_graphs):
        try:
            self.model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
        except (FileNotFoundError, RuntimeError):
            print("Error: Could not load model weights. Returning zero metrics.")
            return 0.0, 0.0, 0.0, 0.0
        self.model.eval()
        test_loader = DataLoader(test_graphs, batch_size=8, shuffle=False, collate_fn=custom_collate)
        all_preds, all_labels = [], []
        with torch.no_grad():
            for i, batch in enumerate(test_loader):
                try:
                    batch = batch.to(self.device)
                    batch = self._clean_batch(batch)
                    batch.y = batch.y.clamp(0, 1)
                    outputs = self.model(batch)
                    all_preds.extend(outputs.argmax(dim=1).cpu().numpy())
                    all_labels.extend(batch.y.cpu().numpy())
                except Exception:
                    continue
        acc = accuracy_score(all_labels, all_preds) if all_labels else 0.0
        prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="binary") if all_labels else (0.0, 0.0, 0.0, None)
        print(f"Test Metrics: Accuracy: {acc*100:.2f}%, Precision: {prec*100:.2f}%, Recall: {rec*100:.2f}%, F1-Score: {f1*100:.2f}%")
        metrics_file = os.path.join("C:\\Users\\ARNAV", "metrics.txt")
        with open(metrics_file, 'a') as f:
            f.write(f"Test Metrics: Accuracy: {acc*100:.2f}%, Precision: {prec*100:.2f}%, Recall: {rec*100:.2f}%, F1-Score: {f1*100:.2f}%\n")
        return acc, prec, rec, f1
    
def debug_metadata_structure(dataset_path):
    metadata_path = os.path.join(dataset_path, 'metadata.json')
    if not os.path.exists(metadata_path):
        return None
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    return metadata

def main():
    dataset_path = r"C:\archive\LAV-DF"
    feature_dir = r"C:\Users\ARNAV\features"  
    if not os.path.exists(dataset_path):
        print(f"Error: Dataset path {dataset_path} does not exist.")
        return
    if not os.path.exists(feature_dir):
        os.makedirs(feature_dir, exist_ok=True)
    video_audio_pairs, labels = load_all_subsets(dataset_path, max_clips=750)
    if len(video_audio_pairs) == 0:
        print("Error: No video-audio pairs loaded. Check metadata structure.")
        debug_metadata_structure(dataset_path)
        return
    detector = DeepfakeDetector(device=device)
    detector.feature_extractor.fit_scalers(video_audio_pairs)
    real_count = labels.count(0)
    fake_count = labels.count(1)
    if real_count == 0 or fake_count == 0:
        print("Error: Dataset contains only one class. Check metadata structure.")
        debug_metadata_structure(dataset_path)
        return
    detector.feature_extractor.save_features(video_audio_pairs, feature_dir)
    graphs = detector.prepare_data(video_audio_pairs, labels, feature_dir=feature_dir)
    if len(graphs) < 10:
        print("Error: Too few graphs generated. Check feature extraction.")
        return
    train_graphs, temp_graphs = train_test_split(graphs, test_size=0.3, random_state=42, stratify=labels)
    val_graphs, test_graphs = train_test_split(temp_graphs, test_size=0.3333, random_state=42, stratify=[labels[i] for i in range(len(labels)) if graphs[i] in temp_graphs])
    best_acc = detector.train(train_graphs, val_graphs, epochs=100, batch_size=8, accum_steps=4)
    print(f"Best Validation Accuracy: {best_acc:.2f}%")
    if test_graphs:
        accuracy, precision, recall, f1 = detector.evaluate(test_graphs)

if __name__ == "__main__":
    main()