In [None]:
import xgboost as xgb
import numpy as np
import ipywidgets as widgets
from IPython.display import display, HTML


model_path = 'xgboost_model.json'
model = xgb.Booster()
model.load_model(model_path)

min_max_values = {
    'popularity': (0, 100),
    'acousticness': (0, 1),
    'danceability': (0, 1),
    'duration_ms': (1000, 900000),  # Between 1 second and 15 minutes
    'energy': (0, 1),
    'instrumentalness': (0, 1),
    'liveness': (0, 1),
    'loudness': (-30, 0),
    'speechiness': (0, 1),
    'tempo': (50, 300),
    'valence': (0, 1)
}


def normalize(value, min_val, max_val):
    return (value - min_val) / (max_val - min_val)


key_map = {'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5, 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11}
mode_map = {'Major': 1, 'Minor': 0}


genre_map = {
    0: "Hip-Hop",
    1: "Anime",
    2: "Blues",
    3: "Classical",
    4: "Country",
    5: "Electronic",
    6: "Jazz",
    7: "Rock"
}

style = {'description_width': 'initial'}
layout = widgets.Layout(width='400px', transition='all 0.3s ease')

popularity = widgets.FloatSlider(min=0, max=100, description='🎶 Popularity', value=50, style=style, layout=layout, continuous_update=False)
acousticness = widgets.FloatSlider(min=0, max=1, description='🎧 Acousticness', value=0.5, style=style, layout=layout, continuous_update=False)
danceability = widgets.FloatSlider(min=0, max=1, description='💃 Danceability', value=0.5, style=style, layout=layout, continuous_update=False)
duration_ms = widgets.FloatSlider(min=1000, max=900000, description='⏲️ Duration (ms)', value=250000, style=style, layout=layout, continuous_update=False)
energy = widgets.FloatSlider(min=0, max=1, description='⚡ Energy', value=0.5, style=style, layout=layout, continuous_update=False)
instrumentalness = widgets.FloatSlider(min=0, max=1, description='🎻 Instrumentalness', value=0.5, style=style, layout=layout, continuous_update=False)
liveness = widgets.FloatSlider(min=0, max=1, description='🎤 Liveness', value=0.5, style=style, layout=layout, continuous_update=False)
loudness = widgets.FloatSlider(min=-60, max=0, description='🔊 Loudness', value=-30, style=style, layout=layout, continuous_update=False)
speechiness = widgets.FloatSlider(min=0, max=1, description='🗣️ Speechiness', value=0.5, style=style, layout=layout, continuous_update=False)
tempo = widgets.FloatSlider(min=50, max=250, description='🎼 Tempo', value=120, style=style, layout=layout, continuous_update=False)
valence = widgets.FloatSlider(min=0, max=1, description='😊 Valence', value=0.5, style=style, layout=layout, continuous_update=False)

key = widgets.Dropdown(options=list(key_map.keys()), description='🎹 Key', value='C', style=style, layout=layout)
mode = widgets.Dropdown(options=list(mode_map.keys()), description='🎶 Mode', value='Major', style=style, layout=layout)

button = widgets.Button(description='🎯 Predict Genre', button_style='info',
                        layout=widgets.Layout(width='200px', height='40px', border_radius='12px', font_size='16px'))
output = widgets.Output()

def on_button_clicked(b):
    with output:
        output.clear_output()

        input_data = {
            'popularity': normalize(popularity.value, *min_max_values['popularity']),
            'acousticness': normalize(acousticness.value, *min_max_values['acousticness']),
            'danceability': normalize(danceability.value, *min_max_values['danceability']),
            'duration_ms': normalize(duration_ms.value, *min_max_values['duration_ms']),
            'energy': normalize(energy.value, *min_max_values['energy']),
            'instrumentalness': normalize(instrumentalness.value, *min_max_values['instrumentalness']),
            'liveness': normalize(liveness.value, *min_max_values['liveness']),
            'loudness': normalize(loudness.value, *min_max_values['loudness']),
            'speechiness': normalize(speechiness.value, *min_max_values['speechiness']),
            'tempo': normalize(tempo.value, *min_max_values['tempo']),
            'valence': normalize(valence.value, *min_max_values['valence']),
            'key': key_map[key.value],
            'mode': mode_map[mode.value]
        }


        feature_order = ['popularity', 'acousticness', 'danceability', 'duration_ms', 'energy',
                         'instrumentalness', 'key', 'liveness', 'loudness', 'mode',
                         'speechiness', 'tempo', 'valence']
        input_vector = np.array([input_data[feature] for feature in feature_order]).reshape(1, -1)


        dmatrix = xgb.DMatrix(input_vector, feature_names=feature_order)


        prediction = model.predict(dmatrix)
        predicted_genre = genre_map[np.argmax(prediction)]

        display(HTML(f"""
            <div style='background-color:#e0f7fa; border-radius:15px; padding:20px;
                        box-shadow: 0px 4px 6px rgba(0,0,0,0.1); text-align:center;'>
                <h2 style='color:#00796b; font-family: Arial, sans-serif;'>Predicted Genre:
                    <span style='color:#d32f2f;'>{predicted_genre}</span></h2>
            </div>
        """))

button.on_click(on_button_clicked)

ui = widgets.VBox([
    widgets.HTML("""
        <div style='background-color:#f3e5f5; padding:20px; border-radius:15px;'>
            <h1 style='color:#673ab7; text-align:center; font-family:Verdana;'>Music Genre Predictor 🎶</h1>
            <p style='text-align:center; color:#6a1b9a;'>Insert the audio parameters to predict the genre</p>
        </div>
    """),
    popularity, acousticness, danceability, duration_ms, energy, instrumentalness,
    liveness, loudness, speechiness, tempo, valence, key, mode,
    widgets.HTML("<hr style='border: 1px solid #ccc;'/>"),
    widgets.HBox([button], layout=widgets.Layout(justify_content='center')),
    output
])

display(ui)


VBox(children=(HTML(value="\n        <div style='background-color:#f3e5f5; padding:20px; border-radius:15px;'>…