In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import dgl
from dgl.nn.pytorch import edge_softmax
import dgl.function as fn
import dgl.data
from dgl.nn.pytorch import GATConv

def l0_train(logAlpha, beta=0.66, gamma=-0.1, zeta=1.1, eps=1e-20):
    U = torch.rand(logAlpha.size()).type_as(logAlpha) + eps
    s = torch.sigmoid((torch.log(U / (1 - U)) + logAlpha) / beta)
    s_bar = s * (zeta - gamma) + gamma
    mask = F.hardtanh(s_bar, 0, 1)
    return mask

def l0_test(logAlpha, beta=0.66, gamma=-0.1, zeta=1.1):
    s = torch.sigmoid(logAlpha / beta)
    s_bar = s * (zeta - gamma) + gamma
    mask = F.hardtanh(s_bar, 0, 1)
    return mask

# Merged GAT Layer
class MergedGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, dropout, alpha, bias_l0, residual=False):
        super(MergedGATLayer, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.leaky_relu = nn.LeakyReLU(alpha)

        self.residual = residual
        if self.residual and in_dim != (out_dim * num_heads):
            self.res_fc = nn.Linear(in_dim, out_dim * num_heads, bias=False)
        else:
            self.res_fc = None

        self.lin_l = nn.Linear(in_dim, num_heads * out_dim, bias=False)
        self.lin_r = nn.Linear(in_dim, num_heads * out_dim, bias=False)
        
        self.att = nn.Parameter(torch.Tensor(1, num_heads, out_dim))
        self.bias_l0 = nn.Parameter(torch.FloatTensor([bias_l0]))
        self.beta = 0.66
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_normal_(self.lin_l.weight)
        nn.init.xavier_normal_(self.lin_r.weight)
        nn.init.xavier_normal_(self.att)

    def forward(self, g, inputs):
        if inputs.dim() != 2:
            raise ValueError("Expected 2D input tensor, but got {}D".format(inputs.dim()))

        num_nodes = inputs.shape[0]
        new_feature_size = self.out_dim * self.num_heads  # Assuming out_dim is defined in __init__

        h_l = F.dropout(self.lin_l(inputs), p=self.dropout, training=self.training)
        h_r = F.dropout(self.lin_r(inputs), p=self.dropout, training=self.training)

        # Ensure that the total size matches before reshaping
        if h_l.numel() != num_nodes * new_feature_size or h_r.numel() != num_nodes * new_feature_size:
            raise RuntimeError("Mismatch in total elements for reshaping.")

        h_l = h_l.view(num_nodes, self.num_heads, -1)
        h_r = h_r.view(num_nodes, self.num_heads, -1)

        g.ndata['h_l'] = h_l
        g.ndata['h_r'] = h_r

        # Compute attention scores
        g.apply_edges(self.edge_attention)
        
        # Apply edge softmax to normalize attention scores
        g.edata['a'] = edge_softmax(g, g.edata['a'])

        g.update_all(fn.u_mul_e('h_r', 'a', 'm'), fn.sum('m', 'h'))
        h = g.ndata.pop('h').view(inputs.shape[0], -1)

        # Apply residual connection
        if self.residual:
            if self.res_fc is not None:
                res_out = self.res_fc(inputs)
            else:
                res_out = inputs
            h += res_out 

        return h      
        return h

    def edge_attention(self, edges):
        # GATv2 dynamic attention mechanism
        h_l = edges.src['h_l']
        h_r = edges.dst['h_r']
        e = self.leaky_relu(h_l + h_r)
        alpha = (e * self.att).sum(dim=-1, keepdim=True)

        alpha = alpha + self.bias_l0
        if self.training:
            alpha = l0_train(alpha, beta=self.beta)
        else:
            alpha = l0_test(alpha, beta=self.beta)

        return {'a': alpha}


class MergedGAT(nn.Module):
    def __init__(self, g, in_dim, num_hidden, num_classes, num_heads, dropout, alpha, bias_l0):
        super(MergedGAT, self).__init__()
        self.dropout = dropout
        self.num_classes = num_classes
        self.graph = g

  
        self.layers = nn.ModuleList([
            MergedGATLayer(in_dim, num_hidden, num_heads, dropout, alpha, bias_l0, residual=True)
        ])

        # Intermediate layers
        for _ in range(3):
            self.layers.append(MergedGATLayer(num_hidden * num_heads, num_hidden, num_heads, dropout, alpha, bias_l0, residual=True))

        # Final layer
        self.layers.append(MergedGATLayer(num_hidden * num_heads, num_classes, 1, dropout, alpha, bias_l0, residual=True))

    def forward(self, inputs):
        h = inputs
        for i, layer in enumerate(self.layers):
            h = layer(self.graph, h)
            if i < len(self.layers) - 1:
                h = F.elu(h)
                h = F.dropout(h, p=self.dropout, training=self.training)
            else:
                h = h.view(h.shape[0], -1)
                h = h[:, :self.num_classes]
        return h



In [7]:
def load_cora_data():
    dataset = dgl.data.CoraGraphDataset()
    g = dataset[0]
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
    degs = g.in_degrees().float()
    norm = torch.pow(degs, -0.5)
    norm[torch.isinf(norm)] = 0
    g.ndata['norm'] = norm.unsqueeze(1)

    return g, dataset.num_classes


In [8]:
def train_model(g, num_classes):
    # Hyperparameters
    num_heads = 8
    num_hidden = 256
    dropout = 0.6
    alpha = 0.2
    bias_l0 = 0.1
    in_dim = g.ndata['feat'].shape[1]
    
    model = model = MergedGAT(g, in_dim=1433, num_hidden=256, num_classes=num_classes, num_heads=8, dropout=0.6, alpha=0.2, bias_l0=0.1)

    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)

    for epoch in range(100):
        model.train()
        logits = model(g.ndata['feat'])
        loss = loss_func(logits[g.ndata['train_mask']], g.ndata['label'][g.ndata['train_mask']])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate training accuracy
        _, train_indices = torch.max(logits, dim=1)
        train_correct = torch.sum(train_indices[g.ndata['train_mask']] == g.ndata['label'][g.ndata['train_mask']])
        train_accuracy = float(train_correct) / int(g.ndata['train_mask'].sum())

        # Evaluate on validation set
        model.eval()
        with torch.no_grad():
            val_logits = model(g.ndata['feat'])
            val_loss = loss_func(val_logits[g.ndata['val_mask']], g.ndata['label'][g.ndata['val_mask']])
            _, val_indices = torch.max(val_logits, dim=1)
            val_correct = torch.sum(val_indices[g.ndata['val_mask']] == g.ndata['label'][g.ndata['val_mask']])
            val_accuracy = float(val_correct) / int(g.ndata['val_mask'].sum())

        print(f"Epoch {epoch:05d} | Train Loss {loss.item():.4f} | Train Accuracy {train_accuracy:.4f} | Val Loss {val_loss.item():.4f} | Val Accuracy {val_accuracy:.4f}")


