## Read in Data

In [1]:
import numpy as np
import csv

# Predict via the median number of plays.

train_file = 'train.csv'
test_file  = 'test.csv'
soln_file  = 'global_median.csv'

artist_file = "artists.csv"
profile_file = "profiles.csv"

# Load the training data.
train_data = {}
with open(train_file, 'r') as train_fh:
    train_csv = csv.reader(train_fh, delimiter=',', quotechar='"')
    next(train_csv, None)
    for row in train_csv:
        user   = row[0]
        artist = row[1]
        plays  = int(row[2])
    
        if not user in train_data:
            train_data[user] = {}
        
        train_data[user][artist] = plays


artist_data = {}        
with open(artist_file, 'r') as artist_fh:
    artist_csv = csv.reader(artist_fh, delimiter=',', quotechar='"')
    next(artist_csv, None)
    for row in artist_csv:
        artist   = row[0]
        name = row[1]
          
        artist_data[artist] = name

## Convert training data to user by artist sparse matrix

In [2]:
user_dict = {user : i for i, user in enumerate(set(train_data.keys()))}
artist_dict = {artist : i for i, artist in enumerate(set(artist_data.keys()))}

In [80]:
len(user_dict)

233286

In [83]:
from scipy.sparse import csr_matrix
data, row, col = [], [] , []

for user, artists in train_data.iteritems():
    
    for artist,plays in artists.iteritems():
        row.append(user_dict[user])
        col.append(artist_dict[artist])
        data.append(plays)
        
play_sp = csr_matrix((data, (row, col)), shape=(len(user_dict), len(artist_dict)))      

In [84]:
play_sp

<233286x2000 sparse matrix of type '<type 'numpy.int64'>'
	with 4154804 stored elements in Compressed Sparse Row format>

In [94]:
user_total = play_sp.sum(axis=1)
user_total_sp=csr_matrix(user_total)
user_total_sp.shape

(233286, 1)

In [86]:
#Truncated SVD -- reduce dimension of the play matrix 
from sklearn.decomposition import TruncatedSVD

svd=TruncatedSVD(n_components=20)
play_svd = svd.fit_transform(play_sp)
print np.sum(svd.explained_variance_ratio_)
print svd.explained_variance_ratio_

0.340847168361
[ 0.07628592  0.0390799   0.02300728  0.01844125  0.01611094  0.01477318
  0.01404703  0.0136785   0.01338818  0.0122241   0.01084059  0.01093041
  0.01080837  0.01057224  0.0099792   0.00981099  0.0096101   0.00921232
  0.00914396  0.00890275]


In [92]:
play_svd_sp=csr_matrix(play_svd)

## Convert profile data to sparse matrix

In [139]:
import pandas as pd
profile_data = pd.read_csv("profiles.csv")

In [140]:
profile_data.head(10)

Unnamed: 0,user,sex,age,country
0,fa40b43298ba3f8aa52e8e8863faf2e2171e0b5d,f,25.0,Sweden
1,5909125332c108365a26ccf0ee62636eee08215c,m,29.0,Iceland
2,d1867cbda35e0d48e9a8390d9f5e079c9d99ea96,m,30.0,United States
3,63268cce0d68127729890c1691f62d5be5abd87c,m,21.0,Germany
4,02871cd952d607ba69b64e2e107773012c708113,m,24.0,Netherlands
5,0938eb3d1b449b480c4e2431c457f6ead7063a34,m,22.0,United States
6,e4c6b36e65db3d48474dd538fe74d2dbb5a2e79e,f,,United States
7,b97479f9a563a5c43b423a976f51fd509e1ec5ba,f,,Poland
8,3bb020df0ff376dfdded4d5e63e2d35a50b3c535,m,,United States
9,f3fb86c0f024f640cae3fb479f3a27e0dd499891,,16.0,Ukraine


In [141]:
profile_data.describe()

Unnamed: 0,age
count,188444.0
mean,24.5174
std,21.853296
min,-1337.0
25%,20.0
50%,23.0
75%,27.0
max,1002.0


