In [2]:
import os
import tempfile

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import librosa
import torch
import mlflow.pytorch
from torchvision import transforms

In [3]:
from dotenv import load_dotenv
load_dotenv()

DEFAULT_SETTINGS = {
    "minio_endpoint": os.getenv("REMOTE_MLFLOW_STORAGE_URI"),
    "minio_music_net_bucket_name": os.getenv("REMOTE_MLFLOW_BUCKET_NAME"),
    "minio_access_key": os.getenv("AWS_ACCESS_KEY_ID"),
    "minio_secret_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
}

In [4]:

MAPPING_DICT_MUSIC_NET = {
    'blues': 0, 'chanson': 1, 'classical': 2, 'country': 3, 'dance': 4, 'dub': 5,
    'electro': 6, 'folk': 7, 'funk': 8, 'hard rock': 9, 'hip-hop': 10, 'house': 11,
    'jazz': 12, 'metal': 13, 'pop': 14, 'rap': 15, 'reggae': 16, 'rock': 17,
}


def create_preprocessed_spectrogram(audio_path, sr=22050, n_mels=128, fmax=8000, img_size=(224, 224), start_time=20, segment_duration=20):
    try:
        # Load the audio file
        y, sr = librosa.load(audio_path, sr=sr, offset=start_time, duration=segment_duration)
        
        # Generate the spectrogram
        S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, fmax=fmax)
        S_DB = librosa.power_to_db(S, ref=np.max)
        
        # Plot the spectrogram
        plt.figure(figsize=(10, 4))
        plt.axis('off')
        librosa.display.specshow(S_DB, sr=sr, x_axis=None, y_axis=None, fmax=fmax)
        
        # Save the plot to a temporary file
        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmpfile:
            plt.savefig(tmpfile.name, bbox_inches='tight', pad_inches=0)
            plt.close()
            
            # Open the image and resize it
            img = Image.open(tmpfile.name).convert('RGB')  # Convert to RGB
            img = img.resize(img_size, Image.Resampling.LANCZOS)
            os.remove(tmpfile.name)
        
        # Transform the image to tensor
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        img_tensor = transform(img).unsqueeze(0)  # Add batch dimension
        
        return img_tensor
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None
    

def get_production_model():
    # Load the model using mlflow
    minio_url = f"s3://{DEFAULT_SETTINGS['minio_music_net_bucket_name']}/data/"

    os.environ["AWS_ACCESS_KEY_ID"] = DEFAULT_SETTINGS['minio_access_key']
    os.environ["AWS_SECRET_ACCESS_KEY"] = DEFAULT_SETTINGS['minio_secret_key']
    os.environ["MLFLOW_S3_ENDPOINT_URL"] = f"https://{DEFAULT_SETTINGS['minio_endpoint']}"

    map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return mlflow.pytorch.load_model(minio_url, map_location=map_location)


def predicted_item_to_class_name(predicted_item):
    class_names = ['class1', 'class2', 'class3', 'class4', 'class5']
    return class_names[predicted_item]


def predict_with_production_music_net(model, img_tensor):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    img_tensor = img_tensor.to(device)
    model.to(device)
    model.eval()
    with torch.no_grad():
        output = model(img_tensor)
        _, predicted = torch.max(output, 1)
    
    idx_to_class = {v: k for k, v in MAPPING_DICT_MUSIC_NET.items()}
    predicted_class_name = idx_to_class[predicted.item()]

    return predicted_class_name



In [5]:
def predict_genre(audio_path: str):
    """
    Predicts the genre of a music segment using a pre-trained MusicNet model.

    - **audio_path**: str - The path to the audio file.
    - **start_time**: int - The start time of the segment.
    - **segment_duration**: int - The duration of the segment.
    - **return**: dict - A dictionary containing the predicted genre and the corresponding probability.
    """
    genre = None  # Initialize the genre variable

    try:
        # Load the production model
        model = get_production_model()
        if model is None:
            print("Failed to load the production model.")
    except Exception as e:
        print(f"Error loading the production model: {e}")

    try:
        # Create a preprocessed spectrogram
        img_tensor = create_preprocessed_spectrogram(audio_path)
        if img_tensor is None:
            print("Failed to create the preprocessed spectrogram.")
    except Exception as e:
        print(f"Error creating the preprocessed spectrogram: {e}")

    try:
        # Predict the genre
        genre = predict_with_production_music_net(model, img_tensor)
    except Exception as e:
        print(f"Error predicting the genre: {e}")

    return {"genre": genre}

In [7]:
audio_path = "/home/kin/Documents/music_similarity/preprocessing/MegaSet/Parov Stelar/Parov Stelar - That Swing (2009)/14. A Night In Torino.mp3"
predict_genre(audio_path)

Downloading artifacts: 100%|██████████| 10/10 [00:00<00:00, 241.73it/s] 


{'genre': 'dub'}

In [None]:
### Create a new user and set it as admin, then delete the default admin user

# import os
# import time
# import logging

# import mlflow
# from mlflow.server.auth.client import AuthServiceClient

# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

# # Wait for the MLflow server to start
# time.sleep(20)

# MLFLOW_TRACKIN_FULL_URL = os.getenv("MLFLOW_TRACKIN_FULL_URL")
# NEW_USER_USERNAME = os.getenv("NEW_USER_USERNAME")
# NEW_USER_PASSWORD = os.getenv("NEW_USER_PASSWORD")

# # Set the tracking URI
# mlflow.set_tracking_uri(MLFLOW_TRACKIN_FULL_URL)

# # Set environment variables for authentication
# os.environ['MLFLOW_TRACKING_USERNAME'] = "admin"
# os.environ['MLFLOW_TRACKING_PASSWORD'] = "password"

# # Initialize the AuthServiceClient with admin credentials
# client = AuthServiceClient(MLFLOW_TRACKIN_FULL_URL)

# # Create a new user
# try:
#     client.create_user(NEW_USER_USERNAME, NEW_USER_PASSWORD)
#     logger.info("User created successfully.")
# except Exception as e:
#     logger.error(f"Failed to create user: {e}")

# # Set the new user as admin
# try:
#     client.update_user_admin(NEW_USER_USERNAME, True)
#     logger.info("User set as admin successfully.")
# except Exception as e:
#     logger.error(f"Failed to set user as admin: {e}")

# # Delete the default admin user
# try:
#     client.delete_user("admin")
#     logger.info("Default admin user deleted successfully.")
# except Exception as e:
#     logger.error(f"Failed to delete default admin user: {e}")