# 15688 Project - Lyrics Generator & Classificator

“Music can change the word because it can change people.” said by the Legendary U2 rocker Bono. A beautiful song usually has memorable lyrics that sometimes change people. However, it is not an easy task to write good lyrics. 

The aim of our project is to create a lyric generation model based on existing lyrics of different music genres - pop, rock, hip hop, etc - using machine learning algorithms that are common in natural language processing.

**Attention:** <br>
We develop this project via python scripts instead of Jupyter Notebook. <br>
This tutorial mainly walks you through each step.  <br>
It is highly recommended to go over this project via the command line we suggest instead of on this notebook. (In case some unexpected error.)

## Step 1. Data Collection

### Part 1. Singer and song collection
In order to train the lyric model, the first step is to collect lyrics by genre. We collected the male and female artists' names from [music.163.com ](https://music.163.com/#/discover/artist/cat?id=2001)by copying the information on the webpage and saved them as csv files.

After getting artists' names, we use the [musixmatch](http://api.musixmatch.com/ws/1.1/) api to collect the genre and name of songs of the artists. The results are exported as csv file so that we can count the most frequent genres among all the songs we collected. 

In [4]:
'''
15688 final project - lyric generator

data collection

retrieve the artists, genres and tracks and export to csv file

API used: musixmatch Developer
documentation: https://developer.musixmatch.com/documentation

'''
import os
import json
import requests
import pandas as pd

# load api key
with open("../musicmatch_api.key",'r') as f:
    api = f.read()

root = "http://api.musixmatch.com/ws/1.1/"

def get_artist(api, pageNum, page_size=100, country = "us"):

    '''
    getting top artists and their genres
    Args:
        api: API key
        pageNum: the page number for paginated results
        page_size: the page size for paginated results. Range is 1 to 100
        country: country of the artist ranking
    Return:
        df: a pandas dataframe containing artists, genres and genre id
        all_genres: a set of all genres related to the artists found
    '''
    result = []
    all_genres = set()
    for i in range(pageNum):
        param = {
            "apikey":api,
            "country": "country",
            "page": i+1,
            "page_size": page_size,
            "format": "json"
        }

        singers = requests.get(root + "chart.artists.get?", params = param)
        response = json.loads(singers.content)
        artist_list = response.get("message").get("body").get("artist_list")
        
        for artist in artist_list:
            name = artist.get("artist").get('artist_name')
            genres = artist.get("artist").get("primary_genres").get("music_genre_list")
            for g in genres:
                genre = g.get("music_genre").get("music_genre_name")
                genre_id = g.get("music_genre").get("music_genre_id")
                all_genres.add(genre)
                result.append({"artist":name, "genre":genre, "genre_id":genre_id})
    
    df = pd.DataFrame(result)
    df = df.loc[:, ["artist", "genre", "genre_id"]]
    return df, all_genres


def get_artist_genre(api, all_artist_list):

    '''
    getting the artists and their genres of given list

    Args:
        api: API key
        all_artist_list: list of all the artists
    Return:
        df: a pandas dataframe containing artists, genres and genre id
        all_genres: a set of all genres related to the artists found

    '''
    result = []
    all_genres = set()
    param = {
            "apikey":api,
            "page":1,
            "page_size":10
        }
    for artist in all_artist_list:
        param["q_artist"] = artist

        search_result = requests.get(root + "artist.search?", params = param)
        response = json.loads(search_result.content)

        artist_list = response.get("message").get("body").get("artist_list")

        if not artist_list:
            continue
        artist_item = artist_list[0]
        
        name = artist_item.get("artist").get('artist_name')
        genres = artist_item.get("artist").get("primary_genres").get("music_genre_list")
        for g in genres:
            genre = g.get("music_genre").get("music_genre_name")
            genre_id = g.get("music_genre").get("music_genre_id")
            all_genres.add(genre)
            result.append({"artist":name, "genre":genre, "genre_id":genre_id})
    
    df = pd.DataFrame(result)

    if not df.empty:
        df = df.loc[:, ["artist", "genre", "genre_id"]]
    else:
        print("result is an empty dataframe")
    return df, all_genres

