In [1]:
import os
import random
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
!unzip data.zip -d /content/genres/

Archive:  data.zip
   creating: /content/genres/genres_original/
   creating: /content/genres/genres_original/blues/
  inflating: /content/genres/genres_original/blues/blues.00000.wav  
  inflating: /content/genres/genres_original/blues/blues.00001.wav  
  inflating: /content/genres/genres_original/blues/blues.00002.wav  
  inflating: /content/genres/genres_original/blues/blues.00003.wav  
  inflating: /content/genres/genres_original/blues/blues.00004.wav  
  inflating: /content/genres/genres_original/blues/blues.00005.wav  
  inflating: /content/genres/genres_original/blues/blues.00006.wav  
  inflating: /content/genres/genres_original/blues/blues.00007.wav  
  inflating: /content/genres/genres_original/blues/blues.00008.wav  
  inflating: /content/genres/genres_original/blues/blues.00009.wav  
  inflating: /content/genres/genres_original/blues/blues.00010.wav  
  inflating: /content/genres/genres_original/blues/blues.00011.wav  
  inflating: /content/genres/genres_original/blues/blue

In [28]:
DATA_DIR = "/content/genres/genres_original"   # <-- dataset folder path (genres/{blues,rock,...})
GENRES = ['blues','classical','country','disco','hiphop','jazz','metal','pop','reggae','rock']

SR = 22050           # sampling rate
DURATION = 30        # seconds to use per clip
N_MFCC = 40
HOP_LENGTH = 512
MAX_FRAMES = int(np.ceil(SR * DURATION / HOP_LENGTH))  # around 1293 for 30s
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 32
NUM_EPOCHS = 25
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5
PATIENCE = 6        # early stopping patience on val loss
SEED = 42

# reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE.type == 'cuda':
    torch.cuda.manual_seed_all(SEED)


In [29]:
for root, dirs, files in os.walk(DATA_DIR):
    for file in files:
        if file.endswith(".wav"):
            path = os.path.join(root, file)
            new_path = path.replace(".wav", "_fixed.wav")
            os.system(f"ffmpeg -y -i '{path}' -acodec pcm_s16le -ar 44100 '{new_path}'")


In [31]:
def augment_audio(y, sr):
    # random small augmentations: pitch shift, time stretch, noise
    # apply each augmentation with some probability
    if random.random() < 0.3:
        # pitch shift by -2..2 semitones
        n_steps = random.uniform(-2, 2)
        try:
            y = librosa.effects.pitch_shift(y, sr, n_steps)
        except Exception:
            pass
    if random.random() < 0.3:
        # time stretch 0.9..1.1 (keep length approx; we'll trim/pad later)
        rate = random.uniform(0.9, 1.1)
        try:
            y = librosa.effects.time_stretch(y, rate)
        except Exception:
            pass
    if random.random() < 0.3:
        # add small gaussian noise
        noise_amp = 0.005 * np.random.uniform() * np.max(y)
        y = y + noise_amp * np.random.normal(size=y.shape[0])
    return y

In [32]:
def load_and_extract_mfcc(path, augment=False):
    # load full or truncated to DURATION
    try:
        y, sr = librosa.load(path, sr=SR, duration=DURATION)
    except Exception as e:
        raise e
    # if too short, pad
    if len(y) < SR * DURATION:
        pad_len = SR * DURATION - len(y)
        y = np.pad(y, (0, int(pad_len)), mode='constant')
    # augmentation
    if augment:
        y = augment_audio(y, sr)
        # after augment we may have different length -> ensure trim/pad
        if len(y) < SR * DURATION:
            y = np.pad(y, (0, int(SR * DURATION - len(y))), mode='constant')
        if len(y) > SR * DURATION:
            y = y[:SR * DURATION]
    # compute MFCC (shape: n_mfcc x t)
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=N_MFCC, hop_length=HOP_LENGTH)
    # transpose -> (time_steps, n_mfcc)
    mfcc = mfcc.T.astype(np.float32)
    # pad or truncate to MAX_FRAMES
    if mfcc.shape[0] < MAX_FRAMES:
        pad_width = MAX_FRAMES - mfcc.shape[0]
        mfcc = np.pad(mfcc, ((0, pad_width), (0,0)), mode='constant')
    elif mfcc.shape[0] > MAX_FRAMES:
        mfcc = mfcc[:MAX_FRAMES, :]
    return mfcc  # shape: (MAX_FRAMES, N_MFCC)