In [166]:
country_dict = {'Afghanistan': 'Asia',
 'Albania': 'Europe',
 'Algeria': 'Africa',
 'Andorra': 'Europe',
 'Angola': 'Africa',
 'Antigua and Barbuda': 'North America',
 'Argentina': 'South America',
 'Armenia': 'Asia',
 'Australia': 'Oceania',
 'Austria': 'Europe',
 'Azerbaijan': 'Asia',
 'Bahamas': 'North America',
 'Bahrain': 'Asia',
 'Bangladesh': 'Asia',
 'Barbados': 'North America',
 'Belarus': 'Europe',
 'Belgium': 'Europe',
 'Belize': 'North America',
 'Benin': 'Africa',
 'Bhutan': 'Asia',
 'Bolivia': 'South America',
 'Bosnia and Herzegovina': 'Europe',
 'Botswana': 'Africa',
 'Brazil': 'South America',
 'Brunei Darussalam': 'Asia',
 'Bulgaria': 'Europe',
 'Burkina Faso': 'Africa',
 'Burundi': 'Africa',
 'Cambodia': 'Asia',
 'Cameroon': 'Africa',
 'Canada': 'North America',
 'Cape Verde': 'Africa',
 'Central African Republic': 'Africa',
 'Chad': 'Africa',
 'Chile': 'South America',
 'Colombia': 'South America',
 'Comoros': 'Africa',
 'Costa Rica': 'North America',
 'Croatia': 'Europe',
 'Cuba': 'North America',
 'Cyprus': 'Asia',
 'Czech Republic': 'Europe',
 "C\xc3\xb4te d'Ivoire": 'Africa',
 'Democratic Republic of the Congo': 'Africa',
 'Denmark': 'Europe',
 'Djibouti': 'Africa',
 'Dominica': 'North America',
 'Dominican Republic': 'North America',
 'East Timor': 'Asia',
 'Ecuador': 'South America',
 'Egypt': 'Africa',
 'El Salvador': 'North America',
 'Equatorial Guinea': 'Africa',
 'Eritrea': 'Africa',
 'Estonia': 'Europe',
 'Ethiopia': 'Africa',
 'Federated States of Micronesia': 'Oceania',
 'Fiji': 'Oceania',
 'Finland': 'Europe',
 'France': 'Europe',
 'Gabon': 'Africa',
 'Georgia': 'Asia',
 'Germany': 'Europe',
 'Ghana': 'Africa',
 'Greece': 'Europe',
 'Grenada': 'North America',
 'Guatemala': 'North America',
 'Guinea': 'Africa',
 'Guinea-Bissau': 'Africa',
 'Guyana': 'South America',
 'Haiti': 'North America',
 'Honduras': 'North America',
 'Hungary': 'Europe',
 'Iceland': 'Europe',
 'India': 'Asia',
 'Indonesia': 'Asia',
 'Iran': 'Asia',
 'Iraq': 'Asia',
 'Israel': 'Asia',
 'Italy': 'Europe',
 'Jamaica': 'North America',
 'Japan': 'Asia',
 'Jordan': 'Asia',
 'Kazakhstan': 'Asia',
 'Kenya': 'Africa',
 'Kingdom of the Netherlands': 'Europe',
 'Kiribati': 'Oceania',
 'Kuwait': 'Asia',
 'Kyrgyzstan': 'Asia',
 'Laos': 'Asia',
 'Latvia': 'Europe',
 'Lebanon': 'Asia',
 'Lesotho': 'Africa',
 'Liberia': 'Africa',
 'Libya': 'Africa',
 'Liechtenstein': 'Europe',
 'Lithuania': 'Europe',
 'Luxembourg': 'Europe',
 'Macedonia': 'Europe',
 'Madagascar': 'Africa',
 'Malawi': 'Africa',
 'Malaysia': 'Asia',
 'Maldives': 'Asia',
 'Mali': 'Africa',
 'Malta': 'Europe',
 'Marshall Islands': 'Oceania',
 'Mauritania': 'Africa',
 'Mauritius': 'Africa',
 'Mexico': 'North America',
 'Moldova': 'Europe',
 'Monaco': 'Europe',
 'Mongolia': 'Asia',
 'Montenegro': 'Europe',
 'Morocco': 'Africa',
 'Mozambique': 'Africa',
 'Myanmar': 'Asia',
 'Namibia': 'Africa',
 'Nauru': 'Oceania',
 'Nepal': 'Asia',
 'New Zealand': 'Oceania',
 'Nicaragua': 'North America',
 'Niger': 'Africa',
 'Nigeria': 'Africa',
 'North Korea': 'Asia',
 'Norway': 'Europe',
 'Oman': 'Asia',
 'Pakistan': 'Asia',
 'Palau': 'Oceania',
 'Panama': 'North America',
 'Papua New Guinea': 'Oceania',
 'Paraguay': 'South America',
 "People's Republic of China": 'Asia',
 'Peru': 'South America',
 'Philippines': 'Asia',
 'Poland': 'Europe',
 'Portugal': 'Europe',
 'Qatar': 'Asia',
 'Republic of Ireland': 'Europe',
 'Republic of the Congo': 'Africa',
 'Romania': 'Europe',
 'Russia': 'Europe',
 'Rwanda': 'Africa',
 'Saint Kitts and Nevis': 'North America',
 'Saint Lucia': 'North America',
 'Saint Vincent and the Grenadines': 'North America',
 'Samoa': 'Oceania',
 'San Marino': 'Europe',
 'Saudi Arabia': 'Asia',
 'Senegal': 'Africa',
 'Serbia': 'Europe',
 'Seychelles': 'Africa',
 'Sierra Leone': 'Africa',
 'Singapore': 'Asia',
 'Slovakia': 'Europe',
 'Slovenia': 'Europe',
 'Solomon Islands': 'Oceania',
 'Somalia': 'Africa',
 'South Africa': 'Africa',
 'South Korea': 'Asia',
 'Spain': 'Europe',
 'Sri Lanka': 'Asia',
 'Sudan': 'Africa',
 'Suriname': 'South America',
 'Swaziland': 'Africa',
 'Sweden': 'Europe',
 'Switzerland': 'Europe',
 'Syria': 'Asia',
 'S\xc3\xa3o Tom\xc3\xa9 and Pr\xc3\xadncipe': 'Africa',
 'Tajikistan': 'Asia',
 'Tanzania': 'Africa',
 'Thailand': 'Asia',
 'The Gambia': 'Africa',
 'Togo': 'Africa',
 'Tonga': 'Oceania',
 'Trinidad and Tobago': 'North America',
 'Tunisia': 'Africa',
 'Turkey': 'Asia',
 'Turkmenistan': 'Asia',
 'Tuvalu': 'Oceania',
 'Uganda': 'Africa',
 'Ukraine': 'Europe',
 'United Arab Emirates': 'Asia',
 'United Kingdom': 'Europe',
 'United States': 'North America',
 'Uruguay': 'South America',
 'Uzbekistan': 'Asia',
 'Vanuatu': 'Oceania',
 'Vatican City': 'Europe',
 'Venezuela': 'South America',
 'Vietnam': 'Asia',
 'Yemen': 'Asia',
 'Zambia': 'Africa',
 'Zimbabwe': 'Africa'}

