In [None]:
'''
requirments.txt

tensorflow==2.14
tensorflow-addons
tf2onnx
onnxruntime
scikit-image
'''

# Deep Voice Classifier using Vision Transformer

## Flow
1. Load wav Data
2. Audio S
2. Convert to Mel Spectogram
3. Set Hyperparameter
4. Define Vision Transformer
5. Model Training & Evaluate
6. Export Model to ONNX

## Dataset
- Train
    - REAL
    - FAKE
- Valid
    - REAL
    - FAKE
- Test
    - REAL
    - FAKE

## Import Libraries

In [None]:
## Common Libraries
import os
# os.chdir('') # Current Directory
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import soundfile as sf
import warnings
warnings.filterwarnings('ignore')

from tqdm.notebook import tqdm
from scipy.signal import butter, lfilter
from skimage.transform import resize
from sklearn.metrics import *

## Torch Audio Libraries
import torch
import torchaudio
import torchaudio.transforms as transforms
from torchaudio.transforms import MelSpectrogram

## Audio Splitter
from lib import dataset, nets, spec_utils, utils

## Vision Transformer Libraries
import tensorflow as tf
import tensorflow_addons as tfa
import keras

## Model Export
import tf2onnx

## Set seeds for reproducibility
def set_seed(seed=530):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    random.seed(seed)

set_seed()

## Set Data Path

In [None]:
train_path = './data_wav/Train'
valid_path = './data_wav/Valid'
test_path = './data_wav/Test'

## Data Preprocessing

