In [1]:
import pandas as pd
import os
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from sklearn.preprocessing import MultiLabelBinarizer
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image, ImageReadMode
import torchvision
from sklearn.metrics import label_ranking_average_precision_score
from PIL import Image
import torchvision.transforms as transforms

In [9]:
def create_convnext_tiny(out_features, device):
    weights = torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT
    model = torchvision.models.convnext_tiny(weights=weights).to(device)

    for param in model.features.parameters():
        param.requires_grad = False

    model.classifier = nn.Sequential(
        torchvision.models.convnext.LayerNorm2d((768,), eps=1e-06, elementwise_affine=True),
        nn.Flatten(start_dim=1, end_dim=-1),
        nn.Linear(in_features=768, out_features=out_features, bias=True)
    ).to(device)
    
    model.name = 'ConvNeXt Tiny'
    print(f'[INFO] Created new {model.name} model.')

    return model, weights

In [10]:
def audio_to_spectogram(filepath):
    y, sr = librosa.load(f'{filepath}', sr=None)
    # mel = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2560, hop_length=512, n_mels=128, fmin=20, fmax=22050)
    mel = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2560, hop_length=512, n_mels=256, fmin=20, fmax=16000)

    mel_db = librosa.power_to_db(mel, ref=np.max)

    fig, ax = plt.subplots()
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.set_frame_on(False)
    librosa.display.specshow(mel_db, sr=sr)
    plt.savefig(f"./test/test.png", dpi=400, bbox_inches='tight', pad_inches=0)
    plt.close('all')

In [11]:
test_file_path = './tracks/1539_3DivHejgz3x3dn1FkjHjWJ.mp3'
audio_to_spectogram(test_file_path)

In [12]:
model_path = './models/model0_5epochs_feb3.pth'
loaded_checkpoint = torch.load(model_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model, _ = create_convnext_tiny(out_features=172, device=device)
model.load_state_dict(loaded_checkpoint['model_state_dict'])
model.to(device)
model.eval()

[INFO] Created new ConvNeXt Tiny model.


ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

In [13]:
image_path = './test/test.png'
spectrogram_image = Image.open(image_path).convert('RGB')

height, width = 1478, 1984

# Apply the same transformations used during training
transform = transforms.Compose([
    transforms.Resize((height, width)),
    transforms.ToTensor(),
])

input_tensor = transform(spectrogram_image).unsqueeze(0).to(device)

# Perform Inference
with torch.no_grad():
    output = model(input_tensor)

prediction = torch.argmax(output, dim=1).item()
print(prediction)
print(type(prediction))

1
<class 'int'>


In [14]:
class_labels = [
    'acerbic', 'aggressive', 'angry', 'angst-ridden', 'anxious', 'apocalyptic',
    'atmospheric', 'autumnal', 'belligerent', 'bitter', 'bittersweet', 'bleak',
    'bombastic', 'brash', 'bright', 'brooding', 'calm', 'carefree', 'cathartic',
    'cerebral', 'cheerful', 'cold', 'complex', 'confident', 'confrontational',
    'crunchy', 'cynical', 'dark', 'defiant', 'delicate', 'demonic', 'desperate',
    'detached', 'dramatic', 'driving', 'druggy', 'eccentric', 'ecstatic', 'eerie',
    'elegant', 'energetic', 'epic', 'erotic', 'ethereal', 'exciting', 'exotic',
    'explosive', 'fierce', 'fiery', 'flashy', 'flowing', 'freewheeling', 'fun',
    'gentle', 'gleeful', 'gloomy', 'grim', 'gritty', 'gutsy', 'halloween', 'happy',
    'harsh', 'hostile', 'humorous', 'hyper', 'hypnotic', 'intense', 'introspective',
    'ironic', 'irreverent', 'joyous', 'laid-back', 'light', 'literate', 'lively',
    'lonely', 'lush', 'lyrical', 'malevolent', 'manic', 'martial', 'melancholy',
    'mellow', 'menacing', 'messy', 'monumental', 'mysterious', 'mystical',
    'narrative', 'negative', 'nervous', 'nihilistic', 'nocturnal', 'nostalgic',
    'ominous', 'optimistic', 'organic', 'outrageous', 'paranoid', 'passionate',
    'peaceful', 'playful', 'poignant', 'positive', 'powerful', 'precious',
    'provocative', 'pure', 'quirky', 'rambunctious', 'ramshackle', 'raucous',
    'rebellious', 'reckless', 'reflective', 'relaxed', 'rollicking', 'romantic',
    'rousing', 'rowdy', 'sad', 'sarcastic', 'sardonic', 'savage', 'scary',
    'scarymusic', 'searching', 'sensual', 'sentimental', 'sexual', 'sexy', 'silly',
    'sleazy', 'smooth', 'soft', 'somber', 'soothing', 'sophisticated', 'spacey',
    'sparkling', 'spiritual', 'spooky', 'sprawling', 'springlike', 'street-smart',
    'strong', 'stylish', 'summery', 'sweet', 'technical', 'tense', 'theatrical',
    'thoughtful', 'threatening', 'thrilling', 'thuggish', 'tragic', 'trashy',
    'trippy', 'triumphant', 'uncompromising', 'unsettling', 'uplifting', 'urgent',
    'visceral', 'volatile', 'warm', 'whimsical', 'wistful', 'witty', 'wry', 'yearning'
 ]

In [15]:
predicted_label = class_labels[prediction]
print(f'The model predicts class: {predicted_label}')

The model predicts class: aggressive
