In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from random import randint
from tqdm import tqdm
import re
import os

In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType,StructField, StringType, IntegerType 
from pyspark.sql.types import ArrayType, DoubleType, BooleanType
from pyspark.sql.functions import col,array_contains

In [None]:
spark = SparkSession.builder \
    .appName('MusicGen') \
    .master("local[*]") \
    .getOrCreate()

sc = spark.sparkContext

train = spark.read.format("parquet").load("train.parquet")
train.printSchema()

n_rows = train.count()
print(f"N rows: ", n_rows)

In [None]:
WANTED_ROWS = 100_000
frac = WANTED_ROWS / n_rows
print(frac)

sampled = train.sample(fraction=frac).toPandas()
sampled

In [None]:
!pip install nltk 
import nltk
nltk.download('punkt')

In [None]:
from nltk.util import pad_sequence
from nltk.lm.preprocessing import padded_everygram_pipeline
from nltk.lm import MLE

In [None]:
tokenized_data = [nltk.word_tokenize(text.lower()) for text in sampled["lyrics"]]
tokenized_data[0]

In [None]:
std_data = [
    list(
        pad_sequence(tokens, n=3, pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>')
    ) for tokens in tokenized_data
]

In [None]:
training, vocab = padded_everygram_pipeline(3, std_data)

In [None]:
model = MLE(3)
model.fit(training, vocab)

In [None]:
def complete_text(model, previous_text, n_tokens=10):
    tokenized_previous = nltk.word_tokenize(previous_text.lower())
    generated_text = model.generate(n_tokens, random_seed=1, text_seed=tokenized_previous)
    texto_gerado = [token for token in generated_text if token != '<s>' and token != '</s>']
    return ' '.join(texto_gerado)


In [None]:
complete_text(model, 'The stars remind of')

In [None]:
import pickle
# f = open('baseline-model.pickle', 'wb')
# pickle.dump(model, f)
# f.close()
f = open('baseline-model.pickle', 'rb')
model = pickle.load(f)
f.close()

In [None]:
test = spark.read.format("parquet").load("test.parquet")
test.printSchema()

In [None]:
N = 10
TOKEN_REGEX = r'\b\w+\b'
def got_row_right(row, bd):
    # music = row['lyrics']
    music = row['lyrics']
    model = bd.value
    tokens = nltk.word_tokenize(music.lower())
    # Get one random part of the music
    i = randint(0, len(tokens) - N - 1)
    
    generated_text = model.generate(1, random_seed=1, text_seed=tokens[i:i+N])
    texto_gerado = [token for token in generated_text if token != '<s>' and token != '</s>']
    return texto_gerado == tokens[i+N]

In [None]:
bd = sc.broadcast(model)

In [None]:
acc_list = []
res = test.rdd.map(got_row_right)

for _ in range(10):
    res_sample = res.sample(False, fraction=0.1)
    right = res_sample.filter(lambda x: x).count()
    wrong = res_sample.filter(lambda x: not x).count()
    acc = right / (right + wrong)
    print(f"Right: {right}, Wrong: {wrong}, acc: {acc}")
    acc_list.append(acc)    