# Gradio

**Autores**:

José Antonio Ruiz Heredia (josrui05@ucm.es)

Néstor Marín

**Descripción**:

En este código generamos una interfaz UI para probar el modelo con nuevas muestras.

**Antes de ejecutar este archivo, asegurarse de haber instalado los prerrequisitos ejecutando el script "Prerrequisitos.py"**

Librerias

In [None]:
import torch
import torchaudio
import torchaudio.transforms as T
from torchaudio.transforms import Resample
import gradio as gr
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch.nn as nn
import timm

Parámetros

In [None]:
ESPEC = 'mel'
SAMPLE_RATE = 16000
MAX_DURATION = 8
N_FFT = 1024
PRETRAINED = True
N = 30
TIME_MASK_PARAM = 10
FREQ_MASK_PARAM = 5
HOP_LENGTH = 512

N_MELS = 224
N_MFCC = N_MELS
N_LFCC = N_MELS
RESIZE = False
IMG_SIZE = 224
NUM_CLASSES = 3
SAMPLE_RATE = 16000

num_samples = MAX_DURATION * SAMPLE_RATE
MAX_PADDING = 118989
HIDDEN_UNITS=[256,128, 64]
num_samples = MAX_DURATION * SAMPLE_RATE
DROPOUT_RATE = 0.5

Clasificación

In [None]:
# Función para añadir el preprocesado
def preprocess_audio(audio):
    WAVEFORM, SAMPLE_RATE = torchaudio.load(audio.name)

    if SAMPLE_RATE != 16000:
        resampler = Resample(orig_freq=SAMPLE_RATE, new_freq=16000)
        WAVEFORM = resampler(WAVEFORM)

    padding_needed = (MAX_PADDING) - WAVEFORM.shape[1]
    padding = max(padding_needed, 0)
    waveform_padding = torch.nn.functional.pad(WAVEFORM, (0, padding))

    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=N_FFT,
        win_length=None,
        hop_length=HOP_LENGTH,
        center=True,
        pad_mode="reflect",
        power=2.0,
        norm="slaney",
        n_mels=N_MELS,
        mel_scale="htk",
    )

    melspec = mel_spectrogram(waveform_padding)

    # Aplicar conversión a log mel
    log_mel_spectrogram = torchaudio.transforms.AmplitudeToDB(top_db=80)(melspec)

    # Normalizar el log mel espectrograma
    if log_mel_spectrogram.max() - log_mel_spectrogram.min() != 0:
        log_mel_spectrogram_norm = (log_mel_spectrogram - log_mel_spectrogram.min()) / (log_mel_spectrogram.max() - log_mel_spectrogram.min())
    else:
        log_mel_spectrogram_norm = (log_mel_spectrogram - log_mel_spectrogram.min())

    log_mel_spectrogram_norm = log_mel_spectrogram_norm[0] * 255

    # Convertir el log mel espectrograma a imagen
    log_mel_spectrogram_norm_rgb = log_mel_spectrogram_norm.repeat(3, 1, 1)
    log_mel_spectrogram_norm_rgb = log_mel_spectrogram_norm_rgb / 255


    #Convertir el espectrograma de mel a imagen
    # melspec = melspec.unsqueeze(0)
    # melspec = torch.nn.functional.interpolate(melspec, size=(224, 224))
    # melspec = melspec.repeat(1, 3, 1, 1)
    # melspec = melspec.squeeze().permute(1, 2, 0).numpy()
    # melspec = Image.fromarray((melspec * 255).astype(np.uint8))

    # Normalizar la imagen
    # melspec = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    log_mel_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    log_mel_imagenet = log_mel_normalize(log_mel_spectrogram_norm_rgb)
    return log_mel_imagenet






In [None]:


