In [1]:
# !pip install scikit-surprise
# 안되면 anaconda prompt에서 conda install -c conda-forge scikit-surprise

## surprise 라이브러리 사용

In [2]:
import surprise

In [3]:
print(surprise.__version__)

1.1.1


In [4]:
from surprise import SVD 
## SGD(Stochastic 통계적인 Gradient 경사 하강법)
from surprise import Dataset, Reader 
from surprise import accuracy 
from surprise.model_selection import train_test_split 

## 1. 데이터셋 생성(사용자, 아이템, 평점) --> 훈련/검증 데이터 분류

In [5]:
data = Dataset.load_builtin('ml-100k')
data

<surprise.dataset.DatasetAutoFolds at 0x1758edefcf8>

In [6]:
trainset, testset = train_test_split(data, test_size=0.25, random_state=0)

In [7]:
trainset

<surprise.trainset.Trainset at 0x1758ee00a58>

In [8]:
testset[1] # 882번 사람이 291번 영화에 4.0 평점을 부여함

('882', '291', 4.0)

In [9]:
testset[:5]

[('120', '282', 4.0),
 ('882', '291', 4.0),
 ('535', '507', 5.0),
 ('697', '244', 5.0),
 ('751', '385', 4.0)]

In [10]:
algo = SVD()

In [11]:
algo.fit(trainset)

<surprise.prediction_algorithms.matrix_factorization.SVD at 0x1759200a518>

## prediction의 예측값 출력

In [12]:
prediction = algo.test(testset)[1]
prediction # 하나의 영화에 대한 예측값은 Prediction 객체

Prediction(uid='882', iid='291', r_ui=4.0, est=3.9549514108285253, details={'was_impossible': False})

In [13]:
prediction

Prediction(uid='882', iid='291', r_ui=4.0, est=3.9549514108285253, details={'was_impossible': False})

In [14]:
'userid: ' + prediction.uid + ', iid: ' + prediction.iid + ' ==> est: ' + str(prediction.est)

'userid: 882, iid: 291 ==> est: 3.9549514108285253'

## predictions의 모든 예측값 출력

In [15]:
predictions = algo.test(testset)
predictions

