In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader

class GATLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0, **kwargs):
        super(GATLayer, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
        self.att = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.xavier_uniform_(self.att)

    def forward(self, x, edge_index):
        x = torch.matmul(x, self.weight).view(-1, self.heads, self.out_channels)
        return self.propagate(edge_index, x=x)

    def message(self, edge_index_i, x_i, x_j):
        alpha = (torch.cat([x_i, x_j - x_i], dim=-1) * self.att).sum(dim=-1)
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = F.softmax(alpha, dim=1) 
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        return x_j * alpha.view(-1, self.heads, 1)

    def update(self, aggr_out):
        if self.concat:
            return aggr_out.view(-1, self.heads * self.out_channels)
        else:
            return aggr_out.mean(dim=1)

class GATModel(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, num_layers=3, num_classes=121, heads=1, dropout=0.5):
        super(GATModel, self).__init__()

        self.layers = nn.ModuleList()
        self.layers.append(GATLayer(in_channels, hidden_dim, heads=heads, concat=True, dropout=dropout))
        for _ in range(num_layers - 2):
            self.layers.append(GATLayer(heads * hidden_dim, hidden_dim, heads=heads, concat=True, dropout=dropout))
        self.layers.append(GATLayer(heads * hidden_dim, out_channels, heads=1, concat=False, dropout=dropout))

        self.classifier = nn.Linear(out_channels, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for layer in self.layers:
            x = F.elu(layer(x, edge_index))
        # x = torch.mean(x, dim=0)  # Global pooling
        return x

def weighted_cross_entropy(logits, labels, class_weights):
    ce_loss = F.cross_entropy(logits, labels, reduction='none')
    weighted_ce_loss = ce_loss * class_weights[labels]
    return torch.mean(weighted_ce_loss)

path = 'C:/Users/Satvik/Desktop/gnn'

train_dataset = PPI(path, split='train')
test_dataset = PPI(path, split='test')

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

model = GATModel(in_channels=50, hidden_dim=256, out_channels=256, num_layers=3, num_classes=121, heads=1, dropout=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

def train(model, optimizer, train_loader, epochs=100):
    model.train()
    class_weights = torch.ones(121) 
    for epoch in range(epochs):
        for data in train_loader:
            optimizer.zero_grad()
            out = model(data)
            batch_target = data.y[data.batch].argmax(dim=-1) 
            loss = weighted_cross_entropy(out, batch_target, class_weights)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

train(model, optimizer, train_loader, epochs=50)

Epoch 1, Loss: 0.27302277088165283
Epoch 2, Loss: 0.25745782256126404
Epoch 3, Loss: 0.30528122186660767
Epoch 4, Loss: 0.3001825213432312
Epoch 5, Loss: 0.2424461990594864
Epoch 6, Loss: 0.3024771213531494
Epoch 7, Loss: 0.22772887349128723
Epoch 8, Loss: 0.29497867822647095
Epoch 9, Loss: 0.5283387303352356
Epoch 10, Loss: 7.728038311004639
Epoch 11, Loss: 0.24155022203922272
Epoch 12, Loss: 0.31539788842201233
Epoch 13, Loss: 0.24897168576717377
Epoch 14, Loss: 0.2713432013988495
Epoch 15, Loss: 0.38756266236305237
Epoch 16, Loss: 0.2740943431854248
Epoch 17, Loss: 0.23190602660179138
Epoch 18, Loss: 0.30856752395629883
Epoch 19, Loss: 0.2556227445602417
Epoch 20, Loss: 0.25707709789276123
Epoch 21, Loss: 0.2275291383266449
Epoch 22, Loss: 0.27755415439605713
Epoch 23, Loss: 0.2987552285194397
Epoch 24, Loss: 0.4721508026123047
Epoch 25, Loss: 0.24020107090473175
Epoch 26, Loss: 1.1369359493255615
Epoch 27, Loss: 0.26932552456855774
Epoch 28, Loss: 0.26065313816070557
Epoch 29, Loss

In [2]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            out = model(data)
            pred = out.argmax(dim=1) 
            target = data.y[data.batch].argmax(dim=-1)  
            correct += (pred == target).sum().item()
            total += len(pred)

    accuracy = correct / total
    print(f'Test Accuracy: {accuracy}')

test(model, test_loader)

Test Accuracy: 0.9994569152787834
