In [7]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
import matplotlib
matplotlib.use('Agg')  # Set backend to avoid display issues
import matplotlib.pyplot as plt
import urllib.request
import tarfile
import zipfile
from pathlib import Path
import glob
import argparse
import re
import shutil
import sys
import librosa
import soundfile as sf

# Handle Colab vs. Kaggle environment differences
IN_COLAB = 'google.colab' in sys.modules
IN_KAGGLE = 'kaggle_secrets' in sys.modules if not IN_COLAB else False

if IN_COLAB:
    # Colab directories
    BASE_DIR = Path('/content')
    MODEL_DIR = BASE_DIR / 'models'
    LOG_DIR = BASE_DIR / 'logs'
    WORKING_DIR = BASE_DIR / 'working'
else:
    # Kaggle directories - updated to use /kaggle/working for outputs
    BASE_DIR = Path('/kaggle')
    WORKING_DIR = BASE_DIR / 'working'
    MODEL_DIR = WORKING_DIR / 'models'
    LOG_DIR = WORKING_DIR / 'logs'

# Create directories
WORKING_DIR.mkdir(exist_ok=True, parents=True)
MODEL_DIR.mkdir(exist_ok=True, parents=True)
LOG_DIR.mkdir(exist_ok=True, parents=True)

# Configuration
SAMPLE_RATE = 16000  # Standard speech sampling rate
FRAME_LENGTH = 2048  # Length of audio frames
HOP_LENGTH = 512     # Hop length between frames

def setup_arg_parser():
    """Setup argument parser for training configuration"""
    parser = argparse.ArgumentParser(description='Train an audio denoising autoencoder')
    parser.add_argument('--resume-from', type=str, default=None, 
                        help='Checkpoint file to resume training from')
    parser.add_argument('--keep-checkpoints', type=int, default=3, 
                        help='Number of recent checkpoints to keep')
    parser.add_argument('--epochs', type=int, default=100, 
                        help='Number of epochs to train')
    parser.add_argument('--batch-size', type=int, default=32, 
                        help='Batch size for training')
    
    return parser.parse_known_args()[0]

def setup_gpus():
    """Configure TensorFlow to use multiple GPUs if available"""
    gpus = tf.config.list_physical_devices('GPU')
    
    if not gpus:
        print("No GPUs found. Running on CPU.")
        return False
    
    print(f"Found {len(gpus)} GPU(s):")
    for gpu in gpus:
        print(f"  - {gpu.name}")
    
    # Multi-GPU strategy
    if len(gpus) > 1:
        strategy = tf.distribute.MirroredStrategy()
        print(f"Using MirroredStrategy with {strategy.num_replicas_in_sync} devices")
        return strategy
    else:
        print("Using default strategy (single GPU)")
        return tf.distribute.get_strategy()

def build_audio_denoising_autoencoder(strategy=None, input_shape=(128, 128, 1)):
    """
    Build a professional-level 2D CNN-based audio denoising autoencoder
    with multi-GPU support if available
    """
    if strategy:
        with strategy.scope():
            inputs = layers.Input(shape=input_shape)
            
            # Encoder
            x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
            x = layers.BatchNormalization()(x)
            x = layers.MaxPooling2D((2, 2))(x)
            
            x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.MaxPooling2D((2, 2))(x)
            
            x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.MaxPooling2D((2, 2))(x)
            
            # Bottleneck
            x = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
            x = layers.BatchNormalization()(x)
            
            # Decoder
            x = layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
            x = layers.BatchNormalization()(x)
            
            x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
            x = layers.BatchNormalization()(x)
            
            x = layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
            x = layers.BatchNormalization()(x)
            
            # Output layer
            outputs = layers.Conv2D(1, (1, 1), activation='linear')(x)
            
            model = models.Model(inputs, outputs)
            model.compile(
                optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
                loss='mse',
                metrics=['mae']
            )
    else:
        # Standard model creation without strategy
        inputs = layers.Input(shape=input_shape)
        
        # Encoder (similar structure as with strategy)
        x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling2D((2, 2))(x)
        
        x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling2D((2, 2))(x)
        
        x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling2D((2, 2))(x)
        
        # Bottleneck
        x = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        
        # Decoder
        x = layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
        x = layers.BatchNormalization()(x)
        
        x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
        x = layers.BatchNormalization()(x)
        
        x = layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
        x = layers.BatchNormalization()(x)
        
        # Output layer
        outputs = layers.Conv2D(1, (1, 1), activation='linear')(x)
        
        model = models.Model(inputs, outputs)
        model.compile(
            optimizer='adam',
            loss='mse',
            metrics=['mae']
        )
    
    return model

