In [121]:
import torch
from argparse import Namespace
import warnings
import os

warnings.filterwarnings("ignore")
torch.manual_seed(42)
torch.set_float32_matmul_precision("medium")

print("=" * 80)
print("PLFD Deepfake Detection - Demo Training")
print("=" * 80)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
print("=" * 80)

PLFD Deepfake Detection - Demo Training
PyTorch version: 2.5.1
CUDA available: False


# Phoneme-Level Deepfake Detection Training Demo

This notebook demonstrates training the PLFD model for deepfake detection.

## Configuration

**Setup Requirements:**
1. Phoneme model checkpoint: `Best Epoch 42 Validation 0.407.ckpt` (in project root)
2. Vocab files: `vocab_phoneme/` directory with 9 language JSON files
3. HuggingFace token for dataset access (recommended)

**Dataset:**
This notebook uses the ASVspoof 2019 LA dataset from HuggingFace.
- Will download automatically (~1.6GB on first run)
- Create a `.env` file with your HuggingFace token: `HF_TOKEN=your_token_here`

In [122]:
# ============================================================================
# CONFIGURATION
# ============================================================================

# HuggingFace token (load from .env file for security)
# Create a .env file with: HF_TOKEN=your_token_here
from dotenv import load_dotenv
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN", "")  # Will be empty if .env not found

# Training settings
NUM_EPOCHS = 4
BATCH_SIZE = 3
NUM_TRAIN_SAMPLES = 20  # Small for demo

print(f"Data source: Real ASVspoof 2019 LA dataset")
print(f"Training epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Training samples: {NUM_TRAIN_SAMPLES}")
print(f"HF Token loaded: {'✅ Yes' if HF_TOKEN else '❌ No (.env file not found)'}")

Data source: Real ASVspoof 2019 LA dataset
Training epochs: 4
Batch size: 3
Training samples: 20
HF Token loaded: ✅ Yes


## Setup Paths

In [123]:
# Auto-detect paths (works locally and on RunPod)
project_root = os.path.abspath(".")
pretrained_path = os.path.join(project_root, "Best Epoch 42 Validation 0.407.ckpt")
vocab_path = os.path.join(project_root, "vocab_phoneme")

print(f"Project root: {project_root}")
print(f"Checkpoint: {pretrained_path}")
print(f"Checkpoint exists: {os.path.exists(pretrained_path)}")
print(f"Vocab path: {vocab_path}")
print(f"Vocab exists: {os.path.exists(vocab_path)}")

if not os.path.exists(pretrained_path):
    print("\n⚠️  ERROR: Checkpoint not found!")
    print("Download from: https://drive.google.com/file/d/1SbqynkUQxxlhazklZz9OgcVK7Fl2aT-z/view?usp=drive_link")

Project root: /Users/arjunjindal/Desktop/PLFD-ADD
Checkpoint: /Users/arjunjindal/Desktop/PLFD-ADD/Best Epoch 42 Validation 0.407.ckpt
Checkpoint exists: True
Vocab path: /Users/arjunjindal/Desktop/PLFD-ADD/vocab_phoneme
Vocab exists: True


## Load Phoneme Recognition Model

In [124]:
from phoneme_GAT.phoneme_model import BaseModule, load_phoneme_model, optim_param

network_param = Namespace(
    network_name="WavLM",
    pretrained_path=pretrained_path,
    freeze=True,
    freeze_transformer=True,
    eos_token="</s>",
    bos_token="<s>",
    unk_token="<unk>",
    pad_token="<pad>",
    word_delimiter_token="|",
    vocab_size=200,
)

total_num_phonemes = 687  # 198 or 687

print("Loading phoneme recognition model...")
phoneme_model = load_phoneme_model(
    network_name=network_param.network_name,
    pretrained_path=network_param.pretrained_path,
    total_num_phonemes=total_num_phonemes,
)

assert len(phoneme_model.tokenizer.total_phonemes) == total_num_phonemes
print(f"✓ Phoneme model loaded ({total_num_phonemes} phonemes)")

