In [1]:
import torch
import dgl
import dgl.data
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import  roc_auc_score
from dgl.nn.pytorch.glob import SumPooling
from dataLoader import getData
import numpy as np
# Set print options to print the full tensor
torch.set_printoptions(threshold=float("inf"))


In [2]:
class NAM(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_layers,
        hidden_channels=None,
        bias=True,
        dropout=0.0,
        device="cpu",
    ):
        super().__init__()

        self.device = device
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.fs = nn.ModuleList()

        for _ in range(in_channels):
            if num_layers == 1:
                curr_f = [nn.Linear(1, out_channels, bias=bias)]
            else:
                curr_f = [
                    nn.Linear(1, hidden_channels, bias=bias),
                    nn.ReLU(),
                    nn.Dropout(p=dropout),
                ]
                for _ in range(1, num_layers - 1):
                    curr_f.append(
                        nn.Linear(hidden_channels, hidden_channels, bias=bias)
                    )
                    curr_f.append(nn.ReLU())
                    curr_f.append(nn.Dropout(p=dropout))
                curr_f.append(nn.Linear(hidden_channels, out_channels, bias=bias))
            self.fs.append(nn.Sequential(*curr_f))

    def init_params(self):
        for name, param in self.named_parameters():
            if "weight" in name:
                nn.init.xavier_normal_(param)
            elif "bias" in name:
                nn.init.constant_(param, 0)

    def forward(self, x):
        fx = torch.empty(x.size(0), x.size(1), self.out_channels).to(self.device)
        for feature_index in range(x.size(1)):
            feature_col = x[:, feature_index]
            feature_col = feature_col.view(-1, 1)
            feature_col = self.fs[feature_index](feature_col)
            fx[:, feature_index] = feature_col

        f_sums = fx.sum(dim=1)
        return f_sums


In [3]:
class TensorGNAN(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_layers,
        hidden_channels=None,
        bias=True,
        dropout=0.0,
        device="cpu",
        limited_m=True,
        normalize_m=True,
        is_graph_task=False,
        readout_n_layers=1,
        final_agg="sum",
    ):
        super().__init__()

        self.device = device
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.limited_m = limited_m
        self.normalize_m = normalize_m
        self.fs = nn.ModuleList()
        self.is_graph_task = is_graph_task
        self.readout_n_layers = readout_n_layers

        self.actual_output_dim_f = (
            1 if is_graph_task and readout_n_layers > 0 else out_channels
        )
        self.actual_output_dim_m = (
            1 if limited_m or (is_graph_task and readout_n_layers > 0) else out_channels
        )

        for _ in range(in_channels):
            if num_layers == 1:
                curr_f = [nn.Linear(1, self.actual_output_dim_f, bias=bias)]
            else:
                curr_f = [
                    nn.Linear(1, hidden_channels, bias=bias),
                    nn.ReLU(),
                    nn.Dropout(p=dropout),
                ]
                for _ in range(1, num_layers - 1):
                    curr_f.append(
                        nn.Linear(hidden_channels, hidden_channels, bias=bias)
                    )
                    curr_f.append(nn.ReLU())
                    curr_f.append(nn.Dropout(p=dropout))
                curr_f.append(
                    nn.Linear(hidden_channels, self.actual_output_dim_f, bias=bias)
                )
            self.fs.append(nn.Sequential(*curr_f))

        m_bias = True
        if is_graph_task:
            m_bias = False
        if num_layers == 1:
            self.m = [nn.Linear(1, self.actual_output_dim_m, bias=m_bias)]

        else:
            self.m = [nn.Linear(1, hidden_channels, bias=m_bias), nn.ReLU()]
            for _ in range(1, num_layers - 1):
                self.m.append(nn.Linear(hidden_channels, hidden_channels, bias=m_bias))
                self.m.append(nn.ReLU())
            self.m.append(
                nn.Linear(hidden_channels, self.actual_output_dim_m, bias=m_bias)
            )
        self.m = nn.Sequential(*self.m)

        if is_graph_task and self.readout_n_layers > 0:
            self.readout_nam = NAM(
                in_channels,
                out_channels,
                readout_n_layers,
                hidden_channels,
                bias,
                dropout,
                device,
            )

        for name, param in self.named_parameters():
            if "weight" in name:
                nn.init.xavier_normal_(param)
            elif "bias" in name:
                nn.init.constant_(param, 0)

    def forward(self, inputs, feats):
        x, node_distances = (
            feats,
            inputs.ndata["distance_matrix"],
        )
        fx = torch.empty(x.size(0), x.size(1), self.actual_output_dim_f).to(self.device) # (N, D, out)
        for feature_index in range(x.size(1)):
            feature_col = x[:, feature_index]
            feature_col = feature_col.view(-1, 1)
            feature_col = self.fs[feature_index](feature_col)
            fx[:, feature_index] = feature_col

        fx_perm = torch.permute(fx, (2, 0, 1))
        if self.normalize_m:
            node_distances = torch.div(node_distances, inputs.ndata["normalization_distance_matrix"])
        m_dist = self.m(node_distances.flatten().view(-1, 1)).view(
            x.size(0), x.size(0), self.actual_output_dim_m
        )
        m_dist_perm = torch.permute(m_dist, (2, 0, 1))

        mf = torch.matmul(m_dist_perm, fx_perm)

        if not self.is_graph_task:
            out = torch.sum(mf, dim=2)

        else:
            hidden = torch.sum(mf, dim=1)
            if self.readout_n_layers > 0:
                out = self.readout_nam(hidden)
            else:
                out = torch.sum(hidden, dim=1).view(1, -1)
        return out.T


