## Download

In [1]:
from huggingface_hub import hf_hub_download

model_path = hf_hub_download(
    repo_id="m-a-a-p/RPECL",
    filename="first_training_set/2d_alibi_Epoch-511.pt"
)
print("Downloaded checkpoint to:", model_path)

Downloaded checkpoint to: C:\Users\Potat\.cache\huggingface\hub\models--m-a-a-p--RPECL\snapshots\fc8bc535ee7435450e6fcb3cc7a4a716a0ca3559\first_training_set\2d_alibi_Epoch-511.pt


In [2]:
import torch
from models.Myna import Myna

device = "cuda" if torch.cuda.is_available() else "cpu"

chunk_size = 256

model = Myna(image_size=(128, chunk_size),
    channels=1,
    patch_size=(16, 16),
    latent_space=128,
    d_model=384,
    depth=12,
    heads=6,
    mlp_dim=1536,
    mask_ratio=0.0,
    use_cls=True,
    use_sinusoidal=False,
    use_y_emb=False,
    use_rope_x=False,
    use_rope_y=False,
    rope_base=False,
    use_alibi_x=True,
    use_alibi_y=True)

model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
model.to(device)
model.mask_ratio = 0.0
model.eval()

Myna(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16)
    (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=256, out_features=384, bias=True)
    (3): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
  )
  (transformer): Transformer(
    (alibi_2d): Alibi2DBias()
    (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-11): 12 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (to_qkv): Linear(in_features=384, out_features=1152, bias=False)
          (to_out): Linear(in_features=384, out_features=384, bias=False)
        )
        (1): FeedForward(
          (net): Sequential(
            (0): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=384, out_features=1536, bias=True)
            (2): GEL

In [15]:
import os
from datasets import tqdm
from training.inference import load_and_parse_audio

def run_batch(model, chunks, remainder):
    chunks = chunks.unsqueeze(1)
    remainder = remainder.unsqueeze(0).unsqueeze(1)

    max_batch_size = 4

    num_chunks = int(chunks.shape[0] / max_batch_size)

    if num_chunks == 0:
        data_minibatches = [remainder]
    else:
        data_minibatches = torch.chunk(chunks, num_chunks, dim=0)
        data_minibatches+= (remainder,)

    all_preds = []

    for i, batch in enumerate(data_minibatches):
        data_minibatch = batch
        data_minibatch = data_minibatch.to("cuda")
        outputs = model(data_minibatch)
        all_preds.extend(outputs.detach().to("cpu"))
        torch.cuda.empty_cache()

    if len(all_preds) == 0:
        print(f"Empty Song")
        return None

    latents = torch.stack(all_preds, dim=0)
    averages = latents.mean(dim=0).numpy()
    torch.cuda.empty_cache()
    return averages

def inference_on_folder(model, path, chunking=True, chunk_size=256):
    all_songs = os.listdir(path)

    latents = []
    names = []

    for song in tqdm(all_songs):
        song_path = os.path.join(path, song)
        chunks, remainder = load_and_parse_audio(song_path, convert=True, chunking=chunking, chunk_size=chunk_size)
        if chunks is None:
            continue

        latent = run_batch(model, chunks, remainder)

        if latent is None:
            continue

        latents.append(latent)
        names.append(song)

    return latents, names

In [8]:
# Im inferencing on my spotify liked songs which I've downloaded onto a local folder.
# There are around 7,000 songs spanning ~600 genres snd ~3000 artists
# The genre breakdown is skewed towards Japanese Indie, Funk/Breakbeat, EDM, and House,
# But songs many songs from just about every genre are in this list

path_to_folder = "E:\\SongsDataset\\songs\\"
chunk_size = 256
latents, names = inference_on_folder(model, path_to_folder, chunking=True, chunk_size=chunk_size)
torch.save(latents, f"latents-{chunk_size}.pt")

  0%|          | 0/7056 [00:00<?, ?it/s]

In [9]:
import torch.nn.functional as F

def cosine_sim(track_data, interested_point):
    track_norm = F.normalize(track_data, p=2, dim=1)
    point_norm = F.normalize(interested_point, p=2, dim=1)
    similarity_matrix = track_norm @ point_norm.T
    return similarity_matrix

def get_k_most_similar(data, labels, point, k=100):
    track_data = data

    similarity_matrix = cosine_sim(track_data, point)
    ids = np.array(labels)

    # Sort similarities per song
    similarity_matrix = similarity_matrix.mean(dim=1)
    sorted_x, indices = torch.sort(similarity_matrix, dim=0, descending=True)

    # Plot for each selected song
    sim_values = sorted_x.cpu().numpy()
    sim_labels = ids[indices.cpu().numpy()]

    top_k = k
    if len(sim_values) > (top_k):
        sim_labels_trunc = sim_labels[:top_k]
    else:
        sim_labels_trunc = sim_labels

    sim_labels_trunc = [str(x) for x in sim_labels_trunc]

    return sim_labels_trunc

In [18]:
import numpy as np
latents = torch.load(f"latents-{chunk_size}.pt", weights_only=False)
latents = torch.from_numpy(np.array(latents))

In [11]:
index_of_interest = 1000
point_of_interest, name = latents[index_of_interest].unsqueeze(0), names[index_of_interest]
most_similar_songs = get_k_most_similar(latents, names, point_of_interest, k=16)
most_similar_songs

['Clean Tears - Inverse Relation - Radio Mix (New Original Version).mp3',
 '暁Records - Necromantic.mp3',
 'Clean Tears, Hatsune Miku - Blue Layers (feat. 初音ミク).mp3',
 'tokiwa, Megumi Takahashi - Continue (feat. Megumi Takahashi).mp3',
 'ShibayanRecords - Spring Rouge.mp3',
 'DJ Dean - Play It Hard - Club Mix.mp3',
 'DJ Noriken, aran - Comet Coaster.mp3',
 'Clean Tears, Hatsune Miku - Blue layers (Extended Mix) (feat. 初音ミク).mp3',
 'Aura Qualic - Time of my Life (Original Mix).mp3',
 "Clean Tears, Hatsune Miku - don't escape (Miku Version) (feat. 初音ミク).mp3",
 'Hikaru Utada, PUNPEE - Simple And Clean - Ray Of Hope MIX.mp3',
 'Clean Tears, Hatsune Miku - Cyclic (feat. 初音ミク).mp3',
 'Team Grimoire - Sheriruth.mp3',
 'Dropgun - Little Drop.mp3',
 'DJ Dan - Fascinated - Radio Mix.mp3',
 'BeXta, Meza - In My Mind.mp3']

In [12]:
index_of_interest = 10
point_of_interest, name = latents[index_of_interest].unsqueeze(0), names[index_of_interest]
most_similar_songs = get_k_most_similar(latents, names, point_of_interest, k=16)
most_similar_songs

['01sail, Ámina - With My Heart (feat. Ámina).mp3',
 'Edison Chen - 戰爭 - Feat. 陳奐仁, 胡蓓蔚 & Mc仁.mp3',
 'KOJI 1200 - Blow Ya Mind - I LOVE AMERICA.mp3',
 'I DONT KNOW HOW BUT THEY FOUND ME - Do It All The Time.mp3',
 'The Blessed Madonna, Uffie - Serotonin Moonbeams.mp3',
 'Atwood - Careless.mp3',
 'Kevin Ross - Look My Way.mp3',
 "Rui En - Can't Stand It.mp3",
 'DA PUMP - if....mp3',
 'konoco - uranosoko.mp3',
 'Yitai Wang, 刘至佳 - 危险派对.mp3',
 '2 Mello - Future Unwritten.mp3',
 'Gorillaz, De La Soul - Feel Good Inc..mp3',
 'IVE - LOVE DIVE.mp3',
 'marQ - Farewell.mp3',
 "Ms. Lauryn Hill - Can't Take My Eyes Off of You - (I Love You Baby).mp3"]

In [13]:
index_of_interest = 777
point_of_interest, name = latents[index_of_interest].unsqueeze(0), names[index_of_interest]
most_similar_songs = get_k_most_similar(latents, names, point_of_interest, k=16)
most_similar_songs

['Calmera - Smile.mp3',
 'Shin Rizumu - LADY.mp3',
 'カネコアヤノ - かみつきたい.mp3',
 'Miyuki Hatakeyama - 海が欲しいのに.mp3',
 'Calmera - Golden Hour.mp3',
 'Sakura Fujiwara - Soup.mp3',
 'Calmera - 上にいきたくないデパート.mp3',
 'The Dip - Slow Sipper.mp3',
 'Minuano - Endless Season.mp3',
 'Shin Rizumu - ショートヘア.mp3',
 'Jack Black Polka Band - Everybody Polka.mp3',
 'mouse on the keys - spectres de mouse.mp3',
 'Rollercoaster - Cheer up Mr. Kim.mp3',
 'Katsuo Ohno - 名探偵コナン メイン・テーマ - 暗殺者ヴァージョン.mp3',
 "Nelly Furtado - I'm Like A Bird.mp3",
 '40mP, Hatsune Miku - Kiss the villain.mp3']

In [14]:
import random

index_of_interest = random.randint(0, len(names) - 1)
point_of_interest, name = latents[index_of_interest].unsqueeze(0), names[index_of_interest]
print(f"Song of Interest: {name}\nAt Index {index_of_interest}")
most_similar_songs = get_k_most_similar(latents, names, point_of_interest, k=16)
most_similar_songs

Song of Interest: 岡崎律子 - 小さな祈り.mp3
At Index 6848


['岡崎律子 - 小さな祈り.mp3',
 "The 8-Bit Big Band, Martina DaSilva - No More What Ifs (From 'Persona 5').mp3",
 "Alan Menken - Kingdom Dance - From 'Tangled'Score.mp3",
 'Lamp - 1998.mp3',
 '梶浦 由記 - the first town.mp3',
 'Japanese Breakfast - Boyish.mp3',
 'wave to earth - seasons.mp3',
 'Kim Minwoo - 사랑일뿐야.mp3',
 'Sakura Fujiwara - 春の歌.mp3',
 'Lord Huron - The Night We Met.mp3',
 '藤澤慶昌 - 戦い.mp3',
 'Fishmans - DAYDREAM.mp3',
 'Predawn - 炭酸.mp3',
 'Masakatsu Takagi - Amamizu.mp3',
 'wave to earth - ride.mp3',
 'Fuyumi Abe - cinema.mp3']