In [1]:
import numpy as np
import pandas as pd
import math
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import StratifiedKFold


## Read data

In [2]:
features = np.loadtxt('../../data/usr_his_sample/features.txt')

In [3]:
features.shape

(9817, 784)

In [4]:
item_list = np.loadtxt('../../data/usr_his_sample/item_list.txt').astype(int)
len(item_list)

9816

In [5]:
import pickle
usr_list = None
with open ('../../data/usr_his_sample/usr_list_5-20.txt', 'rb') as fp:
    usr_list = pickle.load(fp) 
len(usr_list)

44709

## Functions

In [6]:
class  Cluster:
    def __init__(self, mean, cov, weight = 1):
        self.mean = mean
        self.cov = cov
        self.weight = weight

In [7]:
#public
def cal_logp_item_cluster(feature_item, cluster):
    return multivariate_normal.logpdf(feature_item, mean = cluster.mean, cov = cluster.cov)

#not public
def cal_logp_user_cluster(user, features, cluster):
    return np.sum([cal_logp_item_cluster(features[item], cluster) for item in user])

#public
def cal_ptrans_user(user, features, clusters):
    return np.sum([cluster.weight * cal_logp_user_cluster(user, features, cluster) for cluster in clusters])

#public
def cal_ptrans_item_user(item, user, features, clusters):
    return np.sum([cal_logp_item_cluster(features[item], cluster) + cal_logp_user_cluster(user, features, cluster)\
                   for cluster in clusters])

In [8]:
def  find_clusters(items, features, n_cluster):
    '''
        itemset: list indexes of images corresponding to video_img. Discard indexes where feature is None.
        usr_list: |U| users, each user is a sets of items.
        feature: gist or lab feature coressponding image in video_img.
    '''
    label_init = np.random.randint(0, n_cluster, len(items))

    gmm = GaussianMixture(n_components=n_cluster, covariance_type = 'full', max_iter = 100, random_state=0)
    gmm.means_init = np.array([features[items[label_init == i]].mean(axis=0) for i in range(n_cluster)])
    gmm.fit(features[items])
    label_pred = gmm.predict(features[items])
    clusters = [np.where(label_pred == i)[0] for i in range(n_cluster)]

#     means = gmm.means_
#     covs = gmm.covariances_
#     weights = gmm.weights_
    return gmm

In [9]:
def  validate(gmm, usr_list, features):
    means, covs, weights = gmm.means_, gmm.covariances_, gmm.weights_
    clusters = [Cluster(means[i], covs[i], weights[i]) for i in range(len(means))]
    
    p_users = []
    for usr in usr_list:
        p_user = cal_ptrans_user(usr, features, clusters)
        p_users.append(p_user)
        print p_user
    return np.array(p_users).mean()

## Run

In [10]:
gmm = find_clusters(item_list, features, n_cluster=32)
gmm

GaussianMixture(covariance_type='full', init_params='kmeans', max_iter=100,
        means_init=array([[ 0.,  0., ...,  0.,  0.],
       [ 0.,  0., ...,  0.,  0.],
       ...,
       [ 0.,  0., ...,  0.,  0.],
       [ 0.,  0., ...,  0.,  0.]]),
        n_components=32, n_init=1, precisions_init=None, random_state=0,
        reg_covar=1e-06, tol=0.001, verbose=0, verbose_interval=10,
        warm_start=False, weights_init=None)

In [11]:
means, covs, weights = gmm.means_, gmm.covariances_, gmm.weights_
clusters = [Cluster(means[i], covs[i], weights[i]) for i in range(len(means))]

In [12]:
validate(gmm, usr_list[7328:], features) # 7328

-3528.61476162
34218.9765639
-13969.8031379
46962.7788152
30982.1453779
27319.7196003
35945.0237487
-7832.28814256
43112.4859409
-16428.693983
41975.6221036
-8183.15476085
21810.2373705
12969.4291319
27668.4935832
20653.7806303
17740.7605117
21082.5199705
27760.9452636
33776.1509375
30518.9836216
-2963.94415283
37827.9629531
3401.10588197
23480.7008712
19470.8150014
22488.3404504
26294.0441408
24382.2032169
21499.8454553
15278.277653
24391.5903791
-15320.3724533
65765.9732817
35839.6387383
41122.5206242
27894.7198991
15533.1933079
28202.3374108
-2027.46685319
28965.351007
19732.2076846
23771.8311184
-125023.734204
27564.1237485
4356.22182915
-2314.21035147
35541.2368956
-5009.07317868
27428.5313265
30005.2040026
11197.3142404
12266.9909036
-1224.28608591
17593.1739473
15547.1238215
1323.39430185
21119.3819984
6182.51339512
21615.532076
-99362.5560728
15082.188853
21498.9090388
1955.98118823
40302.8677381
12093.6345051
17533.7353981
-8520.53433969
34130.7040967
22976.8998431
56437.82978