In [None]:
# Audio Splitter Class
class Separator(object):
    def __init__(self, model, device=None, batchsize=1, cropsize=256):
        self.model = model
        self.offset = model.offset
        self.device = device
        self.batchsize = batchsize
        self.cropsize = cropsize

    def _postprocess(self, X_spec, mask_y, mask_v):
        X_mag = np.abs(X_spec)
        X_phase = np.angle(X_spec)
        y_spec = X_mag * mask_y * np.exp(1.j * X_phase)
        v_spec = X_mag * mask_v * np.exp(1.j * X_phase)
        return y_spec, v_spec

    def _separate(self, X_spec_pad, roi_size):
        X_dataset = []
        patches = (X_spec_pad.shape[2] - 2 * self.offset) // roi_size
        for i in range(patches):
            start = i * roi_size
            X_spec_crop = X_spec_pad[:, :, start:start + self.cropsize]
            X_dataset.append(X_spec_crop)

        X_dataset = np.asarray(X_dataset)

        self.model.eval()
        with torch.no_grad():
            mask_y_list = []
            mask_v_list = []
            for i in range(0, patches, self.batchsize):
                X_batch = X_dataset[i: i + self.batchsize]
                X_batch = torch.from_numpy(X_batch).to(self.device)

                mask_y, mask_v = self.model.predict_mask(torch.abs(X_batch))

                mask_y = mask_y.detach().cpu().numpy()
                mask_y = np.concatenate(mask_y, axis=2)
                mask_y_list.append(mask_y)

                mask_v = mask_v.detach().cpu().numpy()
                mask_v = np.concatenate(mask_v, axis=2)
                mask_v_list.append(mask_v)

            mask_y = np.concatenate(mask_y_list, axis=2)
            mask_v = np.concatenate(mask_v_list, axis=2)

        return mask_y, mask_v

    def separate(self, X_spec):
        n_frame = X_spec.shape[2]
        pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset)
        X_spec_pad = np.pad(X_spec, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
        X_spec_pad /= np.abs(X_spec).max()

        mask_y, mask_v = self._separate(X_spec_pad, roi_size)
        mask_y = mask_y[:, :, :n_frame]
        mask_v = mask_v[:, :, :n_frame]

        y_spec, v_spec = self._postprocess(X_spec, mask_y, mask_v)

        return y_spec, v_spec

# Audio Splitter Pre-trained model path
SPLITTER_MODEL_PATH = './models/baseline.pth'
if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

# audio split from raw wave file
def audio_splitter(waveform, sr):
    n_fft = 2048
    hop_length = 1024
    batchsize = 4
    cropsize = 256

    model = nets.CascadedNet(n_fft, hop_length, 32, 128)
    model.load_state_dict(torch.load(SPLITTER_MODEL_PATH, map_location='cpu'))
    model.to(device)

    if waveform.ndim == 1:
        # mono to stereo
        waveform = waveform.unsqueeze(0).repeat(2, 1)
    elif waveform.ndim == 2 and waveform.size(0) == 1:
        # mono to stereo
        waveform = waveform.repeat(2, 1)

    X_spec = spec_utils.wave_to_spectrogram(waveform.numpy(), hop_length, n_fft)
    sp = Separator(
        model=model,
        device=device,
        batchsize=batchsize,
        cropsize=cropsize
    )

    _, v_spec = sp.separate(X_spec)
    wave = spec_utils.spectrogram_to_wave(v_spec, hop_length=hop_length)

    # Ensure wave is 2D
    if wave.ndim == 1:
        wave = np.expand_dims(wave, axis=0)

    return torch.tensor(wave)

# Conver Mel-spectrogram & Resize data
def waveform_to_mel_spectrogram(waveform, sr):
    mel_transform = MelSpectrogram(
        sample_rate=sr,
        n_fft=2048,
        hop_length=512,
        n_mels=128
    )
    # Conver Mono
    if waveform.size(0) > 1:
        waveform = waveform.mean(dim=0).unsqueeze(0)
    mel_spectrogram = mel_transform(waveform).squeeze().numpy()

    # 2D 배열로 변환
    if mel_spectrogram.ndim == 3:
        mel_spectrogram = mel_spectrogram[0]

    # 128x128 크기로 리사이즈
    mel_spectrogram = resize(mel_spectrogram, (128, 128))

    # # 정규화
    # if mel_spectrogram.max() != mel_spectrogram.min():  # max와 min이 같지 않을 경우에만 정규화
    #     mel_spectrogram = (mel_spectrogram - mel_spectrogram.min()) / (mel_spectrogram.max() - mel_spectrogram.min())
    #     mel_spectrogram = mel_spectrogram * 255  # 0과 255 사이로 정규화

    # (128, 128, 1)로 차원 확장
    mel_spectrogram = np.expand_dims(mel_spectrogram, axis=-1)

    return mel_spectrogram

# Audio Data preprocessor
def data_preprocessor(dataset_path):
    features = []
    labels = []
    for cls in os.listdir(dataset_path):
        cls_path = os.path.join(dataset_path, cls)
        if cls == cls.startswith("."):
            continue

        for file in tqdm(os.listdir(cls_path)):
            file_path = os.path.join(cls_path, file)
            waveform, sr = torchaudio.load(file_path)

            # Audio Splitter 사용하여 Vocal만 남기기
            vocal_waveform = audio_splitter(waveform, sr)

            if vocal_waveform.dim() == 2 and vocal_waveform.size(0) == 1:
                vocal_waveform = vocal_waveform.repeat(2, 1)
            elif vocal_waveform.dim() == 1:
                vocal_waveform = vocal_waveform.unsqueeze(0).repeat(2, 1)

            # 1초 미만일 경우 0으로 채워서 1초로 만듦
            if vocal_waveform.size(1) < sr:
                vocal_waveform = torch.nn.functional.pad(vocal_waveform, (0, sr - vocal_waveform.size(1)), 'constant')

            num_chunks = int(np.ceil(vocal_waveform.size(1) / sr))

            for i in range(num_chunks):
                start_sample = i * sr
                end_sample = start_sample + sr
                chunk = vocal_waveform[:, start_sample:end_sample]

                # 1초 미만일 경우 처리
                if chunk.size(1) < sr:
                    if i == 0:
                        chunk = torch.nn.functional.pad(chunk, (0, sr - chunk.size(1)), 'constant')
                    else:
                        previous_chunk = vocal_waveform[:, start_sample - sr:start_sample]
                        chunk = torch.cat((previous_chunk[:, -(sr - chunk.size(1)):], chunk), dim=1)

                mel_spectrogram = waveform_to_mel_spectrogram(chunk, sr)
                features.append(mel_spectrogram)
                # Assign label
                label = 1 if cls == 'FAKE' else 0
                labels.append(label)

    return features, labels

X_train, y_train = data_preprocessor(train_path)
X_valid, y_valid = data_preprocessor(valid_path)
X_test, y_test = data_preprocessor(test_path)

X_train = np.array(X_train)
y_train = np.array(y_train)
X_valid = np.array(X_valid)
y_valid = np.array(y_valid)
X_test = np.array(X_test)
y_test = np.array(y_test)

print(X_train.shape)
print(X_valid.shape)
print(X_test.shape)

# 각 라벨 0과 1의 갯수 확인
train_0_count = np.sum(y_train == 0)
train_1_count = np.sum(y_train == 1)
valid_0_count = np.sum(y_valid == 0)
valid_1_count = np.sum(y_valid == 1)
test_0_count = np.sum(y_test == 0)
test_1_count = np.sum(y_test == 1)

print(f"Training set: 0s = {train_0_count}, 1s = {train_1_count}")
print(f"Validation set: 0s = {valid_0_count}, 1s = {valid_1_count}")
print(f"Test set: 0s = {test_0_count}, 1s = {test_1_count}")

In [None]:
# 랜덤 샘플 선택 및 시각화 함수
def plot_samples(X, y, label, num_samples=5):
    # 주어진 라벨에 해당하는 인덱스 찾기
    label_indices = [i for i, lbl in enumerate(y) if lbl == label]
    # 라벨에 해당하는 인덱스 중에서 num_samples 만큼 랜덤하게 선택
    sample_indices = random.sample(label_indices, num_samples)

    fig, axes = plt.subplots(1, num_samples, figsize=(15, 3))
    fig.suptitle(f'Samples of Label {label}', fontsize=16)
    for i, idx in enumerate(sample_indices):
        axes[i].imshow(X[idx].squeeze(), aspect='auto', origin='lower', cmap='viridis')
        axes[i].axis('off')
    plt.show()

# 0과 1 라벨에 대한 샘플 시각화
plot_samples(X_train, y_train, label=0, num_samples=5)
plot_samples(X_train, y_train, label=1, num_samples=5)

## Set Hyperparameter

In [None]:
num_classes = 2
input_shape = (X_train.shape[1], X_train.shape[2], X_train.shape[3])

learning_rate = 0.0001
weight_decay = 0.001
batch_size = 32
num_epochs = 200
image_size = 128  # We'll resize input images to this size
patch_size = 16  # Size of the patches to be extracted from the input images
num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
projection_dim = 128
num_heads = 32
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [64, 64]  # Size of the dense layers of the final classifier

## Define Vision Transformer

In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = keras.layers.Dense(units, activation=tf.nn.gelu, kernel_regularizer=keras.regularizers.l2(0.01))(x)
        x = keras.layers.Dropout(dropout_rate)(x)
    return x

In [None]:
class Patches(keras.layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size
        self.conv = keras.layers.Conv2D(filters=patch_size * patch_size,
                                        kernel_size=patch_size,
                                        strides=patch_size,
                                        padding='valid')

    def call(self, images):
        patches = self.conv(images)
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [tf.shape(images)[0], -1, patch_dims])
        return patches

In [None]:
class PatchEncoder(keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = keras.layers.Dense(units=projection_dim)
        self.position_embedding = keras.layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

In [None]:
# Build ViT Model
def create_vit_classifier():
    inputs = keras.layers.Input(shape=input_shape)
    # Create patches.
    patches = Patches(patch_size)(inputs)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.3
        )(x1, x1)
        # Skip connection 1.
        x2 = keras.layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = keras.layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.3)
        # Skip connection 2.
        encoded_patches = keras.layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = keras.layers.Flatten()(representation)
    representation = keras.layers.Dropout(0.2)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.2)
    # Classify outputs.
    logits = keras.layers.Dense(1, activation='sigmoid')(features)  # Output layer for binary classification
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

