In [1]:
import os
import librosa
import librosa.display
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

In [2]:
def pad_or_trim_mel_spec(mel_spec, max_length):
    # 获取当前梅尔谱图的时间维度长度
    length = mel_spec.shape[1]
    if length > max_length:
        # 裁剪
        mel_spec = mel_spec[:, :max_length]
    elif length < max_length:
        # 填充
        padding = max_length - length
        mel_spec = np.pad(mel_spec, ((0, 0), (0, padding)), mode='constant')
    return mel_spec

class GTZANDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, index):
        file_path = self.file_paths[index]
        label = self.labels[index]

        y, sr = librosa.load(file_path, sr=22050)
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
        mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
        max_length = 1300
        mel_spec = pad_or_trim_mel_spec(mel_spec, max_length)

        if self.transform:
            mel_spec = self.transform(mel_spec)
        else:
            mel_spec = torch.tensor(mel_spec).unsqueeze(0)

        return mel_spec, label

In [3]:
data_dir = './autodl-fs/Data/genres_original/'

genres = os.listdir(data_dir)

In [4]:
genres

['blues',
 'classical',
 'country',
 'disco',
 'hiphop',
 'jazz',
 'metal',
 'pop',
 'reggae',
 'rock']

In [5]:
file_paths = []
labels = []

for i , genre in enumerate(genres):
    genre_dir = os.path.join(data_dir, genre)
    cnt = 0
    for file in os.listdir(genre_dir):
        if file.endswith('.wav') and file != 'jazz.00054.wav':
            file_paths.append(os.path.join(genre_dir, file))
            labels.append(i)
            cnt += 1
    print("{} cnt: {}".format(genre, cnt))
print(file_paths)
print(labels)

blues cnt: 100
classical cnt: 100
country cnt: 100
disco cnt: 100
hiphop cnt: 100
jazz cnt: 99
metal cnt: 100
pop cnt: 100
reggae cnt: 100
rock cnt: 100
['./autodl-fs/Data/genres_original/blues/blues.00000.wav', './autodl-fs/Data/genres_original/blues/blues.00001.wav', './autodl-fs/Data/genres_original/blues/blues.00002.wav', './autodl-fs/Data/genres_original/blues/blues.00003.wav', './autodl-fs/Data/genres_original/blues/blues.00004.wav', './autodl-fs/Data/genres_original/blues/blues.00005.wav', './autodl-fs/Data/genres_original/blues/blues.00006.wav', './autodl-fs/Data/genres_original/blues/blues.00007.wav', './autodl-fs/Data/genres_original/blues/blues.00008.wav', './autodl-fs/Data/genres_original/blues/blues.00009.wav', './autodl-fs/Data/genres_original/blues/blues.00010.wav', './autodl-fs/Data/genres_original/blues/blues.00011.wav', './autodl-fs/Data/genres_original/blues/blues.00012.wav', './autodl-fs/Data/genres_original/blues/blues.00013.wav', './autodl-fs/Data/genres_original/

In [6]:
len(file_paths)

999

In [7]:
train_files, test_files, train_labels, test_labels = train_test_split(file_paths, labels, test_size=0.2, stratify=labels, random_state=42)

In [8]:
class SpecAugment:
    def __call__(self, spec):
        spec = torch.tensor(spec)
        
        if spec.dim() == 2:
            spec = spec.unsqueeze(0)
        
        freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=15)
        time_mask = torchaudio.transforms.TimeMasking(time_mask_param=35)
        spec = freq_mask(spec)
        spec = time_mask(spec)
        return spec
    
train_transform = transforms.Compose([
    SpecAugment(),
    transforms.Normalize((0.5,), (0.5,))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [9]:
train_dataset = GTZANDataset(train_files, train_labels, transform=train_transform)
test_dataset = GTZANDataset(test_files, test_labels, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

model = models.vgg16(pretrained=True)

model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
model.classifier[6] = nn.Linear(4096, 10)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)



In [10]:
def train(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}')
    torch.save(model.state_dict(), 'model.pth')

In [11]:
train(model, train_loader, criterion, optimizer, epochs=50)

Epoch 1/50, Loss: 2.5098
Epoch 2/50, Loss: 2.1017
Epoch 3/50, Loss: 2.1207
Epoch 4/50, Loss: 1.7675
Epoch 5/50, Loss: 1.5928
Epoch 6/50, Loss: 1.3540
Epoch 7/50, Loss: 1.1942
Epoch 8/50, Loss: 1.1210
Epoch 9/50, Loss: 1.0657
Epoch 10/50, Loss: 0.8919
Epoch 11/50, Loss: 0.8682
Epoch 12/50, Loss: 0.8414
Epoch 13/50, Loss: 0.8259
Epoch 14/50, Loss: 0.6939
Epoch 15/50, Loss: 0.5482
Epoch 16/50, Loss: 0.5576
Epoch 17/50, Loss: 0.4414
Epoch 18/50, Loss: 0.5751
Epoch 19/50, Loss: 0.5096
Epoch 20/50, Loss: 0.3284
Epoch 21/50, Loss: 0.3147
Epoch 22/50, Loss: 0.3039
Epoch 23/50, Loss: 0.2704
Epoch 24/50, Loss: 0.2897
Epoch 25/50, Loss: 0.1804
Epoch 26/50, Loss: 0.1756
Epoch 27/50, Loss: 0.2031
Epoch 28/50, Loss: 0.2393
Epoch 29/50, Loss: 0.1937
Epoch 30/50, Loss: 0.1333
Epoch 31/50, Loss: 0.1566
Epoch 32/50, Loss: 0.1615
Epoch 33/50, Loss: 0.2077
Epoch 34/50, Loss: 0.1355
Epoch 35/50, Loss: 0.1343
Epoch 36/50, Loss: 0.1164
Epoch 37/50, Loss: 0.0845
Epoch 38/50, Loss: 0.1010
Epoch 39/50, Loss: 0.

In [12]:
model = models.vgg16(pretrained=False)
model.features[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
model.classifier[6] = nn.Linear(4096, 10)
model.load_state_dict(torch.load('model.pth'))
model.to(device)

def evaluate(model, test_loader):
    model.eval()
    outputs = []
    targets = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device)
            output = model(inputs)
            _, preds = torch.max(output, 1)
            outputs.extend(preds.cpu().numpy())
            targets.extend(labels.cpu().numpy())
    print("Classification Report: ")
    print(classification_report(targets, outputs, target_names=genres))
    print("Confusion Matrix: ")
    print(confusion_matrix(targets, outputs))

evaluate(model, test_loader)

  model.load_state_dict(torch.load('model.pth'))


Classification Report: 
              precision    recall  f1-score   support

       blues       0.69      0.55      0.61        20
   classical       0.95      1.00      0.98        20
     country       0.60      0.75      0.67        20
       disco       1.00      0.40      0.57        20
      hiphop       0.75      0.90      0.82        20
        jazz       0.67      0.80      0.73        20
       metal       0.90      0.95      0.93        20
         pop       0.68      0.85      0.76        20
      reggae       0.63      0.85      0.72        20
        rock       0.67      0.30      0.41        20

    accuracy                           0.73       200
   macro avg       0.75      0.73      0.72       200
weighted avg       0.75      0.73      0.72       200

Confusion Matrix: 
[[11  0  1  0  0  4  0  0  4  0]
 [ 0 20  0  0  0  0  0  0  0  0]
 [ 2  0 15  0  0  1  0  1  0  1]
 [ 0  0  2  8  3  1  0  4  1  1]
 [ 0  0  0  0 18  0  0  0  2  0]
 [ 2  0  1  0  0 16  0  0  0  1]