In [9]:
g, num_classes = load_cora_data()
train_model(g, num_classes)


  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Epoch 00000 | Train Loss 1.9493 | Train Accuracy 0.1071 | Val Loss 1.9130 | Val Accuracy 0.3280
Epoch 00001 | Train Loss 1.9130 | Train Accuracy 0.3071 | Val Loss 1.8718 | Val Accuracy 0.3880
Epoch 00002 | Train Loss 1.8434 | Train Accuracy 0.4571 | Val Loss 1.8224 | Val Accuracy 0.4280
Epoch 00003 | Train Loss 1.7498 | Train Accuracy 0.5500 | Val Loss 1.7599 | Val Accuracy 0.4540
Epoch 00004 | Train Loss 1.6527 | Train Accuracy 0.5714 | Val Loss 1.6732 | Val Accuracy 0.5420
Epoch 00005 | Train Loss 1.5253 | Train Accuracy 0.6786 | Val Loss 1.5549 | Val Accuracy 0.6280
Epoch 00006 | Train Loss 1.3521 | Train Accuracy 0.7500 | Val Loss 1.3995 | Val Accuracy 0.6920
Epoch 00007 | Train Loss 1.1766 | Train Accuracy 0.8071 | Val Loss 1.2170 | Val Accuracy 0.7400
Epoch 00008 | Train Loss 0.9830 | Train Accuracy 0