# Etapa 1 - Treinamento de Modelo Personalized PageRank

In [1]:
import pickle
import json
import dgl
import torch
import numpy as np
import os
from scipy.sparse import dok_matrix

from tqdm.notebook import tqdm

In [2]:
from utils.incremental_encoder import IncrementalEncoder

## Preparação de Dados

### Calcula rede 
- Se já existir o arquivo `dados-processados/network.pickle` a rede não é calculada novamente

In [3]:
# Carregando Dados Processados
with open("../dados-processados/dataset.pickle", 'rb') as _file:
    dataset = pickle.load(_file)

tracks_encoder = IncrementalEncoder()
tracks_encoder.load("../dados-processados/encoding_tracks.json")

# Cria matriz esparsa
network = dok_matrix((tracks_encoder.last_index, tracks_encoder.last_index), dtype=int)

# Preenche matriz com frequências de transição
for playlist_id, playlist_name, tracklist in tqdm(dataset):
    for track_idx in range( len(tracklist)-1 ):
        current_track_id = tracklist[track_idx]
        next_track_id = tracklist[track_idx+1]
        
        network[current_track_id, next_track_id] += 1

  0%|          | 0/1000000 [00:00<?, ?it/s]

In [5]:
# Salva rede
with open("../dados-processados/network.pickle", 'wb') as _file:
    pickle.dump(network, _file)

### Conversão para Rede DGL

In [6]:
dgl_network = dgl.from_scipy(
    sp_mat = network
)

init_nodes, final_nodes = dgl_network.edges()

weights = []
for i in range(len(init_nodes)):
    weight = network[init_nodes[i], final_nodes[i]]
    weights.append( weight )

dgl_network.edata['weights'] = torch.tensor(weights, dtype=float)

with open("../dados-processados/dgl_network.pickle", 'wb') as _file:
    pickle.dump(dgl_network, _file)

## Implementação Algoritmo

In [7]:
from models.personalized_pagerank import PersonalizedPageRank

In [14]:
model = PersonalizedPageRank(
    walk_depth  = 50,
    n_rounds    = 1,
    dgl_network = dgl_network
)

In [15]:
model.predict([
    [43,328232,22]
])

[[1886,
  5326,
  8436,
  251065,
  39,
  714177,
  447066,
  1023331,
  41,
  1,
  28677,
  243440,
  18950,
  6019,
  25781,
  6975,
  2612,
  10472,
  270150,
  3744,
  339709,
  570,
  201169,
  46145,
  13919,
  332,
  5025,
  2470,
  2518,
  60312,
  6480,
  2076,
  5261,
  4646,
  3151,
  6441,
  25249,
  22804,
  11880,
  30,
  39660,
  4,
  9,
  2977,
  2013,
  538,
  2,
  2474,
  44351,
  13117,
  23,
  52770,
  1763,
  38586,
  705,
  6424,
  2108,
  6344,
  8554,
  36849,
  40511,
  40491,
  1649,
  5024,
  107457,
  166103,
  166104,
  1334,
  29861,
  14082,
  7607,
  70,
  71933,
  71910,
  4838,
  1159,
  3080,
  11812,
  10833,
  1337,
  143738,
  7323,
  215872,
  714178,
  714179,
  712343,
  28681,
  2611,
  11137,
  28846,
  7866,
  7904,
  28,
  3923,
  41845,
  22146,
  2251,
  171790,
  18955,
  18947,
  178727,
  193993,
  18948,
  18917,
  2314,
  2425,
  11560,
  45680,
  17384,
  42584,
  3140,
  2328,
  2024,
  13366,
  43033,
  8165,
  113452,
  323,
  417