In [None]:
# Mount Google Drive to access the files (if they are stored there)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import numpy as np
import librosa
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, Dropout,
    BatchNormalization, Cropping2D, ZeroPadding2D
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import matplotlib.pyplot as plt


preprocessed_data_folder = '/content/drive/MyDrive/Unet_practice/preprocessed_data1024'
clean_preprocessed_folder = os.path.join(preprocessed_data_folder, 'clean')
noisy_preprocessed_folder = os.path.join(preprocessed_data_folder, 'noisy')

X_train_max = 10.0
Y_train_max = 10.0


n_fft = 1024
hop_length = 256
win_length = 1024
max_duration = 4  # 秒
sr = 16000
max_length = sr * max_duration  # 64000

fixed_height = n_fft // 2 + 1  # 1024//2 +1 = 513
fixed_width = (max_length - win_length) // hop_length + 1  # (64000 - 1024)//256 +1 = 247

print(f"Fixed Height (Frequency Dimension): {fixed_height}")  #  513
print(f"Fixed Width (Time Steps): {fixed_width}")  # 247



def compute_log_stft(audio_path, sr=16000, n_fft=1024, hop_length=256, win_length=1024, max_duration=4):
    """
    Computes the log-magnitude STFT of an audio file.

    Parameters:
        audio_path (str): Path to the audio file.
        sr (int): Sampling rate for loading the audio.
        n_fft (int): Number of FFT components.
        hop_length (int): Hop length for STFT.
        win_length (int): Window length for STFT.
        max_duration (int): Maximum duration of the audio in seconds.

    Returns:
        log_magnitude (numpy.ndarray): Log-magnitude of the STFT.
    """
    # Load the audio file
    audio, sr = librosa.load(audio_path, sr=sr, mono=True)

    # Ensure the audio length is exactly max_duration seconds (truncate or pad with zeros)
    max_length = sr * max_duration
    if len(audio) > max_length:
        audio = audio[:max_length]  # Truncate to max_duration
    else:
        audio = np.pad(audio, (0, max_length - len(audio)), mode='constant')  # Pad with zeros

    # Compute STFT
    stft_result = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length)

    # Extract magnitude and compute log-magnitude
    magnitude = np.abs(stft_result)
    log_magnitude = np.log1p(magnitude)

    return log_magnitude


clean_audio_folder = '/content/drive/MyDrive/Unet_practice/audio/clean_train'   
noisy_audio_folder = '/content/drive/MyDrive/Unet_practice/audio/noisy_train'   


os.makedirs(clean_preprocessed_folder, exist_ok=True)
os.makedirs(noisy_preprocessed_folder, exist_ok=True)

clean_audio_files = sorted([f for f in os.listdir(clean_audio_folder) if f.endswith('.mp3')])
noisy_audio_files = sorted([f for f in os.listdir(noisy_audio_folder) if f.endswith('.mp3')])

print(f"Clean audio files: {len(clean_audio_files)}")
print(f"Noisy audio files: {len(noisy_audio_files)}")



for idx, (clean_file) in enumerate(clean_audio_files):
    clean_audio_path = os.path.join(clean_audio_folder, clean_file)

    #clean path 
    clean_npy_filename = f"{os.path.splitext(clean_file)[0]}.npy"
    clean_npy_path = os.path.join(clean_preprocessed_folder, clean_npy_filename)

    # 檢查clean files
    if os.path.exists(clean_npy_path):
        print(f"Clean file {idx+1}/{len(clean_audio_files)} 已經有了，跳過處理。")
        continue 

    try:
        # STFT
        clean_log_magnitude = compute_log_stft(clean_audio_path, sr=sr, n_fft=n_fft, hop_length=hop_length, win_length=win_length, max_duration=max_duration)

        # 檢查shape
        if clean_log_magnitude.shape != (fixed_height, fixed_width):
            print(f"调整 clean_log_magnitude 的shape {clean_log_magnitude.shape} 到 ({fixed_height}, {fixed_width})")
            clean_log_magnitude = librosa.util.fix_length(clean_log_magnitude, size=fixed_width, axis=1)

        # save  .npy 
        np.save(clean_npy_path, clean_log_magnitude)

        print(f"Clean file {idx+1}/{len(clean_audio_files)} 處理完且已經保存了")

    except Exception as e:
        print(f"處理 {clean_file} 時出錯：{e}")

print("所有clean files 轉乘STFT了")

