In [1]:
import utils as ut
import scipy as sc
import numpy as np
from tqdm.notebook import tqdm
import scipy.sparse as sparse
from joblib import Parallel, delayed


def hadamard_matmal(r,X,B,cols,ind):
    idx = cols[ind[r]:ind[r+1]]
    return (B[idx][:,idx]).sum(0)

def WDLAE_cg(pc, X, XtX_lam, weight, nonzeros, path, users, items):
    ind = X.indptr
    B = np.zeros((items,items))
    XWX = (X.T.dot(weight*X)).toarray() # 30 seconds
    temp = Parallel(n_jobs=36)(delayed(hadamard_matmal)(user,X,B,nonzeros[1],ind) for user in range(users)) # 2 minutes
    WXXB = (weight - 1) * sparse.csr_matrix((np.concatenate(temp), (nonzeros[0], nonzeros[1])))
    XtXB = np.array(X.T.dot(WXXB) + XtX_lam.dot(B)) # 1.5 minutes...
    r = XWX - XtXB
    z = pc.dot(r) # 1.5 minutes 
    p = z.copy()
    for _ in tqdm(range(5)):
        temp = Parallel(n_jobs=36)(delayed(hadamard_matmal)(user,X,p,nonzeros[1],ind) for user in range(users))
        WXXp = (weight - 1) * sparse.csr_matrix((np.concatenate(temp), (nonzeros[0], nonzeros[1])))
        XtXp = np.array(X.T.dot(WXXp) + XtX_lam.dot(p))
        rtz = np.inner(r.flatten('F'),z.flatten('F'))  
        alpha = rtz / max(1e-32,np.inner(p.flatten('F'),XtXp.flatten('F')))
        B += alpha * p
        r -= alpha * XtXp
        rnorm = np.linalg.norm(r.flatten('F'))
        if rnorm <= 1e-8:
            ut.evaluate(B,path)
            break
        if _ % 5 == 4:
            ut.evaluate(B,path)
        z = pc.dot(r)
        beta = np.inner(r.flatten('F'),z.flatten('F')) / rtz
        p = z + beta * p
    

#paths = ['/efs/users/hsteck/public/datasets/msd_data/pro_sg', '/efs/users/hsteck/public/datasets/movielens20mio/pro_sg']
path = '/efs/users/hsteck/public/datasets/movielens20mio/pro_sg'

Ps = [0.01,0.1,0.5,0.9,0.99]
weights = [1,2,5,10,20]
X = ut.load_train_data(path)
users,items = X.shape[0],X.shape[1]
XtX = (X.T@X).toarray()
nonzeros = X.nonzero()

In [None]:
for p in Ps:
    XtX_lam = XtX.copy()
    XtX_lam[np.diag_indices(items)] *= (1 + p/(1-p))
    pc = np.linalg.inv(XtX_lam)
    for weight in weights:
        print('running with (p,weight): ', p,' , ', weight)
        WDLAE_cg(pc, X, XtX_lam, weight, nonzeros, path, users, items)

running with (p,weight):  0.01  ,  1


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.38376 (0.00221)
Val Recall@20=0.35643 (0.00271)
Val Recall@50=0.47666 (0.00294)
Test Metrics
Test NDCG@100=0.37786 (0.00219)
Test Recall@20=0.35506 (0.00270)
Test Recall@50=0.47530 (0.00296)
running with (p,weight):  0.01  ,  2


  0%|          | 0/5 [00:00<?, ?it/s]