def download_vctk_dataset():
    """
    Use VCTK dataset from Kaggle input directory
    """
    # Predefined path for Kaggle input dataset
    input_path = Path('/kaggle/input/vctk-corpus/VCTK-Corpus')
    
    # Check if dataset exists
    if not input_path.exists():
        raise ValueError(f"VCTK dataset not found at {input_path}. Please ensure the dataset is uploaded.")
    
    print(f"Using VCTK dataset from: {input_path}")
    return input_path
    
def prepare_audio_data(data_path, target_sr=SAMPLE_RATE, max_files=500, spec_height=128, spec_width=128, batch_size=32):
    """
    Memory-efficient audio data preparation with generator-based processing
    """
    wav_path = Path(data_path) / 'wav48'
    
    def resize_spectrogram(spec, target_height, target_width):
        """
        Resize spectrogram to a consistent shape with robust handling
        """
        # Ensure 2D input
        if spec.ndim > 2:
            spec = spec.squeeze()
        
        # Truncate or pad height
        if spec.shape[0] > target_height:
            spec = spec[:target_height, :]
        else:
            pad_height = target_height - spec.shape[0]
            spec = np.pad(
                spec, 
                ((0, pad_height), (0, 0)), 
                mode='constant', 
                constant_values=0
            )
        
        # Truncate or pad width
        if spec.shape[1] > target_width:
            spec = spec[:, :target_width]
        else:
            pad_width = target_width - spec.shape[1]
            spec = np.pad(
                spec, 
                ((0, 0), (0, pad_width)), 
                mode='constant', 
                constant_values=0
            )
        
        return spec
    
    def spectrogram_generator(file_list, noise_factor=0.05):
        """
        Generator to process spectrograms in memory-efficient batches
        """
        for file_path in file_list:
            try:
                # Load audio file
                audio, sr = librosa.load(str(file_path), sr=target_sr)
                
                # Add noise
                noisy_audio = audio + noise_factor * np.random.normal(0, 1, len(audio))
                
                # Compute spectrograms
                clean_spec = np.abs(librosa.stft(audio, n_fft=FRAME_LENGTH, hop_length=HOP_LENGTH))
                noisy_spec = np.abs(librosa.stft(noisy_audio, n_fft=FRAME_LENGTH, hop_length=HOP_LENGTH))
                
                # Convert to decibel scale
                clean_spec = librosa.amplitude_to_db(clean_spec, ref=np.max)
                noisy_spec = librosa.amplitude_to_db(noisy_spec, ref=np.max)
                
                # Normalize
                clean_spec = (clean_spec - clean_spec.min()) / (clean_spec.max() - clean_spec.min())
                noisy_spec = (noisy_spec - noisy_spec.min()) / (noisy_spec.max() - noisy_spec.min())
                
                # Resize and prepare spectrograms
                clean_spec_resized = resize_spectrogram(clean_spec.T, spec_height, spec_width)
                noisy_spec_resized = resize_spectrogram(noisy_spec.T, spec_height, spec_width)
                
                yield (
                    noisy_spec_resized[np.newaxis, :, :, np.newaxis], 
                    clean_spec_resized[np.newaxis, :, :, np.newaxis]
                )
            except Exception as e:
                print(f"Error processing {file_path}: {e}")
    
    # Find all wav files in speaker subdirectories
    wav_files = []
    for speaker_dir in wav_path.iterdir():
        if speaker_dir.is_dir():
            wav_files.extend(list(speaker_dir.glob('*.wav')))
    
    print(f"Total audio files found: {len(wav_files)}")
    
    # Limit dataset size
    wav_files = wav_files[:max_files]
    
    # Randomly shuffle files
    np.random.shuffle(wav_files)
    
    # Split into training and validation sets
    split_ratio = 0.8
    split_idx = int(len(wav_files) * split_ratio)
    
    train_files = wav_files[:split_idx]
    val_files = wav_files[split_idx:]
    
    # Create generators
    train_generator = spectrogram_generator(train_files)
    val_generator = spectrogram_generator(val_files)
    
    # Collect batches
    def collect_batches(generator, batch_size):
        noisy_batch = []
        clean_batch = []
        
        for noisy, clean in generator:
            noisy_batch.append(noisy)
            clean_batch.append(clean)
            
            if len(noisy_batch) == batch_size:
                yield np.concatenate(noisy_batch), np.concatenate(clean_batch)
                noisy_batch = []
                clean_batch = []
        
        # Handle remaining samples
        if noisy_batch:
            yield np.concatenate(noisy_batch), np.concatenate(clean_batch)
    
    # Collect training and validation data
    train_data = list(collect_batches(train_generator, batch_size))
    val_data = list(collect_batches(val_generator, batch_size))
    
    # Combine batches
    train_noisy = np.concatenate([batch[0] for batch in train_data])
    train_clean = np.concatenate([batch[1] for batch in train_data])
    val_noisy = np.concatenate([batch[0] for batch in val_data])
    val_clean = np.concatenate([batch[1] for batch in val_data])
    
    print(f"Training spectrograms shape: {train_noisy.shape}")
    print(f"Validation spectrograms shape: {val_noisy.shape}")
    
    # Optional: save preprocessed data
    np.save(WORKING_DIR / 'train_noisy_specs.npy', train_noisy)
    np.save(WORKING_DIR / 'train_clean_specs.npy', train_clean)
    np.save(WORKING_DIR / 'val_noisy_specs.npy', val_noisy)
    np.save(WORKING_DIR / 'val_clean_specs.npy', val_clean)
    
    return train_noisy, train_clean, val_noisy, val_clean
