In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Concatenate, UpSampling2D, BatchNormalization, Activation
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback
from tensorflow.keras.utils import Sequence
from tensorflow.keras.optimizers import Adam
import h5py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import librosa
import librosa.display
import os
import soundfile as sf

# Kiểm tra GPU
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# --- CẤU HÌNH ---
# Đường dẫn file dữ liệu train (đã tạo từ bước preprocessing)
DATA_PATH = '../media_files/preprocessed_audio/data_2d.h5'

# Đường dẫn file nhạc để đánh giá (visualize) sau mỗi epoch
# BẠN CẦN THAY ĐỔI ĐƯỜNG DẪN NÀY TỚI 1 FILE NHẠC THỰC TẾ
EVALUATION_PATH = '../media_files/test_audio/test_song.wav' 

# Thư mục lưu ảnh visualize
OUTPUT_VISUAL_DIR = 'training_visualizations'
os.makedirs(OUTPUT_VISUAL_DIR, exist_ok=True)

# Tham số âm thanh (Phải khớp với file preprocessing)
BATCH_SIZE = 16
SR = 44100
WINDOW_SIZE = 128
N_FFT = 2048
HOP_LENGTH = 512

In [None]:
class AudioH5Generator(Sequence):
    def __init__(self, h5_path, list_IDs, batch_size=16, shuffle=True):
        self.h5_path = h5_path
        self.list_IDs = list_IDs
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        X, y = self.__data_generation(list_IDs_temp)
        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        with h5py.File(self.h5_path, 'r') as f:
            X_dset = f['X_train']
            Y_dset = f['Y_train']
            # Dùng list comprehension để lấy data an toàn
            X_batch = [X_dset[int(ID)] for ID in list_IDs_temp]
            y_batch = [Y_dset[int(ID)] for ID in list_IDs_temp]
        return np.array(X_batch), np.array(y_batch)

In [None]:
# 1. Mở file để xem tổng số mẫu
with h5py.File(DATA_PATH, 'r') as f:
    total_samples = f['X_train'].shape[0]
    input_shape = f['X_train'].shape[1:] # (1024, 128, 1)
    print(f"Tổng số mẫu dữ liệu: {total_samples}")
    print(f"Kích thước đầu vào: {input_shape}")

# 2. Chia Train/Val (80% Train, 20% Val)
all_indices = np.arange(total_samples)
train_ids, val_ids = train_test_split(all_indices, test_size=0.2, random_state=42)

print(f"Số mẫu Train: {len(train_ids)}")
print(f"Số mẫu Val: {len(val_ids)}")

