/
MVGRL_graph.py
145 lines (119 loc) · 4.69 KB
/
MVGRL_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import os.path as osp
import GCL.losses as L
import GCL.augmentors as A
from torch import nn
from tqdm import tqdm
from torch.optim import Adam
from GCL.eval import get_split, SVMEvaluator
from GCL.models import DualBranchContrast
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
class GConv(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers):
super(GConv, self).__init__()
self.layers = nn.ModuleList()
self.activation = nn.PReLU(hidden_dim)
for i in range(num_layers):
if i == 0:
self.layers.append(GCNConv(input_dim, hidden_dim))
else:
self.layers.append(GCNConv(hidden_dim, hidden_dim))
def forward(self, x, edge_index, batch):
z = x
zs = []
for conv in self.layers:
z = conv(z, edge_index)
z = self.activation(z)
zs.append(z)
gs = [global_add_pool(z, batch) for z in zs]
g = torch.cat(gs, dim=1)
return z, g
class FC(nn.Module):
def __init__(self, input_dim, output_dim):
super(FC, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.ReLU(),
nn.Linear(output_dim, output_dim),
nn.ReLU(),
nn.Linear(output_dim, output_dim),
nn.ReLU()
)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.fc(x) + self.linear(x)
class Encoder(torch.nn.Module):
def __init__(self, gcn1, gcn2, mlp1, mlp2, aug1, aug2):
super(Encoder, self).__init__()
self.gcn1 = gcn1
self.gcn2 = gcn2
self.mlp1 = mlp1
self.mlp2 = mlp2
self.aug1 = aug1
self.aug2 = aug2
def forward(self, x, edge_index, batch):
x1, edge_index1, edge_weight1 = self.aug1(x, edge_index)
x2, edge_index2, edge_weight2 = self.aug2(x, edge_index)
z1, g1 = self.gcn1(x1, edge_index1, batch)
z2, g2 = self.gcn2(x2, edge_index2, batch)
h1, h2 = [self.mlp1(h) for h in [z1, z2]]
g1, g2 = [self.mlp2(g) for g in [g1, g2]]
return h1, h2, g1, g2
def train(encoder_model, contrast_model, dataloader, optimizer):
encoder_model.train()
epoch_loss = 0
for data in dataloader:
data = data.to('cuda')
optimizer.zero_grad()
if data.x is None:
num_nodes = data.batch.size(0)
data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
h1, h2, g1, g2 = encoder_model(data.x, data.edge_index, data.batch)
loss = contrast_model(h1=h1, h2=h2, g1=g1, g2=g2, batch=data.batch)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
return epoch_loss
def test(encoder_model, dataloader):
encoder_model.eval()
x = []
y = []
for data in dataloader:
data = data.to('cuda')
if data.x is None:
num_nodes = data.batch.size(0)
data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
_, _, g1, g2 = encoder_model(data.x, data.edge_index, data.batch)
x.append(g1 + g2)
y.append(data.y)
x = torch.cat(x, dim=0)
y = torch.cat(y, dim=0)
split = get_split(num_samples=x.size()[0], train_ratio=0.8, test_ratio=0.1)
result = SVMEvaluator(linear=True)(x, y, split)
return result
def main():
device = torch.device('cuda')
path = osp.join(osp.expanduser('~'), 'datasets')
dataset = TUDataset(path, name='PTC_MR')
dataloader = DataLoader(dataset, batch_size=128)
input_dim = max(dataset.num_features, 1)
aug1 = A.Identity()
aug2 = A.PPRDiffusion(alpha=0.2, use_cache=False)
gcn1 = GConv(input_dim=input_dim, hidden_dim=512, num_layers=2).to(device)
gcn2 = GConv(input_dim=input_dim, hidden_dim=512, num_layers=2).to(device)
mlp1 = FC(input_dim=512, output_dim=512)
mlp2 = FC(input_dim=512 * 2, output_dim=512)
encoder_model = Encoder(gcn1=gcn1, gcn2=gcn2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2).to(device)
contrast_model = DualBranchContrast(loss=L.JSD(), mode='G2L').to(device)
optimizer = Adam(encoder_model.parameters(), lr=0.01)
with tqdm(total=100, desc='(T)') as pbar:
for epoch in range(1, 101):
loss = train(encoder_model, contrast_model, dataloader, optimizer)
pbar.set_postfix({'loss': loss})
pbar.update()
test_result = test(encoder_model, dataloader)
print(f'(E): Best test F1Mi={test_result["micro_f1"]:.4f}, F1Ma={test_result["macro_f1"]:.4f}')
if __name__ == '__main__':
main()