This module utilizes the following dataset:
https://www.kaggle.com/datasets/undefinenull/million-song-dataset-spotify-lastfm

It can either be imported using the code block below or downloaded manually as a .zip file. If you use the code below, make sure to run `pip install kagglehub` first.

Almost all code here was generated using AI, but it was a collaborative process
and required a TON of debugging.

In [25]:
import kagglehub

# Download latest version
# path = kagglehub.dataset_download("undefinenull/million-song-dataset-spotify-lastfm")

# print("Path to dataset files:", path)

Go to the specified path above and copy 'Music Info.csv' from that folder to the same directory as this Python file. Rename 'Music Info.csv' to 'music.csv'.

In [26]:
#Import modules
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

This section imports the data and creates the co-occurrence matrix used to create the NetworkX graph below.

In [None]:
# Load the datasets
music_df = pd.read_csv('music.csv')

# Merge the DataFrames on 'track_id'
# Preprocess the tag data

music_df['tags'] = music_df['tags'].astype(str).str.lower().str.strip()
tag_lists = music_df['tags'].str.split(', ')
unique_tags = list(set(tag for tags in tag_lists for tag in tags))

unique_tags = [tag for tag in unique_tags if tag != 'nan']

# Build the tag co-occurrence matrix, using the music dataframe
tag_index = {tag: i for i, tag in enumerate(unique_tags)}
tag_cooccurrence_matrix = [[0] * len(unique_tags) for _ in range(len(unique_tags))]

for tags in tag_lists:
    for i, tag1 in enumerate(tags):
        for j, tag2 in enumerate(tags):
            if i != j:
                tag_cooccurrence_matrix[tag_index[tag1]][tag_index[tag2]] += 1

# Convert the co-occurrence matrix to a DataFrame
cooccurrence_df = pd.DataFrame(tag_cooccurrence_matrix, index=unique_tags, columns=unique_tags)

# Convert to long format for easier use with NetworkX
cooccurrence_df = cooccurrence_df.reset_index().rename(columns={'index': 'tag1'})
cooccurrence_df_long = cooccurrence_df.melt(id_vars='tag1', var_name='tag2', value_name='weight')

# Filter out self-loops and zero-weight edges
cooccurrence_df_filtered = cooccurrence_df_long[
    (cooccurrence_df_long['tag1'] != cooccurrence_df_long['tag2']) & (cooccurrence_df_long['weight'] > 0)
    & (cooccurrence_df_long['tag1'] < cooccurrence_df_long['tag2'])  # Keep only one direction
].copy() # Make a copy to avoid modifying original

# Create a graph from the filtered co-occurrence data
G = nx.from_pandas_edgelist(cooccurrence_df_filtered, 'tag1', 'tag2', edge_attr='weight')

The section below outputs a list of available tags that you can choose. You will have to input one of these tags in the next section.

In [28]:
# Get user input for the central tag
print("Available tags:")
for i, tag in enumerate(unique_tags):
    print(f"{i + 1}. {tag}")

Available tags:
1. screamo
2. german
3. post_rock
4. indie
5. japanese
6. heavy_metal
7. house
8. pop_rock
9. new_wave
10. russian
11. post_hardcore
12. instrumental
13. chillout
14. hard_rock
15. synthpop
16. alternative
17. ambient
18. soundtrack
19. folk
20. 80s
21. progressive_metal
22. avant_garde
23. country
24. punk_rock
25. downtempo
26. guitar
27. soul
28. metal
29. psychedelic_rock
30. metalcore
31. techno
32. britpop
33. classical
34. j_pop
35. cover
36. rnb
37. noise
38. british
39. indie_rock
40. nu_metal
41. dark_ambient
42. trance
43. electronic
44. emo
45. doom_metal
46. classic_rock
47. new_age
48. 00s
49. punk
50. french
51. experimental
52. swedish
53. dance
54. jazz
55. thrash_metal
56. alternative_rock
57. american
58. rock
59. death_metal
60. mellow
61. beautiful
62. hardcore
63. piano
64. gothic
65. blues
66. reggae
67. progressive_rock
68. polish
69. acoustic
70. grunge
71. black_metal
72. pop
73. rap
74. industrial
75. 60s
76. lounge
77. symphonic_metal
78. lov

Choose a tag from the list above and input it. A chart will appear below showing the most similar tags to the one you selected. Higher weights have a higher similarlity.

