In [1]:
from dataclasses import dataclass
import os
import config
import json


@dataclass
class Song:
    lyrics: str
    name: str
    artist: str
    meta: dict

In [2]:
remake = False
if remake:
    data = []
    for file in os.listdir(config.SONG_DATA_PATH):
        with open(os.path.join(config.SONG_DATA_PATH, file), 'r') as f:
            cur = json.load(f)
            for el in cur['data']:
                data.append(Song(el['lyrics'], el['song'], el['artist'], el['meta']))
else:
    pass

In [3]:
len(data)

124132

In [4]:
def clean(text: str) -> str:
    return text.replace('\n', ' ')

In [5]:
from sentence_transformers import SentenceTransformer


In [6]:
transformers = [
    ("paraphrase-multilingual-MiniLM-L12-v2", SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2", device='cuda')),
    ("all-MiniLM-L12-v2", SentenceTransformer("all-MiniLM-L12-v2", device='cuda')),
    ("distiluse-base-multilingual-cased-v2", SentenceTransformer("distiluse-base-multilingual-cased-v2", device='cuda')),
    ("all-mpnet-base-v2", SentenceTransformer("all-mpnet-base-v2", device='cuda')),
    ("all-distilroberta-v1", SentenceTransformer("all-distilroberta-v1", device='cuda')),
    ("multi-qa-mpnet-base-dot-v1", SentenceTransformer("multi-qa-mpnet-base-dot-v1", device='cuda')),
]

In [7]:
told = transformers

In [8]:
transformers = [
    told[1],
    told[2],
    told[0],
    told[4],
    told[3],
    told[5],
]

In [9]:
def dict_to_list_fixed(d):
    tmp = list(d.items())
    tmp.sort()
    return [i[1] for i in tmp]


In [10]:
for i, el in enumerate(data):
    data[i].lyrics = clean(el.lyrics)
    data[i].meta = dict_to_list_fixed(data[i].meta)

In [11]:
import time
runs = 0
def add_data(tf, _songs):
    global runs, cur
    for el in _songs:
        cur.append({
            "lyrics": tf.encode(el.lyrics).tolist(),
            "song": el.name,
            "artist": el.artist,
            "meta": el.meta
        })
        while len(cur) >= 1000:
            time.sleep(.5)
    runs -= 1


In [12]:
import numpy as np
lyrics = [el.lyrics for el in data]

In [13]:
import tqdm
from threading import Thread
import torch
from torch.utils.data import DataLoader
print(torch.cuda.is_available())


for el in transformers:
    dir = config.SONG_DATA_PATH + "-" + el[0]
    bert = el[1]
    if not os.path.exists(dir):
        os.mkdir(dir)
    last_file = f"last/lastEncode{el[0].capitalize()}.txt"
    if not os.path.exists(last_file):
        with open(last_file, 'w') as f:
            f.write("0 0\n")
    with open(last_file, 'r') as f:
        index, start = map(int, f.readline().split())
    s0 = start
    batch_size = 1000
    dl = DataLoader(lyrics[start:], batch_size=batch_size, shuffle=False)
    
    for bi, batch in tqdm.tqdm(enumerate(dl)):
        # print(type(batch))
        encoded = bert.encode(batch)
        cur = [{
            "lyrics": encoded[i].tolist(),
            "song": data[s0 + bi * batch_size + i].name,
            "artist": data[s0 + bi * batch_size + i].artist,
            "meta": data[s0 + bi * batch_size + i].meta,
        } for i in range(len(batch))]
        with open(os.path.join(dir, config.SONG_FILENAME.format(index)),'w') as f:
            json.dump(cur, f)
        with open(last_file, 'w') as f:
            f.write(str(index + 1) + " " + str(start + batch_size))
        index += 1
        start += batch_size



True


0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
125it [2:10:24, 62.59s/it]