KeyboardInterrupt: 

In [None]:
validate(gmm, usr_list, features) # 0 - 7327

-31850.1404828
-6143.73515657
29354.1554098
8518.13375911
31340.7814866
17742.9248339
54224.4822351
18224.2031331
10261.7468958
39544.7328248
14894.4643181
26253.3539523
23294.7801602
-13082.7006791
-38718.5621533
47640.4770991
31433.0202305
-10969.7994361
17345.593453
32134.1321142
22564.0813923
62341.9529762
23980.3357243
-44767.3648859
17463.9870255
20884.7806018
39136.4296124
23611.6957584
25746.3410875
30481.1903191
41970.3876774
18780.0802783
7024.48831537
29015.3352597
-681.140858564
31225.7903152
9375.82406626
16568.0848705
20212.9285236
26045.5452173
28192.4330529
37954.2484437
21929.8527316
-10860.1443307
22149.428705
29062.7957352
21360.5226062
41830.123574
22153.499827
-43681.9559218
25715.2826334
13548.4877847
44606.9258568
34540.4542459
9171.4057144
25302.3561005
49397.2717086
32236.708253
8914.45145442
2978.90163102
36309.057612
-82800.4721003
31561.6247359
26099.5212693
-33226.2250969
-13600.683571
50874.4376866
23160.3056734
31542.0635102
20148.9895759
53897.4191823
17

54914.9808573
12136.5128693
-16172.610647
10985.6125229
-15335.0959347
28886.6738279
27115.9960161
29399.6733234
41165.332871
33103.0629948
40760.2582459
41966.9547911
-2845.90631541
20241.3973221
17407.8804204
37152.5120846
60704.2228076
36652.9762164
-9308.98380424
64937.1472704
24191.3469292
27000.8934273
19053.9538372
22911.8596813
57038.5049578
-41965.5492724
-8232.23980094
15980.9463998
-24561.6837141
-12212.8122201
39470.4094227
8894.09086542
-5126.09597768
19374.7990913
-9568.93769824
29601.0743696
40825.9985464
21939.051077
28444.1651266
44424.7336243
-12220.2344628
21641.1700903
22932.4524895
26430.499424
19710.1684564
-74681.0764205
-35041.3164546
-29684.0305278
20906.9303448
19742.3272412
-4052.46175895
-1576.04235731
20692.9424554
22339.8565453
11104.4427981
7769.35319521
20875.1466716
8421.6496699
38813.4772626
9967.75702885
32501.5091949
-16071.4224172
22273.4970954
29921.3039558
26510.5942937
-2886.36701219
25763.0323671
-1421.83432397
-23701.4395368
-32060.1381079
2049

2491.56111587
-144164.609272
20548.0002695
-31027.6600923
-109005.710317
-54215.6206901
26362.8345803
-7254.44911768
6203.67436548
23465.3139007
22738.5001115
28432.8564044
29017.4128853
36476.4538802
37234.6382138
25814.7720086
50817.5913575
34140.4927083
20541.1300066
6762.01153533
-11795.2146574
11271.1911971
26260.9186231
25488.0196842
23522.4741749
30689.8890734
16346.8212709
17471.263119
37551.1293656
32212.7675099
32028.0720822
24863.6653683
7664.60167558
35163.6058057
-11488.3167404
-3610.11307902
27740.3075603
20627.2084128
-38256.1533838
51165.93783
20011.4806345
31190.5198061
-1290.71374596
6825.83051759
14540.5326248
20659.8648547
19757.6833928
5105.36097649
5264.08417494
33482.3285132
24268.5825668
18449.0100731
44486.0236623
26946.8888156
37297.9758237
26384.881483
15343.7380645
51479.9596067
26968.3562292
19672.8896478
22824.5436417
-66155.545414
39909.7345981
-14987.1415419
24308.7751023
11667.1818509
17988.5792692
-20585.3794886
28897.7309285
7948.01975761
34875.385698

