In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import random_split
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import from_networkx, to_networkx
from scipy.sparse import csr_matrix
import networkx as nx
from tqdm import tqdm

In [2]:
loaded_data = torch.load('../../ppmi.pth')
data_tensor = loaded_data['data'].to(dtype=torch.float32)
class_label = loaded_data['class_label']

In [3]:
def normalize_correlation_matrix(matrix):
    """
    Normalize a correlation matrix to the range [-1, 1].
    
    :param matrix: A numpy array or PyTorch tensor representing the correlation matrix.
    :return: A normalized correlation matrix of the same type (numpy array or PyTorch tensor).
    """
    if isinstance(matrix, np.ndarray):
        max_val = np.max(matrix)
        min_val = np.min(matrix)
        normalized_matrix = 2 * (matrix - min_val) / (max_val - min_val) - 1
    elif isinstance(matrix, torch.Tensor):
        max_val = torch.max(matrix)
        min_val = torch.min(matrix)
        normalized_matrix = 2 * (matrix - min_val) / (max_val - min_val) - 1
    else:
        raise TypeError("Input must be a numpy array or a PyTorch tensor.")
    
    return normalized_matrix

In [5]:
def construct_pyg_graph(correlation_matrix, threshold=0.5):
    num_regions = correlation_matrix.shape[0]
    # Create NetworkX graph to utilize its easy manipulation capabilities
    G = nx.Graph()
    # Add nodes with alternating hemispheres based on AAL116 atlas
    for i in range(num_regions):
        hemisphere = 'left' if i % 2 == 0 else 'right'
        G.add_node(i, hemisphere=hemisphere)
    # Add edges
    for i in range(num_regions):
        for j in range(i + 1, num_regions):
            if abs(correlation_matrix[i, j]) > threshold:
                G.add_edge(i, j, weight=correlation_matrix[i, j])
    # Convert to PyTorch Geometric Data object
    pyg_graph = from_networkx(G)
    return pyg_graph

In [6]:
def adj_to_graphs(data_tensor, batch_size=10, threshold=0.5):
    graphs = []
    num_batches = (data_tensor.shape[0] + batch_size - 1) // batch_size
    for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, data_tensor.shape[0])
        print(f"Processing batch {batch_idx+1}/{num_batches}, matrices {start_idx+1} to {end_idx}")
        batch_graphs = []
        for i in range(start_idx, end_idx):
            print(f"Processing matrix {i+1}")
            if torch.isnan(data_tensor[i]).any() or torch.isinf(data_tensor[i]).any():
                print(f"Matrix {i+1} contains NaN or Inf values. Skipping.")
                continue
            try:
                normalized_matrix = normalize_correlation_matrix(data_tensor[i].to(dtype=torch.float32))  # Ensure float32
                graph = construct_pyg_graph(normalized_matrix, threshold)
                batch_graphs.append(graph)
            except Exception as e:
                print(f"Error processing matrix {i+1}: {e}")
        graphs.extend(batch_graphs)
        # Clear batch variables to free memory
        del batch_graphs
        torch.cuda.empty_cache()  # Clear GPU memory if using GPU
    return graphs

In [7]:
graphs = adj_to_graphs(data_tensor)

Processing batch 1/209, matrices 1 to 1
Processing matrix 1
Processing batch 2/209, matrices 2 to 2
Processing matrix 2
Processing batch 3/209, matrices 3 to 3
Processing matrix 3
Processing batch 4/209, matrices 4 to 4
Processing matrix 4
Processing batch 5/209, matrices 5 to 5
Processing matrix 5
Processing batch 6/209, matrices 6 to 6
Processing matrix 6
Processing batch 7/209, matrices 7 to 7
Processing matrix 7
Processing batch 8/209, matrices 8 to 8
Processing matrix 8
Processing batch 9/209, matrices 9 to 9
Processing matrix 9
Processing batch 10/209, matrices 10 to 10
Processing matrix 10
Processing batch 11/209, matrices 11 to 11
Processing matrix 11
Processing batch 12/209, matrices 12 to 12
Processing matrix 12
Processing batch 13/209, matrices 13 to 13
Processing matrix 13
Processing batch 14/209, matrices 14 to 14
Processing matrix 14
Processing batch 15/209, matrices 15 to 15
Processing matrix 15
Processing batch 16/209, matrices 16 to 16
Processing matrix 16
Processing b

