In [27]:
CONN_STRING="postgresql://postgres:password1@localhost/discogs"
%load_ext sql
%sql $CONN_STRING

The sql extension is already loaded. To reload it, use:
  %reload_ext sql


'Connected: postgres@discogs'

# Splitting the data into training and test subset

In supervised learning we always want to test the model on unseen data, which is why we split our dataset into the train and test sets. We use the `train_test_split` function from MADlib for that.

In [22]:
%%sql

DROP TABLE IF EXISTS dataset_train, dataset_test;

SELECT madlib.train_test_split(
    'dataset_sample',   -- input table
    'dataset',          -- output table
    0.6,                -- train split
    0.4,                -- test split
    'genre',            -- label column (to pay attention to ratios)
    NULL,               -- subset of columns to output (if NULL then all)
    FALSE,              -- with replacement
    TRUE                -- generate train/test tables
);

 * postgresql://postgres:***@localhost/discogs
Done.
1 rows affected.


train_test_split


We can check the sizes of the train and test sets.

In [23]:
%%sql

SELECT (SELECT COUNT(*) FROM dataset_train) as train_count, (SELECT COUNT(*) FROM dataset_test) as test_count

 * postgresql://postgres:***@localhost/discogs
1 rows affected.


train_count,test_count
3000,2010


And also check the label breakdown.

In [24]:
%sql SELECT genre, COUNT(*) FROM dataset_train GROUP BY genre;

 * postgresql://postgres:***@localhost/discogs
15 rows affected.


genre,count
latin,200
funk soul,200
nonmusic,200
rock,200
childrens,200
brass military,200
pop,200
jazz,200
hip hop,200
stage screen,200


In [25]:
%sql SELECT genre, COUNT(*) FROM dataset_test GROUP BY genre;

 * postgresql://postgres:***@localhost/discogs
15 rows affected.


genre,count
latin,134
funk soul,134
nonmusic,134
rock,134
childrens,134
brass military,134
pop,134
jazz,134
hip hop,134
stage screen,134


# Supervised Learning

## Training the model

We will be using the Random forest model since it is very versatile and not sensitive to unprocessed features.

In [26]:
%%sql

DROP TABLE IF EXISTS train_output, train_output_group, train_output_summary;

SELECT madlib.forest_train(
    'dataset_train',         -- source table
    'train_output',          -- output model table
    'release_id',            -- id column
    'genre',                 -- response
    '*',                     -- features
    NULL,                    -- exclude columns
    NULL,                    -- grouping columns
    20::integer,             -- number of trees
    2::integer,              -- number of random features
    TRUE::boolean,           -- variable importance
    1::integer,              -- num_permutations
    8::integer,              -- max depth
    3::integer,              -- min split
    1::integer,              -- min bucket
    10::integer,             -- number of splits per continuous variable
    '',                      -- surrogate params
    TRUE::boolean,           -- verboose
    0.1::float               -- sample ratio
);

SELECT * FROM train_output_summary;

 * postgresql://postgres:***@localhost/discogs
Done.
1 rows affected.
1 rows affected.


method,is_classification,source_table,model_table,id_col_name,dependent_varname,independent_varnames,cat_features,con_features,grouping_cols,num_trees,num_random_features,max_tree_depth,min_split,min_bucket,num_splits,verbose,importance,num_permutations,num_all_groups,num_failed_groups,total_rows_processed,total_rows_skipped,dependent_var_levels,dependent_var_type,independent_var_types,null_proxy
forest_train,True,dataset_train,train_output,release_id,genre,"profile,name,title,url,country,__madlib_id__,artist_id,realname","profile,name,title,url,country,__madlib_id__,artist_id,realname",,,20,2,8,3,1,10,True,True,1,1,0,3000,0,"blues,brass military,childrens,classical,electronic,folk world country,funk soul,hip hop,jazz,latin,nonmusic,pop,reggae,rock,stage screen",text,"text, text, text, text, text, bigint, integer, text",


## Making Predictions

We make predictions on the test set.

In [18]:
%%sql

DROP TABLE IF EXISTS prediction_results;

SELECT madlib.forest_predict('train_output',          --table containing the tree model
                             'dataset_test',          --table containing test data
                             'prediction_results',    --new table to store predictions
                             'response');             --predict the labels directly (as opposed to predicting probabilities)

 * postgresql://postgres:***@localhost/discogs
Done.
1 rows affected.


forest_predict


We need to produce a table of true values and actual predictions, so that we can compute the accuracy.

In [19]:
%%sql

DROP TABLE IF EXISTS test_predictions;

CREATE TABLE test_predictions AS
SELECT pr.release_id, estimated_genre, genre
FROM prediction_results pr
JOIN dataset_test t ON t.release_id = pr.release_id;

