In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from tqdm import tqdm
import wandb
import torch.nn.functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


In [2]:
class MelSpectrogramDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.classes = os.listdir(root_dir)
        self.data = []
        self.labels = []
        
        for idx, genre in enumerate(self.classes):
            genre_path = os.path.join(root_dir, genre)
            for file in os.listdir(genre_path):
                self.data.append(os.path.join(genre_path, file))
                self.labels.append(idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        img = Image.open(img_path).convert('RGB') 
        img = np.array(img) 
        img = img.transpose((2, 0, 1))  
        img = torch.from_numpy(img).float()  
        label = self.labels[idx]
        return img, label


In [3]:
train_dir = '/kaggle/input/mel-spectrogram-for-gtzn-dataset/melspectrograms/train'
val_dir = '/kaggle/input/mel-spectrogram-for-gtzn-dataset/melspectrograms/validation'
test_dir = '/kaggle/input/mel-spectrogram-for-gtzn-dataset/melspectrograms/test'

train_dataset = MelSpectrogramDataset(train_dir)
val_dataset = MelSpectrogramDataset(val_dir)
test_dataset = MelSpectrogramDataset(test_dir)


In [4]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)
        
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)

In [6]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  
            ResidualBlock(64),
            
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  
            ResidualBlock(128),
            
            nn.AdaptiveAvgPool2d((6, 6)),  
            nn.Flatten(),
            
            nn.Linear(128*6*6, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        return self.net(x)

model = CNN()


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

CNN(
  (net): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (6): ResidualBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1)

In [8]:
optimizer = SGD(model.parameters(), 
               lr=0.01, 
               momentum=0.9, 
               weight_decay=1e-4)

criterion = nn.CrossEntropyLoss()

scheduler = CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-4)

In [9]:
for epoch in range(30):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/30')
    for X_batch, y_batch in pbar:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f'Loss: {running_loss / len(train_loader)}')

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    y_pred = []
    y_true = []
    with torch.no_grad():
        pbar_val = tqdm(val_loader, desc='Validation')
        for X_batch, y_batch in pbar_val:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == y_batch).sum().item()
            y_pred.extend(predicted.cpu().numpy())
            y_true.extend(y_batch.cpu().numpy())
        
        accuracy = correct / len(val_dataset)
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        f1 = f1_score(y_true, y_pred, average='macro')


        print(f'Validation Loss: {val_loss / len(val_loader)}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}')
    
    scheduler.step()


Epoch 1/30: 100%|██████████| 250/250 [02:16<00:00,  1.84it/s]


Loss: 1.6679214091300965


Validation: 100%|██████████| 32/32 [00:11<00:00,  2.91it/s]


Validation Loss: 1.3599246330559254, Accuracy: 0.5075, Precision: 0.4843, Recall: 0.5074, F1-Score: 0.4704


Epoch 2/30: 100%|██████████| 250/250 [01:36<00:00,  2.60it/s]


Loss: 1.2378449139595031


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.81it/s]


Validation Loss: 1.4758527658414096, Accuracy: 0.4905, Precision: 0.5964, Recall: 0.4902, F1-Score: 0.4400


Epoch 3/30: 100%|██████████| 250/250 [01:36<00:00,  2.60it/s]


Loss: 1.0647089302539825


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.50it/s]


Validation Loss: 0.9984149541705847, Accuracy: 0.6436, Precision: 0.6702, Recall: 0.6437, F1-Score: 0.6380


Epoch 4/30: 100%|██████████| 250/250 [01:36<00:00,  2.60it/s]


Loss: 0.9223783519268036


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.57it/s]


Validation Loss: 1.0751531394198537, Accuracy: 0.6306, Precision: 0.6657, Recall: 0.6308, F1-Score: 0.6253


Epoch 5/30: 100%|██████████| 250/250 [01:36<00:00,  2.59it/s]


Loss: 0.794201960682869


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.89it/s]


