In [None]:
import torch
from argparse import Namespace
import warnings
import os

warnings.filterwarnings("ignore")
torch.manual_seed(42)
torch.set_float32_matmul_precision("medium")

print("=" * 80)
print("PLFD Deepfake Detection - Demo Training")
print("=" * 80)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
print("=" * 80)

# Phoneme-Level Deepfake Detection Training Demo

This notebook demonstrates training the PLFD model for deepfake detection.

## Configuration

**Setup Requirements:**
1. Phoneme model checkpoint: `Best Epoch 42 Validation 0.407.ckpt` (in project root)
2. Vocab files: `vocab_phoneme/` directory with 9 language JSON files
3. HuggingFace token for dataset access (optional, for real data)

**Choose Your Data Source:**
- `USE_REAL_DATA = False` → Dummy random data (fast, for testing)
- `USE_REAL_DATA = True` → ASVspoof 2019 LA dataset from HuggingFace

In [None]:
# ============================================================================
# CONFIGURATION
# ============================================================================

# Choose data source
USE_REAL_DATA = False  # Set to True to use real ASVspoof data

# HuggingFace token (optional, only needed for private datasets)
HF_TOKEN = "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"

# Training settings
NUM_EPOCHS = 4
BATCH_SIZE = 3
NUM_TRAIN_SAMPLES = 20  # Small for demo

print(f"Data source: {'Real ASVspoof data' if USE_REAL_DATA else 'Dummy synthetic data'}")
print(f"Training epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Training samples: {NUM_TRAIN_SAMPLES}")

## Setup Paths

In [None]:
# Auto-detect paths (works locally and on RunPod)
project_root = os.path.abspath(".")
pretrained_path = os.path.join(project_root, "Best Epoch 42 Validation 0.407.ckpt")
vocab_path = os.path.join(project_root, "vocab_phoneme")

print(f"Project root: {project_root}")
print(f"Checkpoint: {pretrained_path}")
print(f"Checkpoint exists: {os.path.exists(pretrained_path)}")
print(f"Vocab path: {vocab_path}")
print(f"Vocab exists: {os.path.exists(vocab_path)}")

if not os.path.exists(pretrained_path):
    print("\n⚠️  ERROR: Checkpoint not found!")
    print("Download from: https://drive.google.com/file/d/1SbqynkUQxxlhazklZz9OgcVK7Fl2aT-z/view?usp=drive_link")

## Load Phoneme Recognition Model

In [None]:
from phoneme_GAT.phoneme_model import BaseModule, load_phoneme_model, optim_param

network_param = Namespace(
    network_name="WavLM",
    pretrained_path=pretrained_path,
    freeze=True,
    freeze_transformer=True,
    eos_token="</s>",
    bos_token="<s>",
    unk_token="<unk>",
    pad_token="<pad>",
    word_delimiter_token="|",
    vocab_size=200,
)

total_num_phonemes = 687  # 198 or 687

print("Loading phoneme recognition model...")
phoneme_model = load_phoneme_model(
    network_name=network_param.network_name,
    pretrained_path=network_param.pretrained_path,
    total_num_phonemes=total_num_phonemes,
)

assert len(phoneme_model.tokenizer.total_phonemes) == total_num_phonemes
print(f"✓ Phoneme model loaded ({total_num_phonemes} phonemes)")

## Test Phoneme Model

In [None]:
from phoneme_GAT.modules import Phoneme_GAT_lit, Phoneme_GAT

print("Creating audio model...")
audio_model = Phoneme_GAT(
    backbone='wavlm',
    use_raw=0,
    use_GAT=1,
    n_edges=10,
)