In [168]:
continent = []
for i in profile_data['country']:
    if i in country_dict.keys():
        continent.append(country_dict[i])
    else:
        continent.append(i)


In [169]:
profile_data['continent'] = continent

In [170]:
print "The number of females:" , np.sum(profile_data.sex=="f")
print "The number of males:" , np.sum(profile_data.sex=="m")
probf=np.sum(profile_data.sex=="f")/np.float((np.sum(profile_data.sex=="m")+np.sum(profile_data.sex=="f")))
print "Proportional of females:", probf

The number of females: 59391
The number of males: 154360
Proportional of females: 0.27785133169


In [171]:
profile_data.isnull().sum()

user          0
sex           0
age           0
country       0
female        0
age2          0
countryidx    0
continent     0
dtype: int64

In [172]:
profile_data.shape

(233286, 8)

In [145]:
profile_data['sex']=profile_data['sex'].fillna("missing")

In [146]:
def imputegender(row):
   
    if row == "missing":
        return np.random.choice([0,1],p=[1-probf,probf])
    elif row == "f":
        return 1
    elif row == "m":
        return 0
    
profile_data['female']= profile_data['sex'].apply(lambda x: imputegender(x))
profile_data.head(10)

Unnamed: 0,user,sex,age,country,female
0,fa40b43298ba3f8aa52e8e8863faf2e2171e0b5d,f,25.0,Sweden,1
1,5909125332c108365a26ccf0ee62636eee08215c,m,29.0,Iceland,0
2,d1867cbda35e0d48e9a8390d9f5e079c9d99ea96,m,30.0,United States,0
3,63268cce0d68127729890c1691f62d5be5abd87c,m,21.0,Germany,0
4,02871cd952d607ba69b64e2e107773012c708113,m,24.0,Netherlands,0
5,0938eb3d1b449b480c4e2431c457f6ead7063a34,m,22.0,United States,0
6,e4c6b36e65db3d48474dd538fe74d2dbb5a2e79e,f,,United States,1
7,b97479f9a563a5c43b423a976f51fd509e1ec5ba,f,,Poland,1
8,3bb020df0ff376dfdded4d5e63e2d35a50b3c535,m,,United States,0
9,f3fb86c0f024f640cae3fb479f3a27e0dd499891,missing,16.0,Ukraine,0


