<a href="https://colab.research.google.com/github/OliBomby/CM3P/blob/main/colab/CM3P_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Beatmap Embedding Extraction with CM3P

This notebook lets you interactively generate a beatmap embedding dataset to visualize in [the osu! Map Explorer](https://olibomby.github.io/CM3P/).

### Instructions for running:

* __Configure__: `add_to_dataset` lets you add your embeddings to an existing dataset on Hugging Face Hub.
* __Execute the code cell__. Press ▶️ on the left the cell to execute it.
* __Upload Beatmap Files__: Choose one or more beatmap files to extract embeddings for. You can upload .osu, .mp3, .ogg, .osz, or a .zip archive containing beatmaps.
* The script will extract the embeddings and download a new parquet file to your PC with the embedding dataset.


In [None]:
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
import zipfile
from google.colab import files

cwd = Path(os.getcwd())
if cwd.name != 'CM3P':
    print("Installing CM3P")
    !git clone https://github.com/OliBomby/CM3P.git
    os.chdir('CM3P')
    !pip install slider git+https://github.com/OliBomby/slider.git@gedagedigedagedaoh
    !pip install hydra-core
else:
    print("CM3P already installed")

# Create the beatmaps directory if it doesn't exist
beatmaps_dir = Path('./beatmaps')
beatmaps_dir.mkdir(parents=True, exist_ok=True)

print("Please upload your beatmap files (.osu, .mp3, .ogg, .osz, or .zip archive containing beatmaps).")
uploaded = files.upload()

for fn, data in uploaded.items():
    file_path = beatmaps_dir / fn
    if fn.endswith('.osu'):
        with open(file_path, 'wb') as f:
            f.write(data)
        print(f"Processed '{fn}': Moved to '{beatmaps_dir}'.")
    elif fn.endswith('.osz') or fn.endswith('.mp3') or fn.endswith('.ogg'):
        with open(file_path, 'wb') as f:
            f.write(data)
        print(f"Processed '{fn}': Moved to '{beatmaps_dir}'.")
    elif fn.endswith('.zip'):
        # Write the zip file temporarily, then extract
        temp_zip_path = Path(f'./{fn}')
        with open(temp_zip_path, 'wb') as f:
            f.write(data)

        with zipfile.ZipFile(temp_zip_path, 'r') as zip_ref:
            zip_ref.extractall(beatmaps_dir)
        print(f"Processed '{fn}': Extracted contents to '{beatmaps_dir}'.")
        temp_zip_path.unlink() # Remove the temporary zip file
    else:
        print(f"Skipped '{fn}': Unrecognized file type. Please upload .osu, .mp3, .ogg, .osz, or .zip files.")

if any(fn.endswith('.osu') for fn in uploaded) and not any(fn.endswith('.mp3') or fn.endswith('.ogg') for fn in uploaded):
    print("Please upload your audio files (.mp3 or .ogg).")
    uploaded = files.upload()

    for fn, data in uploaded.items():
        file_path = beatmaps_dir / fn
        if fn.endswith('.mp3') or fn.endswith('.ogg'):
            with open(file_path, 'wb') as f:
                f.write(data)
            print(f"Processed '{fn}': Moved to '{beatmaps_dir}'.")
        else:
            print(f"Skipped '{fn}': Unrecognized file type. Please upload .mp3 or .ogg files.")

print("File upload and processing complete.")

add_to_dataset = True # @param {type:"boolean"}
repo_id = 'OliBomby/CM3P-Embeddings-244K' # @param {type:"string"}
file_name = 'beatmap_embeddings.parquet'
output_file = 'beatmap_embeddings_new.parquet'

if add_to_dataset:
    if not Path(file_name).exists():
        print(f"'{file_name}' not found locally. Downloading from Hugging Face...")
        hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file_name, local_dir='.')
        print(f"'{file_name}' downloaded successfully.")
    else:
        print(f"'{file_name}' already exists locally.")

print("Starting embedding extraction...")
if add_to_dataset:
    !python extract_beatmap_embeddings.py --beatmap-paths ./beatmaps --merge-with {file_name} --output {output_file}
else:
    !python extract_beatmap_embeddings.py --beatmap-paths ./beatmaps --output {output_file}

if Path(output_file).exists():
    print(f"Downloading '{output_file}'...")
    files.download(output_file)
    print("Download complete.")
else:
    print(f"Error: '{output_file}' not found. Please ensure the embedding extraction process ran successfully.")