def get_songs(api, artist_df, page_size = 100):


    '''
    getting track names by artists and genre id

    Args:
        api: API key
        artist_df: dataframe with columns of artist, genre and genre id
        page_size: the page size for paginated results. Range is 1 to 100
    Return:
        df: a pandas dataframe containing artists, genres, genre id and the top
        100 tracks with lyrics under that genre by the artist
        
    '''

    result = []

    for i, row in artist_df.iterrows(): 
        param = {
                "apikey":api,
                "q_artist": row['artist'],
                "f_music_genre_id": row['genre_id'], # filter by genre id
                "f_has_lyrics":"True", # only get tracks with lyrics
                "page": 1,
                "page_size": page_size
            }

        singer = requests.get(root + "track.search?", params = param)
        response = json.loads(singer.content)
        song_list = response.get("message").get("body").get("track_list")
        for song in song_list:
    
            track_name = song.get("track").get("track_name")
            result.append(
                {
                "artist":row["artist"], 
                "genre":row["genre"], 
                "genre_id":row["genre_id"],
                "track_name":track_name
                })

    df = pd.DataFrame(result)
    df = df.loc[:, ["artist", "genre","genre_id", "track_name"]]
    return df


    #Step 1. get artists and their genres
    # load the first 1,300 artists of from csv file
    artist_df = pd.read_csv("./csv_files/all_female_artists.csv", header = None)[:50]
    artists_list = []
    for col in artist_df.columns.values:
        artists_list += list(artist_df[col])

    artist_genre_df, all_genres = get_artist_genre(api, artists_list)
    artist_genre_df.to_csv("./csv_files/all_female_artist_genre.csv",index = False)

    #Step 2. get songs by artists and genres
    artist_df = pd.read_csv("./csv_files/all_female_artist_genre.csv")[:1000]
    print(artist_df.shape)
    song_df = get_songs(api, artist_df)
    song_df.to_csv("./csv_files/all_female_artist_genre_track.csv", index = False)

### Part 2. Lyrics collection via *lyricwikia*