def check_for_existing_data():
    """Check if preprocessed data already exists"""
    data_paths = [
        WORKING_DIR / 'train_noisy_specs.npy',
        WORKING_DIR / 'train_clean_specs.npy',
        WORKING_DIR / 'val_noisy_specs.npy',
        WORKING_DIR / 'val_clean_specs.npy'
    ]
    
    if all(path.exists() for path in data_paths):
        print("Found preprocessed spectrograms")
        train_noisy = np.load(data_paths[0])
        train_clean = np.load(data_paths[1])
        val_noisy = np.load(data_paths[2])
        val_clean = np.load(data_paths[3])
        return train_noisy, train_clean, val_noisy, val_clean
    
    return None, None, None, None

def train_model(model, train_noisy, train_clean, val_noisy, val_clean, args):
    """
    Train the audio denoising model with checkpoint management
    """
    epochs = args.epochs
    batch_size = args.batch_size
    
    # Adjust batch size for multi-GPU
    gpus = len(tf.config.list_physical_devices('GPU'))
    if gpus > 1:
        batch_size = max(batch_size, 32 * gpus)
        batch_size = batch_size - (batch_size % gpus) if batch_size % gpus != 0 else batch_size
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        filepath=MODEL_DIR / "audio_denoiser.{epoch:02d}-{val_loss:.4f}.weights.h5",
        save_weights_only=True,
        save_best_only=True,
        monitor='val_loss'
    )
    
    tensorboard_callback = TensorBoard(
        log_dir=LOG_DIR,
        histogram_freq=1,
        write_graph=True,
        update_freq='epoch'
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=15,
        restore_best_weights=True
    )
    
    # Checkpoint management callback
    checkpoint_manager_callback = CheckpointManagerCallback(
        args.keep_checkpoints,
        args.resume_from
    )
    
    history = model.fit(
        train_noisy, train_clean,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(val_noisy, val_clean),
        callbacks=[
            checkpoint_callback, 
            early_stopping, 
            tensorboard_callback,
            checkpoint_manager_callback
        ]
    )
    
    # Save final model
    model.save(MODEL_DIR / "audio_denoiser_final.keras")
    model.save(MODEL_DIR / "audio_denoiser_final.h5")
    
    # Save training history
    np.save(MODEL_DIR / 'training_history.npy', history.history)
    
    # Plot training curves
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(['Train', 'Validation'], loc='upper right')
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['mae'])
    plt.plot(history.history['val_mae'])
    plt.title('Mean Absolute Error')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.legend(['Train', 'Validation'], loc='upper right')
    
    plt.tight_layout()
    plt.savefig(MODEL_DIR / 'training_curves.png')
    
    return model, history

def test_model(model):
    """
    Test the audio denoising model with a sample spectrogram
    """
    # Create a simple test spectrogram
    test_spec = np.random.random((1, 128, 128, 1))
    
    # Add noise
    noisy_spec = test_spec + 0.1 * np.random.normal(0, 1, test_spec.shape)
    
    # Predict denoised spectrogram
    denoised_spec = model.predict(noisy_spec)
    
    # Visualize results
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(test_spec[0, :, :, 0], cmap='viridis')
    plt.title('Original')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(noisy_spec[0, :, :, 0], cmap='viridis')
    plt.title('Noisy')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(denoised_spec[0, :, :, 0], cmap='viridis')
    plt.title('Denoised')
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(MODEL_DIR / 'test_denoising_results.png')
    print("Test results saved to test_denoising_results.png")

def check_for_checkpoint(initial_checkpoint=None, strategy=None):
    """
    Check and load checkpoint if available
    """
    if initial_checkpoint:
        try:
            model, initial_epoch = load_model_from_checkpoint(initial_checkpoint, strategy)
            return model, initial_epoch
        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting from scratch.")
    
    return create_and_save_model(strategy), 0

def create_and_save_model(strategy=None):
    """Create and save initial model"""
    model = build_audio_denoising_autoencoder(strategy)
    model.summary()
    
    # Save model architecture
    model_json = model.to_json()
    with open(MODEL_DIR / "audio_denoiser_architecture.json", "w") as json_file:
        json_file.write(model_json)
    
    model.save_weights(MODEL_DIR / "audio_denoiser_initial.weights.h5")
    model.save(MODEL_DIR / "audio_denoiser_initial.keras")
    model.save(MODEL_DIR / "audio_denoiser_initial.h5")
    
    return model

