In [1]:
import torch
import dgl
import dgl.data
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import  roc_auc_score
from dgl.nn.pytorch.glob import SumPooling
from dataLoader import getData
import numpy as np
# Set print options to print the full tensor
# torch.set_printoptions(threshold=float("inf"))


In [2]:
class GNAN(nn.Module):
    def __init__(
        self,
        in_dim,
        hidden_dim,
        num_layers,
        out_dim,
        feat_trasform_bias=True,
        dist_transform_bias=False,
        dropout=0.0,
    ):
        super().__init__()
        self.out_dim = out_dim

        self.distance_transform = nn.ModuleList()
        self.feature_transform = nn.ModuleList()

        for _ in range(in_dim):
            distance_transform_layers = []
            feature_transform_layers = []
            if num_layers == 1:
                distance_transform_layers.append(
                    nn.Linear(1, out_dim, bias=dist_transform_bias)
                )
                feature_transform_layers.append(
                    nn.Linear(1, out_dim, bias=feat_trasform_bias)
                )

            else:
                distance_transform_layers.append(
                    nn.Linear(1, hidden_dim, bias=dist_transform_bias)
                )
                distance_transform_layers.append(nn.ReLU())

                feature_transform_layers.append(
                    nn.Linear(1, hidden_dim, bias=feat_trasform_bias)
                )
                feature_transform_layers.append(nn.ReLU())
                feature_transform_layers.append(nn.Dropout(p=dropout))

                for _ in range(1, num_layers - 1):
                    distance_transform_layers.append(
                        nn.Linear(hidden_dim, hidden_dim, bias=dist_transform_bias)
                    )
                    distance_transform_layers.append(nn.ReLU())

                    feature_transform_layers.append(
                        nn.Linear(hidden_dim, hidden_dim, bias=feat_trasform_bias)
                    )
                    feature_transform_layers.append(nn.ReLU())
                    feature_transform_layers.append(nn.Dropout(p=dropout))

                distance_transform_layers.append(
                    nn.Linear(hidden_dim, out_dim, bias=dist_transform_bias)
                )
                feature_transform_layers.append(
                    nn.Linear(hidden_dim, out_dim, bias=feat_trasform_bias)
                )
                self.distance_transform.append(
                    nn.Sequential(*distance_transform_layers)
                )
                self.feature_transform.append(nn.Sequential(*feature_transform_layers))

    def forward(self, graph, feats):
        """
        params:

            dis_matrix: shape (N,N) where N is number of nodes each element coresponds to node pair distance
            feats: shape (N, d) where d is the feature dim

        """

        distance_matrix = graph.ndata["distance_matrix"]  # (N, N)

        normalization_distance_matrix = graph.ndata[
            "normalization_distance_matrix"
        ]  # N, N

        num_nodes, feat_dim = feats.shape
        f_matrix = torch.empty(
            feat_dim, num_nodes, num_nodes, self.out_dim
        )  # F matrix # (d, N, N, C)
        m_matrix = torch.empty(
            feat_dim, num_nodes, num_nodes, self.out_dim
        )  # M matrix # (d, N, N, C)
        # print(f"noramlization_distance_matrix is {normalization_distance_matrix}")
        distance_matrix = torch.div(distance_matrix, normalization_distance_matrix)
        # print(f"distance_matrix is {distance_matrix}")
        distance_matrix = distance_matrix.view(-1, 1)  # (N*N,1)
        # print(f"distance_matrix.view(-1, 1) is {distance_matrix}")
        for k in range(feat_dim):
            # x_k is the kth feature of all nodes
            # Target sizes: [16, 16, 1].  Tensor sizes: [16, 28]
            x_k = feats[:, k].view(-1, 1)  # shape (N, 1)
            # print(f"x_k shape is {x_k.shape}")
            # print(
            #     f"self.feature_transform[k](x_k) shape is {self.feature_transform[k](x_k).shape}"
            # )
            # print(f"f_matrix[k, :, :, :] shape is {f_matrix[k, :, :, :].shape}")
            # print(f"self.feature_trasform is {self.feature_transform}")
            f_matrix[k, :, :, :] = self.feature_transform[k](x_k).repeat(
                num_nodes, 1, 1
            )  # (N, N, out)
            m_matrix[k, :, :, :] = self.distance_transform[k](
                distance_matrix
            ).view(  # (N, N, out)
                num_nodes, num_nodes, -1
            )
            # print(f"m_matrix[k, :, :, :]shape {m_matrix[k, :, :, :].shape}")
            # m_matrix[k, :, :, :] = torch.div(
            #     m_matrix[k, :, :, :], normalization_distance_matrix.unsqueeze(-1)
            # )
        # torch.set_printoptions(threshold=float("inf"))
        # print(f"m matrix  before division is {m_matrix}")
        # m_matrix = torch.div(
        #     m_matrix, normalization_distance_matrix.unsqueeze(-1).unsqueeze(0)
        # )
        # print(f"m_matrix after divsion is {m_matrix}")
        # f matrix shape (d, N, N, out)
        # m matrix shape (d, N, N, out)
        f_matrix = f_matrix.permute(3, 0, 1, 2)  # (out, d, N, N)
        m_matrix = m_matrix.permute(3, 0, 1, 2)  # (out, d, N, N)
        m_f_matrix = f_matrix * m_matrix  # (out, d, N, N)
        # print(f"m_f_matrix is {m_f_matrix}")
        # print(f"m_f_matrix is {m_f_matrix.shape}")
        # print(f"m_f_matrix is {m_f_matrix}")
        h = m_f_matrix.sum(3)  # (out, d, N)
        # print(f"h 1ist is is of shape {h.shape}")
        # print(f"h = m_f_matrix.shape is {h.shape}")
        h = h.sum(1)  # (out, N)
        # print(f"h = h.sum(1)is {h.shape}")
        # print(f"h 2nd is shape is {h.shape}")
        # h = h.permute(1, 0)  # (N, out)
        # print(f"h.permute(1, 0) is {h}")
        # print(h)
        h = h.sum(1) # (out,
        # print(f"output is {h.T}")
        return h