# Test with random audio
x = torch.randn(3, 1, 48000)
num_frames = torch.full((x.shape[0],), 48000 // 320 - 1)
res = audio_model(x, num_frames=num_frames)

print("\n✓ Audio model created successfully!")
print("\nOutput shapes:")
for key, value in res.items():
    print(f"  {key:20s}: {str(value.shape):20s}")

## Create PyTorch Lightning Module

In [None]:
cfg = Namespace(
    PhonemeGAT=Namespace(
        backbone="wavlm",
        use_raw=False,
        use_GAT=True,
        n_edges=10,
        use_aug=True,
        use_pool=True,
        use_clip=True,
    )
)

print("Creating Lightning module...")
audio_model_lit = Phoneme_GAT_lit(cfg=cfg)

# Test forward pass
batch = {
    "label": torch.randint(0, 2, (3,)),
    "audio": torch.randn(3, 1, 48000),
    "sample_rate": 16000,
}

batch_res = audio_model_lit._shared_pred(batch=batch, batch_idx=0, stage="train")
print("\n✓ Lightning module working!")
print("\nPrediction output shapes:")
for key, value in batch_res.items():
    print(f"  {key:20s}: {str(value.shape):20s}")

## Create Dataset

This cell creates either:
- **Dummy data**: Random synthetic audio for quick testing
- **Real data**: ASVspoof 2019 LA dataset from HuggingFace (requires download)

In [None]:
from torch.utils.data import Dataset, DataLoader

if USE_REAL_DATA:
    print("Loading real ASVspoof 2019 LA dataset from HuggingFace...")
    print("This may take a few minutes on first run (downloads ~1.6GB)")
    
    try:
        from datasets import load_dataset
        import torchaudio
        import numpy as np
        
        # Login to HuggingFace if token provided
        if HF_TOKEN:
            from huggingface_hub import login
            login(token=HF_TOKEN, add_to_git_credential=True)
        
        # Load dataset
        dataset = load_dataset("Bisher/ASVspoof_2019_LA")
        train_data = dataset['train'].select(range(min(NUM_TRAIN_SAMPLES, len(dataset['train']))))
        
        class RealDataset(Dataset):
            def __init__(self, hf_dataset):
                self.dataset = hf_dataset
            
            def __len__(self):
                return len(self.dataset)
            
            def __getitem__(self, idx):
                item = self.dataset[idx]
                
                # Get audio
                if 'audio' in item and isinstance(item['audio'], dict):
                    if 'array' in item['audio']:
                        audio = torch.tensor(item['audio']['array'], dtype=torch.float32)
                        if audio.ndim == 1:
                            audio = audio.unsqueeze(0)
                    else:
                        audio = torch.randn(1, 48000)
                else:
                    audio = torch.randn(1, 48000)
                
                # Pad/trim to 48000
                if audio.shape[1] < 48000:
                    audio = torch.nn.functional.pad(audio, (0, 48000 - audio.shape[1]))
                elif audio.shape[1] > 48000:
                    audio = audio[:, :48000]
                
                # Get label
                label = 0 if item['label'] == 'bonafide' else 1
                
                return {
                    "audio": audio,
                    "label": label,
                    "sample_rate": 16000,
                }
        
        test_dataset = RealDataset(train_data)
        print(f"✓ Loaded {len(test_dataset)} real audio samples")
        
    except Exception as e:
        print(f"⚠️  Error loading real data: {e}")
        print("Falling back to dummy data...")
        USE_REAL_DATA = False

if not USE_REAL_DATA:
    print("Using dummy synthetic data for quick testing...")
    
    class DummyDataset(Dataset):
        def __init__(self, num_samples=20):
            self.samples = []
            for _ in range(num_samples):
                self.samples.append({
                    "audio": torch.randn(1, 48000),
                    "label": torch.randint(0, 2, (1,)).item(),
                    "sample_rate": 16000,
                })
        
        def __len__(self):
            return len(self.samples)
        
        def __getitem__(self, idx):
            return self.samples[idx]
    
    test_dataset = DummyDataset(num_samples=NUM_TRAIN_SAMPLES)
    print(f"✓ Created {len(test_dataset)} dummy samples")

# Create dataloader
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # 0 for compatibility
)

print(f"\n✓ DataLoader ready: {len(test_dataloader)} batches")

## Setup Training

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from callbacks import EER_Callback, BinaryAUC_Callback, BinaryACC_Callback

# Auto-detect GPU or CPU
if torch.cuda.is_available():
    accelerator = "gpu"
    devices = 1
    print("✓ Using GPU acceleration")
else:
    accelerator = "cpu"
    devices = "auto"
    print("✓ Using CPU")

trainer = Trainer(
    logger=CSVLogger(save_dir="./logs", version=None),
    max_epochs=NUM_EPOCHS,
    accelerator=accelerator,
    devices=devices,
    callbacks=[
        BinaryACC_Callback(batch_key="label", output_key="logit"),
        BinaryAUC_Callback(batch_key="label", output_key="logit"),
        EER_Callback(batch_key="label", output_key="logit"),
    ],
    enable_progress_bar=True,
)

print(f"\nTraining configuration:")
print(f"  Accelerator: {accelerator}")
print(f"  Max epochs: {NUM_EPOCHS}")
print(f"  Log directory: {trainer.logger.log_dir}")

## Train Model

In [None]:
print("=" * 80)
print("Starting training...")
print("=" * 80)

trainer.fit(audio_model_lit, test_dataloader)

print("\n" + "=" * 80)
print("✓ Training completed!")
print("=" * 80)

## Test Model

In [None]:
print("=" * 80)
print("Testing model...")
print("=" * 80)

results = trainer.test(audio_model_lit, test_dataloader)

print("\n" + "=" * 80)
print("✓ DEMO COMPLETED SUCCESSFULLY!")
print("=" * 80)
print(f"\nResults saved to: {trainer.logger.log_dir}")
print(f"Metrics CSV: {trainer.logger.log_dir}/metrics.csv")
print("\nTest Results:")
for key, value in results[0].items():
    print(f"  {key:20s}: {value:.4f}")
print("=" * 80)

## Summary

This notebook demonstrated:

1. ✅ Loading the pretrained phoneme recognition model
2. ✅ Creating the Phoneme_GAT deepfake detection model
3. ✅ Setting up PyTorch Lightning training
4. ✅ Training on dummy or real ASVspoof data
5. ✅ Evaluating the model

### Next Steps

**For RunPod deployment:**
1. Upload this notebook and all code to RunPod
2. Run `bash setup_runpod.sh` to install dependencies
3. Set `USE_REAL_DATA = True` to use full dataset
4. Increase `NUM_EPOCHS` and `NUM_TRAIN_SAMPLES` for production training

**For local use:**
- Metrics are saved in the logs directory
- View training progress: `cat logs/lightning_logs/version_X/metrics.csv`
- Best checkpoint is saved automatically

**Configuration for full training:**
```python
USE_REAL_DATA = True
NUM_EPOCHS = 20
BATCH_SIZE = 16  # Adjust based on GPU memory
NUM_TRAIN_SAMPLES = -1  # Use full dataset
```