In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from tqdm import tqdm
import wandb
import torch.nn.functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


In [2]:
class MelSpectrogramDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.classes = os.listdir(root_dir)
        self.data = []
        self.labels = []
        
        for idx, genre in enumerate(self.classes):
            genre_path = os.path.join(root_dir, genre)
            for file in os.listdir(genre_path):
                self.data.append(os.path.join(genre_path, file))
                self.labels.append(idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        img = Image.open(img_path).convert('RGB') 
        img = np.array(img) 
        img = img.transpose((2, 0, 1))  
        img = torch.from_numpy(img).float()  
        label = self.labels[idx]
        return img, label


In [3]:
train_dir = '/kaggle/input/mel-spectrogram-for-gtzn-dataset/melspectrograms/train'
val_dir = '/kaggle/input/mel-spectrogram-for-gtzn-dataset/melspectrograms/validation'
test_dir = '/kaggle/input/mel-spectrogram-for-gtzn-dataset/melspectrograms/test'

train_dataset = MelSpectrogramDataset(train_dir)
val_dataset = MelSpectrogramDataset(val_dir)
test_dataset = MelSpectrogramDataset(test_dir)


In [4]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)
        
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)

In [6]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  
            ResidualBlock(64),
            
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  
            ResidualBlock(128),
            
            nn.AdaptiveAvgPool2d((6, 6)),  
            nn.Flatten(),
            
            nn.Linear(128*6*6, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        return self.net(x)

model = CNN()


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

CNN(
  (net): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (6): ResidualBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1)

In [8]:
optimizer = SGD(model.parameters(), 
               lr=0.01, 
               momentum=0.9, 
               weight_decay=1e-4)

criterion = nn.CrossEntropyLoss()

scheduler = CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-4)

In [9]:
for epoch in range(30):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/30')
    for X_batch, y_batch in pbar:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f'Loss: {running_loss / len(train_loader)}')

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    y_pred = []
    y_true = []
    with torch.no_grad():
        pbar_val = tqdm(val_loader, desc='Validation')
        for X_batch, y_batch in pbar_val:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == y_batch).sum().item()
            y_pred.extend(predicted.cpu().numpy())
            y_true.extend(y_batch.cpu().numpy())
        
        accuracy = correct / len(val_dataset)
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        f1 = f1_score(y_true, y_pred, average='macro')


        print(f'Validation Loss: {val_loss / len(val_loader)}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}')
    
    scheduler.step()


Epoch 1/30: 100%|██████████| 250/250 [01:48<00:00,  2.29it/s]


Loss: 1.6254043917655945


Validation: 100%|██████████| 32/32 [00:07<00:00,  4.22it/s]


Validation Loss: 1.5874359868466854, Accuracy: 0.4134, Precision: 0.4420, Recall: 0.4136, F1-Score: 0.3846


Epoch 2/30: 100%|██████████| 250/250 [01:36<00:00,  2.60it/s]


Loss: 1.2523079633712768


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.62it/s]


Validation Loss: 1.6168007333762944, Accuracy: 0.5045, Precision: 0.5275, Recall: 0.5044, F1-Score: 0.4842


Epoch 3/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 1.0819310219287872


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.58it/s]


Validation Loss: 1.2592457719147205, Accuracy: 0.5475, Precision: 0.6158, Recall: 0.5478, F1-Score: 0.5223


Epoch 4/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.9166520788669587


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.50it/s]


Validation Loss: 0.9892066321335733, Accuracy: 0.6406, Precision: 0.6723, Recall: 0.6407, F1-Score: 0.6199


Epoch 5/30: 100%|██████████| 250/250 [01:38<00:00,  2.55it/s]


Loss: 0.8197016817331314


Validation: 100%|██████████| 32/32 [00:05<00:00,  6.22it/s]


Validation Loss: 1.130441400455311, Accuracy: 0.6246, Precision: 0.7129, Recall: 0.6249, F1-Score: 0.6363


Epoch 6/30: 100%|██████████| 250/250 [01:37<00:00,  2.57it/s]


Loss: 0.7012608894705772


Validation: 100%|██████████| 32/32 [00:05<00:00,  6.30it/s]


