-
Notifications
You must be signed in to change notification settings - Fork 1
/
Model
28 lines (23 loc) · 945 Bytes
/
Model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch import nn,optim
from torch.nn import functional as F
from torch_geometric.nn import GATConv
#### "" Model architecture"" ####
class GAT_network(torch.nn.Module):
def __init__(self):
super(GAT_network,self).__init__()
self.conv1=GATConv(58,96,heads=1,negative_slope=0.2)
self.conv2=GATConv(96,48,heads=1,negative_slope=0.2)
self.l1=nn.Linear(40*48,256)
self.l2=nn.Linear(256,64)
self.l3=nn.Linear(64,1)
def forward(self,data):
x=data.x
edge_index=data.edge_index
x=F.dropout(F.relu(self.conv1(x,edge_index)),training=self.training)
x=F.relu(self.conv2(x,edge_index))
x=x.view(int(len(x)/40),-1) ### "" flatenning the data"" ###
x=F.dropout(F.relu(self.l1(x)),training=self.training)
x=F.dropout(F.relu(self.l2(x)),training=self.training)
x=self.l3(x)
return x