-
Notifications
You must be signed in to change notification settings - Fork 2
/
helpers.py
117 lines (102 loc) · 4.18 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pandas as pd
import numpy as np
import os
import json
import urllib.request
from urllib.request import Request
from pandas.io.json import json_normalize
from urllib.parse import quote
import time
import sys
# To authenticate to Google Cloud and download a ready to use model
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import torch
# To get lyrics
import lyricsgenius
# Data visualization
import matplotlib.pyplot as plt
# need to keep quering "Get Token" button here: https://developer.spotify.com/console/get-audio-features-track/?id=06AKEBrKUckW0KREUWRnvT
def get_spotify_valence(song_title, artist_name, spotify_api_key):
"""
Returns the a 0-1 real value number that represents the song's valence of None if not found.
"""
# Format given information
song = quote(song_title)
artist = quote(artist_name)
# Get song URI from artist and song (https://developer.spotify.com/documentation/web-api/reference/search/search/)
request = Request('https://api.spotify.com/v1/search?q=track:' + song + '%20artist:' + artist + '&type=track&limit=1')
request.add_header('Accept', 'application/json')
request.add_header('Content-Type', 'application/json')
request.add_header('Authorization', 'Bearer ' + spotify_api_key)
res = urllib.request.urlopen(request)
resObject = json.load(res)
# if not found
if (len(resObject["tracks"]["items"]) == 0):
print("Song {} not found".format(song_title))
return None
else:
songURI = resObject["tracks"]["items"][0]["id"]
audioRequest = Request('https://api.spotify.com/v1/audio-features/' + songURI)
audioRequest.add_header('Accept', 'application/json')
audioRequest.add_header('Content-Type', 'application/json')
audioRequest.add_header('Authorization', 'Bearer ' + spotify_api_key)
audioRes = urllib.request.urlopen(audioRequest)
jsonObject = json.load(audioRes)
valence = jsonObject["valence"]
print("Found valence: {:.2f} of the song: {} - {}".format(valence, song_title, artist_name))
return valence
def load_model(model, shared_file_id, path, use_trained_model=True,colab=False):
# Use your trained model
if use_trained_model:
model.load_state_dict(torch.load(path))
print('Loaded the trained model Successfully')
# Use existing model
else:
# When run in colab need a small modification to connect to Drive.
if colab:
auth.authenticate_user()
gauth = GoogleAuth()
# Download json metadata
gauth.credentials = GoogleCredentials.get_application_default()
# When run in console
else:
gauth = GoogleAuth()
# Create local webserver which automatically handles authentication.
gauth.LocalWebserverAuth()
# Create GoogleDrive instance with authenticated GoogleAuth instance.
drive = GoogleDrive(gauth)
# Initialize GoogleDriveFile instance with file id.
file_object = drive.CreateFile({'id':shared_file_id})
# Download file with name MODEL_NAME
file_object.GetContentFile(path)
print('Downloaded model Successfully')
model.load_state_dict(torch.load(path))
print('Loaded the ready-to-use model Successfully')
def get_lyrics_from_genius(song, artist, token):
genius = lyricsgenius.Genius(token, remove_section_headers=True)
song_obj = genius.search_song(song,artist)
return song_obj.lyrics
def plot_training(num_epochs, val_loss, accuracy, pos_accuracy, neg_accuracy, neutral_accuracy):
plt.title("Training Evaluaiton per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Metrics")
plt.gca().set_prop_cycle(color=['black', 'blue', 'green', 'red', 'orange'])
# Set x axis
plt.xticks(range(num_epochs), range(1,num_epochs+1))
plt.plot(val_loss)
plt.plot(accuracy)
plt.plot(pos_accuracy)
plt.plot(neg_accuracy)
plt.plot(neutral_accuracy)
plt.legend(['val_loss', 'accuracy', 'pos_accuracy', 'neg_accuracy', 'neutral_accuracy'], loc='lower left')
plt.show()
def plot_histogram(distribution,bins_num, xlabel, ylabel, title):
print('{} avg. value: {:.2f}'.format(xlabel, np.mean(distribution)))
plt.hist(distribution, bins=bins_num)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.show()