9446.35064622
-75295.8842383
6515.77975167
6771.49784161
32983.9338316
34834.6615125
15434.9621112
32773.3353256
-1497.62081173
30705.3017362
24522.4342069
24035.7000286
-8905.50807811
23102.3076357
35929.471849
-15769.0277032
7668.91441519
195.285271474
-2116.49283841
20249.0995966
22648.2478047
53664.6905239
-18273.3843826
48781.7137609
46381.2510391
34315.0969515
-3273.55146435
35354.8784896
6129.56918701
53698.4673494
28409.9678194
36487.9303433
41898.299688
16346.9984798
9020.14805775
20992.13132
22810.4674063
13636.1937878
-50995.6114833
45506.1878939
667.098053153
703.829558219
7373.17698191
47989.20099
22041.5357272
27462.6325262
72168.4188612
25923.6235894
27656.0872457
40513.8254377
35324.472306
34746.3070672
46346.5493847
17682.8355028
26461.2147019
30443.384295
9470.97659972
-13266.0673893
48348.6310292
21313.6842754
24210.1982859
10675.7182661
19417.4722874
23107.4343652
25765.5065769
29021.867624
25496.8377011
23437.9310732
6805.80832408
25019.8922169
-28707.4037537
-2489

-885.253910772
23212.69509
19949.6233025
62332.13222
-2838.40447519
34082.475985
34976.3464115
23482.5588731
34766.1809665
53433.8660406
27263.4382249
16165.5998011
2740.57742439
-26719.7007222
34954.7551281
34824.3955918
19600.7222791
-26318.3216995
36536.4006815
54914.1135069
35802.4023095
35138.0777521
7.50899098538
25444.4630754
30336.9976596
31185.143534
49767.6707041
24122.2383498
53650.4028484
25155.001585
17460.1277149
27278.8344103
5510.42288123
32850.5337855
30814.4301477
-50488.7016156
16201.355908
-1797.91694926
-115816.726237
7346.6743281
29131.0876631
24854.7822958
13812.0143114
4993.78657166
26602.5410613
8609.73622585
18475.611868
28362.4714953
8367.57723313
38762.4328535
30730.1757182
21839.2393272
35366.2987251
-32695.3262597
29769.874489
30737.8986838
6583.23777544
22280.5615664
31108.6760158
16292.0520728
45893.7251229
-7160.84972044
60618.0009783
30791.9352548
27381.7590562
-36767.6823408
-8842.60544595
32972.0028276
21568.8877005
15673.684649
36193.0034838
-22678.

-30279.0401509
30344.8052952
70765.8745424
21396.7078012
24575.3365024
-32312.8987881
62697.2466857
12452.1185702
31579.7027318
32397.3601275
24168.2673403
20954.6996918
12381.8727871
21864.8723734
65631.58664
21808.6003163
-1572.93644736
11918.2709797
-44254.420108
-5639.72294117
29334.4092585
59273.8239131
7040.16063536
16443.0075308
48774.6946594
-32890.9926741
36652.4926892
20743.8901417
-27839.5512073
57838.4951594
30725.1296447
11551.5069795
27371.3219327
880.715748444
19335.3325353
-754.175771634
19089.8338093
57399.3172444
12082.1153088
30100.7401154
1114.13946534
-23638.8972094
32154.2386012
17015.0861554
-29172.0018335
23879.9156438
32433.854416
17504.4255051
-1918.06319109
-38810.3482435
1714.76021813
10037.3990552
32162.0066456
15959.4363943
43329.7491907
12975.7833388
23676.9775589
43038.7303032
54104.8804097
233.735132073
42243.8561493
25088.8608502
2743.41538786
24108.1639521
32889.237671
18432.5045647
-20838.3499136
24657.3920206
-18006.7193967
41700.4905362
-80630.5857

