In [46]:
import pandas as pd

df = pd.DataFrame(list(Rating.objects.values_list('user_id', 'work_id', 'choice')), columns=('user_id', 'work_id', 'choice'))

In [47]:
from collections import defaultdict
rating_values = ['dislike', 'wontsee', 'neutral', 'willsee', 'like', 'favorite']

In [48]:
df['rating'] = df['choice'].map(lambda choice: rating_values.index(choice))

In [49]:
df.head()

Unnamed: 0,user_id,work_id,choice,rating
0,169,1,like,4
1,1,1146,like,4
2,177,8083,like,4
3,181,1864,willsee,3
4,188,9,like,4


In [50]:
count_user = df.groupby('user_id').size().to_frame('count_user')
df = df.join(count_user, on='user_id')

In [51]:
count_work = df.groupby('work_id').size().to_frame('count_work')
df = df.join(count_work, on='work_id')

In [52]:
len(df)

345491

In [53]:
df = df[(df.count_work >= 10) & (df.count_user >= 10)]

In [54]:
from sklearn.model_selection import train_test_split

In [55]:
encode_user = dict(zip(df['user_id'].unique(), range(3000)))

In [56]:
encode_work = dict(zip(df['work_id'].unique(), range(20000)))

In [57]:
df['encoded_user_id'] = df['user_id'].map(encode_user)
df['encoded_work_id'] = df['work_id'].map(encode_work)

In [58]:
df['encoded_user_id'].max() + 1, df['encoded_work_id'].max() + 1

(1869, 3409)

In [59]:
# Train with every rating
#df[['encoded_user_id', 'encoded_work_id', 'rating']].to_csv('/tmp/train.csv', header=False, index=False)
#df[['encoded_user_id', 'encoded_work_id', 'rating']].to_csv('/tmp/val.csv', header=False, index=False)

In [60]:
# Test

In [61]:
import random

test_users = random.sample(list(df[df.count_user <= 100]['user_id'].unique()), 187)

In [62]:
#trainval_users, test_users = train_test_split(df['user_id'].unique(), test_size=0.1)

In [63]:
len(test_users)

187

In [64]:
#test_users = [1]

In [65]:
test = df.query('user_id in @test_users')[['encoded_user_id', 'encoded_work_id', 'rating']]
test.to_csv('/tmp/test.csv', header=False, index=False)

In [66]:
# Train et Val

In [67]:
train, val = train_test_split(df.query('user_id not in @test_users'), test_size=0.2)

In [68]:
train[['encoded_user_id', 'encoded_work_id', 'rating']].to_csv('/tmp/train.csv', header=False, index=False)

In [69]:
val[['encoded_user_id', 'encoded_work_id', 'rating']].to_csv('/tmp/val.csv', header=False, index=False)

In [70]:
import yaml

with open('/tmp/config.yml', 'w') as f:
    config = {
        'USER_NUM': len(df['user_id'].unique()),
        'ITEM_NUM': len(df['work_id'].unique()),
        'NB_CLASSES': len(df['rating'].unique()),
        'BATCH_SIZE': len(train)
    }
    f.write(yaml.dump(config, default_flow_style=False))

In [71]:
len(train), len(val), len(test)

(250716, 62679, 8040)

In [72]:
from mangaki.algo.als import MangakiALS

In [73]:
als = MangakiALS()
als.set_parameters(config['USER_NUM'], config['ITEM_NUM'])

In [74]:
import numpy as np

X_train = np.array(train[['encoded_user_id', 'encoded_work_id']])
y_train = np.array(train['rating'])
X_val = np.array(val[['encoded_user_id', 'encoded_work_id']])
y_val = np.array(val['rating'])
als.fit(X_train, y_train)

Computing M: (1869 × 3409)


Chrono: fill and center matrix [4q, 2642ms]
Chrono: factor matrix [4q, 6799ms]


Shapes (1869, 20) (20, 3409)


In [75]:
als.compute_all_errors(X_train, y_train, X_val, y_val)

Train RMSE=0.929643
Test RMSE=1.148684


In [76]:
titles = dict(Work.objects.values_list('id', 'title'))

In [38]:
#titles

