In [None]:
#scan a directory for audio files
#add it to a sql database
#calculate a vector using ntu-spml/distilhubert and store it in another vector database alongside its id in the sql database

In [None]:
import os
import sqlite3
from dataclasses import dataclass
import mutagen
from mutagen.mp3 import MP3
from mutagen.flac import FLAC
from mutagen.id3 import ID3, TIT2, TPE1, TALB, TCON, TDRC, TRCK, TPOS, APIC, error

In [None]:
MUSIC_DIR = "D:/Music"

In [None]:
def is_music_file(file):
    file_endings = [".mp3",".wav",".flac",".ogg"]
    for ending in file_endings:
        if file.endswith(ending):
            return True
    return False

In [None]:
music_files = []

for root, dirs, files in os.walk(MUSIC_DIR):
    for file in files:
        if is_music_file(file):
            music_files.append(os.path.join(root,file))

In [None]:
@dataclass
class MusicFile:
    name:str
    abs_path:str
    bpm:int
    length:int #in seconds
    kbps:int
    genre:str
    artist:str
    album:str
    album_art:bytes
    
    def __str__(self) -> str:
        return f"{self.name} by {self.artist} from {self.album} ({self.genre})"

In [None]:
def get_music_file_info(file_path):
    file_name = os.path.basename(file_path)
    if file_name.endswith(".mp3"):
        audio = MP3(file_path)
    elif file_name.endswith(".flac"):
        audio = FLAC(file_path)
    
    try:
        bpm = audio["TBPM"][0]
    except KeyError:
        bpm = 0
        
    name = audio["TIT2"][0] if "TIT2" in audio else file_name
    length = int(audio.info.length)
    kbps = int(audio.info.bitrate/1000)
    
    try:
        genre = str(audio["TCON"][0]) #TCON(encoding=<Encoding.UTF16: 1>, text=['Dance'])
        #print(genre)
    except KeyError:
        genre = "Unknown"
    except IndexError:
        genre = "Unknown"
    
    artist = audio["TPE1"][0] if "TPE1" in audio else "Unknown"
    album = audio["TALB"][0] if "TALB" in audio else "Unknown"
    try:
        album_art = audio["APIC:"].data
    except KeyError:
        album_art = b"" #empty bytes
    
    return MusicFile(name, file_path, bpm, length, kbps, genre, artist, album, album_art)

In [None]:
songs = []

for idx, file in enumerate(music_files):
    songs.append(get_music_file_info(file))
    print(idx, len(music_files), end="\r")

print(songs[0])

In [None]:
#save to database
conn = sqlite3.connect("music.db")
c = conn.cursor()

c.execute("DROP TABLE IF EXISTS songs")

cmd = """CREATE TABLE IF NOT EXISTS songs (
    id INTEGER PRIMARY KEY,
    name TEXT,
    abs_path TEXT,
    bpm INTEGER,
    length INTEGER,
    kbps INTEGER,
    genre TEXT,
    artist TEXT,
    album TEXT,
    album_art BLOB
);"""

c.execute(cmd)
conn.commit()

In [None]:
error_songs = []
for idx, song in enumerate(songs):
    cmd = """INSERT INTO songs (name, abs_path, bpm, length, kbps, genre, artist, album, album_art) VALUES (?,?,?,?,?,?,?,?,?);"""
    try:
        c.execute(cmd, (song.name, song.abs_path, song.bpm, song.length, song.kbps, song.genre, song.artist, song.album, song.album_art))
        conn.commit()
    except sqlite3.InterfaceError:
        error_songs.append(song)
        break

if len(error_songs) > 0:
    print("Error songs:")
    for song in error_songs:
        print(song)
else:
    print("No errors")

In [None]:
#calculate vectors
from qdrant_client import QdrantClient
from qdrant_client.http.models import VectorParams, Distance, PointStruct

#ai model
import openl3
import soundfile as sf

import json

client = QdrantClient(host="localhost", port=6333)

#create collection
try:
    client.create_collection(collection_name="songs", vectors_config=VectorParams(size=6144,distance=Distance.COSINE))
except:
    print("collection already exists")




def get_vector(file_path) -> list[list[float]]:
    #calculate vector representation of audio file
    
    audio, sampling_rate = sf.read(file_path)
    embeddings, timestamps = openl3.get_audio_embedding(audio, sr=sampling_rate)
    
    #print(embeddings.shape)
    
    return embeddings.tolist()
    

def get_next_id() -> int:
    #get next id from qdrant
    #get current count of vectors
    #return count + 1
    operation = client.count(collection_name="songs")
    return operation.count + 1

def get_audio_obj_from_id(id) -> MusicFile:
    #get audio object from id from sql database
    conn = sqlite3.connect("music.db")
    c = conn.cursor()
    c.execute("SELECT * FROM songs WHERE id=?", (id,))
    result = c.fetchone()
    data = result[1:] #remove id
    return MusicFile(*data)

def add_audio_to_vector_db(id):
    audio = get_audio_obj_from_id(id)
    vector = get_vector(audio.abs_path)
    
    points = []
    for idx, vec in enumerate(vector):
        payload = audio.__dict__
        
        try:
            payload = json.loads(json.dumps(payload, default=lambda o: o.__dict__))
        except:
            #bytes of album art can't be serialized
            payload["album_art"] = ""
        
        pointruct = PointStruct(id = get_next_id() + idx, vector = vec, payload = payload)
        points.append(pointruct)
    
    max_vectors_per_request = 32
    
    for i in range(0, len(points), max_vectors_per_request):
        vectors = points[i:i+max_vectors_per_request]
        operation = client.upsert(
            collection_name="songs",
            wait=True,
            points = vectors
        )
        print(operation)



In [None]:
audio = get_audio_obj_from_id(1)
fp = audio.abs_path
print(fp)
print(get_vector(fp))

In [None]:
conn = sqlite3.connect("music.db")
c = conn.cursor()

cmd = "SELECT MAX(id) FROM songs"
c.execute(cmd)
num_entries = int(c.fetchone()[0])

In [14]:
for i in range(1, num_entries+1):
    add_audio_to_vector_db(i)
    print(i, num_entries, end="\r")

operation_id=146 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=147 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=148 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=149 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=150 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=151 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=152 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=153 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=154 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=155 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=156 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=157 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=158 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=159 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=160 status=<UpdateStatus.COMPLETED: 'completed'>
operation_id=161 status=<UpdateStatus.COMPLETED: 'completed'>
operatio