In [136]:
from torch.utils.data import random_split
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
def load_tu_data(dataset_name='PROTEINS', use_node_attr=True):
    dataset = TUDataset(root='/tmp/' + dataset_name, name=dataset_name, use_node_attr=use_node_attr)

    num_classes = dataset.num_classes
    num_node_features = dataset.num_features

    num_total = len(dataset)
    num_train = int(num_total * 0.6)
    num_val = int(num_total * 0.1)
    num_test = int(num_total * 0.1)
    num_pool = num_total - num_train - num_val - num_test

    indices = torch.randperm(num_total).tolist()
    train_indices = indices[:num_train]
    val_indices = indices[num_train:num_train + num_val]
    test_indices = indices[num_train + num_val:num_train + num_val + num_test]
    pool_indices = indices[num_train + num_val + num_test:]

    return dataset, train_indices, val_indices, test_indices, pool_indices, num_classes, num_node_features


Using device: cuda


Indice alapján Data Loader

In [137]:
def create_data_loader(dataset, indices, batch_size=10, shuffle=True):
    subset = torch.utils.data.Subset(dataset, indices)
    return DataLoader(subset, batch_size=batch_size, shuffle=shuffle)

def update_indices(train_indices, pool_indices, selected_pool_indices):
    """
    Update training and pool indices after selecting some indices from the pool.
    
    Args:
        train_indices (list): Current list of training indices.
        pool_indices (list): Current list of pool indices.
        selected_pool_indices (list): Indices selected from the pool to be moved to training.
    
    Returns:
        tuple: Updated lists of train and pool indices.
    """
    # Ensure all are lists
    train_indices = list(train_indices)
    selected_pool_indices = list(selected_pool_indices)

    # Add selected indices to train indices
    new_train_indices = train_indices + selected_pool_indices

    # Remove the selected indices from pool indices
    new_pool_indices = [idx for idx in pool_indices if idx not in selected_pool_indices]

    return new_train_indices, new_pool_indices



In [138]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool


class GNN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes, hidden_channels=64):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.out = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        # Global mean pooling
        x = global_mean_pool(x, batch)  # Aggregate node features to graph features
        
        x = self.out(x)
        return F.log_softmax(x, dim=1)

Uniform Acquisition function

In [139]:
import numpy as np

def uniform(model, dataset, pool_indices, n_query, T=100, training=True):
    """
    Uniformly random selection of data points from the unlabeled pool.
    
    Args:
    pool_indices (list): List of indices available in the pool.
    n_query (int): Number of queries to make.
    
    Returns:
    list: Indices of the selected data points.
    """
    # Directly use the pool_indices to select data points
    selected_indices = np.random.choice(pool_indices, size=n_query, replace=False)
    
    return selected_indices.tolist()


Active learning loop

In [140]:
def active_learning_loop(model, dataset, train_indices, pool_indices, val_indices, test_indices, query_strategy, n_query=10, epochs=100):
    train_loader = create_data_loader(dataset, train_indices)
    val_loader = create_data_loader(dataset, val_indices, shuffle=False)
    test_loader = create_data_loader(dataset, test_indices, shuffle=False)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.NLLLoss()

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        num_batches = 0

        for batch in train_loader:
            batch.to(device)
            optimizer.zero_grad()
            out = model(batch)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            num_batches += 1
             

        if epoch % 10 == 0:
            val_acc = evaluate_model(model, val_loader)
            print(f'Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Val Accuracy: {val_acc:.4f}')
            selected_indices = query_strategy(model, dataset, pool_indices, n_query, T=100, training=True)
            print(f'Selected indices for training: {selected_indices}')
            train_indices, pool_indices = update_indices(train_indices, pool_indices, selected_indices)
            print(f'Updated train indices count: {len(train_indices)}, Pool indices count: {len(pool_indices)}')
            train_loader = create_data_loader(dataset, train_indices)  # Recreate the loader with updated indices

    test_acc = evaluate_model(model, test_loader)
    print(f'Final Test Accuracy: {test_acc:.4f}')


Eval

In [141]:

def evaluate_model(model, loader):
    correct = 0
    total = 0
    for batch in loader:
        batch.to(device) # Ensure labels are also moved
        out = model(batch)
        pred = out.max(dim=1)[1]
        correct += pred.eq(batch.y).sum().item()
        total += batch.num_graphs
    return correct / total