In [4]:
class GNAN(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        bias=True,
        dropout=0.0,
    ):
        super().__init__()
        self.out_dim = out_dim

        self.distance_transform = nn.ModuleList()
        self.feature_transform = nn.ModuleList()

        for _ in range(in_dim):
            self.distance_transform.append(nn.Linear(1, out_dim))
            self.feature_transform.append(nn.Linear(1, out_dim))

    def forward(self, graph, feats):
        """
        params:

            dis_matrix: shape (N,N) where N is number of nodes each element coresponds to node pair distance
            feats: shape (N, d) where d is the feature dim

        """

        distance_matrix = graph.ndata["distance_matrix"]  # (N, N)
        normalization_distance_matrix = graph.ndata[
            "normalization_distance_matrix"
        ]  # N, N
        num_nodes, feat_dim = feats.shape
        f_matrix = torch.empty(feat_dim, num_nodes, self.out_dim)  # F matrix
        m_matrix = torch.empty(feat_dim, num_nodes, num_nodes, self.out_dim)  # M matrix
        # distance_matrix = torch.div(distance_matrix, normalization_distance_matrix)
        distance_matrix = distance_matrix.view(-1, 1)  # (N*N,1)
        for k in range(feat_dim):
            # x_k is the kth feature of all nodes
            x_k = feats[:, k].view(-1, 1)  # shape (N, 1)
            f_matrix[k, :, :] = F.relu(self.feature_transform[k](x_k))  # (N, out)
            m_matrix[k, :, :, :] = F.relu(self.distance_transform[k](distance_matrix).view(
                num_nodes, num_nodes, -1
            ) ) # (N, N, out)
            # print(f"m_matrix[k, :, :, :]shape {m_matrix[k, :, :, :].shape}")
            # m_matrix[k, :, :, :] = torch.div(
            #     m_matrix[k, :, :, :], normalization_distance_matrix.unsqueeze(-1)
            # )
        # torch.set_printoptions(threshold=float("inf"))
        # print(f"m matrix is {m_matrix}")
        m_matrix = torch.div(
            m_matrix, normalization_distance_matrix.unsqueeze(-1).unsqueeze(0)
        )
        # print(f"m_matrix after divsion is{m_matrix}")
        # f matrix shape (d, N, out)
        # m matrix shape (d, N, N, out)
        f_matrix = f_matrix.permute(2, 0, 1).unsqueeze(-1)  # (out,d, N, 1)
        m_matrix = m_matrix.permute(3, 0, 1, 2)  # (out,d, N, N)
        m_f_matrix = torch.matmul(m_matrix, f_matrix)  # (out, d, N, 1)
        h = m_f_matrix.sum(1)  # (out, N, 1)
        h = h.permute(2, 1, 0)  # (1,N,out)
        h = h.squeeze(0)  # (N,out)
        return h