In [147]:
female_age = round(profile_data[profile_data.female==1].age.mean())
male_age = round(profile_data[profile_data.female==0].age.mean())

print "Females avg. age:" , round(female_age)
print "Males avg. age:" , round(male_age)
# probf=np.sum(profile_data.sex=="f")/np.float((np.sum(profile_data.sex=="m")+np.sum(profile_data.sex=="f")))
# print "Proportional of females:", probf

Females avg. age: 23.0
Males avg. age: 25.0


In [148]:

def imputeage(row):
   
    if row['age'] == "missing":
        if row['female']==1:
            return female_age
        else:
            return male_age
    else:
        return row['age']

profile_data['age']=profile_data['age'].fillna("missing")
profile_data['age2']= profile_data.apply(lambda x: imputeage(x),axis=1)
profile_data.head(10)

Unnamed: 0,user,sex,age,country,female,age2
0,fa40b43298ba3f8aa52e8e8863faf2e2171e0b5d,f,25,Sweden,1,25
1,5909125332c108365a26ccf0ee62636eee08215c,m,29,Iceland,0,29
2,d1867cbda35e0d48e9a8390d9f5e079c9d99ea96,m,30,United States,0,30
3,63268cce0d68127729890c1691f62d5be5abd87c,m,21,Germany,0,21
4,02871cd952d607ba69b64e2e107773012c708113,m,24,Netherlands,0,24
5,0938eb3d1b449b480c4e2431c457f6ead7063a34,m,22,United States,0,22
6,e4c6b36e65db3d48474dd538fe74d2dbb5a2e79e,f,missing,United States,1,23
7,b97479f9a563a5c43b423a976f51fd509e1ec5ba,f,missing,Poland,1,23
8,3bb020df0ff376dfdded4d5e63e2d35a50b3c535,m,missing,United States,0,25
9,f3fb86c0f024f640cae3fb479f3a27e0dd499891,missing,16,Ukraine,0,16


In [173]:
profile_data.head()

Unnamed: 0,user,sex,age,country,female,age2,countryidx,continent
0,fa40b43298ba3f8aa52e8e8863faf2e2171e0b5d,f,25,Sweden,1,25,0,Europe
1,5909125332c108365a26ccf0ee62636eee08215c,m,29,Iceland,0,29,1,Europe
2,d1867cbda35e0d48e9a8390d9f5e079c9d99ea96,m,30,United States,0,30,2,North America
3,63268cce0d68127729890c1691f62d5be5abd87c,m,21,Germany,0,21,3,Europe
4,02871cd952d607ba69b64e2e107773012c708113,m,24,Netherlands,0,24,4,Netherlands


