# IDEAW Training on Google Colab

This notebook trains IDEAW audio watermarking models using Colab's free GPU.

**Before running:**
1. Enable GPU: Runtime → Change runtime type → GPU
2. Upload your data to Google Drive
3. Update the GitHub URL below with your repository

## 1. Setup Environment

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up paths
DRIVE_PATH = '/content/drive/MyDrive/IDEAW_Research'
DATA_PATH = f'{DRIVE_PATH}/data'
CHECKPOINT_PATH = f'{DRIVE_PATH}/checkpoints'
RESULTS_PATH = f'{DRIVE_PATH}/results'

# Create directories
import os
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
os.makedirs(RESULTS_PATH, exist_ok=True)

print("✓ Google Drive mounted")
print(f"✓ Data path: {DATA_PATH}")
print(f"✓ Checkpoint path: {CHECKPOINT_PATH}")
print(f"✓ Results path: {RESULTS_PATH}")

In [None]:
# Clone your repository
GITHUB_URL = "https://github.com/Abdullahyassir007/audio-watermarking-demo.git"

# Remove existing directory if present
!rm -rf audio-watermarking-demo

# Clone
!git clone {GITHUB_URL}
%cd audio-watermarking-demo

print("✓ Repository cloned")
print(f"✓ Working directory: {!pwd}")

In [None]:
# Install dependencies from IDEAW requirements
!pip install -q -r research/IDEAW/requirements.txt
!pip install -q FrEIA

print("✓ Dependencies installed")

In [None]:
# Check GPU availability
import torch

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    device = 'cuda'
else:
    print("⚠️ No GPU available, using CPU")
    device = 'cpu'

print(f"\n✓ Using device: {device}")

## 2. Load IDEAW Model

In [None]:
# Import IDEAW
import sys
sys.path.insert(0, '/content/audio-watermarking-demo/research/IDEAW')

from models.ideaw import IDEAW

# Configuration
config_path = '/content/audio-watermarking-demo/research/IDEAW/config.yaml'
model_config_path = '/content/audio-watermarking-demo/research/IDEAW/models/config.yaml'

# Initialize model
ideaw = IDEAW(model_config_path, device)
print("✓ IDEAW model initialized")