Validation Loss: 0.7459244169294834, Accuracy: 0.7528, Precision: 0.7750, Recall: 0.7529, F1-Score: 0.7541


Epoch 7/30: 100%|██████████| 250/250 [01:37<00:00,  2.56it/s]


Loss: 0.612633834183216


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.57it/s]


Validation Loss: 1.036275491118431, Accuracy: 0.6657, Precision: 0.7347, Recall: 0.6652, F1-Score: 0.6573


Epoch 8/30: 100%|██████████| 250/250 [01:36<00:00,  2.58it/s]


Loss: 0.5516052901148796


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.51it/s]


Validation Loss: 0.7680572171229869, Accuracy: 0.7387, Precision: 0.7579, Recall: 0.7387, F1-Score: 0.7349


Epoch 9/30: 100%|██████████| 250/250 [01:36<00:00,  2.58it/s]


Loss: 0.48981555193662646


Validation: 100%|██████████| 32/32 [00:05<00:00,  6.32it/s]


Validation Loss: 0.7514859489165246, Accuracy: 0.7487, Precision: 0.7773, Recall: 0.7489, F1-Score: 0.7466


Epoch 10/30: 100%|██████████| 250/250 [01:37<00:00,  2.56it/s]


Loss: 0.4469645252227783


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.57it/s]


Validation Loss: 0.5860658176243305, Accuracy: 0.8078, Precision: 0.8196, Recall: 0.8077, F1-Score: 0.8084


Epoch 11/30: 100%|██████████| 250/250 [01:37<00:00,  2.57it/s]


Loss: 0.38048500287532805


Validation: 100%|██████████| 32/32 [00:05<00:00,  6.38it/s]


Validation Loss: 0.5898168344283476, Accuracy: 0.8068, Precision: 0.8106, Recall: 0.8067, F1-Score: 0.8034


Epoch 12/30: 100%|██████████| 250/250 [01:37<00:00,  2.57it/s]


Loss: 0.319423245549202


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.48it/s]


Validation Loss: 0.502879403764382, Accuracy: 0.8268, Precision: 0.8344, Recall: 0.8268, F1-Score: 0.8274


Epoch 13/30: 100%|██████████| 250/250 [01:37<00:00,  2.57it/s]


Loss: 0.2949705655872822


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.48it/s]


Validation Loss: 0.628327960614115, Accuracy: 0.8138, Precision: 0.8265, Recall: 0.8138, F1-Score: 0.8109


Epoch 14/30: 100%|██████████| 250/250 [01:36<00:00,  2.58it/s]


Loss: 0.2644238718599081


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.59it/s]


Validation Loss: 0.6078123381303158, Accuracy: 0.8108, Precision: 0.8332, Recall: 0.8109, F1-Score: 0.8121


Epoch 15/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.213011667445302


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.46it/s]


Validation Loss: 0.5303825719020097, Accuracy: 0.8358, Precision: 0.8506, Recall: 0.8359, F1-Score: 0.8367


Epoch 16/30: 100%|██████████| 250/250 [01:37<00:00,  2.58it/s]


Loss: 0.19636470860242844


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.57it/s]


Validation Loss: 0.44353839533869177, Accuracy: 0.8779, Precision: 0.8846, Recall: 0.8779, F1-Score: 0.8791


Epoch 17/30: 100%|██████████| 250/250 [01:36<00:00,  2.58it/s]


Loss: 0.16615451061725617


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.71it/s]


Validation Loss: 0.45989810503670014, Accuracy: 0.8579, Precision: 0.8670, Recall: 0.8577, F1-Score: 0.8577


Epoch 18/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.14206852726638317


Validation: 100%|██████████| 32/32 [00:05<00:00,  6.38it/s]


Validation Loss: 0.4458844919281546, Accuracy: 0.8769, Precision: 0.8851, Recall: 0.8770, F1-Score: 0.8774


Epoch 19/30: 100%|██████████| 250/250 [01:36<00:00,  2.58it/s]


Loss: 0.12963933090865612


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.46it/s]


Validation Loss: 0.3830244508353644, Accuracy: 0.8869, Precision: 0.8904, Recall: 0.8869, F1-Score: 0.8877


Epoch 20/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.10691709638386965


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.73it/s]


