-
Notifications
You must be signed in to change notification settings - Fork 103
/
gsimclr.py
265 lines (201 loc) · 8.12 KB
/
gsimclr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import os.path as osp
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
# from core.encoders import *
# from torch_geometric.datasets import TUDataset
from aug import TUDataset_aug as TUDataset
from torch_geometric.data import DataLoader
import sys
import json
from torch import optim
from cortex_DIM.nn_modules.mi_networks import MIFCNet, MI1x1ConvNet
from losses import *
from gin import Encoder
from evaluate_embedding import evaluate_embedding
from model import *
from arguments import arg_parse
from torch_geometric.transforms import Constant
import pdb
class GcnInfomax(nn.Module):
def __init__(self, hidden_dim, num_gc_layers, alpha=0.5, beta=1., gamma=.1):
super(GcnInfomax, self).__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.prior = args.prior
self.embedding_dim = mi_units = hidden_dim * num_gc_layers
self.encoder = Encoder(dataset_num_features, hidden_dim, num_gc_layers)
self.local_d = FF(self.embedding_dim)
self.global_d = FF(self.embedding_dim)
# self.local_d = MI1x1ConvNet(self.embedding_dim, mi_units)
# self.global_d = MIFCNet(self.embedding_dim, mi_units)
if self.prior:
self.prior_d = PriorDiscriminator(self.embedding_dim)
self.init_emb()
def init_emb(self):
initrange = -1.5 / self.embedding_dim
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self, x, edge_index, batch, num_graphs):
# batch_size = data.num_graphs
if x is None:
x = torch.ones(batch.shape[0]).to(device)
y, M = self.encoder(x, edge_index, batch)
g_enc = self.global_d(y)
l_enc = self.local_d(M)
mode='fd'
measure='JSD'
local_global_loss = local_global_loss_(l_enc, g_enc, edge_index, batch, measure)
if self.prior:
prior = torch.rand_like(y)
term_a = torch.log(self.prior_d(prior)).mean()
term_b = torch.log(1.0 - self.prior_d(y)).mean()
PRIOR = - (term_a + term_b) * self.gamma
else:
PRIOR = 0
return local_global_loss + PRIOR
class simclr(nn.Module):
def __init__(self, hidden_dim, num_gc_layers, alpha=0.5, beta=1., gamma=.1):
super(simclr, self).__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.prior = args.prior
self.embedding_dim = mi_units = hidden_dim * num_gc_layers
self.encoder = Encoder(dataset_num_features, hidden_dim, num_gc_layers)
self.proj_head = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim), nn.ReLU(inplace=True), nn.Linear(self.embedding_dim, self.embedding_dim))
self.init_emb()
def init_emb(self):
initrange = -1.5 / self.embedding_dim
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self, x, edge_index, batch, num_graphs):
# batch_size = data.num_graphs
if x is None:
x = torch.ones(batch.shape[0]).to(device)
y, M = self.encoder(x, edge_index, batch)
y = self.proj_head(y)
return y
def loss_cal(self, x, x_aug):
T = 0.2
batch_size, _ = x.size()
x_abs = x.norm(dim=1)
x_aug_abs = x_aug.norm(dim=1)
sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs)
sim_matrix = torch.exp(sim_matrix / T)
pos_sim = sim_matrix[range(batch_size), range(batch_size)]
loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
loss = - torch.log(loss).mean()
return loss
import random
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
np.random.seed(seed)
random.seed(seed)
if __name__ == '__main__':
args = arg_parse()
setup_seed(args.seed)
accuracies = {'val':[], 'test':[]}
epochs = 20
log_interval = 10
batch_size = 128
# batch_size = 512
lr = args.lr
DS = args.DS
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', DS)
# kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
dataset = TUDataset(path, name=DS, aug=args.aug).shuffle()
dataset_eval = TUDataset(path, name=DS, aug='none').shuffle()
print(len(dataset))
print(dataset.get_num_feature())
try:
dataset_num_features = dataset.get_num_feature()
except:
dataset_num_features = 1
dataloader = DataLoader(dataset, batch_size=batch_size)
dataloader_eval = DataLoader(dataset_eval, batch_size=batch_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = simclr(args.hidden_dim, args.num_gc_layers).to(device)
# print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
print('================')
print('lr: {}'.format(lr))
print('num_features: {}'.format(dataset_num_features))
print('hidden_dim: {}'.format(args.hidden_dim))
print('num_gc_layers: {}'.format(args.num_gc_layers))
print('================')
model.eval()
emb, y = model.encoder.get_embeddings(dataloader_eval)
# print(emb.shape, y.shape)
"""
acc_val, acc = evaluate_embedding(emb, y)
accuracies['val'].append(acc_val)
accuracies['test'].append(acc)
"""
for epoch in range(1, epochs+1):
loss_all = 0
model.train()
for data in dataloader:
# print('start')
data, data_aug = data
optimizer.zero_grad()
node_num, _ = data.x.size()
data = data.to(device)
x = model(data.x, data.edge_index, data.batch, data.num_graphs)
if args.aug == 'dnodes' or args.aug == 'subgraph' or args.aug == 'random2' or args.aug == 'random3' or args.aug == 'random4':
# node_num_aug, _ = data_aug.x.size()
edge_idx = data_aug.edge_index.numpy()
_, edge_num = edge_idx.shape
idx_not_missing = [n for n in range(node_num) if (n in edge_idx[0] or n in edge_idx[1])]
node_num_aug = len(idx_not_missing)
data_aug.x = data_aug.x[idx_not_missing]
data_aug.batch = data.batch[idx_not_missing]
idx_dict = {idx_not_missing[n]:n for n in range(node_num_aug)}
edge_idx = [[idx_dict[edge_idx[0, n]], idx_dict[edge_idx[1, n]]] for n in range(edge_num) if not edge_idx[0, n] == edge_idx[1, n]]
data_aug.edge_index = torch.tensor(edge_idx).transpose_(0, 1)
data_aug = data_aug.to(device)
'''
print(data.edge_index)
print(data.edge_index.size())
print(data_aug.edge_index)
print(data_aug.edge_index.size())
print(data.x.size())
print(data_aug.x.size())
print(data.batch.size())
print(data_aug.batch.size())
pdb.set_trace()
'''
x_aug = model(data_aug.x, data_aug.edge_index, data_aug.batch, data_aug.num_graphs)
# print(x)
# print(x_aug)
loss = model.loss_cal(x, x_aug)
print(loss)
loss_all += loss.item() * data.num_graphs
loss.backward()
optimizer.step()
# print('batch')
print('Epoch {}, Loss {}'.format(epoch, loss_all / len(dataloader)))
if epoch % log_interval == 0:
model.eval()
emb, y = model.encoder.get_embeddings(dataloader_eval)
acc_val, acc = evaluate_embedding(emb, y)
accuracies['val'].append(acc_val)
accuracies['test'].append(acc)
# print(accuracies['val'][-1], accuracies['test'][-1])
tpe = ('local' if args.local else '') + ('prior' if args.prior else '')
with open('logs/log_' + args.DS + '_' + args.aug, 'a+') as f:
s = json.dumps(accuracies)
f.write('{},{},{},{},{},{},{}\n'.format(args.DS, tpe, args.num_gc_layers, epochs, log_interval, lr, s))
f.write('\n')