With all the song names, we use [lyricwikia](https://github.com/enricobacis/lyricwikia) package in Python to collect the lyrics. The package can be installed with pip.

```python
pip3 install lyricwikia
```


In [None]:
import lyricwikia as ly

#request lyric song by song
#row by row in the dataframe
def getLyrics(songs):
    i = -1
    print("Total songs number:" + str(songs.shape[0]))
    for index, row in songs.iterrows():
        i += 1
        if i%100 == 0:
            print("Processing song [" + str(i) + "]")

        song = row['track_name']
        #print(song)
        artist = row['artist']
        try:
            lyric = ly.get_lyrics(artist, song, linesep='\n', timeout=None)
            songs.loc[index,'lyric'] = lyric
        except:
            continue    
        #print(lyric)
    return songs


def run(oriFile, newFile):
    songs = pd.read_csv(oriFile, encoding = "ISO-8859-1")
    songs = getLyrics(songs)
    songs = songs.dropna()
    #print(songs)
    songs.to_csv(newFile)


You can run in command line:

```bash
python3 get_lyrics.py -h
```
to choose the csv file of the lyric track and customize the output path of the lyric file. 

From the csv file of songs and their genres, we found the top 3 genres are:

* Pop
* Hip Hop/Rap
* Rock

We will train the model based on these three genres. Therefore, we will extract and generate the dataset of lyrics of each genre.

In [None]:
import numpy as np

def split_lyrics(csv_path):    
    df = pd.read_csv(csv_path)

    df = df.iloc[:,1:]

    result = []
    for genre in ['Pop','Rock','Hip Hop/Rap']:
        result.append(df[df['genre'] == genre])
    return result



if __name__ == "__main__":
    df_female = split_lyrics('../csv_files/all_female_artist_lyrics.csv')
    df_male = split_lyrics('../csv_files/all_male_artist_lyrics.csv')

    for d, f in zip(df_female,df_male):
        genre = d.iloc[0,1]
        genre = genre.replace(" ", "_").replace("/", "_")
        df = pd.concat([d,f])
        df.to_csv('../csv_files/lyrics_' + genre +".csv", index = False)


We are now ready to train the model with 3 datasets consisting of lyrics of different genres. 

And the number of lyrics in each generes are:

| genre | lyric number |
|---|---|
| rap | 5039 |
| pop | 21998 |
| rock | 6503 | 

*The singer, song and lyric files are all stored in `../csv_files/` directory*

## Step 2. Data Preprocessing


Later we'll apply two different deep learning methods:
+ LSTM model to generate chosen genre of lyrics
+ CNN model to classify a lyric into specific genre

Since these two methods need different preprocessing, here we divide data preprocessing part into 2 part.

### Part 1. Preprocessing for LSTM model:


For LSTM model, it is important to know the start and the end of a sentense. So here in the preprocessing stage, we manually add a start mark and an end mark to each lyric. And then use `nltk` package to tokenize lyrics into words and stem them. Then remove all the rare words.   <br>
The most important method below is the `process` method, which generates all the features needed by the LSTM model. The returning X represents the word id sequences in each batch size lyrics. The Y is almost the same as X, except it is actually X moving 1 word to the right. And we also need a word to Id dict so that we can transform generated ids into words in the test stage. 

In [1]:
import nltk
from collections import Counter
import numpy as np
import string
import pandas as pd

START_MARK = "["
END_MARK = "]"

def seperate(docs_ls, is_rnn):
    if is_rnn:
        docs_raw = [tokenize(START_MARK+str(doc)+END_MARK) for doc in docs_ls]
    else:
        docs_raw = [tokenize(str(doc)) for doc in docs_ls]
    docs = remove_stopwords(docs_raw)
    print(" ".join(docs[0]))
    return docs

def remove_stopwords(docs):
    stopwords = get_rare_words(docs)
    stopwords = set(stopwords)
    res = [[word for word in doc if word not in stopwords ] for doc in docs]
    return res


def tokenize(text, lemmatizer=nltk.stem.wordnet.WordNetLemmatizer()):
    """ Normalizes case and handles punctuation
    Inputs:
        text: str: raw text
        lemmatizer: an instance of a class implementing the lemmatize() method
                    (the default argument is of type nltk.stem.wordnet.WordNetLemmatizer)
    Outputs:
        list(str): tokenized text
    """
    text = text.strip()
    text = text.lower()
    text = text.replace("'", "")
    text = text.replace("\n", ".\n")
    text = text.replace("\t", " ")
    
    punc = string.punctuation
    for c in punc:
        if c in text:
            text = text.replace(c, ' '+c+' ')
    
    tokens = nltk.word_tokenize(text)
    res = []
    
    for token in tokens:
        try:
            word = lemmatizer.lemmatize(token)
            res.append(str(word))
        except:
            continue
    docs=nltk.word_tokenize(" ".join(res))
    return res

def get_rare_words(tokens_ls):
    """ use the word count information across all tweets in training data to come up with a feature list
    Inputs:
        processed_tweets: pd.DataFrame: the output of process_all() function
    Outputs:
        list(str): list of rare words, sorted alphabetically.
    """
    counter = Counter([])
    for tokens in tokens_ls:
        counter.update(tokens)
    
    rare_tokes = [k for k,v in counter.items() if v<=3]
    rare_tokes.sort()
    return rare_tokes

def process(lyrics, batchSize=10, is_rnn=True):
    """
    It will change lyrics to vetors as well as build the
    features and labels for LSTM

    lyric: list of str. all of the lyrics
    return: (X, Y, vocab_size, vocab_ID, vocab)
    """

    lyricDocs = seperate(lyrics, is_rnn)
    print("Totally %d lyrics."%len(lyricDocs))

    allWords = {}
    for lyricDoc in lyricDocs:
        for word in lyricDoc:
            if word not in allWords:
                allWords[word] = 1
            else:
                allWords[word] += 1

    wordPairs = sorted(allWords.items(), key = lambda x: -x[1])
    words, a= zip(*wordPairs)
    words += (" ", )
    wordToID = dict(zip(words, range(len(words)))) #word to ID
    wordTOIDFun = lambda A: wordToID.get(A, len(words))

    lyricVector = [([wordTOIDFun(word) for word in lyricDoc]) for lyricDoc in lyricDocs] 

    batchNum = (len(lyrics) - 1) // batchSize 

    X = []
    Y = []

    for i in range(batchNum):
        batchVec = lyricVector[i*batchSize: (i+1)*batchSize]

        maxLen = max([len(vector) for vector in batchVec])

        temp = np.full((batchSize, maxLen), wordTOIDFun(" "),np.int32)

        for j in range(batchSize):
            temp[j, :len(batchVec[j])] = batchVec[j]

        X.append(temp)

        temp_copy = np.copy(temp)
        temp_copy[:, :-1] = temp[:, 1:]
        Y.append(temp_copy)
    return X, Y, len(words) + 1, wordToID, words


def generate_feature(filename, ouput_path):
    """
    This methods is mainly for printing out the result to examine
    """
    df = pd.read_csv(filename)

    docs = df['lyric'].values.tolist()[:100]
    print(docs[0])
    print()
    print()
    X, Y, size, wordToId, words = process(docs)
    print(size)
    print(X[0][0].shape)

def pretreatment(filename, batchSize):
    df = pd.read_csv(filename)
    docs = df['lyric'].values
    P = np.random.permutation(len(docs))
    print("Shuffling")
    docs = docs[P].tolist()
    print("Processing")
    return process(docs, batchSize)

Since the processing is a little bit time-consuming, we save the result into a pickle file to make it easier for us to testing LSTM model. We use the following code to save preprocessed LSTM data. 

In [2]:
def save(input_path, param_saving_path, batch_size):
    X, Y, wordNum, wordToID, words = lyric_processing.pretreatment(input_path,batch_size)
    data = {'X': X, "Y":Y, "wordNum":wordNum, 
        "wordToID": wordToID, "words":words, 'batch_size':batch_size}

    with open(param_saving_path, 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

    print("model saved!")

You can run in command line
```bash
python3 save_data.py -h
```
to choose the input raw data and customize your output file name. 

### Part 2. Preprocessing for CNN model. 


Unlike the LSTM model, the start and the end is not important in CNN model. <br>
We use the same method in Part 1. to tokenize lyrics and generate a vocabulary. And then we pad each lyrics into the longest (or our chosen) length. (Here we pad using the mark `<PAD>`) Then we save parameters into a pickle file. 

In [None]:
from lyric_processing import tokenize, remove_stopwords, seperate
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow.contrib.keras as kr
import argparse
import pickle

filename_rock = "../csv_files/lyrics_Rock.csv"
filename_pop  = "../csv_files/lyrics_Pop.csv"
filename_rap  = "../csv_files/lyrics_Hip_Hop_Rap.csv"
doc_num = 5000
max_length = 800
categories = ['pop','rock', 'rap']
cat_to_id = dict(zip(categories, range(len(categories))))


def load_lyrics(filename):
    df = pd.read_csv(filename)
    docs = df['lyric'].values
    return docs

def get_raw_data():
    # --------------- load and select data -------------
    lyric_rock = load_lyrics(filename_rock)
    lyric_pop = load_lyrics(filename_pop)
    lyric_rap = load_lyrics(filename_rap)
    P_rap = np.random.permutation(lyric_rap.shape[0])[:doc_num]
    P_rock = np.random.permutation(lyric_rock.shape[0])[:doc_num]
    P_pop = np.random.permutation(lyric_pop.shape[0])[:doc_num]

    lyric_pop_chosen = lyric_pop[P_pop]
    lyric_rap_chosen = lyric_rap[P_rap]
    lyric_rock_chosen = lyric_rock[P_rock]
    lyrics = np.concatenate((lyric_pop_chosen, lyric_rock_chosen, lyric_rap_chosen))

    y_pop = np.array([cat_to_id['pop'] for _ in lyric_pop_chosen])
    y_rock = np.array([cat_to_id['rock'] for _ in lyric_rock_chosen])
    y_rap = np.array([cat_to_id['rap'] for _ in lyric_rap_chosen])
    y = np.concatenate((y_pop, y_rock, y_rap))

    return lyrics, y


def process(param_saving_path):
    lyrics, y = get_raw_data()
    lyricDocs = seperate(lyrics, False)
    print("Totally %d lyrics."%len(lyricDocs))
    allWords = {}
    for lyricDoc in lyricDocs:
        for word in lyricDoc:
            if word not in allWords:
                allWords[word] = 1
            else:
                allWords[word] += 1

    wordPairs = sorted(allWords.items(), key = lambda x: -x[1])
    words, a= zip(*wordPairs)
    words += (" ", )
    words = ['<PAD>'] + list(words)
    wordToID = dict(zip(words, range(len(words)))) #word to ID
    wordTOIDFun = lambda A: wordToID.get(A, len(words))

    lyricVector = [([wordTOIDFun(word) for word in lyricDoc]) for lyricDoc in lyricDocs]

    x_pad = kr.preprocessing.sequence.pad_sequences(lyricVector, max_length)
    y_pad = kr.utils.to_categorical(y, num_classes=len(cat_to_id))

    data = {'X': x_pad, 'Y': y_pad, 'wordToID': wordToID, 'seq_length': max_length}
    
    with open(param_saving_path, 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
    
    print('Finish!')

You can run in command line
```bash
python3 classification_preprocess.py -h
```
to customize your output parameter pickle file name and path. 

## Step 3. Training Models


### Part 1. LSTM model 

Here we use LSTM model to generate lyrics for different genres. [This blog](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) clearly states the knowledge of LSTM. <br>
For each genres, we do preprocessing and save the needed parameters into pickle files. Then we load the specific pickle file and train the LSTM model for the genre. <br>
The word IDs will be embedded into a dense representation before feeding to the LSTM, which is called embedding layer. Here we use 2 layers of LSTM to process the data, followed by softmax representing each word's appearing probability. <br>
In the training stage, we do 100 epoches. Since the data we use is very large and LSTM model is very slow to train. Here we use AWS to train 3 models for 3 genres. Even on AWS GPU server, it took nearly 40 hours to train. (40+ hours to train for pop, 18+ hours to train for rap and 10+ hours to train for rock). <br>
The models are saved in:
+ rap model:  `./checkpoints/rap/` 
+ rock model: `./checkpoints/rock/`
+ pop model:  `./checkpoints/pop/`

In [None]:
import tensorflow as tf
import numpy as np

# batchSize = 10
learningRateBase = 0.001
learningRateDecreaseStep = 100
epochNum = 100                    # train epoch

generateNum = 1

checkpointsPath = "./checkpoints" # checkpoints location

def buildModel(wordNum, gtX, hidden_units = 128, layers = 2):
    """build rnn"""
    with tf.variable_scope("embedding"): #embedding
        embedding = tf.get_variable("embedding", [wordNum, hidden_units], dtype = tf.float32)
        inputbatch = tf.nn.embedding_lookup(embedding, gtX)

    basicCell = tf.contrib.rnn.BasicLSTMCell(hidden_units)    
    stackCell = tf.contrib.rnn.MultiRNNCell([basicCell] * layers)
    initState = stackCell.zero_state(np.shape(gtX)[0], tf.float32)
    outputs, finalState = tf.nn.dynamic_rnn(stackCell, inputbatch, initial_state = initState)
    outputs = tf.reshape(outputs, [-1, hidden_units])

    with tf.variable_scope("softmax"):
        w = tf.get_variable("w", [hidden_units, wordNum])
        b = tf.get_variable("b", [wordNum])
        logits = tf.matmul(outputs, w) + b

    probs = tf.nn.softmax(logits)
    return logits, probs, stackCell, initState, finalState

def train(X, Y, wordNum, batchSize,reload=True):
    """train model"""
    gtX = tf.placeholder(tf.int32, shape=[batchSize, None])  # input
    gtY = tf.placeholder(tf.int32, shape=[batchSize, None])  # output
    logits, probs, a, b, c = buildModel(wordNum, gtX)
    targets = tf.reshape(gtY, [-1])
    #loss
    loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets],
                                                              [tf.ones_like(targets, dtype=tf.float32)], wordNum)
    cost = tf.reduce_mean(loss)
    tvars = tf.trainable_variables()
    grads, a = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5)
    learningRate = learningRateBase
    optimizer = tf.train.AdamOptimizer(learningRate)
    trainOP = optimizer.apply_gradients(zip(grads, tvars))
    globalStep = 0

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        if reload:
            checkPoint = tf.train.get_checkpoint_state(checkpointsPath)
            # if have checkPoint, restore checkPoint
            if checkPoint and checkPoint.model_checkpoint_path:
                saver.restore(sess, checkPoint.model_checkpoint_path)
                print("restored %s" % checkPoint.model_checkpoint_path)
            else:
                print("no checkpoint found!")

        for epoch in range(epochNum):
            if globalStep % learningRateDecreaseStep == 0: #learning rate decrease by epoch
                learningRate = learningRateBase * (0.95 ** epoch)
            epochSteps = len(X) # equal to batch
            for step, (x, y) in enumerate(zip(X, Y)):
                globalStep = epoch * epochSteps + step
                a, loss = sess.run([trainOP, cost], feed_dict = {gtX:x, gtY:y})
                print("epoch: %d steps:%d/%d loss:%3f" % (epoch,step,epochSteps,loss))
                if globalStep%1000==0:
                    print("save model")
                    # save_path = saver.save(sess, '/output/model.ckpt')
                    save_path = saver.save(sess,checkpointsPath + "/lyric",global_step=epoch)
                    print("Model saved in file: %s" % save_path)

