# IDEAW Training on Google Colab

This notebook trains IDEAW audio watermarking models using Colab's free GPU.

**Before running:**
1. Enable GPU: Runtime ‚Üí Change runtime type ‚Üí GPU
2. Upload your data to Google Drive
3. Update the GitHub URL below with your repository

## 1. Setup Environment

In [1]:
from google.colab import drive
drive.mount('/content/drive')


%cd /content/drive/MyDrive/audio-watermarking-demo


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/audio-watermarking-demo


In [13]:
!git status


^C


In [None]:
!git reset --soft HEAD~5


In [24]:
!git add colab_notebooks/IDEAW_Training_Template.ipynb

In [40]:
!git config --global user.name "Abdullah Yassir"
!git config --global user.email "abdullahyassir2222@gmail.com"


In [41]:
!git commit -m "directories added"

!git push origin main


On branch main
Your branch is up to date with 'origin/main'.

Changes not staged for commit:
  (use "git add <file>..." to update what will be committed)
  (use "git restore <file>..." to discard changes in working directory)
	[31mmodified:   colab_notebooks/IDEAW_Training_Template.ipynb[m

no changes added to commit (use "git add" and/or "git commit -a")
Everything up-to-date


In [28]:
# List all untracked files recursively, excluding those ignored by .gitignore
!git ls-files --others --exclude-standard

In [None]:
!git add colab_notebooks/IDEAW_Training_Template.ipynb

# 4. Commit with a message
!git commit -m "Running training loop"

# 5. Push to GitHub
!git push origin main

[main 486d4a1] Running training loop
 1 file changed, 1 insertion(+), 1 deletion(-)
 rewrite colab_notebooks/IDEAW_Training_Template.ipynb (97%)
Enumerating objects: 7, done.
Counting objects: 100% (7/7), done.
Delta compression using up to 2 threads
Compressing objects: 100% (4/4), done.
Writing objects: 100% (4/4), 5.23 KiB | 382.00 KiB/s, done.
Total 4 (delta 2), reused 0 (delta 0), pack-reused 0
remote: Resolving deltas: 100% (2/2), completed with 2 local objects.[K
remote: This repository moved. Please use the new location:[K
remote:   https://github.com/Abdullahyassir007/audio-watermarking-demo.git[K
To https://github.com/AbdullahYassir007/audio-watermarking-demo.git
   8dccca0..486d4a1  main -> main


In [16]:
# !git checkout -- colab_notebooks/IDEAW_Training_Template.ipynb
!git pull origin main

From https://github.com/Abdullahyassir007/audio-watermarking-demo
 * branch            main       -> FETCH_HEAD
Already up to date.


In [None]:
# Abort the rebase
!git rebase --abort

# Accept the remote version (my fix)
!git reset --hard origin/main

# Now re-apply just your notebook and config changes
!git checkout HEAD~1 -- colab_notebooks/IDEAW_Training_Template.ipynb
!git checkout HEAD~1 -- research/IDEAW/config.yaml

# Commit these changes
!git add colab_notebooks/IDEAW_Training_Template.ipynb research/IDEAW/config.yaml
!git commit -m "Update Colab notebook and config for batch size 2"

# Push
!git push origin main


HEAD is now at e0e1c82 Fix IDEAW PyTorch 2.x compatibility - STFT/iSTFT complex tensor handling
On branch main
Your branch is up to date with 'origin/main'.

nothing to commit, working tree clean
Everything up-to-date


In [3]:


# Set up paths
DRIVE_PATH = '/content/drive/MyDrive/audio-watermarking-demo'
DATA_PATH = f'{DRIVE_PATH}/Dataset'
CHECKPOINT_PATH = f'{DRIVE_PATH}/checkpoints'
RESULTS_PATH = f'{DRIVE_PATH}/results'

# Create directories
import os
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
os.makedirs(RESULTS_PATH, exist_ok=True)

print("‚úì Google Drive mounted")
print(f"‚úì Data path: {DATA_PATH}")
print(f"‚úì Checkpoint path: {CHECKPOINT_PATH}")
print(f"‚úì Results path: {RESULTS_PATH}")

‚úì Google Drive mounted
‚úì Data path: /content/drive/MyDrive/audio-watermarking-demo/Dataset
‚úì Checkpoint path: /content/drive/MyDrive/audio-watermarking-demo/checkpoints
‚úì Results path: /content/drive/MyDrive/audio-watermarking-demo/results


In [None]:
# # Just install the missing packages, use Colab's existing PyTorch
# !pip install -q librosa==0.10.1 pydub PyYAML soundfile tqdm resampy

# # Restart runtime
# import os
# os.kill(os.getpid(), 9)



In [3]:
# # Install dependencies from IDEAW requirements
# !pip install -q -r research/IDEAW/requirements_colab.txt
# !pip install -q FrEIA

# print("‚úì Dependencies installed")

‚úì Dependencies installed


In [4]:
# Check GPU availability
import torch

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    device = 'cuda'
else:
    print("‚ö†Ô∏è No GPU available, using CPU")
    device = 'cpu'

print(f"\n‚úì Using device: {device}")

GPU Available: True
GPU Name: Tesla T4
GPU Memory: 15.83 GB

‚úì Using device: cuda


In [5]:
# Verify installation
import torch
import librosa
import scipy
import numpy as np
import yaml

print("=" * 50)
print("ENVIRONMENT CHECK")
print("=" * 50)
print(f"‚úì PyTorch: {torch.__version__}")
print(f"‚úì Librosa: {librosa.__version__}")
print(f"‚úì Scipy: {scipy.__version__}")
print(f"‚úì Numpy: {np.__version__}")
print(f"‚úì CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")
print("=" * 50)


ENVIRONMENT CHECK
‚úì PyTorch: 2.8.0+cu126
‚úì Librosa: 0.10.1
‚úì Scipy: 1.11.4
‚úì Numpy: 1.26.4
‚úì CUDA available: True
‚úì GPU: Tesla T4


## 2. Load IDEAW Model

In [None]:
# # Import IDEAW
# import sys
# sys.path.insert(0, '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW')

# from models.ideaw import IDEAW

# # Configuration
# config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/config.yaml'
# model_config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/models/config.yaml'

# # Initialize model
# ideaw = IDEAW(model_config_path, device)
# print("‚úì IDEAW model initialized")

# # Count parameters
# total_params = sum(p.numel() for p in ideaw.parameters())
# trainable_params = sum(p.numel() for p in ideaw.parameters() if p.requires_grad)
# print(f"Total parameters: {total_params:,}")
# print(f"Trainable parameters: {trainable_params:,}")

‚úì IDEAW model initialized
Total parameters: 8,425,023
Trainable parameters: 8,425,023


## 3. Prepare Data

In [6]:
# ============================================
# PREPARE DATA FOR IDEAW TRAINING
# ============================================
import os
import pickle
import librosa
import numpy as np
from tqdm import tqdm

# Paths
DRIVE_PATH = '/content/drive/MyDrive/audio-watermarking-demo'
RAW_DATA_PATH = f'{DRIVE_PATH}/Dataset'
PROCESSED_DATA_PATH = '/content/processed_data'
CHECKPOINT_PATH = f'{DRIVE_PATH}/checkpoints'
RESULTS_PATH = f'{DRIVE_PATH}/results'

# Parameters
MAX_FILES = 50  # Quick test with 50 files (set to None for all)
SAMPLE_RATE = 16000
SEGMENT_SAMPLES = 16000  # 1 second

# Create directories
os.makedirs(PROCESSED_DATA_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
os.makedirs(f'{CHECKPOINT_PATH}/stage_I', exist_ok=True)
os.makedirs(f'{CHECKPOINT_PATH}/stage_II', exist_ok=True)
os.makedirs(RESULTS_PATH, exist_ok=True)

print("="*50)
print("DATA PREPARATION")
print("="*50)

# Find audio files
if not os.path.exists(RAW_DATA_PATH):
    print(f"‚ùå Data not found at {RAW_DATA_PATH}")
else:
    audio_extensions = ['.mp3', '.wav', '.flac', '.m4a']
    audio_files = []

    for root, dirs, files in os.walk(RAW_DATA_PATH):
        for file in files:
            if any(file.lower().endswith(ext) for ext in audio_extensions):
                audio_files.append(os.path.join(root, file))

    print(f"\n‚úì Found {len(audio_files)} audio files")

    # Limit for testing
    if MAX_FILES and len(audio_files) > MAX_FILES:
        audio_files = audio_files[:MAX_FILES]
        print(f"‚úì Using {MAX_FILES} files for quick test")

    if len(audio_files) > 0:
        print(f"\nProcessing {len(audio_files)} files...")
        print(f"Target: 16kHz, 1-second segments")

        data = []

        for audio_path in tqdm(audio_files):
            try:
                # Load and resample
                audio, sr = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True)

                # Split into 1-second segments
                num_segments = int(len(audio) / SEGMENT_SAMPLES)

                for i in range(num_segments):
                    start = i * SEGMENT_SAMPLES
                    end = start + SEGMENT_SAMPLES
                    segment = audio[start:end]

                    if len(segment) == SEGMENT_SAMPLES:
                        data.append(segment)

            except Exception as e:
                print(f"\n‚ö†Ô∏è  Error: {os.path.basename(audio_path)}")
                continue

        print(f"\n‚úì Processed {len(audio_files)} files")
        print(f"‚úì Created {len(data)} segments")

        if len(data) > 0:
            # Save pickle
            pickle_path = os.path.join(PROCESSED_DATA_PATH, 'audio.pkl')
            with open(pickle_path, 'wb') as f:
                pickle.dump(data, f)

            size_mb = os.path.getsize(pickle_path) / (1024 * 1024)

            print(f"\n‚úì Pickle saved: {pickle_path}")
            print(f"‚úì Segments: {len(data)}")
            print(f"‚úì Duration: {len(data)/60:.1f} minutes")
            print(f"‚úì Size: {size_mb:.1f} MB")

            print("\n" + "="*50)
            print("‚úÖ DATA READY FOR TRAINING")
            print("="*50)

            PICKLE_PATH = pickle_path
        else:
            print("‚ùå No segments created")
    else:
        print("‚ùå No audio files found")

DATA PREPARATION

‚úì Found 2699 audio files
‚úì Using 50 files for quick test

Processing 50 files...
Target: 16kHz, 1-second segments


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [00:03<00:00, 16.60it/s]


‚úì Processed 50 files
‚úì Created 396 segments

‚úì Pickle saved: /content/processed_data/audio.pkl
‚úì Segments: 396
‚úì Duration: 6.6 minutes
‚úì Size: 24.2 MB

‚úÖ DATA READY FOR TRAINING





## 4. Training Configuration

In [7]:
# Override batch size in config file
import yaml

config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/config.yaml'

# Read config
with open(config_path, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

# Change batch size
config['train']['batch_size'] = 1  # Try batch size 2 (very small)
config['train']['num_workers'] = 0  # Disable multiprocessing

# Save config
with open(config_path, 'w') as f:
    yaml.dump(config, f)

print(f"‚úì Updated config: batch_size = {config['train']['batch_size']}")


‚úì Updated config: batch_size = 1


In [8]:
# Training hyperparameters
BATCH_SIZE = 1
NUM_ITERATIONS = 100  # Quick test (use 10000+ for full training)
SAVE_EVERY = 40

print("Training Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Iterations: {NUM_ITERATIONS}")
print(f"  Device: {device}")
print(f"  Save every: {SAVE_EVERY} iterations")
print(f"  Pickle path: {PICKLE_PATH}")

Training Configuration:
  Batch size: 1
  Iterations: 100
  Device: cuda
  Save every: 40 iterations
  Pickle path: /content/processed_data/audio.pkl


## 5. Train Model

In [9]:
# Verify IDEAW files exist
import os

ideaw_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW'

print("="*50)
print("VERIFYING IDEAW FILES")
print("="*50)

# Check if directory exists
if not os.path.exists(ideaw_path):
    print(f"‚ùå IDEAW directory not found: {ideaw_path}")
else:
    print(f"‚úì IDEAW directory exists: {ideaw_path}")

    # List all files in IDEAW
    print("\nFiles in IDEAW:")
    for item in os.listdir(ideaw_path):
        item_path = os.path.join(ideaw_path, item)
        if os.path.isdir(item_path):
            print(f"  üìÅ {item}/")
        else:
            print(f"  üìÑ {item}")

    # Check critical files
    critical_files = [
        'solver.py',
        'config.yaml',
        'metrics.py',
        'data/dataset.py',
        'data/config.yaml',
        'models/ideaw.py',
        'models/config.yaml'
    ]

    print("\nCritical files check:")
    all_exist = True
    for file in critical_files:
        file_path = os.path.join(ideaw_path, file)
        if os.path.exists(file_path):
            print(f"  ‚úì {file}")
        else:
            print(f"  ‚ùå {file} - MISSING!")
            all_exist = False

    if all_exist:
        print("\n‚úÖ All critical files present")
    else:
        print("\n‚ùå Some files are missing!")
        print("You may need to re-clone the repository")

print("="*50)


VERIFYING IDEAW FILES
‚úì IDEAW directory exists: /content/drive/MyDrive/audio-watermarking-demo/research/IDEAW

Files in IDEAW:
  üìÑ requirements.txt
  üìÑ embed_extract.py
  üìÑ solver.py
  üìÑ LICENSE
  üìÑ train.sh
  üìÑ train.py
  üìÑ README.md
  üìÅ models/
  üìÅ _DataParallel_version/
  üìÑ requirements_colab.txt
  üìÅ __pycache__/
  üìÅ data/
  üìÑ .gitignore
  üìÑ metrics.py
  üìÑ config.yaml
  üìÅ output/
  üìÅ tmp/

Critical files check:
  ‚úì solver.py
  ‚úì config.yaml
  ‚úì metrics.py
  ‚úì data/dataset.py
  ‚úì data/config.yaml
  ‚úì models/ideaw.py
  ‚úì models/config.yaml

‚úÖ All critical files present


In [10]:
# Initialize solver - use Drive path
import sys
import os
import argparse

# Change to IDEAW directory on Drive
IDEAW_PATH = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW'
os.chdir(IDEAW_PATH)
sys.path.insert(0, IDEAW_PATH)

print(f"‚úì Working directory: {os.getcwd()}")

from solver import Solver

# Create args object
args = argparse.Namespace(
    device=device,
    pickle_path=PICKLE_PATH,
    train_config='./config.yaml',
    store_model_path=f'{CHECKPOINT_PATH}/',
    load_model=True,  # Changed to True
    load_model_path=f'{CHECKPOINT_PATH}/stage_I/',  # Load from stage_I
    summary_steps=10,
    save_steps=SAVE_EVERY
)


config_data_path = './data/config.yaml'
config_model_path = './models/config.yaml'

print("Initializing solver...")
solver = Solver(config_data_path, config_model_path, args)

print("‚úì Solver initialized")
print("\nStarting training...")
print("="*50)

‚úì Working directory: /content/drive/MyDrive/audio-watermarking-demo/research/IDEAW
Initializing solver...
[IDEAW]infinite dataloader built
[IDEAW]model built
[IDEAW]total parameter count: 8425023
[IDEAW]optimizers built
[IDEAW]load model from /content/drive/MyDrive/audio-watermarking-demo/checkpoints/stage_I/
‚úì Solver initialized

Starting training...


In [11]:
os.makedirs('./output', exist_ok=True)
os.makedirs('./tmp', exist_ok=True)

print("‚úì Created output and tmp directories")
print("‚ö†Ô∏è  Restart training from checkpoint")

‚úì Created output and tmp directories
‚ö†Ô∏è  Restart training from checkpoint


In [11]:
# Training loop
import time

start_time = time.time()

try:
    solver.train(NUM_ITERATIONS)

    training_time = time.time() - start_time
    print("\n" + "="*50)
    print("‚úÖ TRAINING COMPLETE")
    print("="*50)
    print(f"Time: {training_time/60:.1f} minutes")
    print(f"Checkpoints saved to: {CHECKPOINT_PATH}")

except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Training interrupted")
    print("Checkpoints saved.")

except Exception as e:
    print(f"\n‚ùå Error: {e}")
    import traceback
    traceback.print_exc()

[IDEAW]starting training...
[IDEAW]:[10/100] Robustness=False shift=False loss_percept=0.063857 loss_integ=1.115202 loss_discr=1.350542 loss_ident=0.657412 SNR=-3.619636 acc_msg=0.521739 acc_lcode=0.800000
[IDEAW]:[20/100] Robustness=False shift=False loss_percept=0.048959 loss_integ=1.177563 loss_discr=1.358223 loss_ident=0.664613 SNR=3.226653 acc_msg=0.478261 acc_lcode=0.900000
[IDEAW]:[30/100] Robustness=False shift=True loss_percept=0.031871 loss_integ=1.132025 loss_discr=1.365473 loss_ident=0.671432 SNR=8.031628 acc_msg=0.521739 acc_lcode=0.700000
[IDEAW]:[40/100] Robustness=False shift=True loss_percept=0.075020 loss_integ=1.176194 loss_discr=1.360777 loss_ident=0.656165 SNR=15.951397 acc_msg=0.500000 acc_lcode=0.800000
[IDEAW]:[50/100] Robustness=False shift=True loss_percept=0.075395 loss_integ=1.190669 loss_discr=1.346068 loss_ident=0.654293 SNR=0.570390 acc_msg=0.478261 acc_lcode=0.700000


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  k = math.sqrt(p_s / (10 ** (self.snr / 10) * p_n))


[IDEAW]:[60/100] Robustness=True shift=True loss_percept=0.047829 loss_integ=1.377067 loss_discr=1.357157 loss_ident=0.662954 SNR=10.746403 acc_msg=0.434783 acc_lcode=0.500000
[IDEAW]:[70/100] Robustness=True shift=True loss_percept=0.041373 loss_integ=1.007230 loss_discr=1.358699 loss_ident=0.664071 SNR=14.993973 acc_msg=0.543478 acc_lcode=1.000000
[IDEAW]:[80/100] Robustness=True shift=True loss_percept=0.045261 loss_integ=1.164460 loss_discr=1.353836 loss_ident=0.659120 SNR=11.130965 acc_msg=0.478261 acc_lcode=0.800000
[IDEAW]:[90/100] Robustness=True shift=True loss_percept=0.029037 loss_integ=1.274569 loss_discr=1.371923 loss_ident=0.679292 SNR=14.167888 acc_msg=0.413043 acc_lcode=0.900000
[IDEAW]:[100/100] Robustness=True shift=True loss_percept=0.045201 loss_integ=1.151209 loss_discr=1.355892 loss_ident=0.662642 SNR=5.919544 acc_msg=0.478261 acc_lcode=1.000000

‚úÖ TRAINING COMPLETE
Time: 3.8 minutes
Checkpoints saved to: /content/drive/MyDrive/audio-watermarking-demo/checkpoint

In [15]:
# Simpler test - just check if checkpoint loads and model structure is correct
import sys
import os
import torch

sys.path.insert(0, '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW')

from models.ideaw import IDEAW

# Initialize model
model_config_path = '/content/drive/MyDrive/audio-watermarking-demo/research/IDEAW/models/config.yaml'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

ideaw = IDEAW(model_config_path, device)
ideaw = ideaw.to(device)
print("‚úì IDEAW model initialized")

# Load checkpoint
checkpoint_path = '/content/drive/MyDrive/audio-watermarking-demo/checkpoints/stage_I/ideaw.ckpt'

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    ideaw.load_state_dict(checkpoint)
    ideaw.eval()
    print("‚úì Checkpoint loaded successfully")

    # Count parameters
    total_params = sum(p.numel() for p in ideaw.parameters())
    print(f"‚úì Model parameters: {total_params:,}")

    print("\n‚úÖ CHECKPOINT TEST PASSED!")
    print("The model checkpoint is valid and can be loaded.")
    print("\nTo properly test watermarking:")
    print("1. Use the standalone_demo.py script")
    print("2. Or continue training to improve accuracy")

else:
    print(f"‚ùå Checkpoint not found at {checkpoint_path}")


‚úì IDEAW model initialized
Loading checkpoint from: /content/drive/MyDrive/audio-watermarking-demo/checkpoints/stage_I/ideaw.ckpt
‚úì Checkpoint loaded successfully
‚úì Model parameters: 8,425,023

‚úÖ CHECKPOINT TEST PASSED!
The model checkpoint is valid and can be loaded.

To properly test watermarking:
1. Use the standalone_demo.py script
2. Or continue training to improve accuracy


## 6. Visualize Training Results

In [12]:
# Plot training curves
import matplotlib.pyplot as plt
import pandas as pd

log_file = f'{RESULTS_PATH}/training_log.csv'

if os.path.exists(log_file):
    df = pd.read_csv(log_file)

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss
    axes[0, 0].plot(df['epoch'], df['loss'])
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].grid(True)

    # SNR
    axes[0, 1].plot(df['epoch'], df['snr'])
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('SNR (dB)')
    axes[0, 1].set_title('Signal-to-Noise Ratio')
    axes[0, 1].grid(True)

    # Accuracy
    axes[1, 0].plot(df['epoch'], df['accuracy'])
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].set_title('Watermark Accuracy')
    axes[1, 0].grid(True)

    # Learning rate
    if 'learning_rate' in df.columns:
        axes[1, 1].plot(df['epoch'], df['learning_rate'])
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_title('Learning Rate Schedule')
        axes[1, 1].set_yscale('log')
        axes[1, 1].grid(True)

    plt.tight_layout()
    plt.savefig(f'{RESULTS_PATH}/training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

    print("‚úì Training curves saved to:", f'{RESULTS_PATH}/training_curves.png')

    # Print final metrics
    print("\nFinal Metrics:")
    print(f"  Loss: {df['loss'].iloc[-1]:.4f}")
    print(f"  SNR: {df['snr'].iloc[-1]:.2f} dB")
    print(f"  Accuracy: {df['accuracy'].iloc[-1]:.2f}%")
else:
    print("‚ö†Ô∏è No training log found")

‚ö†Ô∏è No training log found


## 7. Test Trained Model

In [None]:
# Load best checkpoint
best_checkpoint = f'{CHECKPOINT_PATH}/best_model.pth'

if os.path.exists(best_checkpoint):
    print("Loading best model...")
    checkpoint = torch.load(best_checkpoint)
    ideaw.load_state_dict(checkpoint['model_state_dict'])
    ideaw.eval()
    print("‚úì Best model loaded")

    # Test on sample audio
    import librosa
    import numpy as np

    # Load test audio
    test_audio_path = f'{LOCAL_DATA_PATH}/val/test_audio.wav'  # Update with your test file

    if os.path.exists(test_audio_path):
        audio, sr = librosa.load(test_audio_path, sr=16000)
        audio_tensor = torch.FloatTensor(audio).unsqueeze(0).to(device)

        # Generate random message and location code
        message = torch.randint(0, 2, (1, 16), dtype=torch.float32).to(device)
        lcode = torch.randint(0, 2, (1, 10), dtype=torch.float32).to(device)

        with torch.no_grad():
            # Embed
            audio_wmd1, _ = ideaw.embed_msg(audio_tensor, message)
            audio_wmd2, _ = ideaw.embed_lcode(audio_wmd1, lcode)

            # Extract
            mid_stft, lcode_extracted = ideaw.extract_lcode(audio_wmd2)
            message_extracted = ideaw.extract_msg(mid_stft)

            # Calculate accuracy
            msg_acc = ((message_extracted > 0.5).float() == message).float().mean().item() * 100
            lcode_acc = ((lcode_extracted > 0.5).float() == lcode).float().mean().item() * 100

            print(f"\nTest Results:")
            print(f"  Message accuracy: {msg_acc:.2f}%")
            print(f"  Location code accuracy: {lcode_acc:.2f}%")
    else:
        print(f"‚ö†Ô∏è Test audio not found at {test_audio_path}")
else:
    print(f"‚ö†Ô∏è Checkpoint not found at {best_checkpoint}")

## 8. Download Results

In [None]:
# Zip checkpoints and results
!zip -r checkpoints.zip {CHECKPOINT_PATH}
!zip -r results.zip {RESULTS_PATH}

print("‚úì Files zipped")
print("\nYou can download:")
print("  1. checkpoints.zip - Trained model weights")
print("  2. results.zip - Training logs and plots")
print("\nOr access them directly from Google Drive at:")
print(f"  {DRIVE_PATH}")

In [None]:
# Optional: Download directly from Colab
from google.colab import files

# Uncomment to download
# files.download('checkpoints.zip')
# files.download('results.zip')

## 9. Push Code Updates to GitHub (Optional)

In [None]:
# If you made code changes in Colab, push them back to GitHub

# Configure git (first time only)
!git config --global user.email "your.email@example.com"
!git config --global user.name "Your Name"

# Check what changed
!git status

# Add, commit, and push (uncomment to use)
# !git add .
# !git commit -m "Updated training code from Colab"
# !git push

print("\nNote: You'll need to authenticate with GitHub token if pushing")
print("Generate token at: https://github.com/settings/tokens")

## 10. Pull Latest Code Updates (Optional)

In [None]:
# If you updated code on your local machine, pull latest changes
!git pull origin main

print("‚úì Code updated from GitHub")

## 11. Keep Session Alive (Optional)

Run this JavaScript in your browser console to prevent disconnection:

```javascript
function KeepAlive() {
    console.log("Keeping session alive...");
    document.querySelector("colab-connect-button").click();
}
setInterval(KeepAlive, 60000);
```