In [34]:
import dgl
import torch
import math
import numpy as np
import pandas as pd
import random

from model import BatchSampler, SAGEModel, SAGENet
from sage_ns import Config, build_graph, train_test_split
from torch.utils.data import IterableDataset, DataLoader
import torch.nn as nn

In [35]:
conf = Config()
conf.__dict__

{'input_path': '/data/zsj/sage/input/friend.txt',
 'output_path': '/data/zsj/sage/output/user.pkl',
 'device': device(type='cuda', index=0),
 'test_ratio': 0.05,
 'hard_neg_ratio': 0.2,
 'batch_size': 1024,
 'batch_size_test': 4096,
 'batch_size_export': 256,
 'fanouts': [10, 20],
 'num_workers': 2,
 'feat_dim_dict': {'id': 64},
 'lr': 0.005,
 'num_epochs': 5,
 'top_k': 200,
 'redisClient': Redis<ConnectionPool<Connection<host=sg-proxy01.starmaker.co,port=22122,db=0>>>,
 'redisPrefix': 'friend:friend_sage_ns:',
 'redisTTL': 259200}

In [3]:
g, user_ids = build_graph(conf.input_path)
g, test_data = train_test_split(g, conf.test_ratio)
test_data = torch.stack(test_data, dim=1)

Graph(num_nodes=3984554, num_edges=48796703,
      ndata_schemes={'id': Scheme(shape=(), dtype=torch.int64), 'neg_weight': Scheme(shape=(), dtype=torch.float32)}
      edata_schemes={})
test cnt 2439835


In [4]:
batch_sampler = BatchSampler(g, conf.fanouts, conf.hard_neg_ratio)    

dataloader = DataLoader(torch.arange(g.number_of_nodes()), batch_size=conf.batch_size, collate_fn=batch_sampler.collate_fn_train, num_workers=conf.num_workers)
dataloader_test = DataLoader(test_data, batch_size=conf.batch_size_test, collate_fn=batch_sampler.collate_fn_test, num_workers=conf.num_workers)
    

In [5]:
for pos_graph, neg_graph, blocks in dataloader:
    break
pos_graph, neg_graph, blocks

(Graph(num_nodes=2968, num_edges=1012,
       ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
       edata_schemes={}),
 Graph(num_nodes=2968, num_edges=1012,
       ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
       edata_schemes={}),
 [Block(num_src_nodes=250844, num_dst_nodes=25041, num_edges=498383),
  Block(num_src_nodes=25041, num_dst_nodes=2968, num_edges=28831)])

In [6]:
model = SAGEModel(g, conf.feat_dim_dict).to(conf.device)
for i in range(len(blocks)):
    blocks[i] = blocks[i].to(conf.device)
pos_graph = pos_graph.to(conf.device)
neg_graph = neg_graph.to(conf.device)

pos_score, neg_score = model(pos_graph, neg_graph, blocks)
loss = (neg_score - pos_score + 1).clamp(min=0).mean()
loss

node feat neg_weight not exist


tensor(1.0065, device='cuda:0', grad_fn=<MeanBackward0>)

In [8]:
model.get_node_emb(blocks)

tensor([[-0.1546,  0.0800, -0.1331,  ...,  0.1527, -0.0207,  0.0440],
        [-0.2103, -0.0127, -0.2501,  ...,  0.3298,  0.0597,  0.0445],
        [-0.1576,  0.0217, -0.1581,  ...,  0.3003,  0.0330,  0.1513],
        ...,
        [-0.2657,  0.0674, -0.2567,  ...,  0.2201,  0.1649,  0.0973],
        [-0.0863, -0.0223, -0.0319,  ...,  0.1964, -0.0589,  0.1184],
        [-0.2429,  0.1176, -0.2359,  ...,  0.1478,  0.0416,  0.0847]],
       device='cuda:0', grad_fn=<DivBackward0>)

# faiss

In [1]:
import pickle
import redis
import faiss
from collections import Counter
from sage_ns import Config, build_graph, train_test_split, save_to_redis


conf = Config()
conf.__dict__

