Skip to content

Commit

Permalink
data fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
RexYing committed May 12, 2018
1 parent a64dfc3 commit b9d2b60
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 16 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,8 @@ runs/

data/REDDIT-MULTI-12K/
data/COLLAB
data/PROTEINS
data/PROTEINS_full
data/NCI1


38 changes: 30 additions & 8 deletions gen/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ def gen_ba(n_range, m_range, num_graphs, feature_generator=None):
feature_generator.gen_node_features(G)
return graphs

def gen_er(n_range, p, num_graphs, feature_generator=None):
graphs = []
for i in np.random.choice(n_range, num_graphs):
graphs.append(nx.erdos_renyi_graph(i,p))

if feature_generator is None:
feature_generator = ConstFeatureGen(0)
for G in graphs:
feature_generator.gen_node_features(G)
return graphs

def gen_2community_ba(n_range, m_range, num_graphs, inter_prob, feature_generators):
''' Each community is a BA graph.
Args:
Expand Down Expand Up @@ -53,7 +64,7 @@ def gen_2community_ba(n_range, m_range, num_graphs, inter_prob, feature_generato
graphs.append(G)
return graphs

def gen_2hier(num_graphs, num_clusters, n_range, m_range, inter_prob1, inter_prob2, feat_gen):
def gen_2hier(num_graphs, num_clusters, n, m_range, inter_prob1, inter_prob2, feat_gen):
''' Each community is a BA graph.
Args:
inter_prob1: probability of one node connecting to any node in the other community within
Expand All @@ -66,18 +77,29 @@ def gen_2hier(num_graphs, num_clusters, n_range, m_range, inter_prob1, inter_pro
for i in range(num_graphs):
clusters2 = []
for j in range(len(num_clusters)):
clusters = gen_ba(n_range, m_range, num_clusters[j], feat_gen[0])
clusters = gen_er(range(n, n+1), 0.5, num_clusters[j], feat_gen[0])
G = nx.disjoint_union_all(clusters)
for u1 in range(G.number_of_nodes()):
for u2 in range(G.number_of_nodes()):
if np.random.rand() < inter_prob1 and not G.has_edge(u1, u2):
G.add_edge(u1, u2)
if np.random.rand() < inter_prob1:
target = np.random.choice(G.number_of_nodes() - n)
# move one cluster after to make sure it's not an intra-cluster edge
if target // n >= u1 // n:
target += n
G.add_edge(u1, target)
clusters2.append(G)
G = nx.disjoint_union_all(clusters2)
cluster_sizes_cum = np.cumsum([cluster2.number_of_nodes() for cluster2 in clusters2])
curr_cluster = 0
for u1 in range(G.number_of_nodes()):
for u2 in range(G.number_of_nodes()):
if np.random.rand() < inter_prob2 and not G.has_edge(u1, u2):
G.add_edge(u1, u2)
if u1 >= cluster_sizes_cum[curr_cluster]:
curr_cluster += 1
if np.random.rand() < inter_prob2:
target = np.random.choice(G.number_of_nodes() -
clusters2[curr_cluster].number_of_nodes())
# move one cluster after to make sure it's not an intra-cluster edge
if curr_cluster == 0 or target >= cluster_sizes_cum[curr_cluster - 1]:
target += cluster_sizes_cum[curr_cluster]
G.add_edge(u1, target)
graphs.append(G)

return graphs
Expand Down
2 changes: 1 addition & 1 deletion graph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class GraphSampler(torch.utils.data.Dataset):
''' Sample graphs and nodes in graph
'''
def __init__(self, G_list, features='default', normalize=True):
def __init__(self, G_list, features='struct', normalize=True):
self.adj_all = []
self.len_all = []
self.feature_all = []
Expand Down
9 changes: 8 additions & 1 deletion load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,19 @@ def read_graphfile(datadir, dataname, max_nodes=None):
except IOError:
print('No node attributes')

label_has_zero = False
filename_graphs=prefix + '_graph_labels.txt'
graph_labels=[]
with open(filename_graphs) as f:
for line in f:
line=line.strip("\n")
graph_labels.append(int(line) - 1)
val = int(line)
if val == 0:
label_has_zero = True
graph_labels.append(val - 1)
graph_labels = np.array(graph_labels)
if label_has_zero:
graph_labels += 1

filename_adj=prefix + '_A.txt'
adj_list={i:[] for i in range(1,len(graph_labels)+1)}
Expand Down
12 changes: 6 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ def log_graph(adj, batch_num_nodes, writer, epoch, batch_idx):
fig = plt.figure(figsize=(8,6), dpi=300)

for i in range(len(batch_idx)):
plt.subplot(2, 2, i+1)
ax = plt.subplot(2, 2, i+1)
num_nodes = batch_num_nodes[batch_idx[i]]
adj_matrix = adj[batch_idx[i], :num_nodes, :num_nodes].cpu().data.numpy()
G = nx.from_numpy_matrix(adj_matrix)
nx.draw(G, pos=nx.spring_layout(G), with_labels=True, node_color='#336699',
edge_color='grey', width=0.5, node_size=300,
alpha=0.7)
ax.xaxis.set_visible(False)

plt.xticks([])
plt.tight_layout()
fig.canvas.draw()

Expand Down Expand Up @@ -294,9 +294,9 @@ def syn_community2hier(args, writer=None):

# data
feat_gen = [featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float))]
graphs1 = datagen.gen_2hier(1000, [2,4], range(40,50), range(4,5), 0.2, 0.05, feat_gen)
graphs2 = datagen.gen_2hier(1000, [3,3], range(40,50), range(4,5), 0.2, 0.05, feat_gen)
graphs3 = datagen.gen_2community_ba(range(120, 150), range(4,7), 1000, 0.25, feat_gen)
graphs1 = datagen.gen_2hier(1000, [2,4], 10, range(4,5), 0.1, 0.03, feat_gen)
graphs2 = datagen.gen_2hier(1000, [3,3], 10, range(4,5), 0.1, 0.03, feat_gen)
graphs3 = datagen.gen_2community_ba(range(28, 33), range(4,7), 1000, 0.25, feat_gen)

for G in graphs1:
G.graph['label'] = 0
Expand Down Expand Up @@ -458,7 +458,7 @@ def arg_parse():
lr=0.001,
clip=2.0,
batch_size=20,
num_epochs=70,
num_epochs=1000,
train_ratio=0.8,
test_ratio=0.1,
num_workers=1,
Expand Down

0 comments on commit b9d2b60

Please sign in to comment.