In [25]:
import torch
import torch.nn.functional as F
import scipy.io
from sklearn.metrics import accuracy_score, f1_score
from torch_geometric.data import Data, DataLoader, InMemoryDataset
from torch.utils.data import Subset
from torch.cuda.amp import GradScaler, autocast
from pretrain_gnns.bio.model import GNN
from torch.utils.checkpoint import checkpoint
import numpy as np

# Set the device to CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Function to load and process the .mat files into the correct format
def load_data(file_path):
    data = scipy.io.loadmat(file_path)
    x = torch.tensor(data['attrb'].todense(), dtype=torch.float32)
    edge_index = torch.tensor(data['network'].nonzero(), dtype=torch.long)

    # Initialize edge_attr correctly
    edge_attr_data = data['network'].data
    num_edges = edge_index.size(1)
    
    # Assuming 9 features per edge
    num_features = 9
    
    if len(edge_attr_data.shape) == 1:
        edge_attr_data = edge_attr_data.reshape(-1, 1)  # Ensure it's a 2D array with shape [num_edges, 1]
    
    if edge_attr_data.shape[1] < num_features:
        # Expand the features if necessary, here using zeros as placeholders for simplicity
        edge_attr = torch.zeros((num_edges, num_features), dtype=torch.float32)
        edge_attr[:, :edge_attr_data.shape[1]] = torch.tensor(edge_attr_data, dtype=torch.float32)
    else:
        edge_attr = torch.tensor(edge_attr_data, dtype=torch.float32)

    y = torch.tensor(data['group'].argmax(axis=1).squeeze(), dtype=torch.long)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

class CustomDataset(InMemoryDataset):
    def __init__(self, data_list):
        super().__init__()
        self.data, self.slices = self.collate(data_list)

    def get(self, idx):
        return self.data.__class__(**{key: self.data[key][idx] for key in self.data.keys()})

# Load the datasets
print("Loading datasets...")
data_list = [load_data('acmv9.mat')]
train_data = CustomDataset(data_list)
test_data = load_data('citationv1.mat')

# Check and reduce the training dataset to half if it's large enough
dataset_size = len(train_data)
if dataset_size > 1:
    indices = np.random.choice(dataset_size, dataset_size // 2, replace=False)
    train_data_reduced = Subset(train_data, indices)
else:
    train_data_reduced = train_data  # Use the full dataset if it's too small to be halved

# Load the pre-trained GAT model
print("Loading pre-trained model...")
model = GNN(num_layer=5, emb_dim=300, gnn_type='gat')  # Maintain the model architecture
model.load_state_dict(torch.load('pretrain_gnns/bio/model_architecture/gat_supervised_masking.pth'))
model = model.to(device)

# Setup the DataLoader with a smaller batch size
train_loader = DataLoader(train_data_reduced, batch_size=1, shuffle=True)

# Setup optimizer, loss, and gradient scaler for mixed precision
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
scaler = GradScaler()

# Function to checkpoint each layer of the model manually
def forward_with_checkpointing(model, x, edge_index, edge_attr):
    # Ensure that the input tensors are correctly shaped
    if len(x.shape) == 1:
        x = x.unsqueeze(-1)  # Reshape x to have shape [num_nodes, 1] if it's 1D

    if len(edge_index.shape) == 1:
        edge_index = edge_index.view(2, -1)  # Reshape edge_index to [2, num_edges]

    if len(edge_attr.shape) == 1:
        edge_attr = edge_attr.unsqueeze(-1)  # Reshape edge_attr to [num_edges, 1]

    # Iterate over each layer in the model
    h = x
    for i, layer in enumerate(model.gnns):
        print(f"Before layer {i}: h.shape={h.shape}, edge_index.shape={edge_index.shape}, edge_attr.shape={edge_attr.shape}")
        h = checkpoint(layer, h, edge_index, edge_attr)
        print(f"After layer {i}: h.shape={h.shape}")
    return h

# Training loop with manual layer checkpointing
accumulation_steps = 10  # Smaller accumulation steps

print("Starting training...")
model.train()
for epoch in range(25):  # Adjust the number of epochs as needed
    optimizer.zero_grad()  # Reset gradients
    for i, batch in enumerate(train_loader):
        with autocast():
            x, edge_index, edge_attr = batch.x.to(device), batch.edge_index.to(device), batch.edge_attr.to(device)
            
            # Forward pass with layer checkpointing
            output = forward_with_checkpointing(model, x, edge_index, edge_attr)
            output = model.graph_pred_linear(output)  # Final prediction layer
            loss = criterion(output, batch.y.to(device))
            loss = loss / accumulation_steps  # Scale loss for accumulation
            
        scaler.scale(loss).backward()

        # Step the optimizer every 'accumulation_steps' batches
        if (i + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            torch.cuda.empty_cache()  # Clear cache to free up memory
    
    print(f'Epoch {epoch + 1}, Loss: {loss.item():.4f}')
    torch.cuda.empty_cache()  # Clear cache to free up memory after each epoch

# Evaluation
print("Evaluating model...")
model.eval()
with torch.no_grad():  # Disabling the dynamic graph for memory efficiency
    x, edge_index, edge_attr, y = test_data.x.to(device), test_data.edge_index.to(device), test_data.edge_attr.to(device), test_data.y.to(device)
    output = model(x, edge_index, edge_attr)
    predictions = torch.argmax(output, dim=1)
    accuracy = accuracy_score(y.cpu(), predictions.cpu())
    micro_f1 = f1_score(y.cpu(), predictions.cpu(), average='micro')
    print(f'Accuracy: {accuracy:.4f}, Micro F1 Score: {micro_f1:.4f}')


Using device: cuda
Loading datasets...
Loading pre-trained model...
Starting training...
Before layer 0: h.shape=torch.Size([6775, 1]), edge_index.shape=torch.Size([2, 15579]), edge_attr.shape=torch.Size([9, 1])


RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 1 but got size 9 for tensor number 1 in the list.