In [77]:
top = [475, 1400, 1649, 1420, 1249, 1468, 2158, 1711, 985, 122, 361, 2182, 3361, 383, 310, 1242, 366, 1844, 1726, 812, 686, 2672, 195, 767, 1417, 1458, 452, 655, 86, 1207, 1408, 563, 3, 664, 568, 809, 1250, 159, 444, 104, 513, 1383, 413, 1763, 1422, 2064, 1384, 1320, 33, 423, 354, 1533, 1257, 1412, 1746, 234, 309, 116, 103, 806, 777, 1427, 319, 545, 860, 272, 386, 999, 994, 1834, 714, 470, 1385, 3305, 516, 2575, 51, 133, 1268, 1484, 1410, 728, 45, 225, 162, 3295, 554, 1344, 1781, 794, 277, 1155, 503, 1039, 549, 1933, 799, 311, 1414, 1392, 1216, 378, 722, 442, 1034, 804, 2457, 456, 52, 3388, 348, 970, 1054, 510, 364, 144, 675, 622, 1352, 267, 243, 1652, 389, 957, 63, 194, 836, 1382, 2481, 322, 748, 518, 208, 3043, 57, 586, 191, 3204, 142, 250, 3334, 1448, 330, 580, 218, 94, 658, 187, 776, 801, 150, 1980, 3392, 22, 1481, 85, 370, 1948, 178, 1407, 949, 449, 1632, 476, 865, 101, 566, 665, 76, 29, 404, 1374, 737, 53, 2229, 254, 2021, 3286, 179, 285, 154, 854, 506, 2655, 1429, 232, 77, 410, 1454, 182, 175, 3344, 197, 602, 2067, 344, 323, 169, 1868, 190, 3367, 368, 646, 687, 2110, 62, 1455, 457, 1139, 300, 261, 136, 174, 418, 167, 1026, 1413, 213, 640, 613, 207, 324, 491, 307, 7, 1105, 1124, 2830, 110, 265, 497, 1870, 555, 1411, 16, 128, 196, 132, 35, 37, 1245, 1488, 241, 775, 441, 1379, 8, 3400, 572, 557, 1432, 203, 282, 488, 770, 3355, 2856, 2017, 55, 540, 760, 3030, 450, 360, 807, 60, 1504, 1786, 291, 126, 242, 1451, 993, 216, 26, 1873, 25, 525, 381, 1943, 335, 2605, 3368, 1035, 1107, 329, 653, 158, 3042, 591, 263, 743, 1476, 553, 486, 99, 635, 2066, 2205, 824, 467, 346, 694, 345, 1332, 56, 332, 1893, 1554, 201, 259, 937, 629, 186, 36, 1487, 376, 1979, 2272, 721, 281, 2016, 1097, 365, 139, 424, 885, 3401, 501, 2495, 3315, 3195, 180, 815, 247, 573, 526, 877, 46, 121, 204, 280, 487, 356, 156, 305, 1867, 328, 102, 855, 1084, 1935, 231, 429, 130, 111, 107, 271, 321, 342, 88, 3356, 183, 546, 1409, 1251, 685, 331, 10, 351, 727, 32, 547, 1969, 3357, 3369, 428, 31, 1146, 2318, 725, 97, 397, 761, 1, 914, 1340, 202, 802, 2146, 1280, 2035, 2512, 927, 20, 83, 477, 940, 215, 1179, 1768, 1461, 2274, 1079, 3307, 135, 416, 620, 205, 2100, 1011, 283, 935, 955, 4, 1881, 717, 1584, 1766, 463, 326, 358, 3380, 1225, 6, 333, 607, 228, 42, 1278, 484, 1475, 27, 308, 149, 435, 125, 1758, 320, 2710, 1614, 859, 74, 315, 472, 251, 11, 726, 690, 115, 1416, 304, 1936, 23, 1145, 1585, 3376, 705, 340, 1820, 997, 2023, 783, 230, 161, 1160, 100, 58, 206, 2089, 1955, 399, 108, 1462, 670, 2121, 2888, 958, 1133, 47, 612, 755, 44, 2513, 615, 87, 845, 2585, 512, 627, 1760, 3323, 124, 54, 2020, 1194, 313, 712, 245, 2019, 1201, 790, 1490, 1894, 357, 273, 238, 3026, 455, 1070, 2562, 98, 284, 334, 338, 779, 28, 474, 1581, 2065, 210, 1308, 509, 791, 377, 1378, 211, 105, 1491, 70, 209, 1561, 1515, 41, 2076, 430, 146, 192, 264, 2106, 172, 1635, 185, 226, 438, 803, 224, 853, 3003, 733, 112, 682, 1128, 64, 21, 127, 253, 1956, 991, 1078, 117, 13, 240, 43, 92, 548, 2835, 5, 1848, 293, 923, 84, 1782, 3364, 585, 95, 177, 0, 148, 2108, 113, 1631, 66, 395, 17, 385, 405, 387, 252, 114, 636, 454, 837, 48, 339, 165, 2968, 34, 82, 523, 237, 278, 564, 163, 30, 19, 171, 160, 236, 244, 188, 93, 118, 384, 288, 65, 170, 12, 59, 72, 9, 155]

In [78]:
decode_work = {v: k for k, v in encode_work.items()}

In [80]:
#[(id_, titles[decode_work[id_]]) for id_ in top]

In [81]:
train['title'] = train['work_id'].map(titles)

In [82]:
train.sort_values('count_work', ascending=False).head()

Unnamed: 0,user_id,work_id,choice,rating,count_user,count_work,encoded_user_id,encoded_work_id,title
24043,314,1,like,4,444,1663,245,0,Death Note
89389,700,1,like,4,129,1663,582,0,Death Note
280499,2012,1,like,4,285,1663,1595,0,Death Note
17739,20,1,neutral,2,273,1663,212,0,Death Note
272889,1933,1,like,4,74,1663,1537,0,Death Note