# 3. Khởi tạo Generators
training_generator = AudioH5Generator(DATA_PATH, train_ids, batch_size=BATCH_SIZE)
validation_generator = AudioH5Generator(DATA_PATH, val_ids, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def build_unet(input_shape):
    inputs = Input(input_shape)

    # --- ENCODER ---
    c1 = Conv2D(16, (3, 3), padding='same', activation='relu')(inputs)
    c1 = BatchNormalization()(c1)
    c1 = Conv2D(16, (3, 3), padding='same', activation='relu')(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(32, (3, 3), padding='same', activation='relu')(p1)
    c2 = BatchNormalization()(c2)
    c2 = Conv2D(32, (3, 3), padding='same', activation='relu')(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(64, (3, 3), padding='same', activation='relu')(p2)
    c3 = BatchNormalization()(c3)
    c3 = Conv2D(64, (3, 3), padding='same', activation='relu')(c3)
    p3 = MaxPooling2D((2, 2))(c3)
    
    c4 = Conv2D(128, (3, 3), padding='same', activation='relu')(p3)
    c4 = BatchNormalization()(c4)
    c4 = Conv2D(128, (3, 3), padding='same', activation='relu')(c4)
    p4 = MaxPooling2D((2, 2))(c4)

    # --- BOTTLENECK ---
    b = Conv2D(256, (3, 3), padding='same', activation='relu')(p4)
    b = BatchNormalization()(b)
    b = Conv2D(256, (3, 3), padding='same', activation='relu')(b)

    # --- DECODER ---
    u1 = UpSampling2D((2, 2))(b)
    u1 = Concatenate()([u1, c4])
    c5 = Conv2D(128, (3, 3), padding='same', activation='relu')(u1)
    c5 = BatchNormalization()(c5)
    
    u2 = UpSampling2D((2, 2))(c5)
    u2 = Concatenate()([u2, c3])
    c6 = Conv2D(64, (3, 3), padding='same', activation='relu')(u2)
    c6 = BatchNormalization()(c6)

    u3 = UpSampling2D((2, 2))(c6)
    u3 = Concatenate()([u3, c2])
    c7 = Conv2D(32, (3, 3), padding='same', activation='relu')(u3)
    c7 = BatchNormalization()(c7)
    
    u4 = UpSampling2D((2, 2))(c7)
    u4 = Concatenate()([u4, c1])
    c8 = Conv2D(16, (3, 3), padding='same', activation='relu')(u4)
    c8 = BatchNormalization()(c8)

    # Output (Sigmoid cho mask 0-1)
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c8)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

In [None]:
def preprocess_evaluation_audio(file_path):
    """Đọc file nhạc và chuẩn bị dữ liệu đầu vào cho model"""
    if not os.path.exists(file_path):
        print(f"Không tìm thấy file: {file_path}")
        return None, None, None
        
    try:
        y, _ = librosa.load(file_path, sr=SR, mono=True)
    except Exception as e:
        print(f"Lỗi đọc file evaluation: {e}")
        return None, None, None

    # STFT & Magnitude
    stft = librosa.stft(y, n_fft=N_FFT, hop_length=HOP_LENGTH)
    mag, phase = librosa.magphase(stft)
    
    # Power Law Compression (Căn bậc 2 - Khớp với lúc train)
    mag_compressed = np.sqrt(mag)
    
    # Cắt 1025 -> 1024 bins
    full_mag = mag_compressed[:1024, :]
    
    # Chuẩn hóa
    max_val = np.max(full_mag)
    if max_val == 0: max_val = 1
    full_mag_norm = full_mag / max_val
    
    # Chunking (Cắt thành các miếng nhỏ)
    num_frames = full_mag_norm.shape[1]
    pad_width = WINDOW_SIZE - (num_frames % WINDOW_SIZE)
    if pad_width < WINDOW_SIZE:
        full_mag_norm = np.pad(full_mag_norm, ((0,0), (0, pad_width)), mode='constant')
        
    chunks = []
    num_chunks = full_mag_norm.shape[1] // WINDOW_SIZE
    
    for i in range(num_chunks):
        start = i * WINDOW_SIZE
        end = start + WINDOW_SIZE
        chunk = full_mag_norm[:, start:end]
        chunks.append(chunk[..., np.newaxis])
        
    return np.array(chunks), phase, max_val

def reconstruct_from_mask(predicted_masks, phase, original_length, max_val):
    """Ghép các mask dự đoán lại"""
    # 1. Ghép chunks
    vocal_mask = np.concatenate(predicted_masks, axis=1) # (1024, Time, 1)
    vocal_mask = vocal_mask.squeeze() # (1024, Time)
    
    # 2. Cắt phần padding thừa
    vocal_mask = vocal_mask[:, :original_length]
    
    # 3. Bù lại dòng tần số 1025 (Pad 0 vào cuối)
    vocal_mask = np.pad(vocal_mask, ((0,1), (0,0)), mode='constant')
    
    return vocal_mask

In [None]:
class EpochVisualizer(Callback):
    def __init__(self, audio_path, output_dir):
        super(EpochVisualizer, self).__init__()
        self.audio_path = audio_path
        self.output_dir = output_dir
        
    def on_epoch_end(self, epoch, logs=None):
        # Visualize sau mỗi epoch
        print(f"\nĐang visualize kết quả epoch {epoch+1}...")
        
        # 1. Preprocess
        chunks, phase, max_val = preprocess_evaluation_audio(self.audio_path)
        if chunks is None: return

        # 2. Predict
        predicted_masks = self.model.predict(chunks, batch_size=16, verbose=0)
        
        # 3. Reconstruct Mask
        original_length = phase.shape[1]
        mask = reconstruct_from_mask(predicted_masks, phase, original_length, max_val)
        
        # 4. Visualize bằng Librosa
        plt.figure(figsize=(12, 6))
        
        # Vẽ Mask dự đoán
        plt.subplot(1, 2, 1)
        librosa.display.specshow(mask, sr=SR, x_axis='time', y_axis='hz', cmap='magma')
        plt.title(f"Predicted Vocal Mask - Epoch {epoch+1}")
        plt.colorbar(format='%+2.0f')
        
        # Vẽ Spectrogram kết quả (Vocal đã tách)
        # Lấy lại Magnitude gốc để hiển thị
        y, _ = librosa.load(self.audio_path, sr=SR)
        S_full, _ = librosa.magphase(librosa.stft(y, n_fft=N_FFT, hop_length=HOP_LENGTH))
        S_vocal = S_full * mask # Áp dụng mask
        S_db = librosa.amplitude_to_db(S_vocal, ref=np.max)
        
        plt.subplot(1, 2, 2)
        librosa.display.specshow(S_db, sr=SR, x_axis='time', y_axis='log', cmap='magma')
        plt.title(f"Separated Vocal Spectrogram - Epoch {epoch+1}")
        plt.colorbar(format='%+2.0f dB')
        
        plt.tight_layout()
        
        # 5. Lưu ảnh
        save_path = os.path.join(self.output_dir, f"epoch_{epoch+1:03d}.png")
        plt.savefig(save_path)
        plt.close() # Đóng plot
        print(f"Đã lưu ảnh visualize tại: {save_path}")