# same to noisy files
for idx, (noisy_file) in enumerate(noisy_audio_files):
    noisy_audio_path = os.path.join(noisy_audio_folder, noisy_file)

    noisy_npy_filename = f"{os.path.splitext(noisy_file)[0]}.npy"
    noisy_npy_path = os.path.join(noisy_preprocessed_folder, noisy_npy_filename)

    if os.path.exists(noisy_npy_path):
        print(f"Noisy file {idx+1}/{len(noisy_audio_files)} 已經有了，跳過處理。")
        continue  
    try:
        noisy_log_magnitude = compute_log_stft(noisy_audio_path, sr=sr, n_fft=n_fft, hop_length=hop_length, win_length=win_length, max_duration=max_duration)

        if noisy_log_magnitude.shape != (fixed_height, fixed_width):
            print(f"调整 noisy_log_magnitude 的形状从 {noisy_log_magnitude.shape} 到 ({fixed_height}, {fixed_width})")
            noisy_log_magnitude = librosa.util.fix_length(noisy_log_magnitude, size=fixed_width, axis=1)

        np.save(noisy_npy_path, noisy_log_magnitude)

        print(f"Noisy file {idx+1}/{len(noisy_audio_files)}  處理完且已經保存了")

    except Exception as e:
        print(f"處理 {noisy_file} 時出錯：{e}")

print("所有noisy files 轉乘STFT了")


clean_npy_files = sorted([f for f in os.listdir(clean_preprocessed_folder) if f.endswith('.npy')])
noisy_npy_files = sorted([f for f in os.listdir(noisy_preprocessed_folder) if f.endswith('.npy')])

print(f"Clean preprocessed files: {len(clean_npy_files)}")
print(f"Noisy preprocessed files: {len(noisy_npy_files)}")


clean_mapping = {}
for f in clean_npy_files:
# 提取數字
    base_name = os.path.splitext(f)[0]  
    parts = base_name.split('_')
    identifier = ''
    for part in parts:
        if part.isdigit():
            identifier = part
            break
        elif any(char.isdigit() for char in part):
            identifier = ''.join(filter(str.isdigit, part))
            break
    if identifier:
        clean_mapping[identifier] = os.path.join(clean_preprocessed_folder, f)
    else:
        print(f"沒法提取{f}")

print(f"clean files num: {len(clean_mapping)}")

# match clean and noisy 
paired_noisy_paths = []
paired_clean_paths = []

missing_clean_files = []

for f in noisy_npy_files:

    base_name = os.path.splitext(f)[0]
    identifier = ''
    parts = base_name.split('_')
    if parts:
        identifier = ''.join(filter(str.isdigit, parts[0]))
    else:
        print(f"沒法提取 {f}")
        continue

    if identifier in clean_mapping:
        paired_noisy_paths.append(os.path.join(noisy_preprocessed_folder, f))
        paired_clean_paths.append(clean_mapping[identifier])
    else:
        missing_clean_files.append(f)

#  check missing 
if missing_clean_files:
    print("these files didnt match ")
    for f in missing_clean_files:
        print(f)
else:
    print("ALL MATCHED!")

print(f"有效match num: {len(paired_noisy_paths)}")

# check the number of clean and noisy if they are same?
assert len(paired_noisy_paths) == len(paired_clean_paths), "clean 的數量和noisy 不一樣多"

# check
print(" first five files for checking out!")
for i in range(min(5, len(paired_noisy_paths))):
    noisy_file = os.path.basename(paired_noisy_paths[i])
    clean_file = os.path.basename(paired_clean_paths[i])
    print(f"noisy: {noisy_file} <--> clean: {clean_file}")

def load_npy(noisy_path, clean_path):
    try:
        
        noisy_path = noisy_path.numpy().decode('utf-8')
        clean_path = clean_path.numpy().decode('utf-8')

       
        noisy = np.load(noisy_path)
        clean = np.load(clean_path)

        # check shape 
        if noisy.shape != (fixed_height, fixed_width):
            print(f"adjust noisy's shape from  {noisy.shape} to ({fixed_height}, {fixed_width})")
            noisy = librosa.util.fix_length(noisy, size=fixed_width, axis=1)
        if clean.shape != (fixed_height, fixed_width):
            print(f"adjust clean's shape from {clean.shape} to ({fixed_height}, {fixed_width})")
            clean = librosa.util.fix_length(clean, size=fixed_width, axis=1)

        # add channel
        noisy = noisy.astype(np.float32)[..., np.newaxis]  # (513, 247, 1)
        clean = clean.astype(np.float32)[..., np.newaxis]  # (513, 247, 1)

        # regular
        noisy /= X_train_max
        clean /= Y_train_max

        return noisy, clean

    except Exception as e:
        print(f"load error: {noisy_path} 或 {clean_path} -> {e}")
        #return all zeros
        noisy = np.zeros((fixed_height, fixed_width, 1), dtype=np.float32)
        clean = np.zeros((fixed_height, fixed_width, 1), dtype=np.float32)
        return noisy, clean