def load_model_from_checkpoint(checkpoint_path, strategy=None):
    """Load model from checkpoint"""
    model = build_audio_denoising_autoencoder(strategy)
    model.load_weights(checkpoint_path)
    
    # Extract initial epoch
    epoch_match = re.search(r'\.(\d+)-', os.path.basename(checkpoint_path))
    initial_epoch = int(epoch_match.group(1)) if epoch_match else 0
    
    return model, initial_epoch

def manage_checkpoints(keep_count=3, started_checkpoint=None):
    """Manage model checkpoints"""
    checkpoint_pattern = str(MODEL_DIR / "audio_denoiser.*.weights.h5")
    checkpoints = glob.glob(checkpoint_pattern)
    
    if len(checkpoints) <= keep_count:
        return
    
    checkpoint_info = []
    for cp in checkpoints:
        if started_checkpoint and os.path.basename(cp) == os.path.basename(started_checkpoint):
            continue
        
        epoch_match = re.search(r'\.(\d+)-', os.path.basename(cp))
        if epoch_match:
            epoch = int(epoch_match.group(1))
            checkpoint_info.append((cp, epoch))
    
    checkpoint_info.sort(key=lambda x: x[1], reverse=True)
    
    # Delete older checkpoints
    for cp, _ in checkpoint_info[keep_count:]:
        os.remove(cp)

class CheckpointManagerCallback(tf.keras.callbacks.Callback):
    """Callback to manage checkpoints during training"""
    def __init__(self, keep_count, started_checkpoint):
        super().__init__()
        self.keep_count = keep_count
        self.started_checkpoint = started_checkpoint
    
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 5 == 0:
            manage_checkpoints(self.keep_count, self.started_checkpoint)

def main():
    print("Starting Audio Denoising Autoencoder setup...")
    
    # Parse arguments
    args = setup_arg_parser()
    
    # Setup GPU strategy
    strategy = setup_gpus()
    
    # Check for existing preprocessed data
    train_noisy, train_clean, val_noisy, val_clean = check_for_existing_data()
    
    if train_noisy is None:
        # Download dataset
        data_path = download_vctk_dataset()
        print(f"Dataset available at: {data_path}")
        
        # Prepare data
        train_noisy, train_clean, val_noisy, val_clean = prepare_audio_data(data_path)
    
    # Check for checkpoint and create/load model
    model, initial_epoch = check_for_checkpoint(args.resume_from, strategy)
    
    # Train model
    model, history = train_model(
        model, 
        train_noisy, train_clean, 
        val_noisy, val_clean, 
        args
    )
    
    # Test model
    test_model(model)
    
    # Final checkpoint management
    manage_checkpoints(args.keep_checkpoints, args.resume_from)
    
    print("\nSetup complete!")
    print(f"Model files saved to: {MODEL_DIR}")
    print(f"Logs saved to: {LOG_DIR}")

if __name__ == "__main__":
    # Set memory growth for GPUs
    physical_devices = tf.config.list_physical_devices('GPU')
    for device in physical_devices:
        try:
            tf.config.experimental.set_memory_growth(device, True)
        except Exception as e:
            print(f"Could not set memory growth for {device}: {e}")
    
    main()

Starting Audio Denoising Autoencoder setup...
Found 2 GPU(s):
  - /physical_device:GPU:0
  - /physical_device:GPU:1
Using MirroredStrategy with 2 devices
Using VCTK dataset from: /kaggle/input/vctk-corpus/VCTK-Corpus
Dataset available at: /kaggle/input/vctk-corpus/VCTK-Corpus
Total audio files found: 44242
Training spectrograms shape: (400, 128, 128, 1)
Validation spectrograms shape: (100, 128, 128, 1)


Epoch 1/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 549ms/step - loss: 1.2449 - mae: 0.8630 - val_loss: 0.1578 - val_mae: 0.3054
Epoch 2/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 201ms/step - loss: 0.6986 - mae: 0.6568 - val_loss: 0.1589 - val_mae: 0.3074
Epoch 3/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 197ms/step - loss: 0.4038 - mae: 0.5055 - val_loss: 0.1598 - val_mae: 0.3091
Epoch 4/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 209ms/step - loss: 0.2654 - mae: 0.4124 - val_loss: 0.1609 - val_mae: 0.3114
Epoch 5/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 202ms/step - loss: 0.1841 - mae: 0.3460 - val_loss: 0.1638 - val_mae: 0.3161
Epoch 6/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 204ms/step - loss: 0.1479 - mae: 0.3120 - val_loss: 0.1668 - val_mae: 0.3206
Epoch 7/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 193ms/step - loss: 