In [26]:
from tqdm import tqdm

import torch
from utils import GraphLoader
from model_4 import GammaModel

In [2]:
DVC = torch.device('cuda:0')

In [3]:
gl = GraphLoader('DataPreprocess/full_graph/interaction_graph_nf.bin')

In [4]:
hg, node_feats, edge2ids = gl.load_graph(device=DVC)

################# Basic Information of The Graph #################
Edge ('contributor', 'contributor_follow_contributor', 'contributor') 2286407
Edge ('contributor', 'contributor_propose_issue', 'issue') 692554
Edge ('contributor', 'contributor_propose_pr', 'pr') 379498
Edge ('contributor', 'contributor_star_repo', 'repository') 947423
Edge ('contributor', 'contributor_watch_repo', 'repository') 150292
Edge ('issue', 'issue_belong_to_repo', 'repository') 692554
Edge ('pr', 'pr_belong_to_repo', 'repository') 379498
Edge ('repository', 'repo_committed_by_contributor', 'contributor') 161241
Node contributor 394474
Node issue 692554
Node pr 379496
Node repository 50000
Total number of nodes 1516524
Total number of edges 5689467
################# End of the Graph Information  #################
{('contributor', 'contributor_follow_contributor', 'contributor'): 0, ('contributor', 'contributor_propose_issue', 'issue'): 1, ('contributor', 'contributor_propose_pr', 'pr'): 2, ('contributor', 'con

In [5]:
node_feat_dim_dict = {ntype: node_feats[ntype].shape[1] for ntype in hg.ntypes}
node_feat_dim_dict2 = {k: v for k, v in node_feat_dim_dict.items()}
EMB_FEATURES = 64
node_feat_dim_dict2['contributor'] = 2 * EMB_FEATURES
node_feat_dim_dict2['repository']  = 2 * EMB_FEATURES

In [6]:
g_inter = hg.edge_type_subgraph([
    'contributor_propose_issue',
    'contributor_propose_pr',
    'contributor_star_repo',
    'contributor_watch_repo',
    'issue_belong_to_repo',
    'pr_belong_to_repo',
    'repo_committed_by_contributor'
])
g_social = hg.edge_type_subgraph(['contributor_follow_contributor'])

In [7]:
from parameter_namespace import \
        GeneralParameterNamespace, \
        HetGATParameterNamespace, \
        HetGCNParameterNamespace

pn = GeneralParameterNamespace('gamma')
gat_pn = HetGATParameterNamespace()
gcn_pn = HetGCNParameterNamespace()

In [8]:
model = GammaModel(
    g_inter,
    g_social,
    node_feat_dim_dict,
    node_feat_dim_dict=node_feat_dim_dict,
    node_feat_dim_dict2=node_feat_dim_dict2
)

In [9]:
state_dict = torch.load('model.bin')
model.load_state_dict(state_dict)
model = model.to(DVC)

In [10]:
BATCH_SIZE = 512

In [11]:
x = model.m_inter\
    .inference(model.g_inter, node_feats, batch_size=BATCH_SIZE, device=DVC)

100%|██████████| 2962/2962 [00:30<00:00, 98.45it/s] 
100%|██████████| 2962/2962 [00:29<00:00, 100.09it/s]
100%|██████████| 2962/2962 [00:28<00:00, 102.80it/s]


In [12]:
x_ = {'contributor': x['contributor'].to(DVC) }

In [13]:
X_ = model.m_social\
    .inference(model.g_social, x_, batch_size=BATCH_SIZE, device=DVC)

100%|██████████| 771/771 [00:01<00:00, 408.43it/s]
100%|██████████| 771/771 [00:01<00:00, 402.36it/s]
100%|██████████| 771/771 [00:01<00:00, 407.84it/s]


In [14]:
X = X_['contributor']

In [15]:
u = model.u_head(node_feats['contributor'])
v = model.v_head(node_feats['repository'])

In [16]:
X = model.mock_x_head(X.to(DVC))

In [17]:
X, uu, vv = map(lambda it: it.unsqueeze(1), [X, u, v])

In [18]:
STEP = 1000
with torch.no_grad():
    f_uus = torch.cat([model.HUR(uu[i:i+STEP], X, X) 
                       for i in tqdm(range(0, uu.shape[0], STEP))])

  0%|          | 0/395 [00:00<?, ?it/s]

In [19]:
with torch.no_grad(): 
    e_uv = torch.cat([
        model.UVR(vv[i:i+STEP], f_uus, f_uus) 
        for i in tqdm(range(0, vv.shape[0], STEP))
    ]).squeeze(1)
    e_vu = torch.cat([
        model.UVR(f_uus[i:i+STEP], vv, vv)
        for i in tqdm(range(0, f_uus.shape[0], STEP))
    ]).squeeze(1)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/395 [00:00<?, ?it/s]

In [20]:
u2 = torch.cat((u, e_vu), dim=1)
v2 = torch.cat((v, e_uv), dim=1)

In [21]:
node_feats2 = {k: v for k, v in node_feats.items()}
node_feats2['contributor'] = u2
node_feats2['repository']  = v2

In [22]:
mdl = model.m_fin 

In [23]:
final = mdl.inference(model.g_inter, node_feats2, batch_size=BATCH_SIZE, device=DVC)

  0%|          | 0/2962 [00:00<?, ?it/s]

100%|██████████| 2962/2962 [00:31<00:00, 94.15it/s]
100%|██████████| 2962/2962 [00:31<00:00, 94.71it/s]
100%|██████████| 2962/2962 [00:31<00:00, 94.69it/s]


In [24]:
{k: v.shape for k, v in final.items()}

{'contributor': torch.Size([394474, 64]),
 'issue': torch.Size([692554, 64]),
 'pr': torch.Size([379496, 64]),
 'repository': torch.Size([50000, 64])}

In [25]:
torch.save(final, 'final_node_embedding.bin')