Validation Loss: 1.3796155116797308, Accuracy: 0.5986, Precision: 0.6681, Recall: 0.5987, F1-Score: 0.5976


Epoch 6/30: 100%|██████████| 250/250 [01:35<00:00,  2.60it/s]


Loss: 0.7066230158805847


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.62it/s]


Validation Loss: 0.9111773823387921, Accuracy: 0.6857, Precision: 0.7153, Recall: 0.6858, F1-Score: 0.6820


Epoch 7/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.6141626895666122


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.42it/s]


Validation Loss: 1.063080510823056, Accuracy: 0.6797, Precision: 0.7224, Recall: 0.6798, F1-Score: 0.6750


Epoch 8/30: 100%|██████████| 250/250 [01:35<00:00,  2.61it/s]


Loss: 0.5570068153738975


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.75it/s]


Validation Loss: 0.6877329854760319, Accuracy: 0.7638, Precision: 0.7792, Recall: 0.7638, F1-Score: 0.7630


Epoch 9/30: 100%|██████████| 250/250 [01:39<00:00,  2.52it/s]


Loss: 0.500680845439434


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.75it/s]


Validation Loss: 0.7868614930193871, Accuracy: 0.7427, Precision: 0.7836, Recall: 0.7430, F1-Score: 0.7390


Epoch 10/30: 100%|██████████| 250/250 [01:35<00:00,  2.61it/s]


Loss: 0.42808925932645797


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.60it/s]


Validation Loss: 0.7615653612883762, Accuracy: 0.7598, Precision: 0.7918, Recall: 0.7600, F1-Score: 0.7591


Epoch 11/30: 100%|██████████| 250/250 [01:35<00:00,  2.61it/s]


Loss: 0.38020256960391996


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.68it/s]


Validation Loss: 0.8022909048013389, Accuracy: 0.7497, Precision: 0.8015, Recall: 0.7499, F1-Score: 0.7568


Epoch 12/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.3476615046262741


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.90it/s]


Validation Loss: 0.556255278934259, Accuracy: 0.8138, Precision: 0.8226, Recall: 0.8136, F1-Score: 0.8115


Epoch 13/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.29910043981671336


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.65it/s]


Validation Loss: 0.6640895986929536, Accuracy: 0.7978, Precision: 0.8091, Recall: 0.7978, F1-Score: 0.7971


Epoch 14/30: 100%|██████████| 250/250 [01:38<00:00,  2.53it/s]


Loss: 0.277275317043066


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.74it/s]


Validation Loss: 0.4241259920527227, Accuracy: 0.8659, Precision: 0.8671, Recall: 0.8659, F1-Score: 0.8658


Epoch 15/30: 100%|██████████| 250/250 [01:35<00:00,  2.61it/s]


Loss: 0.23621797236800193


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.59it/s]


Validation Loss: 0.49288696935400367, Accuracy: 0.8358, Precision: 0.8442, Recall: 0.8358, F1-Score: 0.8356


Epoch 16/30: 100%|██████████| 250/250 [01:35<00:00,  2.61it/s]


Loss: 0.20071407140791417


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.68it/s]


Validation Loss: 0.48551314079668373, Accuracy: 0.8509, Precision: 0.8592, Recall: 0.8509, F1-Score: 0.8519


Epoch 17/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.17299634033441544


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.79it/s]


Validation Loss: 0.44254217739216983, Accuracy: 0.8549, Precision: 0.8662, Recall: 0.8548, F1-Score: 0.8556


Epoch 18/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.15671872089803218


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.91it/s]


Validation Loss: 0.4887123436492402, Accuracy: 0.8478, Precision: 0.8557, Recall: 0.8479, F1-Score: 0.8460


Epoch 19/30: 100%|██████████| 250/250 [01:36<00:00,  2.60it/s]


Loss: 0.14560644014924765


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.68it/s]


