In [1]:
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader

from VGG import VGG
from dataset import ShipEarDataset

In [2]:
ANNOTATIONS_FILE = "../label_process/label.csv"
AUDIO_DIR = r"E:\数据集\ShipEar\shipsEar_AUDIOS"
SAMPLE_RATE = 22050
NUM_SAMPLES = 44100

BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 0.001

In [3]:
def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size)
    return train_dataloader

In [4]:
def train(model, data_loader, loss_fn, optimiser, device, epochs):
    for i in range(epochs):
        print(f"Epoch {i+1}")
        for input, target in data_loader:
            input, target = input.to(device), target.to(device)
            prediction = model(input)
            loss = loss_fn(prediction, target)
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
        print(f"loss: {loss.item()}")
        print("-----------------------------------------------")
    print("Finished training!")

In [5]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using {device}")

mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=2048,
    hop_length=512,
    n_mels=128
)

shipear = ShipEarDataset(
    ANNOTATIONS_FILE,
    AUDIO_DIR,
    mel_spectrogram,
    SAMPLE_RATE,
    NUM_SAMPLES,
    device
)

train_dataloader = create_data_loader(shipear, BATCH_SIZE)

vgg = VGG().to(device)
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(vgg.parameters(), lr=LEARNING_RATE)

train(vgg, train_dataloader, loss_fn, optimiser, device, EPOCHS)
torch.save(vgg.state_dict(), "shipear.pt")
print("Saved.")

Using cuda
Epoch 1
loss: 1.5757983922958374
-----------------------------------------------
Epoch 2
loss: 1.5484259128570557
-----------------------------------------------
Epoch 3
loss: 1.492844820022583
-----------------------------------------------
Epoch 4
loss: 1.4680709838867188
-----------------------------------------------
Epoch 5
loss: 1.431103229522705
-----------------------------------------------
Epoch 6
loss: 1.4545822143554688
-----------------------------------------------
Epoch 7
loss: 1.4080268144607544
-----------------------------------------------
Epoch 8
loss: 1.3925304412841797
-----------------------------------------------
Epoch 9
loss: 1.3796359300613403
-----------------------------------------------
Epoch 10
loss: 1.3775659799575806
-----------------------------------------------
Finished training!
Saved.
