In [1]:
import torch
import torch.nn as nn
relu = nn.ReLU()

x = torch.tensor([-2.0,-1.0,0.0,1.0,2.0])
output = relu(x)
print(output)

tensor([0., 0., 0., 1., 2.])


In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import KarateClub
from torch_geometric.nn import GATConv
import matplotlib.pyplot as plt
import networkx as nx

In [3]:
# Load the dataset
dataset = KarateClub()
data = dataset[0]

In [4]:
# Define the GAT model
class GAT(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gat1 = GATConv(in_channels=34, out_channels=8,heads=2)
        self.gat2 = GATConv(in_channels=16, out_channels=4,heads=1)
        
    def forward(self, x, edge_index):
        #x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        #x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat2(x, edge_index)
        #return F.log_softmax(x, dim=1)
        return x

In [5]:
#help(torch.eye)
data.x = torch.eye(data.num_nodes)
data.train_mask = torch.zeros(data.num_nodes,dtype=torch.bool)
data.train_mask[[0,1,33,23]] = True

In [6]:
model = GAT()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

In [7]:
# Training function
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

In [8]:
# Test function
def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    correct = (pred == data.y).sum()
    acc = int(correct) / data.num_nodes
    return acc,pred

In [9]:
# Train the model
for epoch in range(201):
    loss = train()
    acc,pred = test()
    if epoch % 20 == 0:
        print(f'Epoch:{epoch:30d},Loss:{loss:.3f},Accuracy:{acc:.4f}')

Epoch:                             0,Loss:-0.045,Accuracy:0.7353
Epoch:                            20,Loss:-1.115,Accuracy:0.6471
Epoch:                            40,Loss:-2.808,Accuracy:0.6471
Epoch:                            60,Loss:-5.517,Accuracy:0.6765
Epoch:                            80,Loss:-9.337,Accuracy:0.6471
Epoch:                           100,Loss:-14.342,Accuracy:0.6471
Epoch:                           120,Loss:-20.867,Accuracy:0.6471
Epoch:                           140,Loss:-28.887,Accuracy:0.6176
Epoch:                           160,Loss:-38.177,Accuracy:0.6176
Epoch:                           180,Loss:-48.586,Accuracy:0.5882
Epoch:                           200,Loss:-60.089,Accuracy:0.5882


In [20]:
print(type(data))

<class 'torch_geometric.data.data.Data'>


In [24]:
Gobj = nx.Graph(data)
Gobj.number_of_nodes()

8

In [None]:
nx.draw(Gobj,with_labels=True)