In [33]:
class GTZANDataset(Dataset):
    def __init__(self, filepaths, labels, augment=False):
        self.filepaths = filepaths
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        path = self.filepaths[idx]
        label = self.labels[idx]
        try:
            mfcc = load_and_extract_mfcc(path, augment=self.augment)
        except Exception as e:
            print(f"⚠️ Skipping file: {path} ({e})")
            # return a random valid sample instead
            idx = random.randint(0, len(self.filepaths)-1)
            return self.__getitem__(idx)
        mfcc = mfcc.T  # (N_MFCC, MAX_FRAMES)
        return torch.from_numpy(mfcc), torch.tensor(label, dtype=torch.long)

In [34]:
class ConvLSTMGenre(nn.Module):
    def __init__(self, n_mfcc=N_MFCC, hidden_size=128, n_classes=10, n_lstm_layers=2, dropout=0.3):
        super().__init__()
        # CNN extractor on frequency axis per time-slice using Conv1d across time
        # Input: (batch, n_mfcc, time)
        self.conv1 = nn.Conv1d(in_channels=n_mfcc, out_channels=128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(kernel_size=2)  # halves time dimension

        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(kernel_size=2)  # halves again

        # after conv/pool, time dimension reduced: MAX_FRAMES / 4 approx
        # LSTM expects (batch, seq_len, feat); we'll transpose appropriately
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size=256, hidden_size=hidden_size, num_layers=n_lstm_layers,
                            batch_first=True, bidirectional=True, dropout=dropout if n_lstm_layers>1 else 0.0)

        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hidden_size*2, 128)  # *2 for bidirectional
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        # x: (batch, n_mfcc, time)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)   # (batch, 128, time/2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)   # (batch, 256, time/4)

        # transpose to (batch, time_seq, feat)
        x = x.permute(0, 2, 1)  # (batch, seq_len, 256)

        # LSTM
        out, (hn, cn) = self.lstm(x)  # out: (batch, seq_len, hidden*2)
        # use mean pooling over time
        out = torch.mean(out, dim=1)  # (batch, hidden*2)

        out = self.dropout(F.relu(self.fc1(out)))
        out = self.fc2(out)
        return out

In [35]:
def gather_files(data_dir, genres=GENRES):
    files = []
    labels = []
    for i, g in enumerate(genres):
        folder = os.path.join(data_dir, g)
        if not os.path.isdir(folder):
            print(f"Warning: {folder} not found.")
            continue
        for f in os.listdir(folder):
            if f.lower().endswith(('.wav', '.au', '.mp3', '.aiff', '.aif')):
                files.append(os.path.join(folder, f))
                labels.append(g)
    return files, labels

all_files, all_labels_str = gather_files(DATA_DIR, GENRES)
if len(all_files) == 0:
    raise RuntimeError(f"No audio files found in {DATA_DIR}. Check path and file extensions.")


In [36]:
# Label encode
le = LabelEncoder()
all_labels = le.fit_transform(all_labels_str)

# Train / val / test split (stratified)
X_train, X_temp, y_train, y_temp = train_test_split(all_files, all_labels, test_size=0.25, random_state=SEED, stratify=all_labels)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=SEED, stratify=y_temp)

print("Dataset sizes:", len(X_train), len(X_val), len(X_test))

Dataset sizes: 2248 375 375


In [37]:
# Datasets & loaders
train_ds = GTZANDataset(X_train, y_train, augment=True)
val_ds   = GTZANDataset(X_val, y_val, augment=False)
test_ds  = GTZANDataset(X_test, y_test, augment=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

In [38]:
model = ConvLSTMGenre(n_mfcc=N_MFCC, hidden_size=128, n_classes=len(le.classes_)).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

# Early stopping
best_val_loss = float('inf')
epochs_no_improve = 0
best_model_path = "best_music_genre_model.pth"

In [39]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for X_batch, y_batch in tqdm(loader, desc="Train batches", leave=False):
        X_batch = X_batch.to(DEVICE)        # shape: (batch, n_mfcc, time)
        y_batch = y_batch.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * X_batch.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == y_batch).sum().item()
        total += y_batch.size(0)
    epoch_loss = running_loss / total
    acc = correct / total
    return epoch_loss, acc

