In [7]:
import torch
import torch.nn as nn
import librosa
import torchaudio
import torchvision.models as models
import numpy as np
import os

class Res2NetAudioClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(Res2NetAudioClassifier, self).__init__()
        self.res2net = models.resnet50(pretrained=True)  
        self.res2net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  
        self.res2net.fc = nn.Linear(self.res2net.fc.in_features, num_classes) 

    def forward(self, x):
        return self.res2net(x)

def preprocess_audio(file_path, sample_rate=16000, n_mels=64, max_duration=5.0):
    waveform, sr = librosa.load(file_path, sr=sample_rate, duration=max_duration)
    mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=n_mels)
    mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
    return torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0)


def load_model(device, model_path=None):
    model = Res2NetAudioClassifier(num_classes=2).to(device) 
    if model_path:
        model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

def predict(model, file_path, device):
    spectrogram = preprocess_audio(file_path).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(spectrogram)
        _, predicted = torch.max(output, 1)
        return "AI Generated" if predicted.item() == 1 else "Real"


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model(device) 
    file_path = 'D:/Projects/AudioForgery/HAD/HAD_train/conbine/HAD_train_fake_00000001.wav'
    result = predict(model, file_path, device)
    print(f"Prediction: {result}")


Prediction: AI Generated


In [39]:

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model(device) 
    file_path ='D:/Projects/AudioForgery/HumanVoice/OSR_us_000_0010_8k.wav'
    result = predict(model, file_path, device)
    print(f"Prediction: {result}")


Prediction: Real
