In [None]:
import gradio as gr
from Pipelines.AudioPipeline import AudioPipeline
import torchaudio
import torch
from DetectionModels.AudioCNN import CNNnetwork

NUM_SAMPLES= int(3.876695758374233*16000)
BATCH_SIZE= 128
SAMPLE_RATE= 20000
NUM_OF_MELS= 128
WIN_LENGTH= int(0.016 * SAMPLE_RATE)
HOP_LENGTH= int(0.004 * SAMPLE_RATE) 

mel_spectrogram_transform= torch.nn.Sequential(
    torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=WIN_LENGTH*2,
        hop_length=HOP_LENGTH,
        n_mels=NUM_OF_MELS,
        window_fn=torch.hamming_window,
        win_length=WIN_LENGTH
    ),
    torchaudio.transforms.AmplitudeToDB(top_db=80)
)

model= CNNnetwork()
state_dict= torch.load("Audio_CNN.pth", weights_only=True)
model.load_state_dict(state_dict=state_dict)

pipeline = AudioPipeline(model, mel_spectrogram_transform, SAMPLE_RATE, NUM_SAMPLES)

def gradio_predict(audio_path):
    if audio_path is None:
        return "No Audio", None, None
    results = pipeline.run(audio_path)
    return results["prediction"], results["signal"], results["explination"]

theme = gr.themes.Soft(
    primary_hue="blue",
    secondary_hue="slate",
)

with gr.Blocks(theme=theme) as demo:
    gr.Markdown(
        """
        # Deepfake Audio Detection
        **Upload an audio file to check wether your audio is real or not.**
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1, variant="panel"):
            gr.Markdown("### 1. Input")
            audio_input = gr.Audio(
                type="filepath", 
                label="Upload File",
                sources= ["upload"]
            )

            gr.Markdown("### 2. Analyze")
            submit_btn = gr.Button("Detect Deepfake", variant="primary", size="lg")
            
            gr.Markdown("### 3. Result")
            lbl_output = gr.Label(label="Prediction", num_top_classes=2)

        with gr.Column(scale=2):
            with gr.Tabs():
                with gr.TabItem("Orignal Audio"):
                    plot_output = gr.Image(label="Original Signal", type="pil", show_label=False)
                
                with gr.TabItem("Explanations"):
                    gr.Markdown("An Explination to your result")
                    gallery_output = gr.Gallery(
                        label="Explinations", columns=2, object_fit="contain", height="auto", show_label=False
                    )

    submit_btn.click(
        fn=gradio_predict, 
        inputs=audio_input, 
        outputs=[lbl_output, plot_output, gallery_output]
    )

if __name__ == "__main__":
    demo.launch(
        debug=True,        
        inbrowser=True,  
        # share=True
    )