# Modelo 1 - Personalized PageRank

In [1]:
import pickle
import json
import dgl
import torch
import numpy as np
import os
from scipy.sparse import dok_matrix

from tqdm.notebook import tqdm

## Preparação de Dados

### Calcula rede 
- Se já existir o arquivo `dados-processados/network.pickle` a rede não é calculada novamente

In [2]:
if not os.path.isfile("../dados-processados/network.pickle"):

    # Carregando Dados Processados
    with open("../dados-processados/dataset.pickle", 'rb') as _file:
        dataset = pickle.load(_file)

    # Carregando Dicionario Encoder de Músicas
    with open("../dados-processados/encoding_tracks.json", 'r') as _file:
        tracks_encoder = json.load(_file)

    # Descobre quantidade de musicas únicas
    track_count = len(tracks_encoder.keys())

    # Cria matriz esparsa
    network = dok_matrix((track_count,track_count), dtype=int)

    # Preenche matriz com frequências de transição
    for playlist_name, tracklist in tqdm(dataset):
        for track_idx in range( len(tracklist)-1 ) :
            current_track_id = tracklist[track_idx][1]
            next_track_id = tracklist[track_idx+1][1]

            network[current_track_id, next_track_id] += 1
    
    # Salva rede
    with open("../dados-processados/network.pickle", 'wb') as _file:
        pickle.dump(network, _file)
else:
    # Carrega rede já criada
    with open("../dados-processados/network.pickle", 'rb') as _file:
        network = pickle.load(_file)

  network = pickle.load(_file)


### Normalizando Pesos de Arestas

Calcula a soma total de transições por musica

In [3]:
sums = network.sum(axis=1)

Gera copia da Rede com arestas normalizadas

In [4]:
if not os.path.isfile("../dados-processados/normalized_network.pickle"):

    normalized_network = dok_matrix(network.shape, dtype=float)

    for key in tqdm(network.keys()):
        row_idx = key[0]
        normalized_network[key] = network[key] / float(sums[row_idx])

    with open("../dados-processados/normalized_network.pickle", 'wb') as _file:
        pickle.dump(normalized_network, _file)
else:
    # Carrega rede já criada
    with open("../dados-processados/normalized_network.pickle", 'rb') as _file:
        normalized_network = pickle.load(_file)

  normalized_network = pickle.load(_file)


### Conversão para Rede DGL

In [5]:
if not os.path.isfile("../dados-processados/dgl_network.pickle"):
    dgl_network = dgl.from_scipy(
        sp_mat = normalized_network
    )

    init_nodes, final_nodes = dgl_network.edges()

    weights = []
    for i in range(len(init_nodes)):
        weight = normalized_network[init_nodes[i], final_nodes[i]]
        weights.append( weight )

    dgl_network.edata['weights'] = torch.tensor(weights)

    with open("../dados-processados/dgl_network.pickle", 'wb') as _file:
        pickle.dump(dgl_network, _file)
else:
    # Carrega rede já criada
    with open("../dados-processados/dgl_network.pickle", 'rb') as _file:
        dgl_network = pickle.load(_file)

## Implementação Algoritmo

In [6]:
from models.personalized_pagerank import PersonalizedPageRank

In [24]:
model = PersonalizedPageRank(
    walk_depth  = 100,
    n_rounds    = 5,
    dgl_network=dgl_network
)

In [35]:
model.predict([
    [43,328232,22]
])

[[1950,
  623,
  207,
  343,
  84,
  3651,
  638,
  11050,
  365,
  344,
  19,
  79,
  202,
  339,
  338,
  234,
  432,
  1821,
  1194,
  14,
  412,
  820,
  833,
  13182,
  2825,
  3690,
  41,
  13454,
  215,
  617,
  6065,
  6538,
  3764,
  654,
  230,
  471,
  1692,
  2266,
  4548,
  5337,
  779,
  212,
  931,
  213,
  350,
  603,
  11403,
  1459,
  750,
  764,
  61,
  82,
  615,
  2013,
  30,
  1263,
  2572,
  1034,
  2,
  880,
  716,
  983,
  6799,
  1259,
  24758,
  390,
  5707,
  1201,
  2260,
  318,
  2463,
  991,
  746,
  2417,
  3544,
  1038,
  2948,
  1112,
  3423,
  376,
  311,
  3315,
  912,
  25699,
  2849,
  569,
  609,
  1113,
  352,
  2062,
  1273,
  354,
  205,
  663,
  1075,
  351,
  1292,
  628,
  826,
  2002,
  1694,
  3179,
  7292,
  86982,
  3688,
  337,
  3499,
  18082,
  17032,
  83,
  27026,
  11931,
  576,
  1197,
  881,
  545,
  4081,
  1905,
  1910,
  7051,
  2420,
  570,
  9076,
  909,
  221,
  848,
  664,
  367,
  12679,
  3691,
  1140,
  2271,
  1695,
  