def probsToWord(weights, words):
    """probs to word"""
    t = np.cumsum(weights) #prefix sum
    s = np.sum(weights)
    coff = np.random.rand(1)
    index = int(np.searchsorted(t, coff * s)) # large margin has high possibility to be sampled
    return words[index]

def test(wordNum, wordToID, words, model_path=checkpointsPath):
    """generate lyric"""
    gtX = tf.placeholder(tf.int32, shape=[1, None])  # input
    logits, probs, stackCell, initState, finalState = buildModel(wordNum, gtX)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        checkPoint = tf.train.get_checkpoint_state(model_path)
        # if have checkPoint, restore checkPoint
        if checkPoint and checkPoint.model_checkpoint_path:
            print(checkPoint.model_checkpoint_path)
            saver.restore(sess, checkPoint.model_checkpoint_path)
            print("restored %s" % checkPoint.model_checkpoint_path)
            print("\n\n")
        else:
            print("no checkpoint found!")
            exit(0)

        lyrics = []
        for i in range(generateNum):
            state = sess.run(stackCell.zero_state(1, tf.float32))
            x = np.array([[wordToID['[']]]) # init start sign
            probs1, state = sess.run([probs, finalState], feed_dict={gtX: x, initState: state})
            word = probsToWord(probs1, words)
            lyric = ''
            while word != ']' and word != ' ':
                if word == '.':
                    try:
                        if not (lyric[-1]=='.' and lyric[-2] == '.'):
                            lyric += '. '
                    except:
                        pass
                else:
                    lyric += word + ' '
                x = np.array([[wordToID[word]]])
                #print(word)
                probs2, state = sess.run([probs, finalState], feed_dict={gtX: x, initState: state})
                word = probsToWord(probs2, words)
            print("The generated lyrics: \n")
            print(lyric.replace(". ", "\n"))
            lyrics.append(lyric)
        return lyrics