In [174]:
continent_dict ={}
i=0
for continent in profile_data.continent.unique():
    continent_dict[continent]=i
    i=i+1
    
profile_data['continentidx'] = profile_data['continent'].apply(lambda x: continent_dict[x])
profile_data.head()

Unnamed: 0,user,sex,age,country,female,age2,countryidx,continent,continentidx
0,fa40b43298ba3f8aa52e8e8863faf2e2171e0b5d,f,25,Sweden,1,25,0,Europe,0
1,5909125332c108365a26ccf0ee62636eee08215c,m,29,Iceland,0,29,1,Europe,0
2,d1867cbda35e0d48e9a8390d9f5e079c9d99ea96,m,30,United States,0,30,2,North America,1
3,63268cce0d68127729890c1691f62d5be5abd87c,m,21,Germany,0,21,3,Europe,0
4,02871cd952d607ba69b64e2e107773012c708113,m,24,Netherlands,0,24,4,Netherlands,2


In [149]:
countrydict ={}
i=0
for country in profile_data.country.unique():
    countrydict[country]=i
    i=i+1
    
profile_data['countryidx'] = profile_data['country'].apply(lambda x: countrydict[x])
profile_data.head()

Unnamed: 0,user,sex,age,country,female,age2,countryidx
0,fa40b43298ba3f8aa52e8e8863faf2e2171e0b5d,f,25,Sweden,1,25,0
1,5909125332c108365a26ccf0ee62636eee08215c,m,29,Iceland,0,29,1
2,d1867cbda35e0d48e9a8390d9f5e079c9d99ea96,m,30,United States,0,30,2
3,63268cce0d68127729890c1691f62d5be5abd87c,m,21,Germany,0,21,3
4,02871cd952d607ba69b64e2e107773012c708113,m,24,Netherlands,0,24,4


In [150]:
len(countrydict)

239

In [151]:
data, row, col = [], [] , []

for idx, dfrow in profile_data.iterrows():
    row.append(user_dict[dfrow['user']])
    #col.append(countrydict[dfrow['country']])
    col.append(countrydict[dfrow['continent']])
    data.append(1)
        
#country_sp = csr_matrix((data, (row, col)), shape=(len(user_dict), len(countrydict)))      
continent_sp = csr_matrix((data, (row, col)), shape=(len(user_dict), len(continent_dict)))      

In [153]:
import scipy.sparse as sc
other_sp=csr_matrix(profile_data[['age2','female']].values)
#profile_sp=sc.hstack([country_sp,other_sp])
profile_sp=sc.hstack([continent_sp,other_sp])
profile_sp

<233286x241 sparse matrix of type '<type 'numpy.float64'>'
	with 531448 stored elements in COOrdinate format>

In [154]:
main_sp = sc.hstack([play_svd_sp,profile_sp,user_total_sp])

## Model fitting

In [155]:
from sklearn.cluster import KMeans
random_state=123
class_pred = KMeans(n_clusters=10, random_state=random_state).fit_predict(main_sp)

## Prediction

In [164]:
data, row, col = [], [] , []

for user, artists in train_data.iteritems():
    
    for artist,plays in artists.iteritems():
        row.append(user_dict[user])
        col.append(artist_dict[artist])
        data.append(plays)
        
play_sp_col = sc.csc_matrix((data, (row, col)), shape=(len(user_dict), len(artist_dict)))      

In [165]:
play_sp/user_total_sp

ValueError: inconsistent shapes

In [158]:
indices = np.where(class_pred==1)[0]
out1 = M.tocsc()[:,indices]
out2 = M.tocsr()[indices,:]

array([    78,     83,     89, ..., 233217, 233255, 233263])

## standardized the # of plays