def update_loaders(train_loader, pool_loader, new_data_indices, pool_indices, train_indices):
    # Add new data indices to train indices
    for idx in new_data_indices:
        train_indices.append(pool_indices[idx])

    # Remove the selected indices from pool indices
    new_pool_indices = [idx for i, idx in enumerate(pool_indices) if i not in new_data_indices]

    # Update datasets and loaders
    train_loader.dataset.indices = train_indices
    pool_loader.dataset.indices = new_pool_indices

In [None]:
dataset, train_indices, val_indices, test_indices, pool_indices, num_classes, num_node_features = load_tu_data()
model = GNN(num_node_features, num_classes).to(device)
active_learning_loop(model, dataset, train_indices, pool_indices, val_indices, test_indices, uniform, epochs=200)

Epoch: 1, Train Loss: 49.4214, Val Accuracy: 0.6847
Selected indices for training: [205, 465, 1092, 472, 983, 587, 138, 544, 583, 326]
Updated train indices count: 677, Pool indices count: 214
Epoch: 11, Train Loss: 45.6071, Val Accuracy: 0.6757
Selected indices for training: [419, 834, 686, 844, 130, 407, 1056, 268, 594, 369]
Updated train indices count: 687, Pool indices count: 204
Epoch: 21, Train Loss: 44.9365, Val Accuracy: 0.6847
Selected indices for training: [107, 581, 744, 585, 169, 827, 247, 967, 227, 322]
Updated train indices count: 697, Pool indices count: 194
Epoch: 31, Train Loss: 44.6771, Val Accuracy: 0.7477
Selected indices for training: [1033, 710, 578, 696, 84, 33, 60, 224, 79, 1059]
Updated train indices count: 707, Pool indices count: 184
Epoch: 41, Train Loss: 45.0485, Val Accuracy: 0.7027
Selected indices for training: [968, 129, 476, 1057, 276, 541, 292, 662, 327, 939]
Updated train indices count: 717, Pool indices count: 174
Epoch: 51, Train Loss: 44.3233, Val

Max Entropy Acquisition function

In [None]:
def predictions_from_pool(model, dataset, pool_indices, T=100, training=True):
    """
    Run MC dropout prediction on model using graphs from the pool and return the output.
    """
    # Randomly select indices from the pool
    random_subset = np.random.choice(pool_indices, size=min(2000, len(pool_indices)), replace=False)
    
    # Fetch the actual graph data from the dataset
    subset_loader = DataLoader(dataset[random_subset.tolist()], batch_size=len(random_subset), shuffle=False)
    batch = next(iter(subset_loader))  # Load the batch
    
    # Perform prediction
    outputs = []
    with torch.no_grad():
        for _ in range(T):
            batch.to(device)
            model.train(training)  # Enable/disable dropout
            output = torch.softmax(model(batch), dim=-1)
            outputs.append(output.cpu().numpy())
    outputs = np.stack(outputs)
    print(outputs.shape)
    return outputs, random_subset

def shannon_entropy_function(model, dataset, pool_indices, T=100, E_H=False, training=True):
    """
    Compute the Shannon entropy and optionally E_H if needed for BALD.
    """
    outputs, random_subset = predictions_from_pool(model, dataset, pool_indices, T, training)
    pc = outputs.mean(axis=0)
    H = (-pc * np.log(pc + 1e-10)).sum(axis=-1)  # Prevent log(0)

    if E_H:
        E = -np.mean(np.sum(outputs * np.log(outputs + 1e-10), axis=-1), axis=0)
        return H, E, random_subset
    return H, random_subset

def max_entropy(model, dataset, pool_indices, n_query=10, T=100, training=True):
    """
    Choose pool points that maximize the predictive entropy.
    """
    acquisition, random_subset = shannon_entropy_function(model, dataset, pool_indices, T, training=training)
    idx = (-acquisition).argsort()[:n_query]
    query_idx = random_subset[idx]
    return query_idx


Training with Max Entropy

In [None]:
dataset, train_indices, val_indices, test_indices, pool_indices, num_classes, num_node_features = load_tu_data()
model = GNN(num_node_features, num_classes).to(device)
active_learning_loop(model, dataset, train_indices, pool_indices, val_indices, test_indices, max_entropy, epochs=200)