Loading phoneme recognition model...
Now, load vocab json files from  /Users/arjunjindal/Desktop/PLFD-ADD/vocab_phoneme Please make sure the vocab files are correct
Load WavLM model!!!!!!!


Some weights of WavLMForCTC were not initialized from the model checkpoint at microsoft/wavlm-base and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([687, 768])
✓ Phoneme model loaded (687 phonemes)


## Test Phoneme Model

In [125]:
from phoneme_GAT.modules import Phoneme_GAT_lit, Phoneme_GAT

print("Creating audio model...")
audio_model = Phoneme_GAT(
    backbone='wavlm',
    use_raw=0,
    use_GAT=1,
    n_edges=10,
)

# Test with random audio
x = torch.randn(3, 1, 48000)
num_frames = torch.full((x.shape[0],), 48000 // 320 - 1)
res = audio_model(x, num_frames=num_frames)

print("\n✓ Audio model created successfully!")
print("\nOutput shapes:")
for key, value in res.items():
    print(f"  {key:20s}: {str(value.shape):20s}")

Creating audio model...
Now, load vocab json files from  /Users/arjunjindal/Desktop/PLFD-ADD/vocab_phoneme Please make sure the vocab files are correct
Load WavLM model!!!!!!!


Some weights of WavLMForCTC were not initialized from the model checkpoint at microsoft/wavlm-base and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([687, 768])

✓ Audio model created successfully!

Output shapes:
  logit               : torch.Size([3])     
  hidden_states       : torch.Size([3, 768])
  phoneme_feat        : torch.Size([3, 149, 768])
  encoder_feat        : torch.Size([3, 149, 768])
  phoneme_cls_logit   : torch.Size([3, 687])
  phoneme_cls_label   : torch.Size([3])     
  aug_logit           : torch.Size([3])     
  aug_frame_logit     : torch.Size([3])     
  aug_labels          : torch.Size([3])     


## Create PyTorch Lightning Module

In [126]:
cfg = Namespace(
    PhonemeGAT=Namespace(
        backbone="wavlm",
        use_raw=False,
        use_GAT=True,
        n_edges=10,
        use_aug=True,
        use_pool=True,
        use_clip=True,
    )
)

print("Creating Lightning module...")
audio_model_lit = Phoneme_GAT_lit(cfg=cfg)

# Test forward pass
batch = {
    "label": torch.randint(0, 2, (3,)),
    "audio": torch.randn(3, 1, 48000),
    "sample_rate": 16000,
}

batch_res = audio_model_lit._shared_pred(batch=batch, batch_idx=0, stage="train")
print("\n✓ Lightning module working!")
print("\nPrediction output shapes:")
for key, value in batch_res.items():
    print(f"  {key:20s}: {str(value.shape):20s}")

Creating Lightning module...
Now, load vocab json files from  /Users/arjunjindal/Desktop/PLFD-ADD/vocab_phoneme Please make sure the vocab files are correct
Load WavLM model!!!!!!!


Some weights of WavLMForCTC were not initialized from the model checkpoint at microsoft/wavlm-base and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([687, 768])

✓ Lightning module working!

Prediction output shapes:
  logit               : torch.Size([3])     
  hidden_states       : torch.Size([3, 768])
  phoneme_feat        : torch.Size([3, 149, 768])
  encoder_feat        : torch.Size([3, 149, 768])
  phoneme_cls_logit   : torch.Size([3, 687])
  phoneme_cls_label   : torch.Size([3])     
  aug_logit           : torch.Size([3])     
  aug_frame_logit     : torch.Size([3])     
  aug_labels          : torch.Size([3])     


## Load ASVspoof 2019 LA Dataset

This cell loads the real ASVspoof 2019 LA dataset from HuggingFace.
- Dataset will be downloaded on first run (~1.6GB)
- Requires HuggingFace token for access

In [127]:
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, Dataset as HFDataset
import torchaudio
import shutil
from pathlib import Path

print("=" * 80)
print("Loading ASVspoof 2019 LA dataset from HuggingFace...")
print("=" * 80)

# Find the parquet files in your existing cache
print("\nLocating downloaded parquet files...")
hub_cache = Path.home() / ".cache" / "huggingface" / "hub" / "datasets--Bisher--ASVspoof_2019_LA"
snapshot_dir = hub_cache / "snapshots" / "aea92dd83a9c56e070c0b1e9f02e7c0d96216a4c" / "data"

if not snapshot_dir.exists():
    print("❌ Cached files not found. Please download the dataset first.")
    raise FileNotFoundError(f"Cache directory not found: {snapshot_dir}")

print(f"✓ Found cache directory: {snapshot_dir}")

# Load directly from parquet files using pandas (bypasses all cache issues)
print("\nLoading from parquet files using pandas...")
train_parquet_link = snapshot_dir / "train-00000-of-00001.parquet"

# Resolve the symlink to get the actual blob file
train_parquet = train_parquet_link.resolve()
print(f"  Resolved parquet file: {train_parquet.name}")

# Load with pandas (more reliable)
import pandas as pd
print("  Reading parquet file...")
df = pd.read_parquet(train_parquet)
print(f"✓ Loaded {len(df)} samples from parquet")

# Select subset for demo
df_subset = df.head(NUM_TRAIN_SAMPLES)
print(f"✓ Selected {len(df_subset)} samples for training")

# Convert to HuggingFace dataset
train_data = HFDataset.from_pandas(df_subset, preserve_index=False)
print(f"✓ Converted to HuggingFace dataset")

# Inspect first sample
print("\nInspecting first sample...")
first_sample = train_data[0]
print(f"  Keys: {list(first_sample.keys())}")

# Check what label column is called
label_key = None
for key in first_sample.keys():
    if 'label' in key.lower() or 'key' in key.lower():
        label_key = key
        print(f"  Found label column: '{label_key}' = {first_sample[label_key]}")
        break

if 'audio' in first_sample:
    audio_data = first_sample['audio']
    print(f"  Audio type: {type(audio_data)}")
    if isinstance(audio_data, dict):
        print(f"  Audio dict keys: {list(audio_data.keys())}")
        for key in audio_data.keys():
            val = audio_data[key]
            print(f"    '{key}': type={type(val)}, value={str(val)[:100]}...")
    else:
        print(f"  Audio value type: {type(audio_data)}")
        if hasattr(audio_data, 'shape'):
            print(f"  Audio shape: {audio_data.shape}")
        else:
            print(f"  Audio value: {str(audio_data)[:200]}...")

class ASVspoofDataset(Dataset):
    """Dataset wrapper for ASVspoof 2019 LA"""
    
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset
        
        # Auto-detect label column name
        sample_keys = list(hf_dataset[0].keys())
        self.label_key = None
        for key in sample_keys:
            if 'label' in key.lower() or 'key' in key.lower():
                self.label_key = key
                break
        
        if self.label_key is None:
            print(f"  Warning: No label column found. Available columns: {sample_keys}")
            print(f"  Using first column as label: '{sample_keys[0]}'")
            self.label_key = sample_keys[0]
        
        print(f"\n✓ ASVspoofDataset initialized with {len(self.dataset)} samples")
        print(f"  Using label column: '{self.label_key}'")
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Extract audio array - handle multiple possible formats
        audio_array = None
        sample_rate = 16000
        
        if 'audio' in item:
            audio_data = item['audio']
            
            if isinstance(audio_data, dict):
                # Audio is a dict - check for different possible keys
                if 'array' in audio_data:
                    audio_array = audio_data['array']
                    sample_rate = audio_data.get('sampling_rate', 16000)
                elif 'bytes' in audio_data:
                    # Encoded audio bytes - need to decode
                    import io
                    import soundfile as sf
                    audio_bytes = audio_data['bytes']
                    audio_array, sample_rate = sf.read(io.BytesIO(audio_bytes))
                elif 'path' in audio_data:
                    # File path - need to load
                    import soundfile as sf
                    audio_array, sample_rate = sf.read(audio_data['path'])
                else:
                    raise ValueError(f"Sample {idx}: Audio dict has unexpected keys: {list(audio_data.keys())}")
            
            elif hasattr(audio_data, 'shape'):
                # Audio is directly a numpy array
                audio_array = audio_data
            
            elif isinstance(audio_data, (list, tuple)):
                # Audio is a list/tuple
                import numpy as np
                audio_array = np.array(audio_data)
            
            else:
                raise ValueError(f"Sample {idx}: Unexpected audio type: {type(audio_data)}")
        
        else:
            raise ValueError(f"Sample {idx}: No 'audio' key in item. Keys: {list(item.keys())}")
        
        if audio_array is None:
            raise ValueError(f"Sample {idx}: Could not extract audio array")
        
        # Convert to torch tensor and ensure 2D [channels, samples]
        audio = torch.tensor(audio_array, dtype=torch.float32)
        if audio.ndim == 1:
            audio = audio.unsqueeze(0)  # Add channel dimension
        elif audio.ndim > 2:
            raise ValueError(f"Sample {idx}: Unexpected audio dimensions: {audio.shape}")
        
        # Resample if needed (ASVspoof is 16kHz, keep it at 16kHz)
        # Model expects 48000 samples at 16kHz = 3 seconds
        target_length = 48000
        
        # Pad or trim to target length
        if audio.shape[1] < target_length:
            # Pad with zeros
            audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[1]))
        elif audio.shape[1] > target_length:
            # Trim to target length
            audio = audio[:, :target_length]
        
        # Get label - handle different possible values
        label_value = item[self.label_key]
        if isinstance(label_value, str):
            # String label: bonafide=0, spoof/fake=1
            label = 0 if 'bonafide' in label_value.lower() else 1
        else:
            # Already numeric
            label = int(label_value)
        
        return {
            "audio": audio,
            "label": label,
            "sample_rate": sample_rate,
        }

