# Gradio

**Autores**: 

José Antonio Ruiz Heredia (josrui05@ucm.es) 

Néstor Marín 

**Descripción**: 

En este código generamos una interfaz UI para probar el modelo con nuevas muestras.

**Antes de ejecutar este archivo, asegurarse de haber instalado los prerrequisitos ejecutando el script "Prerrequisitos.py"**

Librerias

In [126]:
import torch
import torchaudio
import torchaudio.transforms as T
from torchaudio.transforms import Resample
import gradio as gr
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image

Parámetros

In [127]:
ESPEC = 'mel'
SAMPLE_RATE = 16000
MAX_DURATION = 8
N_FFT = 1024
PRETRAINED = True
N = 30
TIME_MASK_PARAM = 10
FREQ_MASK_PARAM = 5
HOP_LENGTH = 512
        
N_MELS = 224 
N_MFCC = N_MELS
N_LFCC = N_MELS
RESIZE = False
IMG_SIZE = 224 
NUM_CLASSES = 3
SAMPLE_RATE = 16000

num_samples = MAX_DURATION * SAMPLE_RATE
MAX_PADDING = 118989

num_samples = MAX_DURATION * SAMPLE_RATE

Clasificación

In [128]:
# Función para añadir el preprocesado
def preprocess_audio(audio):
    WAVEFORM, SAMPLE_RATE = torchaudio.load(audio.name)

    if SAMPLE_RATE != 16000:
        resampler = Resample(orig_freq=SAMPLE_RATE, new_freq=16000)
        WAVEFORM = resampler(WAVEFORM)

    padding_needed = (MAX_PADDING) - WAVEFORM.shape[1]
    padding = max(padding_needed, 0)
    waveform_padding = torch.nn.functional.pad(WAVEFORM, (0, padding))

    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=N_FFT,
        win_length=None,
        hop_length=HOP_LENGTH,
        center=True,
        pad_mode="reflect",
        power=2.0,
        norm="slaney",
        n_mels=N_MELS,
        mel_scale="htk",
    )

    melspec = mel_spectrogram(waveform_padding)

    # Aplicar conversión a log mel
    log_mel_spectrogram = torchaudio.transforms.AmplitudeToDB(top_db=80)(melspec)

    # Normalizar el log mel espectrograma
    if log_mel_spectrogram.max() - log_mel_spectrogram.min() != 0:
        log_mel_spectrogram_norm = (log_mel_spectrogram - log_mel_spectrogram.min()) / (log_mel_spectrogram.max() - log_mel_spectrogram.min())
    else:
        log_mel_spectrogram_norm = (log_mel_spectrogram - log_mel_spectrogram.min())  
    
    log_mel_spectrogram_norm = log_mel_spectrogram_norm[0] * 255

    # Convertir el log mel espectrograma a imagen
    log_mel_spectrogram_norm_rgb = log_mel_spectrogram_norm.repeat(3, 1, 1) 
    log_mel_spectrogram_norm_rgb = log_mel_spectrogram_norm_rgb / 255


    #Convertir el espectrograma de mel a imagen
    # melspec = melspec.unsqueeze(0)
    # melspec = torch.nn.functional.interpolate(melspec, size=(224, 224))
    # melspec = melspec.repeat(1, 3, 1, 1)
    # melspec = melspec.squeeze().permute(1, 2, 0).numpy()
    # melspec = Image.fromarray((melspec * 255).astype(np.uint8))

    # Normalizar la imagen
    # melspec = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


    return log_mel_spectrogram_norm_rgb
    



# Función para predecir la clase del audio
def classify_audio(audio_files):
    results = []  
    processor = AutoImageProcessor.from_pretrained('facebook/convnextv2-femto-1k-224')
    for audio_file in audio_files:
        image = preprocess_audio(audio_file)
        inputs = processor(images=image, return_tensors="pt")
        with torch.no_grad():
            output = model(inputs['pixel_values'])
        probabilities = torch.softmax(output, dim=1).squeeze()
        class_idx = torch.argmax(probabilities).item()
        classes = ["original", "splicing", "copy-move"]
        results.append(classes[class_idx])
    return results


class AudioClassifier(torch.nn.Module):
    def __init__(self):
        super(AudioClassifier, self).__init__()
        self.model = AutoModelForImageClassification.from_pretrained(
            'facebook/convnextv2-femto-1k-224', num_labels=NUM_CLASSES, ignore_mismatched_sizes=True
        )
        
    def forward(self, x):
        return self.model(x).logits

# Inicializar el modelo
model = AudioClassifier()

# Cargar el modelo
state_dict = torch.load('PicklesAndFinalModel/test_mel_preprocess_data_3_chans_history_torch_baseline_convnext_femto_224_custom_exp_8_mel_test_sin_resize_mel_best_model.pth', map_location=torch.device('cpu'))

# Filtrar el state_dict para eliminar claves no necesarias
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict().keys()}

# Cargar el state_dict filtrado
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()


# Interfaz de Gradio
with gr.Blocks(gr.themes.Default(), title="Voice Cloning Demo") as demo:
    gr.Markdown("Audio Classifier")
    with gr.Tab("Inference"):
        with gr.Column() as col1:
            upload_file = gr.File(
                file_count="multiple",
                label="Select here the audio files",
            )

            label = gr.Textbox(label="Predicted Class(es)")

            button = gr.Button("Classify")

            button.click(fn=classify_audio, inputs=upload_file, outputs=label)

demo.launch()

Some weights of ConvNextV2ForImageClassification were not initialized from the model checkpoint at facebook/convnextv2-femto-1k-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([3, 384]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Running on local URL:  http://127.0.0.1:7893

To create a public link, set `share=True` in `launch()`.


