# IDEAW Training on Google Colab

This notebook trains IDEAW audio watermarking models using Colab's free GPU.

**Before running:**
1. Enable GPU: Runtime → Change runtime type → GPU
2. Upload your data to Google Drive
3. Update the GitHub URL below with your repository

## 1. Setup Environment

In [3]:
from google.colab import drive
drive.mount('/content/drive')


%cd /content/drive/MyDrive/audio-watermarking-demo


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/audio-watermarking-demo


In [2]:
!git status


On branch main
Your branch is up to date with 'origin/main'.

Changes not staged for commit:
  (use "git add <file>..." to update what will be committed)
  (use "git restore <file>..." to discard changes in working directory)
	[31mmodified:   colab_notebooks/IDEAW_Training_Template.ipynb[m

no changes added to commit (use "git add" and/or "git commit -a")


In [35]:
!git reset --soft HEAD~5


In [36]:
!git add colab_notebooks/IDEAW_Training_Template.ipynb

# 4. Commit with a message
!git commit -m "Updated IDEAW training notebook from Colab"

# 5. Push to GitHub
!git push origin main

[main 6e69718] Updated IDEAW training notebook from Colab
 1 file changed, 1 insertion(+), 567 deletions(-)
 rewrite colab_notebooks/IDEAW_Training_Template.ipynb (100%)
Enumerating objects: 7, done.
Counting objects: 100% (7/7), done.
Delta compression using up to 2 threads
Compressing objects: 100% (4/4), done.
Writing objects: 100% (4/4), 6.74 KiB | 690.00 KiB/s, done.
Total 4 (delta 1), reused 0 (delta 0), pack-reused 0
remote: Resolving deltas: 100% (1/1), completed with 1 local object.[K
remote: This repository moved. Please use the new location:[K
remote:   https://github.com/Abdullahyassir007/audio-watermarking-demo.git[K
To https://github.com/AbdullahYassir007/audio-watermarking-demo.git
   25b3e40..6e69718  main -> main


In [37]:
# !git checkout -- colab_notebooks/IDEAW_Training_Template.ipynb
!git pull origin main

From https://github.com/AbdullahYassir007/audio-watermarking-demo
 * branch            main       -> FETCH_HEAD
Already up to date.


In [4]:


# Set up paths
DRIVE_PATH = '/content/drive/MyDrive/audio-watermarking-demo'
DATA_PATH = f'{DRIVE_PATH}/Dataset'
CHECKPOINT_PATH = f'{DRIVE_PATH}/checkpoints'
RESULTS_PATH = f'{DRIVE_PATH}/results'

# Create directories
import os
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
os.makedirs(RESULTS_PATH, exist_ok=True)

print("✓ Google Drive mounted")
print(f"✓ Data path: {DATA_PATH}")
print(f"✓ Checkpoint path: {CHECKPOINT_PATH}")
print(f"✓ Results path: {RESULTS_PATH}")

✓ Google Drive mounted
✓ Data path: /content/drive/MyDrive/audio-watermarking-demo/Dataset
✓ Checkpoint path: /content/drive/MyDrive/audio-watermarking-demo/checkpoints
✓ Results path: /content/drive/MyDrive/audio-watermarking-demo/results


In [5]:
# Install dependencies from IDEAW requirements
!pip install -q -r research/IDEAW/requirements_colab.txt
!pip install -q FrEIA

print("✓ Dependencies installed")

✓ Dependencies installed


In [6]:
# Check GPU availability
import torch

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    device = 'cuda'
else:
    print("⚠️ No GPU available, using CPU")
    device = 'cpu'

print(f"\n✓ Using device: {device}")

GPU Available: True
GPU Name: Tesla T4
GPU Memory: 15.83 GB

✓ Using device: cuda


In [None]:
# Just install the missing packages, use Colab's existing PyTorch
!pip install -q librosa==0.10.1 pydub PyYAML soundfile tqdm resampy

# Restart runtime
import os
os.kill(os.getpid(), 9)



In [7]:
# Verify installation
import torch
import librosa
import scipy
import numpy as np
import yaml

print("=" * 50)
print("ENVIRONMENT CHECK")
print("=" * 50)
print(f"✓ PyTorch: {torch.__version__}")
print(f"✓ Librosa: {librosa.__version__}")
print(f"✓ Scipy: {scipy.__version__}")
print(f"✓ Numpy: {np.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
print("=" * 50)


ENVIRONMENT CHECK
✓ PyTorch: 2.8.0+cu126
✓ Librosa: 0.10.1
✓ Scipy: 1.11.4
✓ Numpy: 1.26.4
✓ CUDA available: True
✓ GPU: Tesla T4


In [4]:
# Fix numpy compatibility issue
!pip uninstall -y numpy
!pip install numpy==1.26.2

# Reinstall packages that depend on numpy
!pip install --force-reinstall --no-cache-dir librosa==0.10.1
!pip install --force-reinstall --no-cache-dir scipy==1.11.4
!pip install --force-reinstall --no-cache-dir resampy==0.4.2

# Restart the runtime after this
print("✓ Packages reinstalled. Please restart runtime:")
print("  Runtime → Restart runtime")
print("Then re-run your cells from the beginning.")


Found existing installation: numpy 1.26.4
Uninstalling numpy-1.26.4:
  Successfully uninstalled numpy-1.26.4
Collecting numpy==1.26.2
  Using cached numpy-1.26.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.9 MB)
Installing collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pytensor 2.35.1 requires numpy>=2.0, but you have numpy 1.26.2 which is incompatible.
moviepy 1.0.3 requires decorator<5.0,>=4.0.2, but you have decorator 5.2.1 which is incompatible.
cuml-cu12 25.10.0 requires numba<0.62.0a0,>=0.60.0, but you have numba 0.62.1 which is incompatible.
opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.2 which is incompatible.
jax 0.7.2 requires numpy>=2.0, bu

Collecting librosa==0.10.1
  Downloading librosa-0.10.1-py3-none-any.whl.metadata (8.3 kB)
Collecting audioread>=2.1.9 (from librosa==0.10.1)
  Downloading audioread-3.1.0-py3-none-any.whl.metadata (9.0 kB)
Collecting numpy!=1.22.0,!=1.22.1,!=1.22.2,>=1.20.3 (from librosa==0.10.1)
  Downloading numpy-2.3.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy>=1.2.0 (from librosa==0.10.1)
  Downloading scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m263.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scikit-learn>=0.20.0 (from librosa==0.10.1)
  Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting joblib>=0.14 (from librosa==0.10.1)


Collecting scipy==1.11.4
  Downloading scipy-1.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.4/60.4 kB[0m [31m92.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpy<1.28.0,>=1.21.6 (from scipy==1.11.4)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m326.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading scipy-1.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (35.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.8/35.8 MB[0m [31m95.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m311.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected p

Collecting resampy==0.4.2
  Downloading resampy-0.4.2-py3-none-any.whl.metadata (2.8 kB)
Collecting numpy>=1.17 (from resampy==0.4.2)
  Downloading numpy-2.3.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 kB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numba>=0.53 (from resampy==0.4.2)
  Downloading numba-0.62.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting llvmlite<0.46,>=0.45.0dev0 (from numba>=0.53->resampy==0.4.2)
  Downloading llvmlite-0.45.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (4.9 kB)
Downloading resampy-0.4.2-py3-none-any.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m131.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numba-0.62.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

✓ Packages reinstalled. Please restart runtime:
  Runtime → Restart runtime
Then re-run your cells from the beginning.


## 2. Load IDEAW Model

In [8]:
# Import IDEAW
import sys
sys.path.insert(0, '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW')

from models.ideaw import IDEAW

# Configuration
config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/config.yaml'
model_config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/models/config.yaml'

# Initialize model
ideaw = IDEAW(model_config_path, device)
print("✓ IDEAW model initialized")

# Count parameters
total_params = sum(p.numel() for p in ideaw.parameters())
trainable_params = sum(p.numel() for p in ideaw.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

✓ IDEAW model initialized
Total parameters: 8,425,023
Trainable parameters: 8,425,023


## 3. Prepare Data

In [None]:
# Copy data from Drive to local storage (faster training)
import shutil

LOCAL_DATA_PATH = '/content/data'

if os.path.exists(DATA_PATH):
    print("Copying data from Drive to local storage...")
    if os.path.exists(LOCAL_DATA_PATH):
        shutil.rmtree(LOCAL_DATA_PATH)
    shutil.copytree(DATA_PATH, LOCAL_DATA_PATH)
    print(f"✓ Data copied to {LOCAL_DATA_PATH}")

    # Count files
    train_files = len(os.listdir(f'{LOCAL_DATA_PATH}/train')) if os.path.exists(f'{LOCAL_DATA_PATH}/train') else 0
    val_files = len(os.listdir(f'{LOCAL_DATA_PATH}/val')) if os.path.exists(f'{LOCAL_DATA_PATH}/val') else 0
    print(f"Training files: {train_files}")
    print(f"Validation files: {val_files}")
else:
    print(f"⚠️ Data not found at {DATA_PATH}")
    print("Please upload your training data to Google Drive first.")

## 4. Training Configuration

In [None]:
# Training hyperparameters
BATCH_SIZE = 16
NUM_EPOCHS = 100
LEARNING_RATE = 1e-5
SAVE_EVERY = 10  # Save checkpoint every N epochs

print("Training Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Device: {device}")
print(f"  Save frequency: Every {SAVE_EVERY} epochs")

## 5. Train Model

In [None]:
# Initialize solver
solver = Solver(
    config_path=config_path,
    model=ideaw,
    device=device
)

print("✓ Solver initialized")
print("\nStarting training...")
print("=" * 50)

In [None]:
# Training loop
import time
from tqdm import tqdm

start_time = time.time()

try:
    solver.train(
        save_path=CHECKPOINT_PATH,
        log_path=RESULTS_PATH,
        num_epochs=NUM_EPOCHS,
        save_every=SAVE_EVERY
    )

    training_time = time.time() - start_time
    print("\n" + "=" * 50)
    print("✓ Training completed!")
    print(f"Total training time: {training_time / 3600:.2f} hours")

except KeyboardInterrupt:
    print("\n⚠️ Training interrupted by user")
    print("Checkpoints have been saved.")

except Exception as e:
    print(f"\n❌ Training error: {e}")
    import traceback
    traceback.print_exc()

## 6. Visualize Training Results

In [None]:
# Plot training curves
import matplotlib.pyplot as plt
import pandas as pd

log_file = f'{RESULTS_PATH}/training_log.csv'

if os.path.exists(log_file):
    df = pd.read_csv(log_file)

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss
    axes[0, 0].plot(df['epoch'], df['loss'])
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].grid(True)

    # SNR
    axes[0, 1].plot(df['epoch'], df['snr'])
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('SNR (dB)')
    axes[0, 1].set_title('Signal-to-Noise Ratio')
    axes[0, 1].grid(True)

    # Accuracy
    axes[1, 0].plot(df['epoch'], df['accuracy'])
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].set_title('Watermark Accuracy')
    axes[1, 0].grid(True)

    # Learning rate
    if 'learning_rate' in df.columns:
        axes[1, 1].plot(df['epoch'], df['learning_rate'])
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_title('Learning Rate Schedule')
        axes[1, 1].set_yscale('log')
        axes[1, 1].grid(True)

    plt.tight_layout()
    plt.savefig(f'{RESULTS_PATH}/training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

    print("✓ Training curves saved to:", f'{RESULTS_PATH}/training_curves.png')

    # Print final metrics
    print("\nFinal Metrics:")
    print(f"  Loss: {df['loss'].iloc[-1]:.4f}")
    print(f"  SNR: {df['snr'].iloc[-1]:.2f} dB")
    print(f"  Accuracy: {df['accuracy'].iloc[-1]:.2f}%")
else:
    print("⚠️ No training log found")

## 7. Test Trained Model

In [None]:
# Load best checkpoint
best_checkpoint = f'{CHECKPOINT_PATH}/best_model.pth'

if os.path.exists(best_checkpoint):
    print("Loading best model...")
    checkpoint = torch.load(best_checkpoint)
    ideaw.load_state_dict(checkpoint['model_state_dict'])
    ideaw.eval()
    print("✓ Best model loaded")

    # Test on sample audio
    import librosa
    import numpy as np

    # Load test audio
    test_audio_path = f'{LOCAL_DATA_PATH}/val/test_audio.wav'  # Update with your test file

    if os.path.exists(test_audio_path):
        audio, sr = librosa.load(test_audio_path, sr=16000)
        audio_tensor = torch.FloatTensor(audio).unsqueeze(0).to(device)

        # Generate random message and location code
        message = torch.randint(0, 2, (1, 16), dtype=torch.float32).to(device)
        lcode = torch.randint(0, 2, (1, 10), dtype=torch.float32).to(device)

        with torch.no_grad():
            # Embed
            audio_wmd1, _ = ideaw.embed_msg(audio_tensor, message)
            audio_wmd2, _ = ideaw.embed_lcode(audio_wmd1, lcode)

            # Extract
            mid_stft, lcode_extracted = ideaw.extract_lcode(audio_wmd2)
            message_extracted = ideaw.extract_msg(mid_stft)

            # Calculate accuracy
            msg_acc = ((message_extracted > 0.5).float() == message).float().mean().item() * 100
            lcode_acc = ((lcode_extracted > 0.5).float() == lcode).float().mean().item() * 100

            print(f"\nTest Results:")
            print(f"  Message accuracy: {msg_acc:.2f}%")
            print(f"  Location code accuracy: {lcode_acc:.2f}%")
    else:
        print(f"⚠️ Test audio not found at {test_audio_path}")
else:
    print(f"⚠️ Checkpoint not found at {best_checkpoint}")

## 8. Download Results

In [None]:
# Zip checkpoints and results
!zip -r checkpoints.zip {CHECKPOINT_PATH}
!zip -r results.zip {RESULTS_PATH}

print("✓ Files zipped")
print("\nYou can download:")
print("  1. checkpoints.zip - Trained model weights")
print("  2. results.zip - Training logs and plots")
print("\nOr access them directly from Google Drive at:")
print(f"  {DRIVE_PATH}")

In [None]:
# Optional: Download directly from Colab
from google.colab import files

# Uncomment to download
# files.download('checkpoints.zip')
# files.download('results.zip')

## 9. Push Code Updates to GitHub (Optional)

In [None]:
# If you made code changes in Colab, push them back to GitHub

# Configure git (first time only)
!git config --global user.email "your.email@example.com"
!git config --global user.name "Your Name"

# Check what changed
!git status

# Add, commit, and push (uncomment to use)
# !git add .
# !git commit -m "Updated training code from Colab"
# !git push

print("\nNote: You'll need to authenticate with GitHub token if pushing")
print("Generate token at: https://github.com/settings/tokens")

## 10. Pull Latest Code Updates (Optional)

In [None]:
# If you updated code on your local machine, pull latest changes
!git pull origin main

print("✓ Code updated from GitHub")

## 11. Keep Session Alive (Optional)

Run this JavaScript in your browser console to prevent disconnection:

```javascript
function KeepAlive() {
    console.log("Keeping session alive...");
    document.querySelector("colab-connect-button").click();
}
setInterval(KeepAlive, 60000);
```