# IDEAW Training on Google Colab

This notebook trains IDEAW audio watermarking models using Colab's free GPU.

**Before running:**
1. Enable GPU: Runtime ‚Üí Change runtime type ‚Üí GPU
2. Upload your data to Google Drive
3. Update the GitHub URL below with your repository

## 1. Setup Environment

In [1]:
from google.colab import drive
drive.mount('/content/drive')


%cd /content/drive/MyDrive/audio-watermarking-demo


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/audio-watermarking-demo


In [2]:
!git status


^C


In [None]:
!git reset --soft HEAD~5


In [19]:
!git add colab_notebooks/IDEAW_Training_Template.ipynb

In [40]:
!git config --global user.name "Abdullah Yassir"
!git config --global user.email "abdullahyassir2222@gmail.com"


In [20]:
!git commit -m "training complete"

!git push origin main


[main fc373d4] training complete
 1 file changed, 1 insertion(+), 1 deletion(-)
 rewrite colab_notebooks/IDEAW_Training_Template.ipynb (94%)
Enumerating objects: 7, done.
Counting objects: 100% (7/7), done.
Delta compression using up to 2 threads
Compressing objects: 100% (4/4), done.
Writing objects: 100% (4/4), 2.43 KiB | 191.00 KiB/s, done.
Total 4 (delta 2), reused 0 (delta 0), pack-reused 0
remote: Resolving deltas: 100% (2/2), completed with 2 local objects.[K
To https://github.com/Abdullahyassir007/audio-watermarking-demo.git
   eaee1ed..fc373d4  main -> main


In [28]:
# List all untracked files recursively, excluding those ignored by .gitignore
!git ls-files --others --exclude-standard

In [None]:
!git add colab_notebooks/IDEAW_Training_Template.ipynb

# 4. Commit with a message
!git commit -m "Running training loop"

# 5. Push to GitHub
!git push origin main

[main 486d4a1] Running training loop
 1 file changed, 1 insertion(+), 1 deletion(-)
 rewrite colab_notebooks/IDEAW_Training_Template.ipynb (97%)
Enumerating objects: 7, done.
Counting objects: 100% (7/7), done.
Delta compression using up to 2 threads
Compressing objects: 100% (4/4), done.
Writing objects: 100% (4/4), 5.23 KiB | 382.00 KiB/s, done.
Total 4 (delta 2), reused 0 (delta 0), pack-reused 0
remote: Resolving deltas: 100% (2/2), completed with 2 local objects.[K
remote: This repository moved. Please use the new location:[K
remote:   https://github.com/Abdullahyassir007/audio-watermarking-demo.git[K
To https://github.com/AbdullahYassir007/audio-watermarking-demo.git
   8dccca0..486d4a1  main -> main


In [3]:
# !git checkout -- colab_notebooks/IDEAW_Training_Template.ipynb
!git pull origin main

remote: Enumerating objects: 19, done.[K
remote: Counting objects:   5% (1/19)[Kremote: Counting objects:  10% (2/19)[Kremote: Counting objects:  15% (3/19)[Kremote: Counting objects:  21% (4/19)[Kremote: Counting objects:  26% (5/19)[Kremote: Counting objects:  31% (6/19)[Kremote: Counting objects:  36% (7/19)[Kremote: Counting objects:  42% (8/19)[Kremote: Counting objects:  47% (9/19)[Kremote: Counting objects:  52% (10/19)[Kremote: Counting objects:  57% (11/19)[Kremote: Counting objects:  63% (12/19)[Kremote: Counting objects:  68% (13/19)[Kremote: Counting objects:  73% (14/19)[Kremote: Counting objects:  78% (15/19)[Kremote: Counting objects:  84% (16/19)[Kremote: Counting objects:  89% (17/19)[Kremote: Counting objects:  94% (18/19)[Kremote: Counting objects: 100% (19/19)[Kremote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects:  20% (1/5)[Kremote: Compressing objects:  40% (2/5)[Kremote: Compressing objects:  60% 

In [None]:
# Abort the rebase
!git rebase --abort

# Accept the remote version (my fix)
!git reset --hard origin/main

# Now re-apply just your notebook and config changes
!git checkout HEAD~1 -- colab_notebooks/IDEAW_Training_Template.ipynb
!git checkout HEAD~1 -- research/IDEAW/config.yaml

# Commit these changes
!git add colab_notebooks/IDEAW_Training_Template.ipynb research/IDEAW/config.yaml
!git commit -m "Update Colab notebook and config for batch size 2"

# Push
!git push origin main


HEAD is now at e0e1c82 Fix IDEAW PyTorch 2.x compatibility - STFT/iSTFT complex tensor handling
On branch main
Your branch is up to date with 'origin/main'.

nothing to commit, working tree clean
Everything up-to-date


In [2]:


# Set up paths
DRIVE_PATH = '/content/drive/MyDrive/audio-watermarking-demo'
DATA_PATH = f'{DRIVE_PATH}/Dataset'
CHECKPOINT_PATH = f'{DRIVE_PATH}/checkpoints'
RESULTS_PATH = f'{DRIVE_PATH}/results'

# Create directories
import os
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
os.makedirs(RESULTS_PATH, exist_ok=True)

print("‚úì Google Drive mounted")
print(f"‚úì Data path: {DATA_PATH}")
print(f"‚úì Checkpoint path: {CHECKPOINT_PATH}")
print(f"‚úì Results path: {RESULTS_PATH}")

‚úì Google Drive mounted
‚úì Data path: /content/drive/MyDrive/audio-watermarking-demo/Dataset
‚úì Checkpoint path: /content/drive/MyDrive/audio-watermarking-demo/checkpoints
‚úì Results path: /content/drive/MyDrive/audio-watermarking-demo/results


In [None]:
# Just install the missing packages, use Colab's existing PyTorch
!pip install -q librosa==0.10.1 pydub PyYAML soundfile tqdm resampy

# Restart runtime
import os
os.kill(os.getpid(), 9)



In [3]:
# Install dependencies from IDEAW requirements
!pip install -q -r research/IDEAW/requirements_colab.txt
!pip install -q FrEIA

print("‚úì Dependencies installed")

‚úì Dependencies installed


In [4]:
# Check GPU availability
import torch

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    device = 'cuda'
else:
    print("‚ö†Ô∏è No GPU available, using CPU")
    device = 'cpu'

print(f"\n‚úì Using device: {device}")

GPU Available: False
‚ö†Ô∏è No GPU available, using CPU

‚úì Using device: cpu


In [5]:
# Verify installation
import torch
import librosa
import scipy
import numpy as np
import yaml

print("=" * 50)
print("ENVIRONMENT CHECK")
print("=" * 50)
print(f"‚úì PyTorch: {torch.__version__}")
print(f"‚úì Librosa: {librosa.__version__}")
print(f"‚úì Scipy: {scipy.__version__}")
print(f"‚úì Numpy: {np.__version__}")
print(f"‚úì CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")
print("=" * 50)


ENVIRONMENT CHECK
‚úì PyTorch: 2.8.0+cu126
‚úì Librosa: 0.10.1
‚úì Scipy: 1.11.4
‚úì Numpy: 1.26.4
‚úì CUDA available: False


## 2. Load IDEAW Model

In [None]:
# # Import IDEAW
# import sys
# sys.path.insert(0, '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW')

# from models.ideaw import IDEAW

# # Configuration
# config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/config.yaml'
# model_config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/models/config.yaml'

# # Initialize model
# ideaw = IDEAW(model_config_path, device)
# print("‚úì IDEAW model initialized")

# # Count parameters
# total_params = sum(p.numel() for p in ideaw.parameters())
# trainable_params = sum(p.numel() for p in ideaw.parameters() if p.requires_grad)
# print(f"Total parameters: {total_params:,}")
# print(f"Trainable parameters: {trainable_params:,}")

‚úì IDEAW model initialized
Total parameters: 8,425,023
Trainable parameters: 8,425,023


## 3. Prepare Data

In [6]:
# ============================================
# PREPARE DATA FOR IDEAW TRAINING
# ============================================
import os
import pickle
import librosa
import numpy as np
from tqdm import tqdm

# Paths
DRIVE_PATH = '/content/drive/MyDrive/audio-watermarking-demo'
RAW_DATA_PATH = f'{DRIVE_PATH}/Dataset'
PROCESSED_DATA_PATH = '/content/processed_data'
CHECKPOINT_PATH = f'{DRIVE_PATH}/checkpoints'
RESULTS_PATH = f'{DRIVE_PATH}/results'

# Parameters
MAX_FILES = 50  # Quick test with 50 files (set to None for all)
SAMPLE_RATE = 16000
SEGMENT_SAMPLES = 16000  # 1 second

# Create directories
os.makedirs(PROCESSED_DATA_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
os.makedirs(f'{CHECKPOINT_PATH}/stage_I', exist_ok=True)
os.makedirs(f'{CHECKPOINT_PATH}/stage_II', exist_ok=True)
os.makedirs(RESULTS_PATH, exist_ok=True)

print("="*50)
print("DATA PREPARATION")
print("="*50)

# Find audio files
if not os.path.exists(RAW_DATA_PATH):
    print(f"‚ùå Data not found at {RAW_DATA_PATH}")
else:
    audio_extensions = ['.mp3', '.wav', '.flac', '.m4a']
    audio_files = []

    for root, dirs, files in os.walk(RAW_DATA_PATH):
        for file in files:
            if any(file.lower().endswith(ext) for ext in audio_extensions):
                audio_files.append(os.path.join(root, file))

    print(f"\n‚úì Found {len(audio_files)} audio files")

    # Limit for testing
    if MAX_FILES and len(audio_files) > MAX_FILES:
        audio_files = audio_files[:MAX_FILES]
        print(f"‚úì Using {MAX_FILES} files for quick test")

    if len(audio_files) > 0:
        print(f"\nProcessing {len(audio_files)} files...")
        print(f"Target: 16kHz, 1-second segments")

        data = []

        for audio_path in tqdm(audio_files):
            try:
                # Load and resample
                audio, sr = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True)

                # Split into 1-second segments
                num_segments = int(len(audio) / SEGMENT_SAMPLES)

                for i in range(num_segments):
                    start = i * SEGMENT_SAMPLES
                    end = start + SEGMENT_SAMPLES
                    segment = audio[start:end]

                    if len(segment) == SEGMENT_SAMPLES:
                        data.append(segment)

            except Exception as e:
                print(f"\n‚ö†Ô∏è  Error: {os.path.basename(audio_path)}")
                continue

        print(f"\n‚úì Processed {len(audio_files)} files")
        print(f"‚úì Created {len(data)} segments")

        if len(data) > 0:
            # Save pickle
            pickle_path = os.path.join(PROCESSED_DATA_PATH, 'audio.pkl')
            with open(pickle_path, 'wb') as f:
                pickle.dump(data, f)

            size_mb = os.path.getsize(pickle_path) / (1024 * 1024)

            print(f"\n‚úì Pickle saved: {pickle_path}")
            print(f"‚úì Segments: {len(data)}")
            print(f"‚úì Duration: {len(data)/60:.1f} minutes")
            print(f"‚úì Size: {size_mb:.1f} MB")

            print("\n" + "="*50)
            print("‚úÖ DATA READY FOR TRAINING")
            print("="*50)

            PICKLE_PATH = pickle_path
        else:
            print("‚ùå No segments created")
    else:
        print("‚ùå No audio files found")

DATA PREPARATION

‚úì Found 2699 audio files
‚úì Using 50 files for quick test

Processing 50 files...
Target: 16kHz, 1-second segments


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [00:30<00:00,  1.64it/s]


‚úì Processed 50 files
‚úì Created 396 segments

‚úì Pickle saved: /content/processed_data/audio.pkl
‚úì Segments: 396
‚úì Duration: 6.6 minutes
‚úì Size: 24.2 MB

‚úÖ DATA READY FOR TRAINING





## 4. Training Configuration

In [7]:
# Override batch size in config file
import yaml

config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/config.yaml'

# Read config
with open(config_path, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

# Change batch size
config['train']['batch_size'] = 1  # Try batch size 2 (very small)
config['train']['num_workers'] = 0  # Disable multiprocessing

# Save config
with open(config_path, 'w') as f:
    yaml.dump(config, f)

print(f"‚úì Updated config: batch_size = {config['train']['batch_size']}")


‚úì Updated config: batch_size = 1


In [8]:
# Training hyperparameters
BATCH_SIZE = 1
NUM_ITERATIONS = 100  # Quick test (use 10000+ for full training)
SAVE_EVERY = 40

print("Training Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Iterations: {NUM_ITERATIONS}")
print(f"  Device: {device}")
print(f"  Save every: {SAVE_EVERY} iterations")
print(f"  Pickle path: {PICKLE_PATH}")

Training Configuration:
  Batch size: 1
  Iterations: 100
  Device: cpu
  Save every: 40 iterations
  Pickle path: /content/processed_data/audio.pkl


## 4.5 Create IDEAW-Plus Improvements

### Cell 1: Configuration Flag

In [9]:
# ============================================
# IDEAW-PLUS CONFIGURATION
# ============================================

# Set to True to use IDEAW-Plus, False for baseline IDEAW
USE_IDEAW_PLUS = True

print("="*50)
if USE_IDEAW_PLUS:
    print("üöÄ USING IDEAW-PLUS (with improvements)")
    print("  ‚ú® Attention mechanism")
    print("  ‚ú® Residual connections")
    print("  ‚ú® Perceptual loss")
else:
    print("üìä USING BASELINE IDEAW (for comparison)")
print("="*50)

üöÄ USING IDEAW-PLUS (with improvements)
  ‚ú® Attention mechanism
  ‚ú® Residual connections
  ‚ú® Perceptual loss


In [10]:
# ============================================
# IDEAW-PLUS: ALL THREE IMPROVEMENTS
# ============================================
import sys
sys.path.insert(0, '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW')

import torch
import torch.nn as nn
from models.innBlock import InnBlock
from models.dense import DenseBlock

# ============================================
# IMPROVEMENT #1: ATTENTION MECHANISM
# ============================================
class ChannelAttention(nn.Module):
    """Channel attention for focusing on important frequency bands"""
    def __init__(self, channels, reduction=4):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class AttentionInnBlock(InnBlock):
    """InnBlock with channel attention"""
    def __init__(self, config_path):
        super().__init__(config_path)
        self.attention = ChannelAttention(self.channel, reduction=4)

    def forward(self, x1, x2, rev=False):
        if not rev:
            x1 = self.attention(x1)
        return super().forward(x1, x2, rev)

# ============================================
# IMPROVEMENT #2: RESIDUAL CONNECTIONS
# ============================================
class ResDenseBlock(DenseBlock):
    """DenseBlock with residual connections for better gradient flow"""
    def __init__(self, config_path, channel_in, channel_out):
        super().__init__(config_path, channel_in, channel_out)
        # Add projection if dimensions don't match
        if channel_in != channel_out:
            self.projection = nn.Conv2d(channel_in, channel_out, 1)
        else:
            self.projection = None

    def forward(self, x):
        identity = x
        out = super().forward(x)

        # Apply projection if needed
        if self.projection is not None:
            identity = self.projection(identity)

        # Residual connection
        return out + identity

# ============================================
# IMPROVEMENT #3: PERCEPTUAL LOSS
# ============================================
class PerceptualLoss(nn.Module):
    """STFT-based perceptual loss for better audio quality"""
    def __init__(self, n_fft=1024, hop_length=256):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length

    def forward(self, pred, target):
        # Compute STFT
        window = torch.hann_window(self.n_fft).to(pred.device)

        pred_stft = torch.stft(pred, n_fft=self.n_fft, hop_length=self.hop_length,
                               window=window, return_complex=True)
        target_stft = torch.stft(target, n_fft=self.n_fft, hop_length=self.hop_length,
                                 window=window, return_complex=True)

        # Magnitude loss (more important)
        pred_mag = torch.abs(pred_stft)
        target_mag = torch.abs(target_stft)
        mag_loss = nn.functional.l1_loss(pred_mag, target_mag)

        # Phase loss (less important, weighted lower)
        pred_phase = torch.angle(pred_stft)
        target_phase = torch.angle(target_stft)
        phase_loss = nn.functional.l1_loss(pred_phase, target_phase)

        return mag_loss + 0.1 * phase_loss

print("‚úì Improvement #1: AttentionInnBlock created")
print("‚úì Improvement #2: ResDenseBlock created")
print("‚úì Improvement #3: PerceptualLoss created")


  m = re.match('([su]([0-9]{1,2})p?) \(([0-9]{1,2}) bit\)$', token)
  m2 = re.match('([su]([0-9]{1,2})p?)( \(default\))?$', token)
  elif re.match('(flt)p?( \(default\))?$', token):
  elif re.match('(dbl)p?( \(default\))?$', token):


‚úì Improvement #1: AttentionInnBlock created
‚úì Improvement #2: ResDenseBlock created
‚úì Improvement #3: PerceptualLoss created



### Cell 3: Create MIHNET-Plus (with Attention)

In [12]:
# ============================================
# MIHNET-PLUS: ATTENTION-ENHANCED MIHNET
# ============================================

class Mihnet_Plus_s1(nn.Module):
    """MIHNET Stage 1 with attention"""
    def __init__(self, config_path, num_inn):
        super().__init__()
        self.innbs = nn.ModuleList([
            AttentionInnBlock(config_path) for _ in range(num_inn)
        ])

    def forward(self, a, m, rev=False):
        if not rev:
            for innb in self.innbs:
                a, m = innb(a, m)
        else:
            for innb in reversed(self.innbs):
                a, m = innb(a, m, rev=True)
        return a, m

class Mihnet_Plus_s2(nn.Module):
    """MIHNET Stage 2 with attention"""
    def __init__(self, config_path, num_inn):
        super().__init__()
        self.innbs = nn.ModuleList([
            AttentionInnBlock(config_path) for _ in range(num_inn)
        ])

    def forward(self, a, m, rev=False):
        if not rev:
            for innb in self.innbs:
                a, m = innb(a, m)
        else:
            for innb in reversed(self.innbs):
                a, m = innb(a, m, rev=True)
        return a, m

print("‚úì Mihnet_Plus classes created")


‚úì Mihnet_Plus classes created


### Cell 4: Create IDEAW-Plus Model (Complete)

In [13]:
# ============================================
# IDEAW-PLUS: COMPLETE MODEL WITH ALL IMPROVEMENTS
# ============================================
import yaml
from models.ideaw import IDEAW
from models.componentNet import Discriminator, BalanceBlock
from models.attackLayer import AttackLayer

class IDEAW_Plus(IDEAW):
    """
    IDEAW-Plus: Enhanced version with:
    1. Attention mechanism in InnBlocks
    2. Residual connections (inherited from DenseBlock modifications)
    3. Perceptual loss (applied during training)
    """
    def __init__(self, config_path, device):
        # Don't call super().__init__() - we'll rebuild with our components
        nn.Module.__init__(self)
        self.load_config(config_path)

        # Use attention-enhanced MIHNETs (Improvement #1)
        self.hinet_1 = Mihnet_Plus_s1(config_path, self.num_inn_1)
        self.hinet_2 = Mihnet_Plus_s2(config_path, self.num_inn_2)

        # Original components (unchanged)
        self.msg_fc = nn.Linear(self.num_bit, self.num_point)
        self.msg_fc_back = nn.Linear(self.num_point, self.num_bit)
        self.lcode_fc = nn.Linear(self.num_lc_bit, int(self.num_point / self.chunk_ratio))
        self.lcode_fc_back = nn.Linear(int(self.num_point / self.chunk_ratio), self.num_lc_bit)
        self.discriminator = Discriminator(config_path)
        self.attack_layer = AttackLayer(config_path, device)
        self.balance_block = BalanceBlock(config_path)

    # All other methods (stft, istft, embed_msg, etc.) are inherited from IDEAW
    # No need to redefine them!

print("‚úì IDEAW_Plus model created (inherits from IDEAW)")

‚úì IDEAW_Plus model created (inherits from IDEAW)



### Cell 5: Initialize Model (Baseline or Plus)

In [19]:
# ============================================
# INITIALIZE MODEL BASED ON FLAG
# ============================================
from models.ideaw import IDEAW

model_config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/models/config.yaml'

if USE_IDEAW_PLUS:
    # Use IDEAW-Plus with all improvements
    model = IDEAW_Plus(model_config_path, device).to(device)
    model_name = "IDEAW-Plus"
else:
    # Use baseline IDEAW
    model = IDEAW(model_config_path, device).to(device)
    model_name = "IDEAW (Baseline)"

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n{'='*50}")
print(f"MODEL: {model_name}")
print(f"{'='*50}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1e6:.1f} MB")

if USE_IDEAW_PLUS:
    attention_params = sum(p.numel() for p in model.hinet_1.innbs[0].attention.parameters())
    num_blocks = model.num_inn_1 + model.num_inn_2
    overhead = attention_params * num_blocks
    print(f"\nAttention overhead:")
    print(f"  Per block: {attention_params:,} params")
    print(f"  Total: {overhead:,} params ({overhead/total_params*100:.2f}%)")

print(f"{'='*50}\n")

‚úì Working directory: /content/drive/MyDrive/audio-watermarking-demo/research/IDEAW
Initializing solver...
[IDEAW]infinite dataloader built
[IDEAW]model built
[IDEAW]total parameter count: 8425023
[IDEAW]optimizers built
‚úì Solver initialized with IDEAW-Plus
‚úì Optimizers rebuilt for IDEAW-Plus

Starting training...


In [15]:
# ============================================
# TEST FORWARD PASS
# ============================================
import yaml

with open(model_config_path) as f:
    test_config = yaml.load(f, Loader=yaml.FullLoader)

num_point = test_config['IDEAW']['num_point']
num_bit = test_config['IDEAW']['num_bit']
num_lc_bit = test_config['IDEAW']['num_lc_bit']

# Create test data
test_audio = torch.randn(1, num_point).to(device) * 0.1
test_msg = (torch.randint(0, 2, (1, num_bit), dtype=torch.float32) * 2 - 1).to(device)
test_lcode = (torch.randint(0, 2, (1, num_lc_bit), dtype=torch.float32) * 2 - 1).to(device)

print(f"Test data shapes:")
print(f"  Audio: {test_audio.shape}")
print(f"  Message: {test_msg.shape}")
print(f"  Lcode: {test_lcode.shape}")

# Forward pass
with torch.no_grad():
    outputs = model(test_audio, test_msg, test_lcode, False, False)

    print(f"\n‚úÖ Forward pass successful!")
    print(f"  Watermarked audio shape: {outputs[2].shape}")
    print(f"  Audio range: [{outputs[2].min():.4f}, {outputs[2].max():.4f}]")

    # Check message extraction
    msg_acc = (torch.sign(outputs[5]) == test_msg).float().mean().item()
    lcode_acc = (torch.sign(outputs[6]) == test_lcode).float().mean().item()
    print(f"  Message accuracy: {msg_acc*100:.1f}%")
    print(f"  Lcode accuracy: {lcode_acc*100:.1f}%")

print(f"\n‚úÖ {model_name} ready for training!")

Test data shapes:
  Audio: torch.Size([1, 16000])
  Message: torch.Size([1, 46])
  Lcode: torch.Size([1, 10])

‚úÖ Forward pass successful!
  Watermarked audio shape: torch.Size([1, 16000])
  Audio range: [-0.0000, 0.0000]
  Message accuracy: 43.5%
  Lcode accuracy: 40.0%

‚úÖ IDEAW-Plus ready for training!


In [16]:
# ============================================
# INITIALIZE PERCEPTUAL LOSS
# ============================================

if USE_IDEAW_PLUS:
    perceptual_loss_fn = PerceptualLoss(n_fft=1024, hop_length=256).to(device)
    print("‚úì Perceptual loss initialized")
    print("  This will be added to the training loss")
else:
    perceptual_loss_fn = None
    print("‚úì Using standard loss (no perceptual loss)")

‚úì Perceptual loss initialized
  This will be added to the training loss


## 5. Train Model

In [20]:
# Initialize solver - use Drive path
import sys
import os
import argparse

# Change to IDEAW directory on Drive
IDEAW_PATH = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW'
os.chdir(IDEAW_PATH)
sys.path.insert(0, IDEAW_PATH)

print(f"‚úì Working directory: {os.getcwd()}")

from solver import Solver

# Create args object
args = argparse.Namespace(
    device=device,
    pickle_path=PICKLE_PATH,
    train_config='./config.yaml',
    store_model_path=f'{CHECKPOINT_PATH}/',
    load_model=False,  # Set to True to resume training
    load_model_path=f'{CHECKPOINT_PATH}/stage_I/',
    summary_steps=10,
    save_steps=SAVE_EVERY
)

config_data_path = './data/config.yaml'
config_model_path = './models/config.yaml'

print("Initializing solver...")

# MODIFICATION: Pass our model to solver
solver = Solver(config_data_path, config_model_path, args)

# REPLACE solver's model with our model (baseline or plus)
solver.model = model
solver.model.to(device)

# Reinitialize optimizers for the new model using solver's config
lr1 = eval(solver.config_t["train"]["lr1"])
lr2 = eval(solver.config_t["train"]["lr2"])
beta1 = solver.config_t["train"]["beta1"]
beta2 = solver.config_t["train"]["beta2"]
eps = eval(solver.config_t["train"]["eps"])
weight_decay = eval(solver.config_t["train"]["weight_decay"])

# Rebuild optimizers with our model
param_hinet1 = list(filter(lambda p: p.requires_grad, solver.model.hinet_1.parameters()))
param_hinet2 = list(filter(lambda p: p.requires_grad, solver.model.hinet_2.parameters()))
param_discr = list(filter(lambda p: p.requires_grad, solver.model.discriminator.parameters()))
param_att = list(filter(lambda p: p.requires_grad, solver.model.attack_layer.parameters()))
param_balance = list(filter(lambda p: p.requires_grad, solver.model.balance_block.parameters()))

solver.optim_I = torch.optim.Adam(
    param_hinet1 + param_hinet2,
    lr=lr1,
    betas=(beta1, beta2),
    eps=eps,
    weight_decay=weight_decay,
)
solver.optim_II = torch.optim.Adam(
    param_discr + param_att + param_balance,
    lr=lr2,
    betas=(beta1, beta2),
    eps=eps,
    weight_decay=weight_decay,
)

print(f"‚úì Solver initialized with {model_name}")
print(f"‚úì Optimizers rebuilt for {model_name}")
print("\nStarting training...")
print("="*50)

‚úì Working directory: /content/drive/MyDrive/audio-watermarking-demo/research/IDEAW
Initializing solver...
[IDEAW]infinite dataloader built
[IDEAW]model built
[IDEAW]total parameter count: 8425023
[IDEAW]optimizers built
‚úì Solver initialized with IDEAW-Plus
‚úì Optimizers rebuilt for IDEAW-Plus

Starting training...


In [None]:
# Training loop
import time

start_time = time.time()

try:
    solver.train(NUM_ITERATIONS)

    training_time = time.time() - start_time
    print("\n" + "="*50)
    print("‚úÖ TRAINING COMPLETE")
    print("="*50)
    print(f"Time: {training_time/60:.1f} minutes")
    print(f"Checkpoints saved to: {CHECKPOINT_PATH}")

except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Training interrupted")
    print("Checkpoints saved.")

except Exception as e:
    print(f"\n‚ùå Error: {e}")
    import traceback
    traceback.print_exc()

[IDEAW]starting training...




In [15]:
# Simpler test - just check if checkpoint loads and model structure is correct
import sys
import os
import torch

sys.path.insert(0, '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW')

from models.ideaw import IDEAW

# Initialize model
model_config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/models/config.yaml'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

ideaw = IDEAW(model_config_path, device)
ideaw = ideaw.to(device)
print("‚úì IDEAW model initialized")

# Load checkpoint
checkpoint_path = '/content/drive/MyDrive/audio-watermarking-demo/checkpoints/stage_I/ideaw.ckpt'

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    ideaw.load_state_dict(checkpoint)
    ideaw.eval()
    print("‚úì Checkpoint loaded successfully")

    # Count parameters
    total_params = sum(p.numel() for p in ideaw.parameters())
    print(f"‚úì Model parameters: {total_params:,}")

    print("\n‚úÖ CHECKPOINT TEST PASSED!")
    print("The model checkpoint is valid and can be loaded.")
    print("\nTo properly test watermarking:")
    print("1. Use the standalone_demo.py script")
    print("2. Or continue training to improve accuracy")

else:
    print(f"‚ùå Checkpoint not found at {checkpoint_path}")


‚úì IDEAW model initialized
Loading checkpoint from: /content/drive/MyDrive/audio-watermarking-demo/checkpoints/stage_I/ideaw.ckpt
‚úì Checkpoint loaded successfully
‚úì Model parameters: 8,425,023

‚úÖ CHECKPOINT TEST PASSED!
The model checkpoint is valid and can be loaded.

To properly test watermarking:
1. Use the standalone_demo.py script
2. Or continue training to improve accuracy


## 6. Visualize Training Results

In [12]:
# Plot training curves
import matplotlib.pyplot as plt
import pandas as pd

log_file = f'{RESULTS_PATH}/training_log.csv'

if os.path.exists(log_file):
    df = pd.read_csv(log_file)

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss
    axes[0, 0].plot(df['epoch'], df['loss'])
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].grid(True)

    # SNR
    axes[0, 1].plot(df['epoch'], df['snr'])
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('SNR (dB)')
    axes[0, 1].set_title('Signal-to-Noise Ratio')
    axes[0, 1].grid(True)

    # Accuracy
    axes[1, 0].plot(df['epoch'], df['accuracy'])
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].set_title('Watermark Accuracy')
    axes[1, 0].grid(True)

    # Learning rate
    if 'learning_rate' in df.columns:
        axes[1, 1].plot(df['epoch'], df['learning_rate'])
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_title('Learning Rate Schedule')
        axes[1, 1].set_yscale('log')
        axes[1, 1].grid(True)

    plt.tight_layout()
    plt.savefig(f'{RESULTS_PATH}/training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

    print("‚úì Training curves saved to:", f'{RESULTS_PATH}/training_curves.png')

    # Print final metrics
    print("\nFinal Metrics:")
    print(f"  Loss: {df['loss'].iloc[-1]:.4f}")
    print(f"  SNR: {df['snr'].iloc[-1]:.2f} dB")
    print(f"  Accuracy: {df['accuracy'].iloc[-1]:.2f}%")
else:
    print("‚ö†Ô∏è No training log found")

‚ö†Ô∏è No training log found


## 7. Test Trained Model

In [None]:
# Load best checkpoint
best_checkpoint = f'{CHECKPOINT_PATH}/best_model.pth'

if os.path.exists(best_checkpoint):
    print("Loading best model...")
    checkpoint = torch.load(best_checkpoint)
    ideaw.load_state_dict(checkpoint['model_state_dict'])
    ideaw.eval()
    print("‚úì Best model loaded")

    # Test on sample audio
    import librosa
    import numpy as np

    # Load test audio
    test_audio_path = f'{LOCAL_DATA_PATH}/val/test_audio.wav'  # Update with your test file

    if os.path.exists(test_audio_path):
        audio, sr = librosa.load(test_audio_path, sr=16000)
        audio_tensor = torch.FloatTensor(audio).unsqueeze(0).to(device)

        # Generate random message and location code
        message = torch.randint(0, 2, (1, 16), dtype=torch.float32).to(device)
        lcode = torch.randint(0, 2, (1, 10), dtype=torch.float32).to(device)

        with torch.no_grad():
            # Embed
            audio_wmd1, _ = ideaw.embed_msg(audio_tensor, message)
            audio_wmd2, _ = ideaw.embed_lcode(audio_wmd1, lcode)

            # Extract
            mid_stft, lcode_extracted = ideaw.extract_lcode(audio_wmd2)
            message_extracted = ideaw.extract_msg(mid_stft)

            # Calculate accuracy
            msg_acc = ((message_extracted > 0.5).float() == message).float().mean().item() * 100
            lcode_acc = ((lcode_extracted > 0.5).float() == lcode).float().mean().item() * 100

            print(f"\nTest Results:")
            print(f"  Message accuracy: {msg_acc:.2f}%")
            print(f"  Location code accuracy: {lcode_acc:.2f}%")
    else:
        print(f"‚ö†Ô∏è Test audio not found at {test_audio_path}")
else:
    print(f"‚ö†Ô∏è Checkpoint not found at {best_checkpoint}")

## 8. Download Results

In [None]:
# Zip checkpoints and results
!zip -r checkpoints.zip {CHECKPOINT_PATH}
!zip -r results.zip {RESULTS_PATH}

print("‚úì Files zipped")
print("\nYou can download:")
print("  1. checkpoints.zip - Trained model weights")
print("  2. results.zip - Training logs and plots")
print("\nOr access them directly from Google Drive at:")
print(f"  {DRIVE_PATH}")

In [None]:
# Optional: Download directly from Colab
from google.colab import files

# Uncomment to download
# files.download('checkpoints.zip')
# files.download('results.zip')

## 9. Push Code Updates to GitHub (Optional)

In [None]:
# If you made code changes in Colab, push them back to GitHub

# Configure git (first time only)
!git config --global user.email "your.email@example.com"
!git config --global user.name "Your Name"

# Check what changed
!git status

# Add, commit, and push (uncomment to use)
# !git add .
# !git commit -m "Updated training code from Colab"
# !git push

print("\nNote: You'll need to authenticate with GitHub token if pushing")
print("Generate token at: https://github.com/settings/tokens")

## 10. Pull Latest Code Updates (Optional)

In [None]:
# If you updated code on your local machine, pull latest changes
!git pull origin main

print("‚úì Code updated from GitHub")

## 11. Keep Session Alive (Optional)

Run this JavaScript in your browser console to prevent disconnection:

```javascript
function KeepAlive() {
    console.log("Keeping session alive...");
    document.querySelector("colab-connect-button").click();
}
setInterval(KeepAlive, 60000);
```