KeyboardInterrupt: 

In [3]:
def extract_interhemispheric_subgraph(pyg_graph):
    G = to_networkx(pyg_graph)
    subgraph = nx.Graph()
    for u, v, data in G.edges(data=True):
        if G.nodes[u]['hemisphere'] != G.nodes[v]['hemisphere']:
            subgraph.add_edge(u, v, weight=data['weight'])
    pyg_subgraph = from_networkx(subgraph)
    return pyg_subgraph

In [4]:
def emphasize_interhemispheric_edges(pyg_graph, emphasis_factor=2):
    pyg_graph = pyg_graph.clone()  # Clone to avoid modifying the original graph
    G = to_networkx(pyg_graph)
    for u, v, data in G.edges(data=True):
        if G.nodes[u]['hemisphere'] != G.nodes[v]['hemisphere']:
            G[u][v]['weight'] *= emphasis_factor
    emphasized_pyg_graph = from_networkx(G)
    return emphasized_pyg_graph

In [5]:
def process_correlation_matrices(correlation_matrices, threshold=0.5, emphasis_factor=2):
    graphs = []
    interhemispheric_subgraphs = []
    emphasized_graphs = []
    
    for matrix in correlation_matrices:
        pyg_graph = construct_pyg_graph(matrix, threshold)
        subgraph = extract_interhemispheric_subgraph(pyg_graph)
        emphasized_graph = emphasize_interhemispheric_edges(pyg_graph, emphasis_factor)
        
        graphs.append(pyg_graph)
        interhemispheric_subgraphs.append(subgraph)
        emphasized_graphs.append(emphasized_graph)
    
    return graphs, interhemispheric_subgraphs, emphasized_graphs

In [6]:
def load_data():
    # Assuming correlation_matrices is loaded from some source, e.g., torch.load or np.load
    correlation_matrices = np.load('./data/correlation_matrices.npy')
    
    train_size = int(0.7 * len(correlation_matrices))
    val_size = int(0.15 * len(correlation_matrices))
    test_size = len(correlation_matrices) - train_size - val_size
    
    train_matrices, val_matrices, test_matrices = random_split(correlation_matrices, [train_size, val_size, test_size])
    
    train_graphs, _, _ = process_correlation_matrices(train_matrices)
    val_graphs, _, _ = process_correlation_matrices(val_matrices)
    test_graphs, _, _ = process_correlation_matrices(test_matrices)
    
    batch_size = 64
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True, num_workers=8)
    val_loader = DataLoader(val_graphs, batch_size=batch_size, shuffle=True, num_workers=8)
    test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=True, num_workers=8)
    
    return train_loader, val_loader, test_loader

In [None]:
def main():
    args = parse_args()
    set_seeds()
    device = get_device(args.gpu)
    train_loader, val_loader, test_loader = load_data()
    model = get_model(args.model, args.hidden_channels).to(device)
    num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model: {args.model}, parameters: {num_parameters}")
    
    optimizer = Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0
    for epoch in range(1, args.epochs + 1):
        loss = train(model, train_loader, criterion, optimizer, device)
        val_acc, _, _, _ = test(model, val_loader, device)
        test_acc, pre, rec, f1 = test(model, test_loader, device)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
        print(f'Epoch: {epoch:03d}, best Acc: {best_val_acc:.4f}, Test Acc: {test_acc:.4f}, Loss: {loss:.4f}, pre: {pre:.4f}, rec: {rec:.4f}, f1: {f1:.4f}')

if __name__ == "__main__":
    main()