In [29]:
while True:
    try:
        central_tag_index = int(input("Enter the number of the central tag: ")) - 1
        if 0 <= central_tag_index < len(unique_tags):
            central_tag = unique_tags[central_tag_index]
            break
        else:
            print("Invalid tag number. Please enter a number from the list.")
    except ValueError:
        print("Invalid input. Please enter a number.")

# Find the closest tags to the central tag
central_tag_connections = cooccurrence_df_filtered[
    (cooccurrence_df_filtered['tag1'] == central_tag) | (cooccurrence_df_filtered['tag2'] == central_tag)
    ].copy() # Make a copy
central_tag_connections = central_tag_connections.sort_values(by='weight', ascending=False)

print(f"\nClosest tags to '{central_tag}':")
print(central_tag_connections.to_markdown(index=False, numalign="left", stralign="left"))


Closest tags to 'techno':
| tag1                | tag2     | weight   |
|:--------------------|:---------|:---------|
| electronic          | techno   | 192395   |
| electro             | techno   | 131735   |
| dance               | techno   | 122482   |
| house               | techno   | 106789   |
| french              | techno   | 65665    |
| techno              | trance   | 40038    |
| 90s                 | techno   | 21240    |
| idm                 | techno   | 19215    |
| ambient             | techno   | 19135    |
| pop                 | techno   | 16310    |
| chillout            | techno   | 14400    |
| industrial          | techno   | 14390    |
| downtempo           | techno   | 13982    |
| instrumental        | techno   | 13735    |
| techno              | trip_hop | 11624    |
| psychedelic         | techno   | 11190    |
| german              | techno   | 9361     |
| hardcore            | techno   | 8198     |
| metalcore           | techno   | 7935     |
| rock 

The render time for the code below is probably going to be in excess of 15 seconds on your computer, or potentially longer.

The output is a NetworkX graph image, which shows how interconnected an individual node (tag) is to other nodes. With some exceptions, the close a node (in blue) is to the chosen node (in red), the higher the similarity between the two genres. The graph is also saved to '{central_tag}_graph.png' for easier viewing.

In [None]:
# Create a subgraph containing only the central tag and its neighbors
neighbors = list(G.neighbors(central_tag))
neighbors.append(central_tag)  # Include the central tag itself
neighbors_subgraph = G.subgraph(neighbors).copy() # Make a copy

# --- Calculate weights relative to the central tag ---
central_weights = {}
for n in neighbors_subgraph.nodes():
        if n != central_tag:
                try:
                        central_weights[n] = G[central_tag][n]['weight']
                except KeyError:
                        central_weights[n] = 0  # handle disconnected nodes

# Normalize the weights relative to the central tag, inverting it
max_central_weight = max(central_weights.values()) if central_weights else 1  # Avoid division by zero
normalized_central_weights = {
        n: 1 - (central_weights[n] / max_central_weight) for n in central_weights
}

# Use the normalized weights to position the nodes
plt.figure(figsize=(95, 95))
pos = nx.spring_layout(neighbors_subgraph, k=0.8, weight='weight') # Use original weight
for node in pos:
        if node != central_tag and node in normalized_central_weights:
        # Adjust position of neighbor nodes based on normalized weight.
                pos[node] = (
                        pos[central_tag][0] + (pos[node][0] - pos[central_tag][0]) * normalized_central_weights[node],
                        pos[central_tag][1] + (pos[node][1] - pos[central_tag][1]) * normalized_central_weights[node],
                )

node_colors = ['red' if node == central_tag else 'blue' for node in neighbors_subgraph.nodes()]
nx.draw(neighbors_subgraph, pos, with_labels=True, node_size=10000, node_color=node_colors, font_size=10,
        font_weight='bold')
plt.title(f"Tag Network around '{central_tag.capitalize()}'", fontsize=220)
plt.savefig(f'{central_tag}_graph.png', dpi='figure')
plt.show()
print(f'Saved graph to {central_tag}_graph.png')

This section, which will take in excess of 15 seconds to run, prints out a graph of all tag relationships that can be used to see a more general comparison between genres. From an anecdotal standpoint, it looks right to me to be honest. The image gets saved as all_tag_graph.png if you want to zoom in more easily.

In [None]:
# ---  Graph of all relationships (for article) ---
plt.figure(figsize=(95, 95))
pos = nx.spring_layout(G, k=0.6, weight='weight')
nx.draw(G, pos, with_labels=True, node_size=10000, node_color='skyblue', font_size=10,
        font_weight='bold')
plt.title("All Tag Relationships", fontsize=220)
plt.savefig('all_tag_graph.png', dpi='figure')
plt.show()
print(f'Saved graph to all_tag_graph.png')