{'input_path': '/data/zsj/sage/input/friend.txt',
 'output_path': '/data/zsj/sage/output/user.pkl',
 'device': device(type='cuda', index=0),
 'test_ratio': 0.05,
 'hard_neg_ratio': 0.1,
 'batch_size': 1024,
 'batch_size_test': 4096,
 'batch_size_export': 256,
 'fanouts': [10, 20],
 'num_workers': 2,
 'feat_dim_dict': {'id': 64},
 'lr': 0.005,
 'num_epochs': 5,
 'top_k': 200,
 'redisClient': Redis<ConnectionPool<Connection<host=sg-proxy01.starmaker.co,port=22122,db=0>>>,
 'redisPrefix': 'friend:friend_sage_ns:',
 'redisTTL': 259200}

In [2]:
with open('/data/zsj/sage/output/user.pkl', 'rb') as f:
    h_item, user_ids = pickle.load(f)

In [3]:
%%time
save_to_redis(h_item, user_ids, conf)

batch 0 done...
batch 1 done...
batch 2 done...
batch 3 done...
batch 4 done...
batch 5 done...
batch 6 done...
batch 7 done...
batch 8 done...
batch 9 done...
batch 10 done...
batch 11 done...
batch 12 done...
batch 13 done...
batch 14 done...
batch 15 done...
batch 16 done...
batch 17 done...
batch 18 done...
batch 19 done...
batch 20 done...
batch 21 done...
batch 22 done...
batch 23 done...
batch 24 done...
batch 25 done...
batch 26 done...
batch 27 done...
batch 28 done...
batch 29 done...
batch 30 done...
batch 31 done...
batch 32 done...
batch 33 done...
batch 34 done...
batch 35 done...
batch 36 done...
batch 37 done...
batch 38 done...
batch 39 done...
batch 40 done...
batch 41 done...
batch 42 done...
batch 43 done...
batch 44 done...
batch 45 done...
batch 46 done...
batch 47 done...
batch 48 done...
batch 49 done...
batch 50 done...
batch 51 done...
batch 52 done...
batch 53 done...
batch 54 done...
batch 55 done...
batch 56 done...
batch 57 done...
batch 58 done...
batch 5

KeyboardInterrupt: 

In [None]:
pipeline = conf.redisClient.pipeline(transaction=False)
D, I = index_with_id_gpu.search(h_item, conf.top_k)
for i in range(len(user_ids)):
    uid = user_ids[i]       
    redisKey = conf.redisPrefix + str(uid)
    sim_res = dict(zip(I[i].tolist(), D[i].tolist()))
    redisVal = json.dumps(sim_res)
#         print(redisKey, redisVal)
    pipeline.set(redisKey, redisVal)
    pipeline.expire(redisKey, conf.redisTTL)
pipeline.execute()

In [4]:
user_ids[30005]

1688849867053591

In [5]:
len(user_ids)

3984554