Validation Loss: 0.43688215338625014, Accuracy: 0.8659, Precision: 0.8716, Recall: 0.8659, F1-Score: 0.8668


Epoch 20/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.11680282637476921


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.67it/s]


Validation Loss: 0.39365234185243025, Accuracy: 0.8899, Precision: 0.8934, Recall: 0.8899, F1-Score: 0.8901


Epoch 21/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.1020952968634665


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.79it/s]


Validation Loss: 0.3671807498903945, Accuracy: 0.8879, Precision: 0.8895, Recall: 0.8879, F1-Score: 0.8880


Epoch 22/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.08246282280609012


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.72it/s]


Validation Loss: 0.3959082613000646, Accuracy: 0.8879, Precision: 0.8921, Recall: 0.8879, F1-Score: 0.8890


Epoch 23/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.08155417448282241


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.65it/s]


Validation Loss: 0.3905926216393709, Accuracy: 0.8869, Precision: 0.8900, Recall: 0.8869, F1-Score: 0.8873


Epoch 24/30: 100%|██████████| 250/250 [01:35<00:00,  2.61it/s]


Loss: 0.08047393995523452


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.66it/s]


Validation Loss: 0.3596175365964882, Accuracy: 0.8929, Precision: 0.8946, Recall: 0.8929, F1-Score: 0.8929


Epoch 25/30: 100%|██████████| 250/250 [01:35<00:00,  2.61it/s]


Loss: 0.06490857395529746


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.80it/s]


Validation Loss: 0.3891144836379681, Accuracy: 0.8889, Precision: 0.8905, Recall: 0.8889, F1-Score: 0.8886


Epoch 26/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.059890842294320464


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.65it/s]


Validation Loss: 0.373807315423619, Accuracy: 0.8969, Precision: 0.8987, Recall: 0.8969, F1-Score: 0.8967


Epoch 27/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.05686163594480604


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.85it/s]


Validation Loss: 0.367835360346362, Accuracy: 0.8959, Precision: 0.8981, Recall: 0.8959, F1-Score: 0.8962


Epoch 28/30: 100%|██████████| 250/250 [01:35<00:00,  2.61it/s]


Loss: 0.05689476793259382


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.77it/s]


Validation Loss: 0.3619605773128569, Accuracy: 0.8989, Precision: 0.8993, Recall: 0.8989, F1-Score: 0.8987


Epoch 29/30: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Loss: 0.05522747428063303


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.68it/s]


Validation Loss: 0.3697805123811122, Accuracy: 0.8979, Precision: 0.8993, Recall: 0.8979, F1-Score: 0.8979


Epoch 30/30: 100%|██████████| 250/250 [01:35<00:00,  2.63it/s]


Loss: 0.05534874832630157


Validation: 100%|██████████| 32/32 [00:05<00:00,  5.74it/s]

Validation Loss: 0.3683874978451058, Accuracy: 0.8989, Precision: 0.9000, Recall: 0.8989, F1-Score: 0.8988





In [10]:
model.eval()
test_loss = 0.0
correct = 0
with torch.no_grad():
    pbar_test = tqdm(test_loader, desc='Test')
    for X_batch, y_batch in pbar_test:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y_batch).sum().item()
        if pbar_test.n > 0:
            pbar_test.set_postfix({'Loss': f'{test_loss / (pbar_test.n + 1):.4f}', 'Acc': f'{correct / len(test_dataset):.4f}'})
        else:
            pbar_test.set_postfix({'Loss': f'{test_loss:.4f}', 'Acc': f'N/A'})
test_accuracy = correct / len(test_dataset)
print(f'Test Loss: {test_loss / len(test_loader)}, Accuracy: {test_accuracy:.4f}')


Test: 100%|██████████| 32/32 [00:11<00:00,  2.87it/s, Loss=0.3033, Acc=0.9079]

Test Loss: 0.3032660461612977, Accuracy: 0.9079