In [5]:
class GNANMODULE(nn.Module):
    def __init__(self, in_dim, out_dim, hiddem_dim=24, num_layers=3):
        super().__init__()

        self.in_dim = in_dim
        self.hidden_dim = hiddem_dim
        self.out_dim = out_dim
        self.num_layers = num_layers
        self.pool = SumPooling()
        self.gnan = nn.ModuleList()

        for i in range(num_layers):
            if i == 0:
                self.gnan.append(GNAN(in_dim, hiddem_dim))
            elif i == num_layers - 1:
                self.gnan.append(GNAN(hiddem_dim, out_dim))
            else:
                self.gnan.append(GNAN(hiddem_dim, hiddem_dim))

    def forward(self, graph, feats):
        for i in range(self.num_layers):
            if i == 0:
                h = F.relu(self.gnan[i](graph, feats))
                h = F.dropout(h,0.2)
            elif i == self.num_layers - 1:
                h = self.gnan[i](graph, h)
            else:
                h = F.relu(self.gnan[i](graph, h))
                h = F.dropout(h,0.2)

        h = self.pool(graph, h)
        # print("class")
        # print(h)
        # print(h.shape)
        
        return h

In [6]:
train_loader, valid_loader, test_loader, num_feats, num_class = getData()
# num_class = data.num_classes
# Define model, loss function, and optimizer
model = GNANMODULE(num_feats, 1, hiddem_dim=28)
# model = TensorGNAN(num_feats,1,1,28, is_graph_task=True)
loss_fn = nn.BCEWithLogitsLoss()  # Binary Cross Entropy with Logits
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.001,
    weight_decay=0.0005
)

# Define training settings
num_epochs = 100
# train_loader

  torch.unique(torch.tensor(labels)).tolist()


In [7]:



def get_accuracy(outputs, labels):
    if outputs.dim() == 2 and outputs.shape[-1] > 1:
        return get_multiclass_accuracy(outputs, labels)
    else:
        y_prob = torch.sigmoid(outputs).view(-1)
        y_prob = y_prob > 0.5
        return (labels == y_prob).sum().item()


def get_multiclass_accuracy(outputs, labels):
    assert outputs.size(1) >= labels.max().item() + 1
    probas = torch.softmax(outputs, dim=-1)
    preds = torch.argmax(probas, dim=-1)
    correct = (preds == labels).sum()
    acc = correct
    return acc


def train_epoch(
    model,
    dloader,
    loss_fn,
    optimizer,
    classify=True,
    label_index=0,
    compute_auc=False,
    is_graph_task=True,
    epoch=-1
):
    with torch.autograd.set_detect_anomaly(True):
        running_loss = 0.0
        n_samples = 0
        all_probas = np.array([])
        all_labels = np.array([])
        if classify:
            running_acc = 0.0
   
        for graph, label in dloader:
    
            if len(label.shape) > 1:
                labels = label[:, label_index].view(-1, 1).flatten()
                labels = labels.float()
            else:
                labels = label.flatten()
            if -1 in labels:
                labels = (labels + 1) / 2
            # if loss_fn.__class__.__name__ == "CrossEntropyLoss":
            #     labels = labels.long()

            # non_zero_ids = None
            # if model.__class__.__name__ == "GNAM":
            #     labels = labels[data.train_mask]
            #     non_zero_ids = torch.nonzero(data.train_mask).flatten()
            # data = data.to(device)
            # labels = labels.to(device)
            # optimizer.zero_grad()
            # if non_zero_ids is not None:
            #     outputs = model.forward(data, non_zero_ids)
            # else:
            outputs = model.forward(graph, graph.ndata["feat"])
            # print("train model output")
            # print(outputs)
            
            # Check for NaN in the outputs
            # if torch.isnan(outputs).any():
            #     print(f"NaN detected in model output at epoch{epoch}")
            #     break
            
            # if not is_graph_task:
            #         labels = labels[data.train_mask]
            #         outputs = outputs[data.train_mask]

            if isinstance(outputs, tuple):
                outputs = outputs[0]

            n_samples += len(labels)
            if outputs.dim() == 2 and outputs.shape[-1] == 1:
                loss = loss_fn(outputs.flatten(), labels.float())
            else:
                loss = loss_fn(outputs, labels)
            # if torch.isnan(loss).any():
            #     print(f"NaN detected in loss at epoch {epoch}")
            #     break
        
            loss.backward()
            optimizer.step()
            if compute_auc:
                probas = torch.sigmoid(outputs).view(-1)
                all_probas = np.concatenate((all_probas, probas.detach().cpu().numpy()))
                all_labels = np.concatenate((all_labels, labels.detach().cpu().numpy()))
            running_loss += loss.item()

            if classify:
                running_acc += get_accuracy(outputs, labels)

        if compute_auc:
            auc = roc_auc_score(all_labels, all_probas)

        if classify:
            if compute_auc:
                return running_loss / len(dloader), running_acc / n_samples, auc
            else:
                return running_loss / len(dloader), running_acc / n_samples, -1
        else:
            return running_loss / len(dloader), -1