Validation Loss: 0.4505244336905889, Accuracy: 0.8839, Precision: 0.8936, Recall: 0.8839, F1-Score: 0.8860


Epoch 21/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.09736614891886711


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.53it/s]


Validation Loss: 0.3682223177020205, Accuracy: 0.9079, Precision: 0.9112, Recall: 0.9079, F1-Score: 0.9086


Epoch 22/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.0867193050980568


Validation: 100%|██████████| 32/32 [00:05<00:00,  6.38it/s]


Validation Loss: 0.3938693075615447, Accuracy: 0.8949, Precision: 0.8961, Recall: 0.8949, F1-Score: 0.8944


Epoch 23/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.07675265031494201


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.53it/s]


Validation Loss: 0.39328783084783936, Accuracy: 0.8979, Precision: 0.8989, Recall: 0.8979, F1-Score: 0.8977


Epoch 24/30: 100%|██████████| 250/250 [01:37<00:00,  2.57it/s]


Loss: 0.06889063449576498


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.52it/s]


Validation Loss: 0.36108799403882585, Accuracy: 0.9059, Precision: 0.9087, Recall: 0.9059, F1-Score: 0.9066


Epoch 25/30: 100%|██████████| 250/250 [01:36<00:00,  2.58it/s]


Loss: 0.06227779110427946


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.48it/s]


Validation Loss: 0.35655959583527874, Accuracy: 0.9039, Precision: 0.9043, Recall: 0.9039, F1-Score: 0.9040


Epoch 26/30: 100%|██████████| 250/250 [01:36<00:00,  2.58it/s]


Loss: 0.057492867033928634


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.46it/s]


Validation Loss: 0.35701088351925137, Accuracy: 0.9109, Precision: 0.9129, Recall: 0.9109, F1-Score: 0.9114


Epoch 27/30: 100%|██████████| 250/250 [01:36<00:00,  2.58it/s]


Loss: 0.0617488148547709


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.59it/s]


Validation Loss: 0.3491798896357068, Accuracy: 0.9089, Precision: 0.9099, Recall: 0.9089, F1-Score: 0.9092


Epoch 28/30: 100%|██████████| 250/250 [01:36<00:00,  2.58it/s]


Loss: 0.05606812455039471


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.61it/s]


Validation Loss: 0.35949763251119293, Accuracy: 0.9039, Precision: 0.9059, Recall: 0.9039, F1-Score: 0.9041


Epoch 29/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.052668693479150534


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.62it/s]


Validation Loss: 0.3588326148747001, Accuracy: 0.9079, Precision: 0.9088, Recall: 0.9079, F1-Score: 0.9080


Epoch 30/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.05567819238826632


Validation: 100%|██████████| 32/32 [00:04<00:00,  6.51it/s]

Validation Loss: 0.3527868778764969, Accuracy: 0.9119, Precision: 0.9130, Recall: 0.9119, F1-Score: 0.9121





In [10]:
model.eval()
test_loss = 0.0
correct = 0
y_pred = []
y_true = []

with torch.no_grad():
    pbar_test = tqdm(test_loader, desc='Test')
    for X_batch, y_batch in pbar_test:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y_batch).sum().item()

        y_pred.extend(predicted.cpu().numpy())
        y_true.extend(y_batch.cpu().numpy())

        if pbar_test.n > 0:
            pbar_test.set_postfix({
                'Loss': f'{test_loss / (pbar_test.n + 1):.4f}',
                'Acc': f'{correct / len(test_dataset):.4f}'
            })
        else:
            pbar_test.set_postfix({'Loss': f'{test_loss:.4f}', 'Acc': 'N/A'})

test_accuracy = correct / len(test_dataset)
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
f1 = f1_score(y_true, y_pred, average='macro')

print(f'Test Loss: {test_loss / len(test_loader):.4f}, '
      f'Accuracy: {test_accuracy:.4f}, '
      f'Precision: {precision:.4f}, '
      f'Recall: {recall:.4f}, '
      f'F1-Score: {f1:.4f}')


Test: 100%|██████████| 32/32 [00:09<00:00,  3.32it/s, Loss=0.2987, Acc=0.9099]

Test Loss: 0.2987, Accuracy: 0.9099, Precision: 0.9112, Recall: 0.9099, F1-Score: 0.9100