In [40]:
def eval_model(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X_batch, y_batch in tqdm(loader, desc="Eval batches", leave=False):
            X_batch = X_batch.to(DEVICE)
            y_batch = y_batch.to(DEVICE)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            running_loss += loss.item() * X_batch.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == y_batch).sum().item()
            total += y_batch.size(0)
    epoch_loss = running_loss / total
    acc = correct / total
    return epoch_loss, acc

In [41]:
for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\n=== Epoch {epoch}/{NUM_EPOCHS} ===")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = eval_model(model, val_loader, criterion)
    print(f"Train Loss: {train_loss:.4f}  Acc: {train_acc*100:.2f}%")
    print(f"Val   Loss: {val_loss:.4f}  Acc: {val_acc*100:.2f}%")

    scheduler.step(val_loss)

    # save best
    if val_loss < best_val_loss - 1e-5:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'label_encoder_classes': le.classes_,
            'val_loss': val_loss
        }, best_model_path)
        print("Saved best model.")
    else:
        epochs_no_improve += 1
        print(f"No improvement for {epochs_no_improve} epoch(s).")

    if epochs_no_improve >= PATIENCE:
        print("Early stopping triggered.")
        break



=== Epoch 1/25 ===


Train batches:   0%|          | 0/71 [00:00<?, ?it/s]Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7eff0bfed8a0><function _MultiProcessingDataLoaderIter.__del__ at 0x7eff0bfed8a0>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():
 
            ^ ^^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    ^assert self._parent_pid == os.getpid(), 'c

⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 1.7218  Acc: 39.41%
Val   Loss: 1.2636  Acc: 53.60%
Saved best model.

=== Epoch 2/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 1.2400  Acc: 55.29%
Val   Loss: 1.2791  Acc: 56.00%
No improvement for 1 epoch(s).

=== Epoch 3/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 1.0469  Acc: 63.75%
Val   Loss: 0.9112  Acc: 68.80%
Saved best model.

=== Epoch 4/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.9562  Acc: 67.30%
Val   Loss: 0.9239  Acc: 68.27%
No improvement for 1 epoch(s).

=== Epoch 5/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.7672  Acc: 73.67%
Val   Loss: 0.8300  Acc: 70.67%
Saved best model.

=== Epoch 6/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.7849  Acc: 73.35%
Val   Loss: 0.6857  Acc: 77.33%
Saved best model.

=== Epoch 7/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.6851  Acc: 76.25%
Val   Loss: 0.6810  Acc: 74.13%
Saved best model.

=== Epoch 8/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.6306  Acc: 78.74%
Val   Loss: 0.7577  Acc: 74.93%
No improvement for 1 epoch(s).

=== Epoch 9/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.5597  Acc: 80.47%
Val   Loss: 0.8515  Acc: 73.60%
No improvement for 2 epoch(s).

=== Epoch 10/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.5062  Acc: 82.61%
Val   Loss: 0.4740  Acc: 82.93%
Saved best model.

=== Epoch 11/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.4669  Acc: 84.61%
Val   Loss: 0.4688  Acc: 85.33%
Saved best model.

=== Epoch 12/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.5227  Acc: 83.41%
Val   Loss: 0.4035  Acc: 87.47%
Saved best model.

=== Epoch 13/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.3441  Acc: 89.32%
Val   Loss: 0.4213  Acc: 88.53%
No improvement for 1 epoch(s).

=== Epoch 14/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.3414  Acc: 88.48%
Val   Loss: 0.3417  Acc: 88.00%
Saved best model.

=== Epoch 15/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.3280  Acc: 89.99%
Val   Loss: 0.4037  Acc: 89.60%
No improvement for 1 epoch(s).

=== Epoch 16/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.2981  Acc: 90.17%
Val   Loss: 0.2822  Acc: 93.07%
Saved best model.

