-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
143 lines (115 loc) · 4.64 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
from torch import nn
import torch.optim as optim
from torch_geometric.utils import from_networkx
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.nn.conv import MessagePassing
from torch_sparse import SparseTensor, matmul
import torch.nn.functional as F
class MetropolisConv(MessagePassing):
def __init__(self):
super(MetropolisConv, self).__init__(aggr='add') # "Add" aggregation.
def forward(self, x, pyg_data):
(B, N, D)=x.shape
out = self.propagate(x=x.view(-1,D), edge_index=pyg_data.edge_weight, node_dim=-1)
return out.view(B,N,D)
def message_and_aggregate(self, adj_t, x):
return matmul(adj_t, x, reduce=self.aggr)
class Net_PGEXTRA(torch.nn.Module):
def __init__(self, step_size, num_layers):
super(Net_PGEXTRA, self).__init__()
self.step_size = nn.Parameter(torch.ones(num_layers)*step_size)
self.lam = nn.Parameter(torch.ones(num_layers)*step_size*10)
self.num_layers = num_layers
self.conv=MetropolisConv()
def tgrad_qp(self, A, b, x):
# A: nodes * k * n
# X: nodes * n
# Y: nodes * k
'''grad_A = np.zeros(x.shape)
for i in range(x.shape[0]):
grad_A[i] = A[i].T @ (A[i] @ x[i] - b[i])
return grad_A'''
x_ = torch.unsqueeze(x, axis = -1)
b_ = torch.unsqueeze(b, axis = -1)
A_t = A.transpose(2,3)
grad_A = A_t @ (A @ x_ - b_)
#print(A.shape, x.shape, b.shape)
#print(grad_A.shape)
grad_A = torch.squeeze(grad_A, axis = -1)
#print(grad_A.shape)
return grad_A
def act(self, x, ii):
tau = self.lam[ii] #* self.step_size[ii]
return F.relu(x - tau) - F.relu( - x - tau)
def forward(self, W, A, b,pyg_data, max_iter):
(batch_size, num_of_nodes, _, dim) = A.shape
init_x = torch.zeros((batch_size, num_of_nodes, dim))
ret_z = []
k = 1
x_0 = init_x
x_12 = self.conv(x_0,pyg_data) - self.step_size[0] * self.tgrad_qp(A, b, x_0)
x_1 = self.act(x_12, 0)
x_hist = [init_x,x_1]
while (k < max_iter):
x_32 = self.conv(x_1,pyg_data) + x_12 - (self.conv(x_0,pyg_data) + x_0)/2 - \
self.step_size[k] * (self.tgrad_qp(A, b, x_1)-self.tgrad_qp(A, b, x_0))
x_2 = self.act(x_32, k)
ret_z.append(x_2)
x_0 = x_1
x_1 = x_2
x_12 = x_32
k = k + 1
x_hist.append(x_2)
ret_z = torch.stack(ret_z)
return ret_z, x_2,x_hist
class Net_Prox_DGD(torch.nn.Module):
def __init__(self, step_size, num_layers):
super(Net_Prox_DGD, self).__init__()
self.step_size = nn.Parameter(torch.ones(num_layers)*step_size)
self.lam = nn.Parameter(torch.ones(num_layers)*step_size*10)
self.num_layers = num_layers
self.conv=MetropolisConv()
def tgrad_qp(self, A, b, x):
# A: nodes * k * n
# X: nodes * n
# Y: nodes * k
'''grad_A = np.zeros(x.shape)
for i in range(x.shape[0]):
grad_A[i] = A[i].T @ (A[i] @ x[i] - b[i])
return grad_A'''
x_ = torch.unsqueeze(x, axis = -1)
b_ = torch.unsqueeze(b, axis = -1)
A_t = A.transpose(2,3)
grad_A = A_t @ (A @ x_ - b_)
#print(A.shape, x.shape, b.shape)
#print(grad_A.shape)
grad_A = torch.squeeze(grad_A, axis = -1)
#print(grad_A.shape)
return grad_A
def act(self, x, ii):
tau = self.lam[ii] #* self.step_size[ii]
return F.relu(x - tau) - F.relu( - x - tau)
def forward(self, W, A, b,pyg_data, max_iter):
(batch_size, num_of_nodes, _, dim) = A.shape
init_x = torch.zeros((batch_size, num_of_nodes, dim))
ret_z = []
k = 1
x_0 = init_x
x_12 = self.conv(x_0,pyg_data) - self.step_size[0] * self.tgrad_qp(A, b, x_0)
x_1 = self.act(x_12, 0)
x_hist = [init_x,x_1]
while (k < max_iter):
#x_32 = self.conv(x_1,pyg_data) + x_12 - (self.conv(x_0,pyg_data) + x_0)/2 - \
# self.step_size[k] * (self.tgrad_qp(A, b, x_1)-self.tgrad_qp(A, b, x_0))
x_32 = self.conv(x_1,pyg_data) - self.step_size[k] * self.tgrad_qp(A, b, x_1)
x_2 = self.act(x_32, k)
ret_z.append(x_2)
x_0 = x_1
x_1 = x_2
x_12 = x_32
k = k + 1
x_hist.append(x_2)
ret_z = torch.stack(ret_z)
return ret_z, x_2,x_hist