In [28]:
import torch
import torchaudio

from VGG import VGG
from dataset import ShipEarDataset

In [29]:
ANNOTATIONS_FILE = "../label_process/label.csv"
AUDIO_DIR = r"E:\数据集\ShipEar\shipsEar_AUDIOS"
SAMPLE_RATE = 22050
NUM_SAMPLES = 44100

class_mapping = [
    "Fishboat, Trawler, Mussel boat, Tugboat, Dredger",
    "Motorboat, Pilot ship, Sailboat",
    "Passengers",
    "Ocean liner, RORO",
    "Natural ambient noise"
]

In [30]:
def predict(model, input, target, class_mapping):
    model.eval()
    with torch.no_grad():
        predictions = model(input)
        predicted_index = predictions[0].argmax(0)
        predicted = class_mapping[predicted_index]
        expected = class_mapping[target]
    return predicted, expected

In [34]:
vgg = VGG()
state_dict = torch.load("shipear.pt", weights_only=True)
vgg.load_state_dict(state_dict)

mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=2048,
    hop_length=512,
    n_mels=128
)

shipear = ShipEarDataset(ANNOTATIONS_FILE,
                        AUDIO_DIR,
                        mel_spectrogram,
                        SAMPLE_RATE,
                        NUM_SAMPLES,
                        "cpu")

input, target = shipear[1][0], shipear[1][1]
print(input.shape)
input.unsqueeze_(0)
print(input.shape)

predicted, expected = predict(vgg, input, target,
                                class_mapping)

print(f"Predicted: '{predicted}', expected: '{expected}'")

torch.Size([1, 128, 87])
torch.Size([1, 1, 128, 87])
Predicted: 'Motorboat, Pilot ship, Sailboat', expected: 'Passengers'