def test_epoch(
    model,
    dloader,
    loss_fn,
    classify=True,
    label_index=0,
    compute_auc=False,
    val_mask=False,
    is_graph_task=True,
):
    with torch.no_grad():
        running_loss = 0.0
        all_probas = np.array([])
        all_labels = np.array([])
        n_samples = 0
        if classify:
            running_acc = 0.0
        model.eval()
        for graph, label in dloader:
            if len(label.shape) > 1:
                labels = label[:, label_index].view(-1, 1).flatten()
                labels = labels.float()
            else:
                labels = label.flatten()
            if -1 in labels:
                labels = (labels + 1) / 2
            # if loss_fn.__class__.__name__ == "CrossEntropyLoss":
            #     labels = labels.long()

            # non_zero_ids = None
            # if model.__class__.__name__ == "GNAM":
            #     if val_mask:
            #         labels = labels[data.val_mask]
            #         non_zero_ids = torch.nonzero(data.val_mask).flatten()
            #     else:
            #         labels = labels[data.test_mask]
            #         non_zero_ids = torch.nonzero(data.test_mask).flatten()

            # forward
            # if non_zero_ids is not None:
            #     outputs = model.forward(inputs, non_zero_ids)
            # else:
            outputs = model.forward(graph, graph.ndata["feat"])
            # if not is_graph_task:
            #     if val_mask:
            #         outputs = outputs[data.val_mask]
            #         labels = labels[data.val_mask]
            # else:

            n_samples += len(labels)
            if outputs.dim() == 2 and outputs.shape[-1] == 1:
                loss = loss_fn(outputs.flatten(), labels.float())
            else:
                loss = loss_fn(outputs, labels)
            running_loss += loss.item()

            if classify:
                running_acc += get_accuracy(outputs, labels)
            if compute_auc:
                probas = torch.sigmoid(outputs).view(-1)
                all_probas = np.concatenate((all_probas, probas.detach().cpu().numpy()))
                all_labels = np.concatenate((all_labels, labels.detach().cpu().numpy()))

        if compute_auc:
            auc = roc_auc_score(all_labels, all_probas)
        if classify:
            if compute_auc:
                return running_loss / len(dloader), running_acc / n_samples, auc
            else:
                return running_loss / len(dloader), running_acc / n_samples, -1
        else:
            return running_loss / len(dloader), -1


In [8]:
for epoch in range(num_epochs):
    loss, tran_acc, _ = train_epoch(
        model,
        train_loader,
        loss_fn,
        optimizer,
        classify=True,
        compute_auc=False,
        epoch=epoch + 1,
    )

    valid_loss, accuracy, auc = test_epoch(model, valid_loader, loss_fn)
    print(
        f"loss : {loss:.4f} |loss, tran_acc : {tran_acc:.4f} |valid_loss : {loss:.4f} | valid_accuracy : {accuracy:4f} | valid_auc : {auc:4f}"
    )


  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/a373k/Desktop/DGL_VENV/DGLenv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/home/a373k/Desktop/DGL_VENV/DGLenv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/home/a373k/Desktop/DGL_VENV/DGLenv/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/home/a373k/Desktop/DGL_VENV/DGLenv/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.12/asyncio/base_events.py", line 641, in run_forever
    self._run_once()
  File "/usr/lib/python3.12/asyncio/base_events.py", line 1987, in _run_once
    handle._run()
  File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callbac

RuntimeError: Function 'BinaryCrossEntropyWithLogitsBackward0' returned nan values in its 0th output.