# MVP Model Inference

The purpose of this notebook is build utlity functions for generating predictions from a pre-trained setlist model

In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import numpy as np
from keras.models import load_model
from keras.preprocessing.sequence import pad_sequences

module_path = os.path.abspath(os.path.join('../'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
import src

Using TensorFlow backend.


## Load in Setlist Data

In [2]:
# load X data

X_test = src.util.load_pickle_object('../data/processed/mvp-setlist-modeling/seqlen-150/X_test.pkl')

In [3]:
X_test.shape

(3529, 150)

##### The testing data consists of ~3.5k sequences of 150 encoded songs

In [4]:
X_test[-1]

array([576,   9, 181,  57, 481,  15, 118, 325, 727, 451, 320, 817, 699,
       272,   7, 148, 274,   8, 729,  46, 118, 855,  96, 143, 239, 745,
       533, 194, 576,   9, 545, 733,  15, 707, 186, 624, 305, 861, 384,
       148,  67,   7, 827,   8, 148, 733, 674, 729,  46, 545, 384, 451,
       320, 817, 746, 727, 533,   9, 120, 571, 481,  96, 178, 686, 745,
       855, 647, 291, 624,  15, 167,   7, 691,   8, 272, 213, 178,  96,
       481, 551, 181, 194, 727, 576,   9, 380,  92, 733,  15, 241, 861,
         7, 274,   8, 533, 855, 226, 807, 786, 792, 733, 674, 686, 577,
       861, 274,   9, 256, 451, 320, 817, 369, 727, 647, 148, 384, 554,
       124, 305, 347,   7, 120, 675, 827,   8,   7, 167,   8, 272, 861,
       807, 686, 533, 241, 551, 181, 120,   9, 451, 320, 817, 213, 792,
       745, 727, 325, 274,   7, 148, 305])

In [5]:
print(X_test[-1])

[576   9 181  57 481  15 118 325 727 451 320 817 699 272   7 148 274   8
 729  46 118 855  96 143 239 745 533 194 576   9 545 733  15 707 186 624
 305 861 384 148  67   7 827   8 148 733 674 729  46 545 384 451 320 817
 746 727 533   9 120 571 481  96 178 686 745 855 647 291 624  15 167   7
 691   8 272 213 178  96 481 551 181 194 727 576   9 380  92 733  15 241
 861   7 274   8 533 855 226 807 786 792 733 674 686 577 861 274   9 256
 451 320 817 369 727 647 148 384 554 124 305 347   7 120 675 827   8   7
 167   8 272 861 807 686 533 241 551 181 120   9 451 320 817 213 792 745
 727 325 274   7 148 305]


## Load Model and Encodings

In [6]:
# load model
model = load_model('../models/mvp-setlist-modeling/model.nn_arch_2-150-seqlen-100-lstmunits-0.5-b_dropout-0.5-a_dropout.hdf5')

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Use tf.cast instead.


In [7]:
# load encoding mappings
idx_to_song = src.util.load_pickle_object('../data/processed/mvp-setlist-modeling/seqlen-150/idx_to_song.pkl')
song_to_idx = src.util.load_pickle_object('../data/processed/mvp-setlist-modeling/seqlen-150/song_to_idx.pkl')

### Test a prediction

In [8]:
# get an input sequence of the last show
test_seq = X_test[-69]

In [9]:
test_seq

array([245, 126,   7, 826,   8, 108, 531, 373, 538, 121, 720, 618, 647,
       482, 280, 624, 675,   9, 424, 167, 730, 364, 382, 364, 609, 533,
       398,  32,   7, 256,   8, 344, 181, 843, 855, 643, 247, 282, 576,
       291,   9, 321, 742, 541, 163, 741, 341, 710, 302, 360, 343,   4,
       596, 202,  62, 184, 743, 414, 546,  10, 861, 369,  13, 674,   7,
       474,   8, 787, 181, 535, 538, 619, 545, 691, 582,   9, 757, 373,
       727,  57, 729,  46, 729, 397, 299, 746,   7, 419,   8,  96, 577,
       690, 734, 364, 185, 432, 282, 124,   9, 247, 594, 861, 159, 861,
       663, 643,  15, 675,   7, 292,   8, 121, 451,  13, 531, 817, 720,
       618, 855, 657,  32, 226, 126,   9,  29, 167, 674, 792, 238, 627,
       245, 672, 307, 576,   7,  17, 274,   8, 473, 473, 407,  96, 279,
       545, 325, 690, 346, 647, 299,   9])

In [10]:
[idx_to_song[idx] for idx in test_seq]

['Frankenstein',
 'Chalk Dust Torture',
 '<ENCORE>',
 'While My Guitar Gently Weeps',
 '<SET1>',
 'Buried Alive',
 'Poor Heart',
 'Julius',
 'Punch You in the Eye',
 'Cars Trucks Buses',
 'The Horse',
 'Silent in the Morning',
 'Split Open and Melt',
 'NICU',
 'Gumbo',
 'Slave to the Traffic Light',
 'Sweet Adeline',
 '<SET2>',
 'Makisupa Policeman',
 'David Bowie',
 'The Mango Song',
 "It's Ice",
 'Kung',
 "It's Ice",
 'Shaggy Dog',
 'Possum',
 'Lifeboy',
 'Amazing Grace',
 '<ENCORE>',
 'Funky Bitch',
 '<SET1>',
 'Icculus',
 'Divided Sky',
 'Wilson',
 'Ya Mar',
 'Sparkle',
 'Free',
 'Guyute',
 'Run Like an Antelope',
 'Harpua',
 '<SET2>',
 'I Am the Sea',
 'The Real Me',
 'Quadrophenia',
 'Cut My Hair',
 'The Punk Meets the Godfather',
 "I'm One",
 'The Dirty Jobs',
 'Helpless Dancer',
 'Is It In My Head?',
 "I've Had Enough",
 '5:15',
 'Sea and Sand',
 'Drowned',
 'Bell Boy',
 'Doctor Jimmy',
 'The Rock',
 'Love',
 "Reign O'er Me",
 '<SET3>',
 'You Enjoy Myself',
 'Jesus Just Left Ch

In [11]:
print(np.array([test_seq]))

[[245 126   7 826   8 108 531 373 538 121 720 618 647 482 280 624 675   9
  424 167 730 364 382 364 609 533 398  32   7 256   8 344 181 843 855 643
  247 282 576 291   9 321 742 541 163 741 341 710 302 360 343   4 596 202
   62 184 743 414 546  10 861 369  13 674   7 474   8 787 181 535 538 619
  545 691 582   9 757 373 727  57 729  46 729 397 299 746   7 419   8  96
  577 690 734 364 185 432 282 124   9 247 594 861 159 861 663 643  15 675
    7 292   8 121 451  13 531 817 720 618 855 657  32 226 126   9  29 167
  674 792 238 627 245 672 307 576   7  17 274   8 473 473 407  96 279 545
  325 690 346 647 299   9]]


In [12]:
np.array([test_seq]).shape

(1, 150)

In [13]:
# make prediction by feeding in a sequence of 150 encoded songs
next_song = model.predict_classes(np.array([test_seq]))

# lookup song name
idx_to_song[next_song.item()]

'Also Sprach Zarathustra'

### Generate a setlist

In [14]:
def generate_full_setlist(model, seed_setlist, n_songs):
    "takes in a length 100 np array of previous songs and generates full next setlist"
    
    setlist = []
    
    for _ in range(n_songs):
        # truncate sequences
        seq = pad_sequences([seed_setlist], maxlen=150, truncating='pre')[0]
        # predict next song
        next_song = model.predict_classes(np.array([seq])).item()
        # un-encode the song
        next_song_clean = idx_to_song[next_song]
        # append to generated list
        setlist.append(next_song_clean)
        # update seed_setlist to re-run for the next song
        seed_setlist = np.append(seed_setlist, next_song)
        
    return setlist

In [15]:
generate_full_setlist(model, test_seq, 25)

['Also Sprach Zarathustra',
 'David Bowie',
 'The Horse',
 'Silent in the Morning',
 'Reba',
 'You Enjoy Myself',
 'Hold Your Head Up',
 "Cracklin' Rosie",
 'Hold Your Head Up',
 'Harry Hood',
 '<ENCORE>',
 'Bold As Love',
 '<SET1>',
 'Runaway Jim',
 'Foam',
 'Bouncing Around the Room',
 'Split Open and Melt',
 'If I Could',
 'Scent of a Mule',
 'Stash',
 'The Squirming Coil',
 '<SET2>',
 'Also Sprach Zarathustra',
 'David Bowie',
 'The Horse']

In [16]:
test_seq

array([245, 126,   7, 826,   8, 108, 531, 373, 538, 121, 720, 618, 647,
       482, 280, 624, 675,   9, 424, 167, 730, 364, 382, 364, 609, 533,
       398,  32,   7, 256,   8, 344, 181, 843, 855, 643, 247, 282, 576,
       291,   9, 321, 742, 541, 163, 741, 341, 710, 302, 360, 343,   4,
       596, 202,  62, 184, 743, 414, 546,  10, 861, 369,  13, 674,   7,
       474,   8, 787, 181, 535, 538, 619, 545, 691, 582,   9, 757, 373,
       727,  57, 729,  46, 729, 397, 299, 746,   7, 419,   8,  96, 577,
       690, 734, 364, 185, 432, 282, 124,   9, 247, 594, 861, 159, 861,
       663, 643,  15, 675,   7, 292,   8, 121, 451,  13, 531, 817, 720,
       618, 855, 657,  32, 226, 126,   9,  29, 167, 674, 792, 238, 627,
       245, 672, 307, 576,   7,  17, 274,   8, 473, 473, 407,  96, 279,
       545, 325, 690, 346, 647, 299,   9])

In [17]:
def generate_full_setlist2(model, seed_setlist):
    '''
    Generate the remainder of a setlist given the previous 150 songs.
    
    Args:
        model (.hdf5) - a Phish prediction tensorflow model
        seed_setlist (ndarray) - encoded array of shape (150,)
    
    Returns:
        setlist (list) - generated sequence of encoded songs to complete the show
    
    '''
    
    setlist = []
    setlist_start = False
    pred_count = 0
    
    # generate remainder of setlist
    while setlist_start == False:
        # truncate sequences
        seq = pad_sequences([seed_setlist], maxlen=150, truncating='pre')[0]
        # predict next song
        next_song = model.predict_classes(np.array([seq])).item()
        # increment prediction counter
        pred_count += 1
        # check if a new setlist start is predicted (and its not the first song)
        if next_song == 8 and pred_count > 1:
            setlist_start = True
        else:
            # append to generated list
            setlist.append(next_song)
            # update seed_setlist to re-run for the next song
            seed_setlist = np.append(seed_setlist, next_song)
            
            
    return setlist

In [18]:
[idx_to_song[idx] for idx in generate_full_setlist2(model, test_seq)]

['Also Sprach Zarathustra',
 'David Bowie',
 'The Horse',
 'Silent in the Morning',
 'Reba',
 'You Enjoy Myself',
 'Hold Your Head Up',
 "Cracklin' Rosie",
 'Hold Your Head Up',
 'Harry Hood',
 '<ENCORE>',
 'Bold As Love']

In [23]:
test_seq.shape

(150,)

In [None]:
yhat = model.predict_classes()

In [18]:
reverse_mapping[129]

'Character Zero'

In [54]:
np.fromiter(reverse_mapping.keys(), dtype=int)

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
       156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
       169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 18

In [43]:
# the final sequence
sequences_array[-1]

array([256, 843, 569, 124,   8, 731, 666,   3, 585, 386, 659, 126, 822,
       529, 810,  44, 576,   9, 196, 222, 604, 790, 292, 517, 292,  10,
       443,  45, 591, 619, 589, 401, 561, 674,   7, 727, 129,   8, 649,
       790, 247, 833, 218, 814, 556, 256, 585,   9, 637, 642, 319, 171,
        29,  57, 746,   7, 810,  92,   8, 861, 785,   1, 492, 210, 779,
       607, 589,   9, 602, 443, 624, 533, 586, 807,   7, 464,   8, 708,
       538,  83, 174, 465, 181, 659, 126,   9, 232, 451, 817, 255, 786,
       119, 259, 591,   7, 619, 431, 382,  67, 627, 787])

In [52]:
len(sequences_array[-1][1:])

100

In [51]:
len(X_train[0])

100

In [59]:
seq = sequences_array[-1][1:]

In [61]:
seq.ndim

1

In [63]:
np.array([seq]).ndim

2

In [94]:
model.predict_classes(np.array([seq])).item()

8

In [70]:
reverse_mapping[787]

'Tweezer Reprise'

In [107]:
def generate_full_setlist(model, seed_setlist, n_songs):
    "takes in a length 100 np array of previous songs and generates full next setlist"
    
    setlist = []
    
    for _ in range(n_songs):
        # truncate sequences
        seq = pad_sequences([seed_setlist], maxlen=100, truncating='pre')[0]
        # predict next song
        next_song = model.predict_classes(np.array([seq])).item()
        # un-encode the song
        next_song_ue = reverse_mapping[next_song]
        # append to list
        setlist.append(next_song_ue)
        
        # update seed_setlist
        seed_setlist = np.append(seed_setlist, next_song)
        
    
    return setlist
        
    

In [110]:
seq = sequences_array[-1][1:]

generate_full_setlist(model, seq, 25)

['<SET1>',
 'The Landlady',
 'Blaze On',
 'Free',
 'Breath and Burning',
 'Sugar Shack',
 'Things People Do',
 'Devotion To a Dream',
 'Sugar Shack',
 'Lawn Boy',
 'More',
 '<SET2>',
 'Blaze On',
 'Fuego',
 'Taste',
 'Carini',
 'Dirt',
 'Your Pet Cat',
 'Harry Hood',
 '<ENCORE>',
 'I Am the Walrus',
 '<SET1>',
 'Sample in a Jar',
 'The Moma Dance',
 'Rift']

In [90]:
pad_sequences([seq], maxlen=100, truncating='pre')[0]

array([843, 569, 124,   8, 731, 666,   3, 585, 386, 659, 126, 822, 529,
       810,  44, 576,   9, 196, 222, 604, 790, 292, 517, 292,  10, 443,
        45, 591, 619, 589, 401, 561, 674,   7, 727, 129,   8, 649, 790,
       247, 833, 218, 814, 556, 256, 585,   9, 637, 642, 319, 171,  29,
        57, 746,   7, 810,  92,   8, 861, 785,   1, 492, 210, 779, 607,
       589,   9, 602, 443, 624, 533, 586, 807,   7, 464,   8, 708, 538,
        83, 174, 465, 181, 659, 126,   9, 232, 451, 817, 255, 786, 119,
       259, 591,   7, 619, 431, 382,  67, 627, 787], dtype=int32)

In [83]:
np.array([seq])

(1, 100)

In [104]:
seq = sequences_array[-1][1:]
seq

array([843, 569, 124,   8, 731, 666,   3, 585, 386, 659, 126, 822, 529,
       810,  44, 576,   9, 196, 222, 604, 790, 292, 517, 292,  10, 443,
        45, 591, 619, 589, 401, 561, 674,   7, 727, 129,   8, 649, 790,
       247, 833, 218, 814, 556, 256, 585,   9, 637, 642, 319, 171,  29,
        57, 746,   7, 810,  92,   8, 861, 785,   1, 492, 210, 779, 607,
       589,   9, 602, 443, 624, 533, 586, 807,   7, 464,   8, 708, 538,
        83, 174, 465, 181, 659, 126,   9, 232, 451, 817, 255, 786, 119,
       259, 591,   7, 619, 431, 382,  67, 627, 787])

In [106]:
np.append(seq, 4444)

array([ 843,  569,  124,    8,  731,  666,    3,  585,  386,  659,  126,
        822,  529,  810,   44,  576,    9,  196,  222,  604,  790,  292,
        517,  292,   10,  443,   45,  591,  619,  589,  401,  561,  674,
          7,  727,  129,    8,  649,  790,  247,  833,  218,  814,  556,
        256,  585,    9,  637,  642,  319,  171,   29,   57,  746,    7,
        810,   92,    8,  861,  785,    1,  492,  210,  779,  607,  589,
          9,  602,  443,  624,  533,  586,  807,    7,  464,    8,  708,
        538,   83,  174,  465,  181,  659,  126,    9,  232,  451,  817,
        255,  786,  119,  259,  591,    7,  619,  431,  382,   67,  627,
        787, 4444])