# Create dataset
test_dataset = ASVspoofDataset(train_data)

# Test first item
print("\nTesting dataset __getitem__...")
sample = test_dataset[0]
print(f"  Audio shape: {sample['audio'].shape}")
print(f"  Label: {sample['label']}")
print(f"  Sample rate: {sample['sample_rate']}")

# Create dataloader
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    drop_last=False,
)

print(f"\n✓ DataLoader created: {len(test_dataloader)} batches")

# Test a batch
print("\nTesting batch loading...")
test_batch = next(iter(test_dataloader))
print(f"  Batch audio shape: {test_batch['audio'].shape}")
print(f"  Batch labels shape: {test_batch['label'].shape}")
print(f"  Batch labels: {test_batch['label']}")

print("\n" + "=" * 80)
print("✓ Dataset loading complete!")
print("=" * 80)

Loading ASVspoof 2019 LA dataset from HuggingFace...

Locating downloaded parquet files...
✓ Found cache directory: /Users/arjunjindal/.cache/huggingface/hub/datasets--Bisher--ASVspoof_2019_LA/snapshots/aea92dd83a9c56e070c0b1e9f02e7c0d96216a4c/data

Loading from parquet files using pandas...
  Resolved parquet file: b4eea1063bbcfa0c1cef1b69a96ad8b787c32f662005562b899cd4b461739619
  Reading parquet file...