In [3]:
train_loader, valid_loader, test_loader, num_feats, num_class = getData()
# num_class = data.num_classes
# Define model, loss function, and optimizer
model = GNAN(in_dim=num_feats, out_dim=1, hidden_dim=64, num_layers=3)

loss_fn = nn.BCEWithLogitsLoss()  # Binary Cross Entropy with Logits
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-5,
)

# Define training settings
num_epochs = 1000
# train_loader


In [4]:
def get_accuracy(outputs, labels):
    if outputs.dim() == 2 and outputs.shape[-1] > 1:
        return get_multiclass_accuracy(outputs, labels)
    else:
        y_prob = torch.sigmoid(outputs).view(-1)
        y_prob = y_prob > 0.5
        return (labels == y_prob).sum().item()


def get_multiclass_accuracy(outputs, labels):
    assert outputs.size(1) >= labels.max().item() + 1
    probas = torch.softmax(outputs, dim=-1)
    preds = torch.argmax(probas, dim=-1)
    correct = (preds == labels).sum()
    acc = correct
    return acc


def train_epoch(
    model,
    dloader,
    loss_fn,
    optimizer,
    classify=True,
    label_index=0,
    compute_auc=False,
    is_graph_task=True,
    epoch=-1,
):
    with torch.autograd.set_detect_anomaly(True):
        running_loss = 0.0
        n_samples = 0
        all_probas = np.array([])
        all_labels = np.array([])
        if classify:
            running_acc = 0.0

        for graph, label in dloader:
            if len(label.shape) > 1:
                labels = label[:, label_index].view(-1, 1).flatten()
                labels = labels.float()
            else:
                labels = label.flatten()
            if -1 in labels:
                labels = (labels + 1) / 2
            # if loss_fn.__class__.__name__ == "CrossEntropyLoss":
            #     labels = labels.long()

            # non_zero_ids = None
            # if model.__class__.__name__ == "GNAM":
            #     labels = labels[data.train_mask]
            #     non_zero_ids = torch.nonzero(data.train_mask).flatten()
            # data = data.to(device)
            # labels = labels.to(device)
            # optimizer.zero_grad()
            # if non_zero_ids is not None:
            #     outputs = model.forward(data, non_zero_ids)
            # else:
            outputs = model.forward(graph, graph.ndata["feat"])
            # print("train model output")
            # print(outputs)

            # Check for NaN in the outputs
            # if torch.isnan(outputs).any():
            #     print(f"NaN detected in model output at epoch{epoch}")
            #     break

            # if not is_graph_task:
            #         labels = labels[data.train_mask]
            #         outputs = outputs[data.train_mask]

            if isinstance(outputs, tuple):
                outputs = outputs[0]

            n_samples += len(labels)
            if outputs.dim() == 2 and outputs.shape[-1] == 1:
                loss = loss_fn(outputs.flatten(), labels.float())
            else:
                loss = loss_fn(outputs, labels.float())
            # if torch.isnan(loss).any():
            #     print(f"NaN detected in loss at epoch {epoch}")
            #     break

            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if compute_auc:
                probas = torch.sigmoid(outputs).view(-1)
                all_probas = np.concatenate((all_probas, probas.detach().cpu().numpy()))
                all_labels = np.concatenate((all_labels, labels.detach().cpu().numpy()))
            running_loss += loss.item()

            if classify:
                running_acc += get_accuracy(outputs, labels)

        if compute_auc:
            auc = roc_auc_score(all_labels, all_probas)

        if classify:
            if compute_auc:
                return running_loss / len(dloader), running_acc / n_samples, auc
            else:
                return running_loss / len(dloader), running_acc / n_samples, -1
        else:
            return running_loss / len(dloader), -1