## Model Training

In [None]:
# Compile, Train, Evaluate the Model
def run_experiment(model):
    lr_schedule = keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=learning_rate,
        decay_steps=10000,
        decay_rate=0.96,
        staircase=True
    )

    optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)

    model.compile(
        optimizer=optimizer,
        loss='binary_crossentropy',
        metrics=[
            'accuracy'
        ],
    )

    checkpoint_filepath = "./tmp/best_model.keras"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_loss",
        patience=30,
        save_best_only=True
    )

    early_stopping_callback = keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=30,
        restore_best_weights=True,
    )

    history = model.fit(
        x=X_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_data=(X_valid, y_valid),
        callbacks=[early_stopping_callback], # checkpoint_callback
    )

    # model.load_weights(checkpoint_filepath)
    best_model = keras.models.load_model(checkpoint_filepath)
    _, accuracy = best_model.evaluate(X_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")

    # Predict on the test set
    y_pred = (model.predict(X_test) > 0.8).astype("int32")

    # Generate classification report
    report = classification_report(y_test, y_pred, target_names=["REAL", "FAKE"])
    print("Classification Report:\n", report)

    return history, best_model

In [None]:
vit_classifier = create_vit_classifier()
history, model = run_experiment(vit_classifier)

### Evaluate Model - test dataset

In [None]:
# best_model = keras.models.load_model("./tmp/best_model.h5")
_, accuracy = model.evaluate(X_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")

# Predict on the test set
y_pred = (model.predict(X_test) > 0.8).astype("int32")

# Generate classification report
report = classification_report(y_test, y_pred, target_names=["REAL", "FAKE"])
print("Classification Report:\n", report)

## Export Model to ONNX

In [None]:
# Function to save the model to ONNX
def save_model_to_onnx(model, input_shape, output_path, opset_version=13):
    spec = (tf.TensorSpec((None, *input_shape), tf.float32, name="input"),)
    output_model_path = f"{output_path}.onnx"
    model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=opset_version)
    with open(output_model_path, "wb") as f:
        f.write(model_proto.SerializeToString())
    print(f"Model saved to {output_model_path}")

# Save the trained model to ONNX
save_model_to_onnx(model, input_shape, "Mel_1s_vit", opset_version=13)