✓ Loaded 25380 samples from parquet
✓ Selected 20 samples for training
✓ Converted to HuggingFace dataset

Inspecting first sample...
  Keys: ['speaker_id', 'audio_file_name', 'audio', 'system_id', 'key']
  Found label column: 'key' = 0
  Audio type: <class 'dict'>
  Audio dict keys: ['bytes', 'path']
    'bytes': type=<class 'bytes'>, value=b'fLaC\x00\x00\x00"\x04\x80\x04\x80\x00\x00\x1a\x00\x07\xe2\x03\xe8\x00\xf0\x00\x00\xd8!\x91\x97\xff...
    'path': type=<class 'str'>, value=LA_T_1138215.flac...

✓ ASVspoofDataset initialized with 20 samples
  Using label column: 'key'

Testing 

## Setup Training

In [128]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from callbacks import EER_Callback, BinaryAUC_Callback, BinaryACC_Callback

# Auto-detect GPU or CPU
if torch.cuda.is_available():
    accelerator = "gpu"
    devices = 1
    print("✓ Using GPU acceleration")
else:
    accelerator = "cpu"
    devices = "auto"
    print("✓ Using CPU")

trainer = Trainer(
    logger=CSVLogger(save_dir="./logs", version=None),
    max_epochs=NUM_EPOCHS,
    accelerator=accelerator,
    devices=devices,
    callbacks=[
        BinaryACC_Callback(batch_key="label", output_key="logit"),
        BinaryAUC_Callback(batch_key="label", output_key="logit"),
        EER_Callback(batch_key="label", output_key="logit"),
    ],
    enable_progress_bar=True,
)