Val NDCG@100=0.37851 (0.00220)
Val Recall@20=0.35216 (0.00272)
Val Recall@50=0.47564 (0.00297)
Test Metrics
Test NDCG@100=0.37215 (0.00218)
Test Recall@20=0.35062 (0.00271)
Test Recall@50=0.47405 (0.00298)
running with (p,weight):  0.01  ,  5


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.36362 (0.00218)
Val Recall@20=0.33960 (0.00274)
Val Recall@50=0.46576 (0.00301)
Test Metrics
Test NDCG@100=0.35760 (0.00217)
Test Recall@20=0.33581 (0.00273)
Test Recall@50=0.46175 (0.00304)
running with (p,weight):  0.01  ,  10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.34215 (0.00216)
Val Recall@20=0.31901 (0.00273)
Val Recall@50=0.44719 (0.00306)
Test Metrics
Test NDCG@100=0.33749 (0.00214)
Test Recall@20=0.31724 (0.00274)
Test Recall@50=0.44334 (0.00310)
running with (p,weight):  0.01  ,  20


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.30738 (0.00207)
Val Recall@20=0.28470 (0.00270)
Val Recall@50=0.41627 (0.00310)
Test Metrics
Test NDCG@100=0.30413 (0.00205)
Test Recall@20=0.28389 (0.00271)
Test Recall@50=0.41398 (0.00314)
running with (p,weight):  0.1  ,  1


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.40318 (0.00217)
Val Recall@20=0.37375 (0.00268)
Val Recall@50=0.49495 (0.00287)
Test Metrics
Test NDCG@100=0.39683 (0.00216)
Test Recall@20=0.37069 (0.00268)
Test Recall@50=0.49247 (0.00289)
running with (p,weight):  0.1  ,  2


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.39848 (0.00216)
Val Recall@20=0.36964 (0.00268)
Val Recall@50=0.49351 (0.00289)
Test Metrics
Test NDCG@100=0.39099 (0.00215)
Test Recall@20=0.36676 (0.00268)
Test Recall@50=0.49100 (0.00291)
running with (p,weight):  0.1  ,  5


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.38440 (0.00214)
Val Recall@20=0.35787 (0.00268)
Val Recall@50=0.48423 (0.00292)
Test Metrics
Test NDCG@100=0.37743 (0.00213)
Test Recall@20=0.35244 (0.00268)
Test Recall@50=0.48084 (0.00296)
running with (p,weight):  0.1  ,  10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.36522 (0.00212)
Val Recall@20=0.33848 (0.00268)
Val Recall@50=0.46931 (0.00297)
Test Metrics
Test NDCG@100=0.35879 (0.00210)
Test Recall@20=0.33655 (0.00270)
Test Recall@50=0.46458 (0.00301)
running with (p,weight):  0.1  ,  20


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.33272 (0.00203)
Val Recall@20=0.30955 (0.00266)
Val Recall@50=0.44171 (0.00301)
Test Metrics
Test NDCG@100=0.32836 (0.00203)
Test Recall@20=0.30623 (0.00267)
Test Recall@50=0.43714 (0.00305)
running with (p,weight):  0.5  ,  1


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.42114 (0.00215)
Val Recall@20=0.38453 (0.00267)
Val Recall@50=0.51152 (0.00281)
Test Metrics
Test NDCG@100=0.41398 (0.00214)
Test Recall@20=0.38068 (0.00267)
Test Recall@50=0.51140 (0.00284)
running with (p,weight):  0.5  ,  2


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.42045 (0.00215)
Val Recall@20=0.38572 (0.00267)
Val Recall@50=0.51257 (0.00282)
Test Metrics
Test NDCG@100=0.41295 (0.00214)
Test Recall@20=0.38134 (0.00268)
Test Recall@50=0.51319 (0.00285)
running with (p,weight):  0.5  ,  5


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.41242 (0.00213)
Val Recall@20=0.37931 (0.00266)
Val Recall@50=0.50949 (0.00283)
Test Metrics
Test NDCG@100=0.40449 (0.00212)
Test Recall@20=0.37468 (0.00268)
Test Recall@50=0.50728 (0.00287)
running with (p,weight):  0.5  ,  10
Validation Metrics
Val NDCG@100=0.39862 (0.00210)
Val Recall@20=0.36597 (0.00266)
Val Recall@50=0.49824 (0.00286)
Test Metrics
Test NDCG@100=0.39040 (0.00209)
Test Recall@20=0.36171 (0.00267)
Test Recall@50=0.49464 (0.00290)
running with (p,weight):  0.5  ,  20


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.37728 (0.00206)
Val Recall@20=0.34591 (0.00266)
Val Recall@50=0.48002 (0.00291)
Test Metrics
Test NDCG@100=0.36890 (0.00205)
Test Recall@20=0.34042 (0.00267)
Test Recall@50=0.47553 (0.00294)
running with (p,weight):  0.9  ,  1


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.39154 (0.00208)
Val Recall@20=0.35083 (0.00262)
Val Recall@50=0.48152 (0.00282)
Test Metrics
Test NDCG@100=0.38423 (0.00207)
Test Recall@20=0.34737 (0.00263)
Test Recall@50=0.47893 (0.00284)
running with (p,weight):  0.9  ,  2


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.40200 (0.00210)
Val Recall@20=0.36322 (0.00264)
Val Recall@50=0.49103 (0.00281)
Test Metrics
Test NDCG@100=0.39513 (0.00208)
Test Recall@20=0.35902 (0.00264)
Test Recall@50=0.48994 (0.00283)
running with (p,weight):  0.9  ,  5


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.41162 (0.00211)
Val Recall@20=0.37500 (0.00264)
Val Recall@50=0.50248 (0.00281)
Test Metrics
Test NDCG@100=0.40362 (0.00209)
Test Recall@20=0.36875 (0.00266)
Test Recall@50=0.50232 (0.00283)
running with (p,weight):  0.9  ,  10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.41361 (0.00211)
Val Recall@20=0.37811 (0.00265)
Val Recall@50=0.50683 (0.00281)
Test Metrics
Test NDCG@100=0.40470 (0.00209)
Test Recall@20=0.37255 (0.00267)
Test Recall@50=0.50585 (0.00284)
Validation Metrics
Val NDCG@100=0.40909 (0.00210)
Val Recall@20=0.37289 (0.00265)
Val Recall@50=0.50615 (0.00282)
Test Metrics
Test NDCG@100=0.39957 (0.00208)
Test Recall@20=0.36965 (0.00267)
Test Recall@50=0.50261 (0.00286)
running with (p,weight):  0.99  ,  1


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.31967 (0.00199)
Val Recall@20=0.28118 (0.00248)
Val Recall@50=0.39641 (0.00277)
Test Metrics
Test NDCG@100=0.31210 (0.00197)
Test Recall@20=0.27359 (0.00247)
Test Recall@50=0.39177 (0.00279)
running with (p,weight):  0.99  ,  2


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.33302 (0.00201)
Val Recall@20=0.29503 (0.00252)
Val Recall@50=0.41304 (0.00280)
Test Metrics
Test NDCG@100=0.32532 (0.00199)
Test Recall@20=0.28911 (0.00252)
Test Recall@50=0.40782 (0.00281)
running with (p,weight):  0.99  ,  5


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.35527 (0.00204)
Val Recall@20=0.31900 (0.00258)
Val Recall@50=0.43834 (0.00282)
Test Metrics
Test NDCG@100=0.34773 (0.00202)
Test Recall@20=0.31462 (0.00258)
Test Recall@50=0.43350 (0.00284)
running with (p,weight):  0.99  ,  10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.37238 (0.00206)
Val Recall@20=0.33648 (0.00260)
Val Recall@50=0.45764 (0.00281)
Test Metrics
Test NDCG@100=0.36496 (0.00203)
Test Recall@20=0.33310 (0.00260)
Test Recall@50=0.45406 (0.00284)
running with (p,weight):  0.99  ,  20


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Metrics
Val NDCG@100=0.38606 (0.00207)
Val Recall@20=0.35133 (0.00262)
Val Recall@50=0.47410 (0.00281)
Test Metrics