SELECT * FROM test_predictions LIMIT 10;

 * postgresql://postgres:***@localhost/discogs
Done.
2010 rows affected.
10 rows affected.


release_id,estimated_genre,genre
411361,stage screen,blues
414582,stage screen,blues
620637,folk world country,blues
620637,folk world country,blues
641784,blues,blues
641784,blues,blues
605829,stage screen,blues
605829,stage screen,blues
404534,stage screen,blues
404534,stage screen,blues


## Measuring Quality

We can simply compute the average accuracy of predictions.

In [20]:
%%sql

SELECT AVG(sub.eq) FROM (
    SELECT (CASE WHEN genre = estimated_genre THEN 1 ELSE 0 END) AS eq
    FROM test_predictions
) as sub

 * postgresql://postgres:***@localhost/discogs
1 rows affected.


avg
0.0865671641791044


We can also examine the accuracy on individual genres.

In [21]:
%%sql

SELECT genre, AVG(sub.eq) FROM (
    SELECT (CASE WHEN genre = estimated_genre THEN 1 ELSE 0 END) AS eq, genre
    FROM test_predictions
) as sub
GROUP BY genre

 * postgresql://postgres:***@localhost/discogs
15 rows affected.


genre,avg
latin,0.0
funk soul,0.0
nonmusic,0.0074626865671641
rock,0.0
childrens,0.1044776119402985
brass military,0.0
pop,0.0
jazz,0.0
hip hop,0.0
stage screen,0.8955223880597014


# Unsupervised Learning

In this portion we use the artist/genre features and we try to cluster artists according to their feature vectors.

In [57]:
%sql select * from dataset_sample limit 3

 * postgresql://postgres:***@localhost/discogs
3 rows affected.


__madlib_id__,artist_id,count_sum_genre_blues,count_sum_genre_brass military,count_sum_genre_childrens,count_sum_genre_classical,count_sum_genre_electronic,count_sum_genre_folk world country,count_sum_genre_funk soul,count_sum_genre_hip hop,count_sum_genre_jazz,count_sum_genre_latin,count_sum_genre_nonmusic,count_sum_genre_pop,count_sum_genre_reggae,count_sum_genre_rock,count_sum_genre_stage screen,genre,release_id
1,832951,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,classical,530219
2,39574,0.0,0.0,0.0,0.8076923076923076,0.1538461538461538,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0384615384615384,classical,613405
3,239236,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,classical,530293


However, the k-means algorithm in MADlib expects a special feature vector column where cells contain lists of real numbers. We need to use the `col2vec` function to achieve this.

In [60]:
%%sql

DROP TABLE IF EXISTS dataset_vec, dataset_vec_summary;

SELECT madlib.cols2vec(
    'dataset_sample',                                -- source table
    'dataset_vec',                                   -- target table
    '*',                                             -- columns to include in vectors
    'artist_id, genre, release_id, __madlib_id__',   -- columns to exclude from vectors
    'artist_id, genre'                               -- columns to add to the final table (apart from the vector column)
);

SELECT * FROM dataset_vec ORDER BY artist_id LIMIT 10;

 * postgresql://postgres:***@localhost/discogs
Done.
1 rows affected.
10 rows affected.


artist_id,genre,feature_vector
78,latin,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.68627450980392156863'), Decimal('0E-20'), Decimal('0.01960784313725490196'), Decimal('0E-20'), Decimal('0.27450980392156862745'), Decimal('0.01960784313725490196'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]"
96,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]"
232,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]"
246,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]"
301,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]"
318,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]"
323,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]"
368,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]"
402,stage screen,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.88888888888888888889'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.11111111111111111111')]"
402,stage screen,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.88888888888888888889'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.11111111111111111111')]"


In [61]:
%%sql

DROP TABLE IF EXISTS km_result;
-- Run kmeans algorithm
CREATE TABLE km_result AS
SELECT * FROM madlib.kmeanspp(
    'dataset_vec',                  -- input data
    'feature_vector',               -- feature column
    15,                             -- k - number of clusters
    'madlib.squared_dist_norm2',    -- distance function
    'madlib.avg'                    -- cluster centroid computation function
);

SELECT * FROM km_result;

 * postgresql://postgres:***@localhost/discogs
Done.
1 rows affected.
1 rows affected.