# Count parameters
total_params = sum(p.numel() for p in ideaw.parameters())
trainable_params = sum(p.numel() for p in ideaw.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 3. Prepare Data

In [None]:
# Copy data from Drive to local storage (faster training)
import shutil

LOCAL_DATA_PATH = '/content/data'

if os.path.exists(DATA_PATH):
    print("Copying data from Drive to local storage...")
    if os.path.exists(LOCAL_DATA_PATH):
        shutil.rmtree(LOCAL_DATA_PATH)
    shutil.copytree(DATA_PATH, LOCAL_DATA_PATH)
    print(f"✓ Data copied to {LOCAL_DATA_PATH}")
    
    # Count files
    train_files = len(os.listdir(f'{LOCAL_DATA_PATH}/train')) if os.path.exists(f'{LOCAL_DATA_PATH}/train') else 0
    val_files = len(os.listdir(f'{LOCAL_DATA_PATH}/val')) if os.path.exists(f'{LOCAL_DATA_PATH}/val') else 0
    print(f"Training files: {train_files}")
    print(f"Validation files: {val_files}")
else:
    print(f"⚠️ Data not found at {DATA_PATH}")
    print("Please upload your training data to Google Drive first.")

## 4. Training Configuration

In [None]:
# Training hyperparameters
BATCH_SIZE = 16
NUM_EPOCHS = 100
LEARNING_RATE = 1e-5
SAVE_EVERY = 10  # Save checkpoint every N epochs

print("Training Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Device: {device}")
print(f"  Save frequency: Every {SAVE_EVERY} epochs")

## 5. Train Model

In [None]:
# Initialize solver
solver = Solver(
    config_path=config_path,
    model=ideaw,
    device=device
)

print("✓ Solver initialized")
print("\nStarting training...")
print("=" * 50)

In [None]:
# Training loop
import time
from tqdm import tqdm

start_time = time.time()

try:
    solver.train(
        save_path=CHECKPOINT_PATH,
        log_path=RESULTS_PATH,
        num_epochs=NUM_EPOCHS,
        save_every=SAVE_EVERY
    )
    
    training_time = time.time() - start_time
    print("\n" + "=" * 50)
    print("✓ Training completed!")
    print(f"Total training time: {training_time / 3600:.2f} hours")
    
except KeyboardInterrupt:
    print("\n⚠️ Training interrupted by user")
    print("Checkpoints have been saved.")
    
except Exception as e:
    print(f"\n❌ Training error: {e}")
    import traceback
    traceback.print_exc()

## 6. Visualize Training Results

In [None]:
# Plot training curves
import matplotlib.pyplot as plt
import pandas as pd

log_file = f'{RESULTS_PATH}/training_log.csv'

if os.path.exists(log_file):
    df = pd.read_csv(log_file)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(df['epoch'], df['loss'])
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].grid(True)
    
    # SNR
    axes[0, 1].plot(df['epoch'], df['snr'])
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('SNR (dB)')
    axes[0, 1].set_title('Signal-to-Noise Ratio')
    axes[0, 1].grid(True)
    
    # Accuracy
    axes[1, 0].plot(df['epoch'], df['accuracy'])
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].set_title('Watermark Accuracy')
    axes[1, 0].grid(True)
    
    # Learning rate
    if 'learning_rate' in df.columns:
        axes[1, 1].plot(df['epoch'], df['learning_rate'])
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_title('Learning Rate Schedule')
        axes[1, 1].set_yscale('log')
        axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.savefig(f'{RESULTS_PATH}/training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("✓ Training curves saved to:", f'{RESULTS_PATH}/training_curves.png')
    
    # Print final metrics
    print("\nFinal Metrics:")
    print(f"  Loss: {df['loss'].iloc[-1]:.4f}")
    print(f"  SNR: {df['snr'].iloc[-1]:.2f} dB")
    print(f"  Accuracy: {df['accuracy'].iloc[-1]:.2f}%")
else:
    print("⚠️ No training log found")

## 7. Test Trained Model

In [None]:
# Load best checkpoint
best_checkpoint = f'{CHECKPOINT_PATH}/best_model.pth'

if os.path.exists(best_checkpoint):
    print("Loading best model...")
    checkpoint = torch.load(best_checkpoint)
    ideaw.load_state_dict(checkpoint['model_state_dict'])
    ideaw.eval()
    print("✓ Best model loaded")
    
    # Test on sample audio
    import librosa
    import numpy as np
    
    # Load test audio
    test_audio_path = f'{LOCAL_DATA_PATH}/val/test_audio.wav'  # Update with your test file
    
    if os.path.exists(test_audio_path):
        audio, sr = librosa.load(test_audio_path, sr=16000)
        audio_tensor = torch.FloatTensor(audio).unsqueeze(0).to(device)
        
        # Generate random message and location code
        message = torch.randint(0, 2, (1, 16), dtype=torch.float32).to(device)
        lcode = torch.randint(0, 2, (1, 10), dtype=torch.float32).to(device)
        
        with torch.no_grad():
            # Embed
            audio_wmd1, _ = ideaw.embed_msg(audio_tensor, message)
            audio_wmd2, _ = ideaw.embed_lcode(audio_wmd1, lcode)
            
            # Extract
            mid_stft, lcode_extracted = ideaw.extract_lcode(audio_wmd2)
            message_extracted = ideaw.extract_msg(mid_stft)
            
            # Calculate accuracy
            msg_acc = ((message_extracted > 0.5).float() == message).float().mean().item() * 100
            lcode_acc = ((lcode_extracted > 0.5).float() == lcode).float().mean().item() * 100
            
            print(f"\nTest Results:")
            print(f"  Message accuracy: {msg_acc:.2f}%")
            print(f"  Location code accuracy: {lcode_acc:.2f}%")
    else:
        print(f"⚠️ Test audio not found at {test_audio_path}")
else:
    print(f"⚠️ Checkpoint not found at {best_checkpoint}")

## 8. Download Results

In [None]:
# Zip checkpoints and results
!zip -r checkpoints.zip {CHECKPOINT_PATH}
!zip -r results.zip {RESULTS_PATH}

print("✓ Files zipped")
print("\nYou can download:")
print("  1. checkpoints.zip - Trained model weights")
print("  2. results.zip - Training logs and plots")
print("\nOr access them directly from Google Drive at:")
print(f"  {DRIVE_PATH}")

In [None]:
# Optional: Download directly from Colab
from google.colab import files

# Uncomment to download
# files.download('checkpoints.zip')
# files.download('results.zip')

## 9. Push Code Updates to GitHub (Optional)

In [None]:
# If you made code changes in Colab, push them back to GitHub

# Configure git (first time only)
!git config --global user.email "your.email@example.com"
!git config --global user.name "Your Name"

# Check what changed
!git status

# Add, commit, and push (uncomment to use)
# !git add .
# !git commit -m "Updated training code from Colab"
# !git push

print("\nNote: You'll need to authenticate with GitHub token if pushing")
print("Generate token at: https://github.com/settings/tokens")

## 10. Pull Latest Code Updates (Optional)

In [None]:
# If you updated code on your local machine, pull latest changes
!git pull origin main

print("✓ Code updated from GitHub")

## 11. Keep Session Alive (Optional)

Run this JavaScript in your browser console to prevent disconnection:

```javascript
function KeepAlive() {
    console.log("Keeping session alive...");
    document.querySelector("colab-connect-button").click();
}
setInterval(KeepAlive, 60000);
```