19067.7428597
24785.1199604
-151653.502541
8681.90189496
24927.5871069
23561.7615488
45323.9251574
25610.2127676
22729.3888082
-39173.4325954
6789.43628288
17852.4409586
31326.3135332
1412.78782306
23706.2736355
17098.6299526
51353.5389366
18515.8021117
24603.0371585
12072.5538261
19647.8832328
13191.4514168
20772.8304047
11654.913992
-3817.95404456
31237.9475245
41849.0409223
32434.474207
-23758.3726699
32327.2942968
33814.4244254
24053.1208142
16782.2479187
38467.9774201
21052.3465459
24698.0268206
41507.5123732
30798.7719059
1764.66018133
27390.8569685
-68301.652315
15820.7356361
23157.62637
50547.9507007
22423.1646713
25402.2870268
21229.6898642
-4253.94865189
42346.685193
25403.2935996
35357.9668538
30591.5434453
30555.4841966
-3721.84999934
-87166.7224357
37360.3082869
38246.8569052
3505.48083151
23753.8548575
27676.4852438
7113.66236676
15691.4995605
39780.1843212
47230.5344487
28529.4274424
3510.2982579
19933.4749211
21999.213745
27883.8770197
54540.074107
22511.2603941
27691.1

7724.17261351
36158.4478618
21802.295713
26246.8105647
23687.7071875
37791.8041522
32736.7492514
38543.2407773
33520.490843
17703.0535686
42378.6279867
14698.7726871
45417.8135008
-11452.9336789
36356.2926856
39350.1882289
14285.1195075
43431.1443753
10345.999548
38126.3623051
-69383.6622231
15250.6013682
12144.5676424
27090.7250905
25191.4209116
51021.1016142
47288.3477551
21868.3024811
42182.266685
33202.5304606
47790.9892116
9054.84461563
14918.6592317
5113.75111236
47725.0483318
25354.8522554
55171.212806
8532.17041947
-36435.2169693
21861.633876
32306.120803
24683.3951179
24545.1339352
34810.0303389
-27874.0098076
20488.171784
28753.9758193
24848.6643886
26369.0901625
58440.725201
42710.0848793
54844.766845
-107572.886645
2616.38749052
4074.92829757
26119.6193848
25912.0222001
-57451.8174634
30118.7827427
33080.875095
51460.7874564
-8381.7113566
20881.5580418
23992.3389154
-534.468106867
44835.040329
16629.5139688
-6264.36538766
8946.23978265
-46344.5293272
36918.6043263
-26549.23

12089.0055744
-8440.95264529
25027.8184207
-81027.0678534
-22177.3234634
8208.98464282
16999.9874186
6774.42600005
68409.6109718
-730.436096244
19111.531143
66164.5065719
7904.1571982
8251.81765758
25297.3411169
5204.1296529
25769.7773258
-2419.95361021
15476.4318846
5517.68631178
-43954.9773506
23724.4362407
60414.045331
26638.1017306
4442.47877977
16005.1251649
22186.6839041
27186.9435446
294.574311843
-12036.5485861
-25975.6409752
32490.210557
7386.5749511
19868.3493185
23243.7751838
21945.217416
-41572.1610927
19087.4171109
7558.56458915
24526.5536873
-142681.313966
27194.6250064
448.696446193
30220.9291934
13130.5955436
20391.634852
-12507.0898341
-566.577737644
32710.7743382
33067.0345117
-22402.8385335
3566.07052633
39209.9732847
25886.6387384
43078.3513803
-21930.1309354
20710.7203107
23933.7335182
23989.1616718
28728.7111043
20639.0541981
29070.7216487
46309.2848028
18269.3591569
16583.9831127
16694.5979593
45085.6290619
497.776911983
-11656.2222437
-34732.4534761
27103.045437

33186.6391466
29544.835027
8299.927159
-67020.3751741
32655.9869299
-12238.4624741
30104.4820704
19613.0581802
-73367.2848797
32154.1673811
5287.58911028
60488.9282655
35966.5385153
21321.6393872
-113854.658183
8196.03791956
29462.0540412
26141.5954455
35437.0141032
27890.110337
30489.6860078
19195.580953
27544.1279024
973.015151164
26556.0935858
28259.734338
37622.2432117
30552.0270015
17412.3595037
24387.2327314
-124876.993793
26629.0059467
35397.5150278
24408.2047579
48217.4625867
-9480.41460544
34490.7377115
45047.6672995
52730.4106265
45089.3799806
19321.7511688
-1683.2746379
24440.5731549
32220.1987872
27796.5348655
42398.1851139
-13820.4120406
-107534.678399
15611.3551826
11844.0793276
58830.0007437
26344.4958087
9270.23718241
-17457.0926276
23167.0449402
54078.8816363
58672.7690169
-25665.2975627
3180.689347
9390.7421639
22865.6724096
25168.3407104
-38941.3899258
24926.2680454
19641.0911653
-16598.0595168
12094.5105531
20045.4002018
14028.2778471
43271.4130194
3258.31747272
431