In [72]:
user_total_df = train_data.groupby('user').sum()
user_total=user_total_df.to_dict()

In [65]:
#user_total =user_total.reset_index()
user_total['plays']['f283c15ed4180e686384dc1de2a5cbf5f95ae269']
train_data['proportion'] = train_data.apply(lambda x: x.plays/np.float(user_total['plays'][x.user]),axis=1)

In [66]:
train_data.head()

Unnamed: 0,user,artist,plays,proportion
0,eb1c57ddc9e0e2d005169d3a1a96e8dd95e3af03,5a8e07d5-d932-4484-a7f7-e700793a9c94,554,0.034378
1,44ce793a6cd9d20f13f4a576a818ef983314bb5d,a3a92047-be1c-4f3e-8960-c4f8570984df,81,0.142857
2,da9cf3f557161d54b76f24db64be9cc76db008e3,eeb1195b-f213-4ce1-b28c-8565211f8e43,708,0.152883
3,8fa49ab25d425edcf05d44bfc1d5aea895287d81,a1419808-65d3-4d40-998c-1a0bac65eabc,265,0.044329
4,b85fcaef67d2669cd99b334b5e8c8705263db2cf,a3cb23fc-acd3-4ce0-8f36-1e5aa6a18432,220,0.141753


In [52]:
user_total.shape

(233286, 1)

In [90]:
user_df=user_total_df.reset_index()
user_df=user_df.rename(columns = {'plays':'total_plays'})
user_df.drop('proportion', axis=1, inplace=True)

tmp_data = pd.merge(train_data,profile_data,on='user')
tmp_data2 = pd.merge(tmp_data, user_df,on='user')
main_data = pd.merge(tmp_data2,artist_data,on='artist')


In [91]:
#main_data['gender'] = [1 if main_data['sex'] == "f" else 0]
main_data.head()

Unnamed: 0,user,artist,plays,proportion,sex,age,country,female,countryidx,total_plays,name
0,eb1c57ddc9e0e2d005169d3a1a96e8dd95e3af03,5a8e07d5-d932-4484-a7f7-e700793a9c94,554,0.034378,m,25,Sweden,0,0,16115,Robyn
1,0ff4166398f035b5fcb8824cc16c8daeb4643911,5a8e07d5-d932-4484-a7f7-e700793a9c94,169,0.082359,f,18,United Kingdom,1,16,2052,Robyn
2,b3f9fa56429c3b7fd348c471452e65747ba9ed50,5a8e07d5-d932-4484-a7f7-e700793a9c94,292,0.009833,m,23,United Kingdom,0,16,29697,Robyn
3,0ffff52af79555e8fe72289c429b2fdfc8ea684b,5a8e07d5-d932-4484-a7f7-e700793a9c94,92,0.012273,m,26,Germany,0,3,7496,Robyn
4,985253be0dc82ffa15a0ad006d0284aa4b7d1e3d,5a8e07d5-d932-4484-a7f7-e700793a9c94,159,0.011976,m,19,Sweden,0,0,13276,Robyn


In [98]:
main_data.to_csv("maindf.csv")

In [29]:
country_df = main_data[['countryidx']]
y_df= main_data[['plays'],['proportion'],['total_plays']]
y_set = y_df.as_matrix()
gender_set = main_data[['female']].as_matrix()

In [92]:
y_prob = main_data[['proportion']].as_matrix()

In [30]:
from sklearn.preprocessing import OneHotEncoder
enc = OneHotEncoder()
cindex = np.matrix(np.arange(239))

enc.fit(cindex.T)
country_set= enc.transform(country_df).toarray()


In [32]:
train_set = np.hstack((country_set,gender_set))

In [93]:
random_state= 123
from sklearn.cross_validation import train_test_split
itrain, itest = train_test_split(xrange(train_set.shape[0]), train_size=0.1,random_state=random_state)
mask=np.ones(train_set.shape[0], dtype='int')
mask[itrain]=1
mask[itest]=0
mask = (mask==1)

X_train = train_set[mask]
Y_train = y_prob[mask]