centroids,cluster_variance,objective_fn,frac_reassigned,num_iterations
"[[0.0, 0.0, 0.0177777777777778, 0.0, 0.00333730158730159, 0.0, 0.000952380952380952, 0.00111111111111111, 0.0, 0.0, 0.956166666666668, 0.00390873015873016, 0.0, 0.0111904761904762, 0.00555555555555555], [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0846686836378498, 0.0, 0.00078812012609922, 0.0259148083004257, 0.000497760079641613, 0.00174216027874564, 0.0, 0.0, 0.883427573628399, 0.00296089394883946, 0.0], [0.0010387811634349, 0.0, 0.0, 0.0, 0.076902529961329, 0.0, 0.00226707907993765, 0.910114850296981, 2.87290278096989e-05, 0.00031328320802005, 0.00171382751021202, 0.0, 0.00651386152547143, 0.00110705822680478, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.901960784313726, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0465686274509804, 0.0, 0.0, 0.0, 0.0, 0.0514705882352941, 0.0], [0.0, 0.0, 0.0, 0.0929070929070929, 0.0468531468531469, 0.0, 0.0, 0.0, 0.075924075924076, 0.0, 0.0, 0.00839160839160839, 0.0, 0.0533799533799534, 0.722544122544122], [0.00326797385620915, 0.0, 0.0206971677559913, 0.0, 0.208346776454213, 0.0, 0.70270344286576, 0.00704482038285008, 0.029013053937377, 0.0162856507176994, 0.0, 0.00135844407578837, 0.00232162309368192, 0.0088212404922495, 0.000139806368180071], [0.00195236843898119, 0.0, 0.0111252234452912, 0.0179511625111547, 0.778848283617718, 0.0178364154197857, 0.0236198578362628, 0.0196532635471085, 0.0338114819693611, 0.0183263242439759, 0.00655942086775581, 0.0154207565556011, 0.0108525660225924, 0.0289133195434213, 0.0151295559809914], [0.0, 0.0217948717948718, 0.0, 0.874103415963882, 0.0680856983406178, 0.00400641025641026, 0.0014792899408284, 0.0, 0.00812376725838265, 0.00213675213675214, 0.00591715976331361, 0.0, 0.0, 0.00979149056072133, 0.00456114398422091], [0.000478927203065134, 0.0, 0.0114942528735632, 0.00836930705084863, 0.047234628906708, 0.0122448638819974, 0.0171000585473507, 0.00212009379925422, 0.846221050917203, 0.0284670878346562, 0.00117889773062187, 0.00443479975140821, 0.0013262599469496, 0.010524086679515, 0.00880568487685982], [0.00547445255474453, 0.0565693430656934, 0.00886339937434828, 0.00121654501216545, 0.0692786675583427, 0.0319343065693431, 0.00256042209854101, 0.00121654501216545, 0.0232607852847459, 0.0076794403892944, 0.00534267978746935, 0.704695793527375, 0.0018377594864627, 0.0800698602793078, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0385709504685408, 0.944561579651941, 0.00602409638554217, 0.0, 0.0108433734939759, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0388156411990083, 0.0, 0.00718390804597701, 0.0, 0.0328769438810007, 0.917292089249492, 0.0, 0.0, 0.00383141762452107, 0.0, 0.0], [0.0420873082428731, 0.0, 0.00985663082437276, 0.00187537652612583, 0.0722304608292443, 0.0293492795883996, 0.00220011999868473, 0.00340273986628026, 0.033444391449881, 0.0045322658225884, 0.00223964535914217, 0.0127230332350491, 0.000710179348588951, 0.783556454213429, 0.0017921146953405]]","[7.64215881519273, 0.0, 13.5153379175503, 10.5676296476127, 0.0, 7.36793540945791, 18.6029503829504, 30.8070044726693, 109.790600954795, 19.82527515442, 27.5398691959023, 45.2835082938832, 4.70028753904507, 9.12437559382553, 54.3337518566022]",359.100685233906,0.0001996007984031,9


We may also want to see which vector was assigned to which cluster. For this we use the `closest_column` method.

In [68]:
%%sql

SELECT data.*,  (madlib.closest_column(centroids, feature_vector)).column_id AS cluster_id
FROM dataset_vec AS data, km_result
ORDER BY data.artist_id
LIMIT 20;

 * postgresql://postgres:***@localhost/discogs
20 rows affected.


artist_id,genre,feature_vector,cluster_id
78,latin,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.68627450980392156863'), Decimal('0E-20'), Decimal('0.01960784313725490196'), Decimal('0E-20'), Decimal('0.27450980392156862745'), Decimal('0.01960784313725490196'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]",8
96,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]",8
232,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]",8
246,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]",8
301,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]",8
318,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]",8
323,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]",8
368,electronic,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('1.00000000000000000000'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20')]",8
402,stage screen,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.88888888888888888889'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.11111111111111111111')]",8
402,stage screen,"[Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.88888888888888888889'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0E-20'), Decimal('0.11111111111111111111')]",8