[Prediction(uid='120', iid='282', r_ui=4.0, est=3.7959475963045897, details={'was_impossible': False}),
 Prediction(uid='882', iid='291', r_ui=4.0, est=3.9549514108285253, details={'was_impossible': False}),
 Prediction(uid='535', iid='507', r_ui=5.0, est=4.07967667715781, details={'was_impossible': False}),
 Prediction(uid='697', iid='244', r_ui=5.0, est=3.636609624656571, details={'was_impossible': False}),
 Prediction(uid='751', iid='385', r_ui=4.0, est=3.322062256767763, details={'was_impossible': False}),
 Prediction(uid='219', iid='82', r_ui=1.0, est=3.614798016694083, details={'was_impossible': False}),
 Prediction(uid='279', iid='571', r_ui=4.0, est=2.482103682192764, details={'was_impossible': False}),
 Prediction(uid='429', iid='568', r_ui=3.0, est=3.118785996715596, details={'was_impossible': False}),
 Prediction(uid='456', iid='100', r_ui=3.0, est=3.996889614491409, details={'was_impossible': False}),
 Prediction(uid='249', iid='23', r_ui=4.0, est=4.728036138776976, details

In [16]:
count = len(predictions)
for i in range(count):
    print('userid: ' + predictions[i].uid + ', iid: ' + predictions[i].iid + ' ==> est: ' + str(predictions[i].est))

userid: 120, iid: 282 ==> est: 3.7959475963045897
userid: 882, iid: 291 ==> est: 3.9549514108285253
userid: 535, iid: 507 ==> est: 4.07967667715781
userid: 697, iid: 244 ==> est: 3.636609624656571
userid: 751, iid: 385 ==> est: 3.322062256767763
userid: 219, iid: 82 ==> est: 3.614798016694083
userid: 279, iid: 571 ==> est: 2.482103682192764
userid: 429, iid: 568 ==> est: 3.118785996715596
userid: 456, iid: 100 ==> est: 3.996889614491409
userid: 249, iid: 23 ==> est: 4.728036138776976
userid: 493, iid: 183 ==> est: 4.277892335829548
userid: 325, iid: 469 ==> est: 3.514769323414671
userid: 631, iid: 682 ==> est: 2.9526246040882747
userid: 276, iid: 121 ==> est: 3.446303993992043
userid: 269, iid: 405 ==> est: 1.7910492220882657
userid: 159, iid: 1095 ==> est: 2.816007289083463
userid: 385, iid: 965 ==> est: 3.414846828985236
userid: 21, iid: 358 ==> est: 2.174934137117569
userid: 181, iid: 1359 ==> est: 1.7227001580523411
userid: 561, iid: 124 ==> est: 3.744933943082304
userid: 658, iid:

userid: 358, iid: 357 ==> est: 4.111772044029601
userid: 405, iid: 1421 ==> est: 1.378866498732323
userid: 425, iid: 11 ==> est: 3.8265529061132817
userid: 901, iid: 287 ==> est: 4.1359837901879954
userid: 23, iid: 662 ==> est: 3.4147392622625663
userid: 541, iid: 121 ==> est: 3.7847548984201542
userid: 601, iid: 208 ==> est: 3.3165751343016914
userid: 378, iid: 572 ==> est: 3.0432628427363952
userid: 299, iid: 150 ==> est: 3.4622976537860155
userid: 160, iid: 228 ==> est: 3.9014796991875205
userid: 860, iid: 257 ==> est: 3.37157638579043
userid: 802, iid: 217 ==> est: 3.3002012269708167
userid: 798, iid: 953 ==> est: 3.255012609565337
userid: 451, iid: 690 ==> est: 2.8306156962966256
userid: 919, iid: 315 ==> est: 4.15425762138948
userid: 890, iid: 286 ==> est: 4.57252488959317
userid: 308, iid: 198 ==> est: 4.096909839750881
userid: 560, iid: 12 ==> est: 4.4506787238870205
userid: 766, iid: 135 ==> est: 4.233435597468365
userid: 237, iid: 58 ==> est: 3.4981992447817953
userid: 361, i

userid: 71, iid: 744 ==> est: 3.340904709223897
userid: 269, iid: 212 ==> est: 3.334852434822706
userid: 734, iid: 318 ==> est: 4.166129655420642
userid: 450, iid: 705 ==> est: 4.376299463940197
userid: 174, iid: 160 ==> est: 3.5472601803165245
userid: 122, iid: 470 ==> est: 4.02937616949451
userid: 16, iid: 99 ==> est: 4.472865756894874
userid: 432, iid: 108 ==> est: 3.257898772254169
userid: 655, iid: 294 ==> est: 2.265543254801176
userid: 27, iid: 288 ==> est: 3.0225710601968268
userid: 327, iid: 86 ==> est: 3.96695191304302
userid: 58, iid: 204 ==> est: 3.6462711100253635
userid: 454, iid: 607 ==> est: 3.3970849708003175
userid: 381, iid: 514 ==> est: 4.144643213587491
userid: 738, iid: 343 ==> est: 3.086608250478119
userid: 753, iid: 215 ==> est: 3.4193733778552065
userid: 399, iid: 1139 ==> est: 2.361113702850614
userid: 554, iid: 125 ==> est: 3.6725623723227416
userid: 387, iid: 189 ==> est: 4.326497410237293
userid: 927, iid: 1095 ==> est: 2.9415483199103893
userid: 21, iid: 26

userid: 620, iid: 28 ==> est: 4.266589262238276
userid: 776, iid: 234 ==> est: 3.5522919772905452
userid: 249, iid: 930 ==> est: 3.1475929062295633
userid: 13, iid: 758 ==> est: 1.7181223744928982
userid: 21, iid: 990 ==> est: 2.43010992161984
userid: 429, iid: 768 ==> est: 3.100732953494247
userid: 318, iid: 631 ==> est: 3.6701647724619977
userid: 59, iid: 313 ==> est: 4.637949876847741
userid: 752, iid: 678 ==> est: 2.7811685373962165
userid: 916, iid: 959 ==> est: 3.384461104585467
userid: 798, iid: 660 ==> est: 3.9370772449876332
userid: 48, iid: 50 ==> est: 4.545234818949134
userid: 94, iid: 268 ==> est: 4.338364955056246
userid: 870, iid: 111 ==> est: 3.388941071186113
userid: 655, iid: 153 ==> est: 3.1642249700728784
userid: 688, iid: 309 ==> est: 4.020177292815483
userid: 236, iid: 179 ==> est: 3.301187226636316
userid: 748, iid: 118 ==> est: 3.1037585602240148
userid: 72, iid: 553 ==> est: 3.368879406609528
userid: 174, iid: 396 ==> est: 3.0864165939448815
userid: 363, iid: 45

userid: 334, iid: 191 ==> est: 4.166584145415955
userid: 30, iid: 231 ==> est: 3.0099386415109826
userid: 843, iid: 135 ==> est: 3.228340178323175
userid: 896, iid: 320 ==> est: 3.4592325620560924
userid: 823, iid: 433 ==> est: 3.8043189853276473
userid: 334, iid: 855 ==> est: 3.829059151681647
userid: 193, iid: 352 ==> est: 2.0327735145083294
userid: 218, iid: 273 ==> est: 3.5485431138525017
userid: 383, iid: 197 ==> est: 4.399328451142648
userid: 92, iid: 428 ==> est: 3.71296836194424
userid: 275, iid: 102 ==> est: 2.6557435733645542
userid: 747, iid: 672 ==> est: 3.219262770024675
userid: 3, iid: 336 ==> est: 2.4972542222810783
userid: 557, iid: 96 ==> est: 4.026423366056207
userid: 42, iid: 48 ==> est: 3.64519002181001
userid: 737, iid: 475 ==> est: 3.767107292043857
userid: 655, iid: 966 ==> est: 3.2173417799612576
userid: 189, iid: 485 ==> est: 3.7723586369855746
userid: 243, iid: 25 ==> est: 3.0007157144912817
userid: 880, iid: 864 ==> est: 3.208919537280458
userid: 424, iid: 30

userid: 367, iid: 234 ==> est: 4.634015067457049
userid: 628, iid: 8 ==> est: 5
userid: 110, iid: 575 ==> est: 2.3217838935387682
userid: 446, iid: 888 ==> est: 3.01327906859924
userid: 606, iid: 833 ==> est: 3.354120507735727
userid: 85, iid: 1166 ==> est: 3.3594233970406058
userid: 429, iid: 128 ==> est: 3.429345261761355
userid: 795, iid: 1041 ==> est: 2.8425698752680977
userid: 580, iid: 300 ==> est: 3.6404067800083535
userid: 64, iid: 718 ==> est: 3.5381104083839543
userid: 653, iid: 576 ==> est: 2.273044964240741
userid: 940, iid: 213 ==> est: 3.2186872872995265
userid: 94, iid: 386 ==> est: 3.2044400316302757
userid: 276, iid: 794 ==> est: 3.702843397558769
userid: 151, iid: 724 ==> est: 4.077276851167389
userid: 585, iid: 30 ==> est: 3.9732499901202507
userid: 907, iid: 86 ==> est: 4.786196604472548
userid: 847, iid: 261 ==> est: 1.986087162976733
userid: 263, iid: 50 ==> est: 4.446510493583169
userid: 186, iid: 546 ==> est: 3.0505415204807607
userid: 327, iid: 582 ==> est: 3.5

userid: 627, iid: 123 ==> est: 2.673433226970249
userid: 623, iid: 258 ==> est: 4.019073452771777
userid: 217, iid: 562 ==> est: 2.6528052984270567
userid: 919, iid: 988 ==> est: 2.5746821301033345
userid: 936, iid: 1 ==> est: 4.226795863180956
userid: 791, iid: 286 ==> est: 3.777959354800279
userid: 472, iid: 97 ==> est: 4.714044315645851
userid: 486, iid: 628 ==> est: 3.4418842218483077
userid: 233, iid: 135 ==> est: 4.358342983458018
userid: 314, iid: 274 ==> est: 4.413744439391673
userid: 624, iid: 310 ==> est: 3.9041335051332267
userid: 199, iid: 408 ==> est: 4.003185705040752
userid: 838, iid: 318 ==> est: 4.8237394676261225
userid: 92, iid: 474 ==> est: 4.2560679442949905
userid: 486, iid: 696 ==> est: 2.9187398164719385
userid: 91, iid: 651 ==> est: 4.514152223075463
userid: 126, iid: 289 ==> est: 3.2650613483640627
userid: 870, iid: 327 ==> est: 3.47623242209342
userid: 372, iid: 176 ==> est: 4.660523163736046
userid: 829, iid: 10 ==> est: 3.952691919852149
userid: 236, iid: 1

userid: 862, iid: 117 ==> est: 4.282384536335987
userid: 711, iid: 1170 ==> est: 3.702880598896247
userid: 472, iid: 417 ==> est: 4.265616324993927
userid: 269, iid: 387 ==> est: 2.9344431533470887
userid: 160, iid: 32 ==> est: 4.072471260750563
userid: 242, iid: 1137 ==> est: 4.358796709347182
userid: 833, iid: 168 ==> est: 3.7227232958658147
userid: 101, iid: 1 ==> est: 3.5745834259787945
userid: 249, iid: 789 ==> est: 4.155107535582703
userid: 314, iid: 410 ==> est: 3.422861110076325
userid: 650, iid: 642 ==> est: 2.968845805651205
userid: 297, iid: 514 ==> est: 3.672916811877294
userid: 804, iid: 969 ==> est: 4.105321260374246
userid: 500, iid: 319 ==> est: 2.731249768732139
userid: 256, iid: 294 ==> est: 3.979341071349506
userid: 108, iid: 290 ==> est: 3.359207193871104
userid: 200, iid: 323 ==> est: 3.7007629592564335
userid: 918, iid: 658 ==> est: 2.8659475165961528
userid: 322, iid: 192 ==> est: 3.97839626453526
userid: 761, iid: 222 ==> est: 3.6535263747465634
userid: 639, iid

userid: 264, iid: 26 ==> est: 3.921225532638646
userid: 303, iid: 191 ==> est: 4.753961788119978
userid: 56, iid: 227 ==> est: 3.325718726501087
userid: 456, iid: 216 ==> est: 3.6781447521497292
userid: 711, iid: 167 ==> est: 3.235749369736096
userid: 727, iid: 431 ==> est: 3.549482432483598
userid: 454, iid: 118 ==> est: 3.13472368757205
userid: 344, iid: 845 ==> est: 3.3531972219611585
userid: 416, iid: 468 ==> est: 3.855592498338903
userid: 60, iid: 209 ==> est: 4.369660119165162
userid: 893, iid: 849 ==> est: 3.1147888432057105
userid: 519, iid: 751 ==> est: 4.418731478825567
userid: 788, iid: 68 ==> est: 3.332909028660842
userid: 228, iid: 313 ==> est: 3.783089681207642
userid: 43, iid: 140 ==> est: 3.1458308121816367
userid: 519, iid: 352 ==> est: 2.922421628016395
userid: 222, iid: 1226 ==> est: 2.9187317503500774
userid: 663, iid: 693 ==> est: 3.5176800043850007
userid: 65, iid: 476 ==> est: 3.144558780998077
userid: 557, iid: 289 ==> est: 3.3559272794701185
userid: 62, iid: 15

userid: 342, iid: 25 ==> est: 3.2739048637555075
userid: 524, iid: 495 ==> est: 3.7726460511448616
userid: 862, iid: 197 ==> est: 4.857402454238013
userid: 509, iid: 181 ==> est: 3.3870448805598032
userid: 705, iid: 1035 ==> est: 3.1745045412879556
userid: 64, iid: 392 ==> est: 3.5505247245975644
userid: 398, iid: 186 ==> est: 4.028780610651345
userid: 747, iid: 238 ==> est: 4.340453645003487
userid: 11, iid: 726 ==> est: 2.844992747396157
userid: 862, iid: 97 ==> est: 4.434180467847278
userid: 497, iid: 1042 ==> est: 2.990955097735918
userid: 610, iid: 276 ==> est: 3.68523813897359
userid: 907, iid: 696 ==> est: 4.110348080970314
userid: 207, iid: 428 ==> est: 3.3895744590760475
userid: 875, iid: 176 ==> est: 4.501239548747246
userid: 425, iid: 307 ==> est: 3.0533465122618764
userid: 60, iid: 73 ==> est: 3.712511762713209
userid: 622, iid: 120 ==> est: 2.3561771230287385
userid: 514, iid: 50 ==> est: 4.550283069209722
userid: 903, iid: 129 ==> est: 4.363134798947403
userid: 500, iid: 

userid: 363, iid: 408 ==> est: 4.024488085883538
userid: 804, iid: 215 ==> est: 4.0411614460828185
userid: 355, iid: 310 ==> est: 4.144092802673971
userid: 629, iid: 132 ==> est: 4.704647311819855
userid: 456, iid: 1198 ==> est: 3.5065301573967775
userid: 747, iid: 952 ==> est: 3.5657890931253795
userid: 683, iid: 346 ==> est: 3.3789986315628373
userid: 605, iid: 127 ==> est: 4.478534911964505
userid: 716, iid: 227 ==> est: 3.5650154285310784
userid: 574, iid: 258 ==> est: 3.532049699971539
userid: 85, iid: 186 ==> est: 3.3884422149913123
userid: 56, iid: 946 ==> est: 3.402364010962852
userid: 655, iid: 295 ==> est: 2.7861320706299244
userid: 823, iid: 229 ==> est: 3.4752181485992475
userid: 614, iid: 476 ==> est: 2.548879640956798
userid: 385, iid: 1128 ==> est: 3.559297599838999
userid: 395, iid: 1028 ==> est: 3.3055758358540044
userid: 437, iid: 286 ==> est: 3.9143437461316823
userid: 348, iid: 546 ==> est: 3.7562020065681194
userid: 184, iid: 70 ==> est: 3.601194763726628
userid: 4

userid: 347, iid: 660 ==> est: 3.648280360377281
userid: 642, iid: 926 ==> est: 3.3434068188895
userid: 41, iid: 156 ==> est: 3.92866008391883
userid: 904, iid: 1041 ==> est: 3.4617950848555634
userid: 406, iid: 274 ==> est: 3.2258788290274754
userid: 82, iid: 458 ==> est: 2.6063367038898955
userid: 798, iid: 480 ==> est: 4.1092082059090345
userid: 13, iid: 792 ==> est: 3.625831785163151
userid: 683, iid: 690 ==> est: 3.215424126521707
userid: 301, iid: 521 ==> est: 3.8524350142791843
userid: 798, iid: 929 ==> est: 2.85730319613462
userid: 333, iid: 255 ==> est: 3.67381636360999
userid: 777, iid: 205 ==> est: 4.268078486535113
userid: 279, iid: 597 ==> est: 3.1897171109647156
userid: 881, iid: 420 ==> est: 3.0898573706318135
userid: 95, iid: 209 ==> est: 3.7571029345765754
userid: 285, iid: 151 ==> est: 3.6830154436006337
userid: 527, iid: 425 ==> est: 3.8161643238236698
userid: 130, iid: 1089 ==> est: 2.8677160186963158
userid: 42, iid: 273 ==> est: 3.4274764131843107
userid: 189, iid

userid: 250, iid: 288 ==> est: 2.9260436010680904
userid: 381, iid: 432 ==> est: 3.9802938297401966
userid: 575, iid: 603 ==> est: 3.5680444650216314
userid: 707, iid: 694 ==> est: 4.036580803541164
userid: 405, iid: 72 ==> est: 1.882936041171695
userid: 109, iid: 410 ==> est: 3.5238817898793533
userid: 298, iid: 357 ==> est: 4.265237484722504
userid: 624, iid: 240 ==> est: 2.665845938223514
userid: 567, iid: 39 ==> est: 3.52745828859406
userid: 342, iid: 607 ==> est: 3.573209171586729
userid: 137, iid: 289 ==> est: 4.060669181977013
userid: 65, iid: 238 ==> est: 3.7184488649153793
userid: 333, iid: 127 ==> est: 4.491917645585454
userid: 201, iid: 302 ==> est: 3.709300909244103
userid: 222, iid: 29 ==> est: 2.275490411323979
userid: 13, iid: 610 ==> est: 3.169758241694263
userid: 331, iid: 214 ==> est: 3.3857718342010314
userid: 894, iid: 1658 ==> est: 3.435825532430282
userid: 736, iid: 253 ==> est: 3.098317116554852
userid: 898, iid: 343 ==> est: 3.082246112982255
userid: 145, iid: 5

In [17]:
result = [(one.uid, one.iid, str(one.est)) for one in predictions]

In [18]:
len(result)

25000

## 오차의 평균값 구하기

In [19]:
accuracy.rmse(predictions) # 오차의 평균값

RMSE: 0.9477


0.9476613425962855

## 교차 검증

In [20]:
from surprise.model_selection import cross_validate

In [21]:
cross_validate(algo, data, measures = ['RMSE', 'MAE'], cv=5, verbose=True)

Evaluating RMSE, MAE of algorithm SVD on 5 split(s).

                  Fold 1  Fold 2  Fold 3  Fold 4  Fold 5  Mean    Std     
RMSE (testset)    0.9349  0.9370  0.9475  0.9346  0.9330  0.9374  0.0052  
MAE (testset)     0.7363  0.7371  0.7481  0.7364  0.7374  0.7391  0.0045  
Fit time          3.98    4.23    4.06    3.74    3.73    3.95    0.19    
Test time         0.12    0.12    0.15    0.14    0.11    0.13    0.02    


{'test_rmse': array([0.93488404, 0.93701793, 0.94746031, 0.93459361, 0.93304881]),
 'test_mae': array([0.73628203, 0.73710922, 0.74809569, 0.73640654, 0.73742474]),
 'fit_time': (3.9813618659973145,
  4.233675479888916,
  4.0571653842926025,
  3.7360167503356934,
  3.727041721343994),
 'test_time': (0.12201666831970215,
  0.12165546417236328,
  0.1515672206878662,
  0.14162254333496094,
  0.10671377182006836)}

## surprise를 사용한 추천시스템 구현

### 평점을 예측해서 추천해주기

In [22]:
import pandas as pd

In [23]:
movies = pd.read_csv('./data_list/movies.csv')
movies

Unnamed: 0,movieId,title,genres
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
1,2,Jumanji (1995),Adventure|Children|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama|Romance
4,5,Father of the Bride Part II (1995),Comedy
...,...,...,...
9737,193581,Black Butler: Book of the Atlantic (2017),Action|Animation|Comedy|Fantasy
9738,193583,No Game No Life: Zero (2017),Animation|Comedy|Fantasy
9739,193585,Flint (2017),Drama
9740,193587,Bungo Stray Dogs: Dead Apple (2018),Action|Animation


In [24]:
ratings = pd.read_csv('./data_list/ratings.csv')
ratings

Unnamed: 0,userId,movieId,rating,timestamp
0,1,1,4.0,964982703
1,1,3,4.0,964981247
2,1,6,4.0,964982224
3,1,47,5.0,964983815
4,1,50,5.0,964982931
...,...,...,...,...
100831,610,166534,4.0,1493848402
100832,610,168248,5.0,1493850091
100833,610,168250,5.0,1494273047
100834,610,168252,5.0,1493846352


In [25]:
movie_list_uid9 = ratings[ratings['userId'] == 9]
movie_list_uid9

Unnamed: 0,userId,movieId,rating,timestamp
1073,9,41,3.0,1044656650
1074,9,187,3.0,1044657119
1075,9,223,4.0,1044656650
1076,9,371,3.0,1044656716
1077,9,627,3.0,1044657102
1078,9,922,4.0,1044657026
1079,9,923,5.0,1044657026
1080,9,1037,2.0,1044656650
1081,9,1095,4.0,1044657088
1082,9,1198,5.0,1044656716


In [26]:
movie_list_uid9.shape # uid가 9인 사람은 46개의 영화를 봄

(46, 4)

In [27]:
# 보지 않은 영화 중에서 추천을 해주려고 함 

In [28]:
# 1) 전체 영화 수
movie_count = movies.shape[0]
movie_count

9742

In [29]:
# 2) 평점 매긴 영화 수
rat_uid9_count = movie_list_uid9.shape[0]
rat_uid9_count

46

In [30]:
# 3) 보지 않은 영화 수 (추천 대상 영화 수)
movie_count - rat_uid9_count

9696

In [31]:
movie_list_uid9[movie_list_uid9['movieId'] == 42] # 42번은 평점의 기록이 없음

Unnamed: 0,userId,movieId,rating,timestamp


In [32]:
# 42번 영화는 어떤 영화였을까?
movies[movies['movieId'] == 42]

Unnamed: 0,movieId,title,genres
38,42,Dead Presidents (1995),Action|Crime|Drama


In [33]:
uid = str(9)
iid = str(42)

In [34]:
pred = algo.predict(uid, iid, verbose=True)
pred

user: 9          item: 42         r_ui = None   est = 4.16   {'was_impossible': False}


Prediction(uid='9', iid='42', r_ui=None, est=4.155300757181799, details={'was_impossible': False})

In [35]:
import numpy as np

In [36]:
# 액션 장르 중 최대 평점이 예측되는 영화 찾기

In [37]:
action_movies = movies[movies['genres'] == 'Action|Crime|Drama']
len(action_movies)

50

In [38]:
result_uid9 = [(algo.predict('9', str(one), verbose=True)) for one in action_movies['movieId']]
result_uid9[:1]

user: 9          item: 42         r_ui = None   est = 4.16   {'was_impossible': False}
user: 9          item: 384        r_ui = None   est = 3.31   {'was_impossible': False}
user: 9          item: 390        r_ui = None   est = 3.78   {'was_impossible': False}
user: 9          item: 493        r_ui = None   est = 4.73   {'was_impossible': False}
user: 9          item: 694        r_ui = None   est = 4.71   {'was_impossible': False}
user: 9          item: 875        r_ui = None   est = 3.11   {'was_impossible': False}
user: 9          item: 2194       r_ui = None   est = 3.87   {'was_impossible': False}
user: 9          item: 2281       r_ui = None   est = 3.87   {'was_impossible': False}
user: 9          item: 3430       r_ui = None   est = 3.87   {'was_impossible': False}
user: 9          item: 3442       r_ui = None   est = 3.87   {'was_impossible': False}
user: 9          item: 4065       r_ui = None   est = 3.87   {'was_impossible': False}
user: 9          item: 4203       r_ui = No

[Prediction(uid='9', iid='42', r_ui=None, est=4.155300757181799, details={'was_impossible': False})]

In [39]:
result_uid9_action_est = [pred.est for pred in result_uid9]
len(result_uid9_action_est)

50

In [40]:
result_uid9_action_est[:5], result_uid9_action_est[46:]

([4.155300757181799,
  3.308443692828777,
  3.7767443813145456,
  4.730617316298199,
  4.71082417550521],
 [3.8739979556748074,
  3.8739979556748074,
  3.8739979556748074,
  3.8739979556748074])

In [41]:
max_num = np.argmax(result_uid9_action_est)
max_num

3

In [42]:
max_value = result_uid9_action_est[4]
max_value

4.71082417550521

In [43]:
[one for one in result_uid9_action_est if one == max_value]

[4.71082417550521]

In [44]:
max_idx = result_uid9_action_est.count(max_value)
max_idx

1

In [45]:
result_uid9[max_num]

Prediction(uid='9', iid='493', r_ui=None, est=4.730617316298199, details={'was_impossible': False})

In [46]:
max_iid = result_uid9[max_num][1]
max_iid

'493'

In [48]:
## 영화id는 384번
movies[movies['movieId'] == 493]

Unnamed: 0,movieId,title,genres
430,493,Menace II Society (1993),Action|Crime|Drama


In [49]:
# csv로 fit를 시킬 경우
# csv를 읽어서 fit시킬 수 잇는 형태의 객체로 만들어주어야 하고, DatasetAutoFolds
# train, testset으로 split

In [50]:
reader = Reader(rating_scale=(0.5, 5.0)) # 0.5단위, 최대값은 5.0의 범위 사용

In [51]:
data2 = Dataset.load_from_df(ratings[['userId', 'movieId', 'rating']], reader)
data2

<surprise.dataset.DatasetAutoFolds at 0x1758dcdd4a8>

In [59]:
trainset2, testset2 = train_test_split(data2, test_size=0.25, random_state=0)
trainset2

<surprise.trainset.Trainset at 0x1759200acc0>

In [60]:
testset2[:1]

[(63, 2000, 3.0)]

In [61]:
algo.fit(trainset2)

<surprise.prediction_algorithms.matrix_factorization.SVD at 0x1759200a518>

In [62]:
prediction3 = algo.test(testset2)
prediction3[:3]

[Prediction(uid=63, iid=2000, r_ui=3.0, est=3.8335972584915807, details={'was_impossible': False}),
 Prediction(uid=31, iid=788, r_ui=2.0, est=3.234025249255524, details={'was_impossible': False}),
 Prediction(uid=159, iid=6373, r_ui=4.0, est=2.8636176642618865, details={'was_impossible': False})]

In [63]:
accuracy.rmse(prediction3)

RMSE: 0.8704


0.8704112167995456

In [64]:
cross_validate(algo, data2, measures = ['RMSE', 'MAE'], cv=5, verbose=True)

Evaluating RMSE, MAE of algorithm SVD on 5 split(s).

                  Fold 1  Fold 2  Fold 3  Fold 4  Fold 5  Mean    Std     
RMSE (testset)    0.8827  0.8689  0.8742  0.8662  0.8769  0.8738  0.0059  
MAE (testset)     0.6794  0.6691  0.6728  0.6658  0.6713  0.6717  0.0045  
Fit time          3.90    3.96    4.00    3.97    3.94    3.95    0.03    
Test time         0.11    0.18    0.11    0.11    0.17    0.14    0.03    


{'test_rmse': array([0.88268926, 0.86892093, 0.87424647, 0.86615701, 0.8769211 ]),
 'test_mae': array([0.67936985, 0.66909053, 0.67278582, 0.66575924, 0.67130311]),
 'fit_time': (3.901559352874756,
  3.963409185409546,
  3.996314525604248,
  3.967418670654297,
  3.942453622817993),
 'test_time': (0.11272192001342773,
  0.18051719665527344,
  0.1077117919921875,
  0.10971283912658691,
  0.1735365390777588)}