In [34]:
from sklearn.decomposition import PCA

pca=PCA(n_components=10)
pca.fit_transform(X_train)
pca.explained_variance_ratio_

array([ 0.18251138,  0.15761538,  0.08477288,  0.06648769,  0.04249111,
        0.03720076,  0.03357104,  0.03013427,  0.02924813,  0.02724646])

In [46]:
np.sum(pca.explained_variance_ratio_)

0.69127910657610625

In [35]:
pca_train=pca.fit_transform(X_train)

In [99]:
from sklearn.cluster import KMeans
class_pred = KMeans(n_clusters=10, random_state=random_state).fit_predict(pca_train)


In [102]:
class_1 = (class_pred==0)
class_2 = Y_train[class_pred ==1]
class_3 = Y_train[class_pred ==2 ]

In [101]:
print "class 1:", np.mean(class_1),np.std(class_1)
print "variance:", np.mean(class_2),np.std(class_2)
print "variance:" ,np.mean(class_3),np.std(class_3)

class 1: 0.0615278466692 0.0596720147168
variance: 0.0567964537502 0.0567446462991
variance: 0.0570431303064 0.0554216098786


In [103]:
train=main_data[mask]

In [171]:
def classtoplay(class_k, df,artist_list):
    cluster = df[mask]
    artist_count=dict(zip(artist_list,np.zeros(len(artist_list))))
    artist_prop = dict(zip(artist_list,np.zeros(len(artist_list))))
    artist_dist = dict(zip(artist_list,np.zeros(len(artist_list))))
    
    artist_include =set([])
    for index,row in cluster.iterrows():
        artist_count[row['name']] += 1
        artist_prop[row['name']] += row['proportion']
        artist_include.add(row['name'])

    total_prob=0    
    for artist in artist_include:
        artist_dist[artist] = artist_prop[artist] / np.float(artist_count[artist])
        total_prob += artist_dist[artist]
    
    for artist in artist_include:
         artist_dist[artist] = artist_dist[artist] /total_prob
    
    
    return artist_dist


In [172]:
type(train.name.values)

numpy.ndarray

In [173]:
artist_count=dict(zip(artist_list,np.zeros(len(artist_list))))

In [174]:
artist_count2=dict(zip(artist_list,np.ones(len(artist_list))))
np.array(artist_count2.values()) - np.array(artist_count.values())

array([ 1.,  1.,  1., ...,  1.,  1.,  1.])

In [175]:
classtoplay(class_1, main_data,artist_list)

{nan: 0.00043636253507237061,
 'Gigi D\xe2\x80\x99Agostino': 0.00047896986283453926,
 'Queens of the Stone Age': 0.00053126002468739826,
 'Deerhunter': 0.00057538479803613985,
 'Neil Young & Crazy Horse': 0.00034528834046350647,
 'Bo Kaspers orkester': 0.00061915781175633712,
 'Sondre Lerche': 0.00048899477243557587,
 'Nirvana': 0.0004662697884213433,
 'Massive Attack': 0.00051038305185167031,
 'Poison the Well': 0.00040238027215657254,
 'Billie the Vision & The Dancers': 0.00060414350072927924,
 'Goldfrapp': 0.00045957785517809736,
 'Bullet for My Valentine': 0.0005436647612305478,
 'De-Phazz': 0.0006008797410586808,
 'Gustavo Santaolalla': 0.00055428884492492473,
 'Sublime': 0.00054196582687029049,
 'a-ha': 0.00051473231011576494,
 'Burzum': 0.00053135686190069987,
 'Billy Bragg': 0.0004653850036450044,
 'The Crystal Method': 0.00051843233547380627,
 'PMMP': 0.00066879174267239511,
 'Twisted Sister': 0.00040722047240733137,
 'Astor Piazzolla': 0.00049078198095452272,
 'Max\xc3\xafmo 

In [None]:
for index,row in train.head(3).iterrows():
    print row.name

In [169]:
artist_include =set([])
artist_include.add('a')

In [170]:
artist_include

{'a'}