# load
def tf_load_npy(noisy_path, clean_path):
    noisy, clean = tf.py_function(
        func=load_npy,
        inp=[noisy_path, clean_path],
        Tout=[tf.float32, tf.float32]
    )
    # set shape 
    noisy.set_shape([fixed_height, fixed_width, 1])
    clean.set_shape([fixed_height, fixed_width, 1])
    return noisy, clean

#
dataset = tf.data.Dataset.from_tensor_slices((paired_noisy_paths, paired_clean_paths))
dataset = dataset.map(tf_load_npy, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(64) 
dataset = dataset.prefetch(tf.data.AUTOTUNE)


split = int(0.9 * len(paired_noisy_paths))
train_dataset = dataset.take(split)
val_dataset = dataset.skip(split)

print(f"train num: {split}")
print(f"valid num: {len(paired_noisy_paths) - split}")


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, BatchNormalization, Dropout, MaxPooling2D,
    Conv2DTranspose, ZeroPadding2D, Cropping2D, concatenate
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard

# ==========================
# 辅助函数
# ==========================
def adjust_shape(tensor_a, tensor_b):
    """
    Adjusts tensor_b to match the shape of tensor_a by cropping or padding.
    """
    shape_a = tensor_a.shape
    shape_b = tensor_b.shape

    height_diff = shape_a[1] - shape_b[1]
    width_diff = shape_a[2] - shape_b[2]

    if None in [shape_a[1], shape_a[2], shape_b[1], shape_b[2]]:
        raise ValueError("One of the tensor dimensions is undefined. Please provide input_shape with fixed dimensions.")

    if height_diff > 0:
        tensor_b = ZeroPadding2D(padding=((0, height_diff), (0, 0)))(tensor_b)
    elif height_diff < 0:
        tensor_b = Cropping2D(cropping=((0, -height_diff), (0, 0)))(tensor_b)

    if width_diff > 0:
        tensor_b = ZeroPadding2D(padding=((0, 0), (0, width_diff)))(tensor_b)
    elif width_diff < 0:
        tensor_b = Cropping2D(cropping=((0, 0), (0, -width_diff)))(tensor_b)

    return tensor_b

def unet(input_shape=(513, 247, 1)):
    inputs = Input(shape=input_shape)

    # Encoder
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Dropout(0.1)(conv1)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2), padding='same')(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Dropout(0.1)(conv2)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2), padding='same')(conv2)

    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Dropout(0.2)(conv3)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2), padding='same')(conv3)

    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Dropout(0.2)(conv4)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2), padding='same')(conv4)

    # Bottleneck
    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Dropout(0.3)(conv5)
    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)

    # Decoder
    up6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(conv5)
    up6 = adjust_shape(conv4, up6)
    merge6 = concatenate([conv4, up6], axis=-1)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(merge6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Dropout(0.2)(conv6)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)

    up7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6)
    up7 = adjust_shape(conv3, up7)
    merge7 = concatenate([conv3, up7], axis=-1)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(merge7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Dropout(0.2)(conv7)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)

    up8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7)
    up8 = adjust_shape(conv2, up8)
    merge8 = concatenate([conv2, up8], axis=-1)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(merge8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Dropout(0.1)(conv8)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)

    up9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    up9 = adjust_shape(conv1, up9)
    merge9 = concatenate([conv1, up9], axis=-1)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(merge9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Dropout(0.1)(conv9)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

def combined_loss(y_true, y_pred):
    mse = tf.reduce_mean(tf.square(y_true - y_pred))
    mae = tf.reduce_mean(tf.abs(y_true - y_pred))
    return mse + 0.5 * mae


for noisy_batch, clean_batch in train_dataset.take(1):
    print(f"Noisy batch shape: {noisy_batch.shape}")
    print(f"Clean batch shape: {clean_batch.shape}")

for noisy_val_batch, clean_val_batch in val_dataset.take(1):
    print(f"Noisy validation batch shape: {noisy_val_batch.shape}")
    print(f"Clean validation batch shape: {clean_val_batch.shape}")


input_shape = (513, 247, 1)
model = unet(input_shape=input_shape)

optimizer = Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss=combined_loss, metrics=['mae'])


log_dir = "/content/drive/MyDrive/Unet_practice/logs"
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6, verbose=1)
checkpoint_callback = ModelCheckpoint(filepath='/content/drive/MyDrive/Unet_practice/20epochtest_model.keras',
                                      save_weights_only=False, save_best_only=True,
                                      monitor='val_loss', mode='min', verbose=1)
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)


history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=1,
    callbacks=[early_stopping, reduce_lr, checkpoint_callback, tensorboard_callback]
)

print("complete!!!!!!!!!!yeahhhhhhhhh!")


In [None]:
# save model
model.save('/content/drive/MyDrive/Unet_practice/final_model.h5')
print("already save the model to  /content/drive/MyDrive/Unet_practice/final_model.h5")