In [4]:
batch_size = 10000
for i in range(len(user_ids)//batch_size + 1):
    h_item_batch = h_item[batch_size*i:batch_size*(i+1)]
    user_ids_batch = user_ids[batch_size*i:batch_size*(i+1)]
    print(h_item_batch.shape, user_ids_batch.shape)

(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (10000,)
(10000, 64) (

In [12]:
D, I = index_with_id_gpu.search(h_item_batch, conf.top_k)

In [None]:
for j in range(len(user_ids_batch)):
    uid = user_ids_batch[j]       
    redisKey = conf.redisPrefix + str(uid)
    sim_res = dict(zip(I[j].tolist(), D[j].tolist()))
    redisVal = json.dumps(sim_res)
#         print(redisKey, redisVal)
    pipeline.set(redisKey, redisVal)
    pipeline.expire(redisKey, conf.redisTTL)

In [21]:
for j in range(len(user_ids_batch)):
    uid = user_ids_batch[j]       
    redisKey = conf.redisPrefix + str(uid)
#     sim_res = dict(zip(I[j].tolist(), D[j].tolist()))
#     redisVal = json.dumps(sim_res)
    print(redisKey, I[j])

friend:friend_sage_ns:5629499488157609 [ 5629499488157609 12947848930263481   281474980778350  5629499489094995
 12384898983518446 10696049123296005 10977524096141109  5629499488146030
 10696049118810399 12384898984572178 10977524100972876  2533274798033755
 10133099171629443 10696049123843193 11821949031231359 12666373958999710
 11821949030896950 12947848937887898 11821949031380134 12103424007082446
   562949960047261  2251799815525218  8725724282270941  7599824374704364
  3377699727847689 12666373961492718  7318349401773651 10414574147840518
  2251799819277135 12666373955753067 10133099171665758   844424933542178
  5066549355075319  1125899914619764 11821949028952226  1688849867496307
   562949958363177  1970324844739160  7318349397700734  7318349397713478
 11540474050115845  8162774324854819  7318349402017707  5629499492226516
 12384898980742023 12384898984709507  3659174701704909  8162774325998041
  3940649681763753  1407374883748872  3096224747985505 12947848937173951
  8162774327

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [9]:
# pipeline = conf.redisClient.pipeline(transaction=False)
batch_size = 10000
for i in range(len(user_ids)//batch_size + 1):
    h_item_batch = h_item[batch_size*i:batch_size*(i+1)]
    user_ids_batch = user_ids[batch_size*i:batch_size*(i+1)]
    print(len(h_item_batch), len(user_ids_batch))
#     break
    
#     D, I = index_with_id_gpu.search(h_item_batch, conf.top_k)
#     for j in range(len(user_ids_batch)):
#         uid = user_ids_batch[i]       
#         redisKey = conf.redisPrefix + str(uid)
#         sim_res = dict(zip(I[i].tolist(), D[i].tolist()))
#         redisVal = json.dumps(sim_res)
#         print(redisKey, redisVal)
#         break
#         pipeline.set(redisKey, redisVal)
#         pipeline.expire(redisKey, conf.redisTTL)
#     break
#     pipeline.execute()

10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
10000 10000
1000

In [None]:
def batch_insert_vec(vecs, ids, lc):
    batch_size = 100000
    for i in range(len(vecs)//batch_size + 1):
#         print([batch_size*i , batch_size*(i+1)])
        status, _ = milvus_client.insert(collection_name=collection_name, records=vecs[batch_size*i : batch_size*(i+1)], ids=ids[batch_size*i : batch_size*(i+1)], partition_tag=lc)
        if not status.OK():
            alert("[related news embedding] insert vec failed: {}".format(status))
            raise Exception

    logging.info("%s totally insert %d news vecs.", lc, len(vecs))

In [11]:
emb_size = h_item.shape[1]
resource = faiss.StandardGpuResources()
index = faiss.IndexFlatIP(emb_size)
index_with_id = faiss.IndexIDMap(index)
index_with_id_gpu = faiss.index_cpu_to_gpu(resource, 0, index_with_id)
index_with_id_gpu.add_with_ids(h_item, user_ids)

In [21]:
%%time
D, I = index_with_id_gpu.search(h_item, 200)

CPU times: user 11.8 s, sys: 18.4 s, total: 30.2 s
Wall time: 30 s


In [None]:
pipeline = conf.redisClient.pipeline(transaction=False)
D, I = index_with_id_gpu.search(h_item, conf.top_k)
print(I[1], I[1000], I[1000000])

for i in range(len(user_ids)):
    uid = user_ids[i]       
    redisKey = conf.redisPrefix + str(uid)
    sim_res = dict(zip(I[i].tolist(), D[i].tolist()))
    redisVal = json.dumps(sim_res)
#         print(redisKey, redisVal)
    pipeline.set(redisKey, redisVal)
    pipeline.expire(redisKey, conf.redisTTL)
pipeline.execute()

In [47]:
for i in range(len(user_ids)):
    if user_ids[i] == 11540474051698622:
        print(i)
        break

750725


In [48]:
user_ids[750725]

11540474051698622

In [52]:
num = 750725
index_with_id_gpu.search(h_item[num:num+1], conf.top_k)[1]

array([[11540474051698622, 11821949029961478,  5066549358342054,
        12666373958405600, 10696049124911695, 12384898984786761,
         5629499490867862, 12666373961523150,  7036874425310400,
        10414574147772007,  8162774332274792, 10977524099409245,
         7036874424940204,  6755399447499890, 10414574139322502,
        12103424005897148,  1970324840815512, 12947848937320679,
        11540474051577209,  8725724281426732,  4222124658527845,
        10977524093248467,  5910973794480639,   281474983881918,
         3096224746205684,  2251799818546097, 13229323905667667,
          562949954093741,  2533274798241246,  8162774328468113,
         7036874421473845, 10696049118996951,  2251799820718194,
         3377699723524161, 12384898984778694,  5910973794089727,
         3659174704146998, 11821949028084309,  7599824374740374,
         2251799819441801,  1407374883672619,  1970324840215588,
         8725724284140093,  2251799819650295, 12103424007703542,
        10696049121709490