=== Epoch 17/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.2971  Acc: 90.30%
Val   Loss: 0.3045  Acc: 90.67%
No improvement for 1 epoch(s).

=== Epoch 18/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.3103  Acc: 89.86%
Val   Loss: 0.2821  Acc: 90.93%
Saved best model.

=== Epoch 19/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.2472  Acc: 92.48%
Val   Loss: 0.3308  Acc: 89.87%
No improvement for 1 epoch(s).

=== Epoch 20/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.2641  Acc: 92.08%
Val   Loss: 0.2463  Acc: 93.60%
Saved best model.

=== Epoch 21/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.2420  Acc: 93.06%
Val   Loss: 0.2302  Acc: 93.87%
Saved best model.

=== Epoch 22/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.1968  Acc: 93.37%
Val   Loss: 0.3052  Acc: 92.27%
No improvement for 1 epoch(s).

=== Epoch 23/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.2026  Acc: 93.24%
Val   Loss: 0.2696  Acc: 92.00%
No improvement for 2 epoch(s).

=== Epoch 24/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()




Train Loss: 0.1805  Acc: 94.84%
Val   Loss: 0.1946  Acc: 95.20%
Saved best model.

=== Epoch 25/25 ===


  y, sr = librosa.load(path, sr=SR, duration=DURATION)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()


                                                             

Train Loss: 0.1341  Acc: 95.91%
Val   Loss: 0.2283  Acc: 93.07%
No improvement for 1 epoch(s).




In [43]:
# -------------------------
# Accuracy check function
# -------------------------
def evaluate_accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for X_batch, y_batch in loader:
            X_batch = X_batch.to(DEVICE)
            y_batch = y_batch.to(DEVICE)
            outputs = model(X_batch)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == y_batch).sum().item()
            total += y_batch.size(0)
    acc = correct / total
    return acc

# -------------------------
# Check accuracy on validation set
# -------------------------
val_acc = evaluate_accuracy(model, val_loader)
print(f"Validation Accuracy: {val_acc*100:.2f}%")

# -------------------------
# Check accuracy on test set
# -------------------------
test_acc = evaluate_accuracy(model, test_loader)
print(f"Test Accuracy: {test_acc*100:.2f}%")


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7eff0bfed8a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7eff0bfed8a0>^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^    ^^self._shutdown_workers()
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^^
 ^

⚠️ Skipping file: /content/genres/genres_original/jazz/jazz.00054.wav ()
Validation Accuracy: 93.07%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7eff0bfed8a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7eff0bfed8a0>^
^Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    
self._shutdown_workers()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

        if w.is_alive(): 
           ^ ^ ^^^^^^^^^^^^^^^^^^^^^^^

Test Accuracy: 93.07%


In [44]:
# Save final model weights & label encoder
torch.save({
    'model_state': model.state_dict(),
    'label_encoder_classes': le.classes_,
}, "final_music_genre_model.pth")

print("Done. Models saved: ", best_model_path, "and final_music_genre_model.pth")

Done. Models saved:  best_music_genre_model.pth and final_music_genre_model.pth


In [48]:
def predict_genre(model, file_path):
    mfcc = load_and_extract_mfcc(file_path, augment=False).T  # (N_MFCC, MAX_FRAMES)
    X = torch.from_numpy(mfcc).unsqueeze(0).to(DEVICE)  # add batch dim
    model.eval()
    with torch.no_grad():
        output = model(X)
        pred = torch.argmax(output, dim=1).item()
    return le.classes_[pred]

# Example
song_path = "/content/jazz.00001.wav"
print("Predicted genre:", predict_genre(model, song_path))

Predicted genre: jazz


In [49]:
def predict_genre(model, file_path):
    mfcc = load_and_extract_mfcc(file_path, augment=False).T  # (N_MFCC, MAX_FRAMES)
    X = torch.from_numpy(mfcc).unsqueeze(0).to(DEVICE)  # add batch dim
    model.eval()
    with torch.no_grad():
        output = model(X)
        pred = torch.argmax(output, dim=1).item()
    return le.classes_[pred]

# Example
song_path = "/content/pop.00003.wav"
print("Predicted genre:", predict_genre(model, song_path))

Predicted genre: pop