class MLPModel(torch.nn.Module):
    def __init__(self, input_size, hidden_units, dropout_rate):
        super(MLPModel, self).__init__()
        self.hidden_units = hidden_units
        self.dropout_rate = dropout_rate
        self.layers = torch.nn.ModuleList()

        # Agregar la primera capa oculta con la entrada original
        self.layers.append(torch.nn.Linear(input_size, hidden_units[0]))
        self.layers.append(torch.nn.BatchNorm1d(hidden_units[0]))
        self.layers.append(torch.nn.ReLU())
        self.layers.append(torch.nn.Dropout(dropout_rate))

        # Agregar el resto de las capas ocultas
        for i in range(len(hidden_units) - 1):
            self.layers.append(torch.nn.Linear(hidden_units[i], hidden_units[i+1]))
            self.layers.append(torch.nn.BatchNorm1d(hidden_units[i+1]))
            self.layers.append(torch.nn.ReLU())
            self.layers.append(torch.nn.Dropout(dropout_rate))
        self.layers.append(torch.nn.Linear(hidden_units[-1], NUM_CLASSES))
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class CustomConvNeXt(nn.Module):
    def __init__(self, N=0):
        super(CustomConvNeXt, self).__init__()

        # Cargar el modelo preentrenado
        self.pretrained_model  = timm.create_model('convnextv2_femto.fcmae_ft_in1k', pretrained=PRETRAINED, num_classes=0, global_pool='avg') #convnextv2_femto.fcmae_ft_in1k.fcmae_ft_in22k_in1k_384
        #timm.create_model('convnextv2_nano.fcmae_ft_in22k_in1k', pretrained=True, num_classes=0)
        self.n_layers = N
        # Congelar todas las capas primero
        for name, param in self.pretrained_model.named_parameters():
           param.requires_grad = False

        # Descongelar las últimas 20 capas que no son BatchNormalization
        unfrozen_count = 0
        for name, param in reversed(list(self.pretrained_model.named_parameters())):
            if 'bn' not in name and unfrozen_count < self.n_layers:
                param.requires_grad = True
                unfrozen_count += 1


        #self.add_vit = self.pretrained_model.num_features
        self.additional_layer = MLPModel(self.pretrained_model.num_features, hidden_units=HIDDEN_UNITS, dropout_rate= DROPOUT_RATE)
        #self.additional_layer = torch.nn.Linear(self.pretrained_model.num_features, NUM_CLASSES)
    def forward(self, x):
        x = self.pretrained_model(x)
        #x = self.avgpool(x)
        #x = torch.flatten(x, 1)
        #x = self.fc(x)
        x = self.additional_layer(x)
        return x


# Función para predecir la clase del audio

# Inicializar el modelo
model = CustomConvNeXt()
device = "cuda" if torch.cuda.is_available() else "cpu"
# Cargar el modelo
state_dict = torch.load('/home/gass/audio-forgery-detection/forgery_models/test_mel_preprocess_data_3_chans_history_torch_baseline_convnext_femto_224_custom_exp_8_mel_test_sin_resize_mel_best_model.pth',
                        map_location=torch.device('cuda'))
# Filtrar el state_dict para eliminar claves no necesarias
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict().keys()}

# Cargar el state_dict filtrado
model.load_state_dict(state_dict['model_state_dict'])
model = model.to(device)
model.eval()

def classify_audio(audio_files):
    results = []
    predict_list = []
    predict_list_max = []
    predicted_TH = []
    # processor = AutoImageProcessor.from_pretrained('facebook/convnextv2-femto-1k-224')
    with torch.no_grad():
        for audio_file in audio_files:
            image = preprocess_audio(audio_file)
            image = torch.unsqueeze(image, 0)
            # inputs = processor(images=image, return_tensors="pt")

            inputs = image.to(device)
            # Make prediction
            output = model(inputs) #feed forward
            # predict_list.extend(F.softmax(outputs, dim=1).cpu().numpy()) # apply softmax
            # _, predicted = torch.max(outputs, 1)
            # predict_list_max.extend(predicted.cpu().numpy())
            probabilities = torch.softmax(output, dim=1).squeeze()
            class_idx = torch.argmax(probabilities).item()
            classes = ["original","copy-move", "splicing" ]
            results.append(classes[class_idx])
    return results


In [None]:





# Interfaz de Gradio
with gr.Blocks(gr.themes.Default(), title="Voice Cloning Demo") as demo:
    gr.Markdown("Audio Classifier")
    with gr.Tab("Inference"):
        with gr.Column() as col1:
            upload_file = gr.File(
                file_count="multiple",
                label="Select here the audio files",
            )

            label = gr.Textbox(label="Predicted Class(es)")

            button = gr.Button("Classify")

            button.click(fn=classify_audio, inputs=upload_file, outputs=label)

demo.launch()

Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.




torch.Size([1, 3, 224, 233])
torch.Size([1, 3, 224, 233])
torch.Size([1, 3, 224, 233])
torch.Size([1, 3, 224, 233])
torch.Size([1, 3, 224, 233])
torch.Size([1, 3, 224, 233])
torch.Size([1, 3, 224, 233])
torch.Size([1, 3, 224, 233])
