In [1]:
import torch
import dgl
import dgl.data
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, roc_auc_score
from dgl.nn.pytorch.glob import SumPooling
from dataLoader import getData
import numpy as np


In [2]:
class GNAN(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,  # let it be num_classes for the time being
        bias=True,
        dropout=0.0,
    ):
        super().__init__()
        self.out_dim = out_dim

        self.distance_transform = nn.ModuleList()
        self.feature_transform = nn.ModuleList()

        for _ in range(in_dim):
            self.distance_transform.append(nn.Linear(1, out_dim))
            self.feature_transform.append(nn.Linear(1, out_dim))

    def forward(self, graph, feats):
        """
        params:

            dis_matrix: shape (N,N) where N is number of nodes
            feats: shape (N, d) where d is the feature dim

        """

        
        distance_matrix = graph.ndata["distance_matrix"] # N, N
        normalization_distance_matrix= graph.ndata["normalization_distance_matrix"]
        num_nodes, feat_dim = feats.shape
        f_matrix = torch.empty(feat_dim, num_nodes, self.out_dim)
        m_matrix = torch.empty(feat_dim, num_nodes, num_nodes, self.out_dim)  #
        distance_matrix = distance_matrix.unsqueeze(-1)
        for k in range(feat_dim):
            # x_k is the kth feature of all nodes
            x_k = feats[:, k].view(-1, 1)  # shape (N, 1)
            f_matrix[k, :, :] = self.feature_transform[k](x_k)  # (N, out)
            m_matrix[k, :, :, :] = self.distance_transform[k](
                distance_matrix
            )  # (N, N, out)
        normalization_distance_matrix = normalization_distance_matrix.unsqueeze(-1)
        m_matrix = torch.div(m_matrix, normalization_distance_matrix)

        f_matrix = f_matrix.permute(2, 0, 1).unsqueeze(-1)
        m_matrix = m_matrix.permute(3, 0, 1, 2)
        m_f_matrix = torch.matmul(m_matrix, f_matrix)  # (out, d, N, 1)
        h = m_f_matrix.sum(1)  # (out, N, 1)
        h = h.permute(2, 1, 0)  # 1,N,out
        h = h.squeeze(0)  # N,out
        return h


In [3]:
class GNANMODULE(nn.Module):
    def __init__(self, in_dim, out_dim, hiddem_dim=24, num_layers=3):
        super().__init__()

        self.in_dim = in_dim
        self.hidden_dim = hiddem_dim
        self.out_dim = out_dim
        self.num_layers = num_layers
        self.pool = SumPooling()
        self.gnan = nn.ModuleList()

        for i in range(num_layers):
            if i == 0:
                self.gnan.append(GNAN(in_dim, hiddem_dim))
            elif i == num_layers - 1:
                self.gnan.append(GNAN(hiddem_dim, out_dim))
            else:
                self.gnan.append(GNAN(hiddem_dim, hiddem_dim))

    def forward(self, graph, feats):
        for i in range(self.num_layers):
            if i == 0:
                h = F.relu(self.gnan[i](graph, feats))
            elif i == self.num_layers - 1:
                h = self.gnan[i](graph, h)
            else:
                h = F.relu(self.gnan[i](graph, h))

        h = self.pool(graph, h)

        return h

In [4]:
train_loader, valid_loader, test_loader, num_feats, num_class = getData()
# num_class = data.num_classes
# Define model, loss function, and optimizer
model = GNANMODULE(num_feats, 1)
loss_fn = nn.BCEWithLogitsLoss()  # Binary Cross Entropy with Logits
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define training settings
num_epochs = 100


  data = torch.load(f"{processed_data_dir}/{data_name}.pt")


In [5]:
# Binary classification accuracy function
def get_accuracy(outputs, labels):
    # Apply sigmoid to get probabilities for binary classification
    y_prob = torch.sigmoid(outputs).view(-1)
    y_pred = y_prob > 0.5  # Threshold at 0.5 to classify as 0 or 1
    return (labels == y_pred).sum().item()


# Train epoch function
def train_epoch(model, dloader, loss_fn, optimizer, classify=True, compute_auc=False):
    running_loss = 0.0
    n_samples = 0
    all_probas = np.array([])
    all_labels = np.array([])

    if classify:
        running_acc = 0.0

    for graph, label in dloader:
        labels = label.flatten()  # Binary labels, assuming they are already 0 or 1

        # Forward pass
        optimizer.zero_grad()
        outputs = model(graph, graph.ndata["feat"])

        n_samples += len(labels)

        # Compute the loss
        if outputs.dim() == 2 and outputs.shape[-1] == 1:
            loss = loss_fn(outputs.flatten(), labels.float())
        else:
            loss = loss_fn(outputs, labels.float())

        loss.backward()
        optimizer.step()

        # Compute AUC and accuracy if needed
        if compute_auc:
            probas = torch.sigmoid(outputs).view(-1)
            all_probas = np.concatenate((all_probas, probas.detach().cpu().numpy()))
            all_labels = np.concatenate((all_labels, labels.detach().cpu().numpy()))

        running_loss += loss.item()

        if classify:
            running_acc += get_accuracy(outputs, labels)

    # Calculate AUC if required
    if compute_auc:
        auc = roc_auc_score(all_labels, all_probas)

    if classify:
        if compute_auc:
            return running_loss / len(dloader), running_acc / n_samples, auc
        else:
            return running_loss / len(dloader), running_acc / n_samples, -1
    else:
        return running_loss / len(dloader), -1


# Test epoch function
def test_epoch(model, dloader, loss_fn, classify=True, compute_auc=False):
    running_loss = 0.0
    all_probas = np.array([])
    all_labels = np.array([])
    n_samples = 0

    if classify:
        running_acc = 0.0

    model.eval()
    for graph, label in dloader:
        labels = label.flatten()  # Binary labels, assuming they are already 0 or 1
        # Forward pass
        outputs = model(graph, graph.ndata["feat"])

        n_samples += len(labels)

        # Compute the loss
        if outputs.dim() == 2 and outputs.shape[-1] == 1:
            loss = loss_fn(outputs.flatten(), labels.float())
        else:
            loss = loss_fn(outputs, labels.float())

        running_loss += loss.item()

        # Compute AUC and accuracy if needed
        if classify:
            running_acc += get_accuracy(outputs, labels)

        if compute_auc:
            probas = torch.sigmoid(outputs).view(-1)
            all_probas = np.concatenate((all_probas, probas.detach().cpu().numpy()))
            all_labels = np.concatenate((all_labels, labels.detach().cpu().numpy()))

    # Calculate AUC if required
    if compute_auc:
        auc = roc_auc_score(all_labels, all_probas)

    if classify:
        if compute_auc:
            return running_loss / len(dloader), running_acc / n_samples, auc
        else:
            return running_loss / len(dloader), running_acc / n_samples, -1
    else:
        return running_loss / len(dloader), -1


In [None]:
for epoch in range(num_epochs):
    loss, tran_acc, _ = train_epoch(
        model, train_loader, loss_fn, optimizer, classify=True, compute_auc=False
    )

    valid_loss, accuracy, auc = test_epoch(model, valid_loader, loss_fn)
    print(
        f"loss : {loss:.4f} |loss, tran_acc : {tran_acc} |valid_loss : {loss:.4f} | valid_accuracy : {accuracy:4f} | valid_auc : {auc:4f}"
    )