You can also run in command line
```bash
python3 main.py -h
```
to choose training or testing the LSTM mode.

### Part 2. CNN model

Here we use the method mentioned in [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882) <br>
The architecture of the model is listed as below, which is taken from the above article. 
![CNN model](https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/Example-of-a-CNN-Filter-and-Polling-Architecture-for-Natural-Language-Processing.png)

In this project, I set the word embedding dimension to be 600 and each sequence length to be 800. (If not satisfied we add `<PAD>` in the front. ) We choose 256 convolution filters and each size is 5 followed by a max-over-time polling. Then we use a fully connected layers with drop out and ReLU. And finally use softmax to do the classification. (Here we do 3-class classification: the 3 genres mentioned above).

In [None]:
import os
import sys
import time
from datetime import timedelta
import tensorflow.contrib.keras as kr
import numpy as np
import tensorflow as tf
from sklearn import metrics
import pickle
import argparse
from classification_preprocess import cat_to_id
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


save_dir = './checkpoints/textcnn'
save_path = os.path.join(save_dir, 'best_validation')
param_saving_path = '../data/param-classify.dat'
tensorboard_dir = './tensorboard/textcnn'
validation_rate = 0.1

class TCNNConfig(object):
    """CNN param"""
    embedding_dim = 64  # word vector dimension
    seq_length = 800  # sequense length
    num_classes = 3  # class number
    num_filters = 256  # kernel number
    kernel_size = 5  # kernel size
    vocab_size = 5000  # vocab size

    hidden_dim = 128  # fully connected neuro number

    dropout_keep_prob = 0.5  # dropout keeping rate
    learning_rate = 1e-3  # learning rate

    batch_size = 64  # batch size
    num_epochs = 10  # total epoch number

    print_per_batch = 10  # output iterations
    save_per_batch = 10  # save tensorboard iterations