def test_epoch(
    model,
    dloader,
    loss_fn,
    classify=True,
    label_index=0,
    compute_auc=False,
    val_mask=False,
    is_graph_task=True,
):
    with torch.no_grad():
        running_loss = 0.0
        all_probas = np.array([])
        all_labels = np.array([])
        n_samples = 0
        if classify:
            running_acc = 0.0
        model.eval()
        for graph, label in dloader:
            if len(label.shape) > 1:
                labels = label[:, label_index].view(-1, 1).flatten()
                labels = labels.float()
            else:
                labels = label.flatten()
            if -1 in labels:
                labels = (labels + 1) / 2
            # if loss_fn.__class__.__name__ == "CrossEntropyLoss":
            #     labels = labels.long()

            # non_zero_ids = None
            # if model.__class__.__name__ == "GNAM":
            #     if val_mask:
            #         labels = labels[data.val_mask]
            #         non_zero_ids = torch.nonzero(data.val_mask).flatten()
            #     else:
            #         labels = labels[data.test_mask]
            #         non_zero_ids = torch.nonzero(data.test_mask).flatten()

            # forward
            # if non_zero_ids is not None:
            #     outputs = model.forward(inputs, non_zero_ids)
            # else:
            outputs = model.forward(graph, graph.ndata["feat"])
            # if not is_graph_task:
            #     if val_mask:
            #         outputs = outputs[data.val_mask]
            #         labels = labels[data.val_mask]
            # else:

            n_samples += len(labels)
            if outputs.dim() == 2 and outputs.shape[-1] == 1:
                loss = loss_fn(outputs.flatten(), labels.float())
            else:
                loss = loss_fn(outputs, labels.float())
            running_loss += loss.item()

            if classify:
                running_acc += get_accuracy(outputs, labels)
            if compute_auc:
                probas = torch.sigmoid(outputs).view(-1)
                all_probas = np.concatenate((all_probas, probas.detach().cpu().numpy()))
                all_labels = np.concatenate((all_labels, labels.detach().cpu().numpy()))

        if compute_auc:
            auc = roc_auc_score(all_labels, all_probas)
        if classify:
            if compute_auc:
                return running_loss / len(dloader), running_acc / n_samples, auc
            else:
                return running_loss / len(dloader), running_acc / n_samples, -1
        else:
            return running_loss / len(dloader), -1


In [None]:
for epoch in range(num_epochs):
    loss, tran_acc, _ = train_epoch(
        model,
        train_loader,
        loss_fn,
        optimizer,
        classify=True,
        compute_auc=False,
        epoch=epoch + 1,
    )

    valid_loss, accuracy, auc = test_epoch(model, valid_loader, loss_fn, compute_auc=True)
    print(
        f"Epoch {epoch+1} | loss : {loss:.4f} |loss, tran_acc : {tran_acc:.4f} |valid_loss : {loss:.4f} | valid_accuracy : {accuracy:4f} | valid_auc : {auc:4f}"
    )