print(f"\nTraining configuration:")
print(f"  Accelerator: {accelerator}")
print(f"  Max epochs: {NUM_EPOCHS}")
print(f"  Log directory: {trainer.logger.log_dir}")

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


✓ Using CPU

Training configuration:
  Accelerator: cpu
  Max epochs: 4
  Log directory: ./logs/lightning_logs/version_11


## Train Model

In [129]:
print("=" * 80)
print("Starting training...")
print("=" * 80)

trainer.fit(audio_model_lit, test_dataloader)

print("\n" + "=" * 80)
print("✓ Training completed!")
print("=" * 80)


  | Name          | Type                    | Params | Mode 
------------------------------------------------------------------
0 | model         | Phoneme_GAT             | 196 M  | train
1 | bce_loss      | BCEWithLogitsLoss       | 0      | train
2 | ce_loss       | CrossEntropyLoss        | 0      | train
3 | contrast_loss | BinaryTokenContrastLoss | 0      | train
4 | clip_head     | Sequential              | 1.2 M  | train
5 | clip_loss     | CLIPLoss1D              | 1      | train
------------------------------------------------------------------
102 M     Trainable params
94.9 M    Non-trainable params
197 M     Total params
790.544   Total estimated model params size (MB)


Starting training...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=4` reached.



✓ Training completed!


## Test Model

In [130]:
print("=" * 80)
print("Testing model...")
print("=" * 80)

results = trainer.test(audio_model_lit, test_dataloader)

print("\n" + "=" * 80)
print("✓ DEMO COMPLETED SUCCESSFULLY!")
print("=" * 80)
print(f"\nResults saved to: {trainer.logger.log_dir}")
print(f"Metrics CSV: {trainer.logger.log_dir}/metrics.csv")
print("\nTest Results:")
for key, value in results[0].items():
    print(f"  {key:20s}: {value:.4f}")
print("=" * 80)

Testing model...


Testing: |          | 0/? [00:00<?, ?it/s]


✓ DEMO COMPLETED SUCCESSFULLY!

Results saved to: ./logs/lightning_logs/version_11
Metrics CSV: ./logs/lightning_logs/version_11/metrics.csv

Test Results:
  test-loss           : 0.7982
  test-cls_loss       : 0.4632
  test-clip_loss      : 0.6701
  test-aug_loss       : 0.0000
  test-acc            : 1.0000
  test-auc            : 0.0000
  test-eer            : 1.0000


## Summary

This notebook demonstrated:

1. ✅ Loading the pretrained phoneme recognition model
2. ✅ Creating the Phoneme_GAT deepfake detection model
3. ✅ Loading the ASVspoof 2019 LA dataset from HuggingFace
4. ✅ Setting up PyTorch Lightning training
5. ✅ Training and evaluating the model

### Next Steps

**For RunPod deployment:**
1. Upload this notebook and all code to RunPod
2. Run `bash setup_runpod.sh` to install dependencies
3. Increase `NUM_EPOCHS` and `NUM_TRAIN_SAMPLES` for production training

**For local use:**
- Metrics are saved in the logs directory
- View training progress: `cat logs/lightning_logs/version_X/metrics.csv`
- Best checkpoint is saved automatically

**Configuration for full training:**
```python
NUM_EPOCHS = 20
BATCH_SIZE = 16  # Adjust based on GPU memory
NUM_TRAIN_SAMPLES = len(dataset['train'])  # Use full dataset
```