class TextCNN(object):
    """text classification，CNN model"""

    def __init__(self, config):
        self.config = config

        self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
        self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

        self.cnn()

    def cnn(self):
        """CNN model"""
        # word embedding
        with tf.device('/cpu:0'):
            embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
            embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)

        with tf.name_scope("cnn"):
            # CNN layer
            conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv')
            # global max pooling layer
            gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')

        with tf.name_scope("score"):
            # fully connected layer，with dropout and ReLU
            fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
            fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            fc = tf.nn.relu(fc)

            # classifier
            self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
            self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # predictor

        with tf.name_scope("optimize"):
            # loss function，cross entropy
            cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.logits, labels=self.input_y)
            self.loss = tf.reduce_mean(cross_entropy)
            # optimizor
            self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)

        with tf.name_scope("accuracy"):
            # accuracy
            correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
            self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

def get_time_dif(start_time):
    """get time"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))

def batch_iter(x, y, batch_size=64):
    """generate batchsize data"""
    data_len = len(x)
    num_batch = int((data_len - 1) / batch_size) + 1

    indices = np.random.permutation(np.arange(data_len))
    x_shuffle = x[indices]
    y_shuffle = y[indices]

    for i in range(num_batch):
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

def feed_data(model, x_batch, y_batch, keep_prob):
    feed_dict = {
        model.input_x: x_batch,
        model.input_y: y_batch,
        model.keep_prob: keep_prob
    }
    return feed_dict


def evaluate(model, sess, x_, y_):
    """evaluate the loss and accuracy"""
    data_len = len(x_)
    batch_eval = batch_iter(x_, y_, 128)
    total_loss = 0.0
    total_acc = 0.0
    for x_batch, y_batch in batch_eval:
        batch_len = len(x_batch)
        feed_dict = feed_data(model, x_batch, y_batch, 1.0)
        loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
        total_loss += loss * batch_len
        total_acc += acc * batch_len

    return total_loss / data_len, total_acc / data_len

def train(filename):
    config = TCNNConfig()
    with open(filename, 'rb') as f:
        data = pickle.load(f)

    x = data['X']
    y = data['Y']
    print(len(x))
    P = np.random.permutation(len(x))
    x = x[P]
    y = y[P]

    wordToID = data['wordToID']
    seq_length = data['seq_length']
    config.vocab_size = len(wordToID)
    config.seq_length = seq_length

    model = TextCNN(config)

    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)

    tf.summary.scalar("loss", model.loss)
    tf.summary.scalar("accuracy", model.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)
    
    
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    idx = int(x.shape[0] * validation_rate)
    x_train = x[idx:]
    x_val = x[:idx]
    y_train = y[idx:]
    y_val = y[:idx]
    
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    writer.add_graph(session.graph)
    
    print('Training and evaluating...')
    start_time = time.time()
    total_batch = 0  # total batch number
    best_acc_val = 0.0  # best validation accuracy
    last_improved = 0  # last improving
    require_improvement = 1000  # if not improving after 1000 iterations, end early
    
    flag = False
    for epoch in range(config.num_epochs):
        print('Epoch:', epoch + 1)
        batch_train = batch_iter(x_train, y_train, config.batch_size)
        for x_batch, y_batch in batch_train:
            feed_dict = feed_data(model, x_batch, y_batch, config.dropout_keep_prob)
            
            if total_batch % config.save_per_batch == 0:
                # save to tensorboard scalar
                s = session.run(merged_summary, feed_dict=feed_dict)
                writer.add_summary(s, total_batch)


            if total_batch % config.print_per_batch == 0:
                # get the loss and accuracy on training set and validation set
                feed_dict[model.keep_prob] = 1.0
                loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
                loss_val, acc_val = evaluate(model, session, x_val, y_val)  # todo

                if acc_val > best_acc_val:
                    # save the best result
                    best_acc_val = acc_val
                    last_improved = total_batch
                    saver.save(sess=session, save_path=save_path)
                    print("Save model!")
                    improved_str = '*'
                else:
                    improved_str = ''

                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6}, Train Loss: {1:>4.4}, Train Acc: {2:>5.2%},' \
                      + ' Val Loss: {3:>4.4}, Val Acc: {4:>5.2%}, Time: {5} {6}'
                print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))

            session.run(model.optim, feed_dict=feed_dict)  
            total_batch += 1

            if total_batch - last_improved > require_improvement:
                # early end
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break 
        if flag:
            break


def test(text, filename, genre, model_path=save_dir):

    config = TCNNConfig()
    with open(filename, 'rb') as f:
        data = pickle.load(f)

    wordToID = data['wordToID']
    seq_length = data['seq_length']
    config.vocab_size = len(wordToID)
    config.seq_length = seq_length

    model = TextCNN(config)

    text_ids = [[wordToID[word] for word in text.split(" ") if word in wordToID]]
    # print(text_ids)
    y = np.array([cat_to_id[genre]])

    x_pad = kr.preprocessing.sequence.pad_sequences(text_ids, seq_length)
    y_pad = kr.utils.to_categorical(y, num_classes=len(cat_to_id)) 

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        checkPoint = tf.train.get_checkpoint_state(model_path)
        # if have checkPoint, restore checkPoint
        if checkPoint and checkPoint.model_checkpoint_path:
            print(checkPoint.model_checkpoint_path)
            saver.restore(session, checkPoint.model_checkpoint_path)
            print("restored %s" % checkPoint.model_checkpoint_path)
            print("\n\n")
        else:
            print("no checkpoint found!")
            exit(0)

        

        print('Testing...')

        feed_dict = feed_data(model, x_pad, y_pad, 1.0)
        y_pred = session.run(model.y_pred_cls, feed_dict=feed_dict)
        return list(cat_to_id)[y_pred[0]]

The training processing is in `classification.ipynb`, where prints out the details in training. <br>
You can also run in command line
```bash
python3 classification_model.py -h
```
to choose training or testing mode. 

The training loss figure is:
![Screen Shot 2018-05-07 at 9.25.54 PM](https://oh1ulkf4j.qnssl.com/Screen Shot 2018-05-07 at 9.25.54 PM.png)

And the training accuracy figure is:
![Screen Shot 2018-05-07 at 9.25.40 PM](https://oh1ulkf4j.qnssl.com/Screen Shot 2018-05-07 at 9.25.40 PM.png)

The above figures are recorded by tensorboard.

The saving model's accuracy on testing dataset is 77.40%.

The model is saved in `./checkpoints/textcnn/`

## Steps 4. Display Result (It's show time!)

We use AWS server to train 3 LSTM lyric generator models for 3 genres and train the CNN classification model locally. With those saving models, now we can use our LSTM model to generate lyric in chosen genre. And then use our CNN classification model to test the result.

In order to get lyric in random, instead of selecting the word with the highest probability, I map the probability to an interval and randomly sample one. See in `probsToWord` method in Step3 part1. (Of course each lyric starts with the starting mark `[`) 

You are recommended to run in command line:
```bash
python3 generator.py -g [pop/rock/rap]
```
to generate a chosen genre lyric and verify in our classification model.

In [1]:
import pickle
from model import test as generate_model
from classification_model import test as classify_model
import tensorflow as tf

pop_model = "./checkpoints/pop"
pop_save = "./generate-param/param-pop-10-test.dat"

rock_model = "./checkpoints/rock"
rock_save = "./generate-param/param-rock-10-test.dat"

rap_model = "./checkpoints/rap"
rap_save = "./generate-param/param-rap-10-test.dat"

classify_model_path = "./checkpoints/textcnn"
classify_save = "./generate-param/param-classify-test.dat"


def run(genre):
    if genre == 'pop':
        model_path = pop_model
        data_path = pop_save
    elif genre == 'rock':
        model_path = rock_model
        data_path = rock_save
    elif genre == 'rap':
        model_path = rap_model
        data_path = rap_save
    else:
        print("Unexpected input!")

    with open(data_path, 'rb') as f:
        data = pickle.load(f)

    print('generating...')

    lyrics = generate_model(data['wordNum'], 
        data['wordToID'], 
        data['words'], 
        model_path=model_path)

    print('\n\n')
    tf.reset_default_graph()

    predicted = classify_model(lyrics[0], classify_save, genre, model_path=classify_model_path)
    print("\n\nOur classification model predict it to be: ")
    print(predicted)

Instructions for updating:
Use the retry module or similar alternatives.


### Generate a Pop lyric and verify
We now use the 'Pop' LSTM model to generate a pop lyric. And verify the result with our classification model

In [5]:
tf.reset_default_graph()
run("pop")

generating...
./checkpoints/pop/lyric-99
INFO:tensorflow:Restoring parameters from ./checkpoints/pop/lyric-99
restored ./checkpoints/pop/lyric-99



The generated lyrics: 

sometimes i wake up in you 
even even though i feel mean it 
hell give it up just for you 

do you wan na play a thing that you ever wan na do 
i really wan na keep comin over the wire 
you wild at me and this is beat everyday 
put your down game take your breath 
push walk away from my lying 
walk em take it , yeah 
and watch ( hey boy ) 
cause you know it dont matter anyway 
we the touch our worst people 
but it the time ( all right ) 
do they know ? ( tell me ) 
( yeah yeah ) 

you know im whole word 
all you want your world to go 
all my love and my world will see me 
baby , i wan na get i 
see there something about my lovin 
and when it doesnt sound 
yes i really wan na 
just like i never knew 
that ill be real you 
baby baby cmon 
i need a good little girl 

get your kick back what you want to do 
let get over

### Generate a rock lyric and verify
We now use the 'rock' LSTM model to generate a rock lyric. And verify the result with our classification model

In [8]:
tf.reset_default_graph()
run("rock")

generating...
./checkpoints/rock/lyric-98
INFO:tensorflow:Restoring parameters from ./checkpoints/rock/lyric-98
restored ./checkpoints/rock/lyric-98



The generated lyrics: 

shed the lover and the sun 
when i see you walking down 
from the window light 
you dont remember my mind 
my life ha have the holy reason 

all the different thing im soul 
bite em waving for myself 

i have lived the hour 
but it we saved wound 
taking thousand year in 
their house are getting chair 
at the bottom of the 
and in their street 
year got a wisdom 
and the hangman move his head 
and im his wicked tongue 
step aside , you did your model and 
he wa saying 
i like them we cant meet 

seemed my time to you , every day 
but once i wa hoping there wa here to end 
and you can trust to get away 

same before his run across blind 
and i wa my strength is with these 
existence will always be opened 
what is much better 
cause just so good when it seems 

im tired of working 
when your window fade 
im livin i

### Generate a Rap lyric and verify
We now use the 'rap' LSTM model to generate a rap lyric. And verify the result with our classification model

In [7]:
tf.reset_default_graph()
run("rap")

generating...
./checkpoints/rap/lyric-99
INFO:tensorflow:Restoring parameters from ./checkpoints/rap/lyric-99
restored ./checkpoints/rap/lyric-99



The generated lyrics: 

lil kim verse two 
uh , kill that shit , motherfucker 

and it not a dream 
hip - hop rule ? done the crew is 
youre my ego 
hoe grab to the 
la 
40 
my hand in the titanic 
like crack in the spot thats whoever to make em tour 
the weapon been twisted ricky lucky hitler 
and all the time to try , we could tell him how he sits : like a cock 

( wyclef ) city 
can i get the clientele by the 
but now we better straw and continue deep , liftin the shore high 
if i do and like the jam ive usually alive 
what my life is than so military , who the fuck at the contract 
aint me drinking faster than you aroma of fluid 
i can feel my music a favor , dont fuck with what we mad 
and the mission of the improvement and these nigga is track or be playing g to bungalow backwards , nigga 

ru - eee - oh - dammit , ay , nigga , heh 
