In [16]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import scipy.sparse
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from torch_geometric.utils import from_scipy_sparse_matrix

from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv

import matplotlib.pyplot as plt

In [17]:

class GCN_MLC(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels=16):
        super(GCN_MLC, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)  # Additional layer for deeper learning

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)  # Dropout for regularization
        x = self.conv2(x, edge_index)
        return x

class SAGE_MLC(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels=16):
        super(SAGE_MLC, self).__init__()
        self.conv1 = SAGEConv(num_features, hidden_channels)
        self.fc = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.fc(x)
        return x


In [18]:
import matplotlib.pyplot as plt


def load_data(device):
    # Load data, assuming the paths are correct
    X = torch.load('results/X_32.pt')
    Y = torch.load('results/Y.pt')
    A = scipy.sparse.load_npz('results/A/A_final.npz')

    edge_index, edge_weight = from_scipy_sparse_matrix(A)
    edge_index = edge_index.to(device)
    edge_weight = edge_weight.to(device)

    # Convert X and Y to torch tensors if they are numpy arrays
    if isinstance(X, np.ndarray):
        X = torch.tensor(X, dtype=torch.float).to(device)
    else:
        X = X.to(device)

    if isinstance(Y, np.ndarray):
        Y = torch.tensor(Y, dtype=torch.float).to(device)  # Ensuring Y is also a float for BCEWithLogitsLoss
    else:
        Y = Y.to(device)

    return X, Y, edge_index, edge_weight

def prepare_masks(num_nodes):

    train_index, temp_index = train_test_split(np.arange(num_nodes), test_size=0.3, random_state=42)
    val_index, test_index = train_test_split(temp_index, test_size=0.6667, random_state=42)  # Adjusted test_size to split remaining 30% into 20% and 10%

    # Create boolean masks for train, validation, and test datasets
    train_mask = torch.zeros(num_nodes, dtype=torch.bool).scatter_(0, torch.tensor(train_index), True)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool).scatter_(0, torch.tensor(val_index), True)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool).scatter_(0, torch.tensor(test_index), True)

    return train_mask, val_mask, test_mask


def train(model, data, train_mask, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[train_mask], data.y[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluate(model, data, mask):
    model.eval()
    with torch.no_grad():
        preds = torch.sigmoid(model(data)[mask])
        preds = (preds > 0.5).float()
        correct = (preds == data.y[mask]).float()
        accuracy = correct.mean()
    return accuracy



def plot_metrics(losses, val_accs):
    fig, ax1 = plt.subplots(figsize=(10, 6))  # Set up the figure and one axis for the loss

    # Plotting the training loss on the primary y-axis
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss', color='tab:blue')
    ax1.plot(losses, label='Training Loss', color='tab:blue')
    ax1.tick_params(axis='y', labelcolor='tab:blue')
    ax1.legend(loc='upper left')

    # Creating a second y-axis for validation accuracy
    ax2 = ax1.twinx()  # Instantiate a second axes that shares the same x-axis
    ax2.set_ylabel('Validation Accuracy', color='tab:red')
    ax2.plot(val_accs, label='Validation Accuracy', color='tab:red')
    ax2.tick_params(axis='y', labelcolor='tab:red')
    ax2.legend(loc='upper right')

    # Adding a title and a grid
    plt.title('Training Loss and Validation Accuracy over Epochs')
    fig.tight_layout()  # Adjust the layout to make room for the second y-axis

    plt.show()




In [19]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    X, Y, edge_index, edge_weight = load_data(device)
    data = Data(x=X, y=Y, edge_index=edge_index, edge_attr=edge_weight).to(device)
    num_features = X.size(1)
    num_classes = Y.size(1)

    model = SAGE_MLC(num_features, num_classes, hidden_channels=16).to(device)
    # model = GCN_MLC(num_features, num_classes, hidden_channels=16).to(device)
    optimizer = Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    criterion = torch.nn.BCEWithLogitsLoss()

    train_mask, val_mask, test_mask = prepare_masks(Y.size(0))

    losses, val_accs = [], []
    for epoch in range(1, 500):
        loss = train(model, data, train_mask, optimizer, criterion)
        val_acc = evaluate(model, data, val_mask)
        losses.append(loss)
        val_accs.append(val_acc)
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}')

    test_acc = evaluate(model, data, test_mask)
    print(f'Test Accuracy (0.6 train): {test_acc:.4f}')
    plot_metrics(losses, val_accs)
    

if __name__ == "__main__":
    main()

Epoch: 001, Loss: 65.8174, Val Acc: 0.4813
Epoch: 002, Loss: 43.9873, Val Acc: 0.5645
Epoch: 003, Loss: 30.4944, Val Acc: 0.6323
Epoch: 004, Loss: 20.8445, Val Acc: 0.7075
Epoch: 005, Loss: 13.6070, Val Acc: 0.7751
Epoch: 006, Loss: 10.4740, Val Acc: 0.7950
Epoch: 007, Loss: 8.4477, Val Acc: 0.8101
Epoch: 008, Loss: 8.1600, Val Acc: 0.7809
Epoch: 009, Loss: 8.7966, Val Acc: 0.7854
Epoch: 010, Loss: 8.1523, Val Acc: 0.8133
Epoch: 011, Loss: 7.4477, Val Acc: 0.8168
Epoch: 012, Loss: 7.0656, Val Acc: 0.8213
Epoch: 013, Loss: 6.6923, Val Acc: 0.8222
Epoch: 014, Loss: 6.3530, Val Acc: 0.8208
Epoch: 015, Loss: 6.1306, Val Acc: 0.8219
Epoch: 016, Loss: 6.0093, Val Acc: 0.8213
Epoch: 017, Loss: 5.8362, Val Acc: 0.8216
Epoch: 018, Loss: 5.5674, Val Acc: 0.8190
Epoch: 019, Loss: 5.2645, Val Acc: 0.8133
Epoch: 020, Loss: 5.0230, Val Acc: 0.8058
Epoch: 021, Loss: 4.8259, Val Acc: 0.7990
Epoch: 022, Loss: 4.5967, Val Acc: 0.7900
Epoch: 023, Loss: 4.3184, Val Acc: 0.8058
Epoch: 024, Loss: 3.9285, Va

KeyboardInterrupt: 