27773.2525001
-34691.9145898
-38484.437305
15009.0601725
11641.3483986
25996.7142982
31920.5134171
48349.3082926
27781.5577669
-13458.7239846
-39044.4091524
17259.2183252
44345.3044821
-569.008722683
3231.95497732
21745.1866587
-3882.54707716
49915.3718885
32607.652974
12828.4118532
19509.7822279
23995.576578
17409.8370102
2363.63931696
-34973.0560704
16293.4561488
8471.22456717
24761.1978102
23946.5431585
24494.025385
45884.9782103
14019.4347828
19229.0891969
13093.9170769
39045.8774256
24345.1796454
41207.1711295
32444.8605201
20196.5625532
2492.41868198
18723.0250269
-145069.49076
20976.7634342
27900.0978784
-12135.9621596
35651.301505
22441.5619665
25075.82176
-89275.9309243
29030.6906198
34945.3046248
21490.8896225
48572.64136
18296.871081
24316.0593784
-36354.7698923
24570.3316824
9110.54621308
7412.14008113
29216.2058216
53506.3931288
20610.1747163
26242.2911765
-95771.260461
18838.7464625
27792.1520812
12162.4031277
24640.47727
-23952.983485
38262.8063537
63670.0222627
27761.09

26991.1285832
43288.0540242
15595.0116251
20359.0739947
23008.8773568
24492.2334874
48154.6126089
2824.55734109
39576.3317384
22446.1895905
21827.1601481
23529.0894734
30740.8307147
20479.4002815
7719.53805987
23739.4456979
-3760.4583243
27950.6309738
15782.268596
11293.149708
10509.5054344
25262.2704847
19598.160474
69165.9249732
26071.5323093
13244.862355
21669.4217942
-116613.012753
23577.8917604
7402.66288783
23239.6662616
23616.1720434
57.1824141621
62386.7466296
28589.9540622
24049.6933073
22509.4370368
25712.4813431
-1797.59851985
-3608.76919462
28400.3946386
24216.8787767
28176.1731173
29296.5076513
-32566.0278543
8808.11291791
28063.7953474
44906.2044172
38425.3835309
43939.3723677
6673.76076839
1311.81099876
23866.6276078
10602.4447946
5508.43307721
17288.4029082
28038.1348326
18986.5024101
47611.3547284
6081.60628761
28146.4889666
24655.62398
16038.2809757
33756.2198273
26591.6185832
21500.1474968
20267.6958309
24315.578983
-229.460880018
25254.3430077
24731.6693087
17015.03

23190.680739
25040.5498496
18028.6850216
35421.7135792
-41237.5133116
33599.2486317
27259.8047316
-59644.7227873
21326.0220691
-60565.6945089
22118.6871192
9355.86255072
27041.7472994
38328.6894092
39688.7498683
15522.8271528
31236.0754736
51628.8462175
36630.5270112
30744.128863
22119.2001898
19817.8671602
20193.6018005
-34162.0759703
32081.9807323
22674.7254873
19275.3211117
-6894.03010239
14015.4151593
21602.2879615
20946.9570493
41084.7598588
25360.3413211
37576.2833219
52337.5163536
12881.9472151
20206.226313
12706.5973521
33977.9780001
-20892.9405823
22950.7222356
20423.8800754
5281.45475918
15245.4855872
12084.890956
-51950.8695641
24431.8885573
23250.3464242
21314.1280799
31650.2606137
20672.1583989
33229.889001
5144.92345286
-1488.40289437
55098.7322836
24786.031256
30555.1721105
-77908.0014825
34540.7642353
-89712.5991462
-32426.582431
24637.9090103
24329.5421233
27307.5952182
25050.235557
33159.3871444
26510.3152731
9668.86864381
38461.3871464
23209.8231262
40036.0290583
153