# Active Learning and GraNd (Gradient Normed) Sampling

*Authors: Sahika Betul Yayli, MD*

## Introduction to Active Learning

Active Learning is a powerful technique in machine learning that helps improve model performance while reducing the need for excessive manual labeling. Instead of training a model on a fully labeled dataset, Active Learning allows the model to decide which data points should be labeled, focusing on the most informative samples.

This approach makes the labeling process more efficient by reducing the number of labels needed while still maintaining (or even improving) performance. It’s particularly useful in cases where labeling data is expensive or time-consuming, such as medical imaging, speech recognition, and other domains requiring expert annotation.

## The GraNd (Gradient Normed) Method

GraNd is a sampling technique used in active learning. It assigns a relevance score to each sample based on the average gradient of the loss. The method typically uses the gradients of the parameters in the penultimate layer of the network (just before the linear layer that yields the class probabilities).

The penultimate layer is chosen for several reasons:
1. It represents high-level features learned by the model.
2. Gradients at this layer are more informative about the model's learning process than the final layer.
3. The final layer often includes a softmax activation, which can mask the true uncertainty of the model.
4. Using the penultimate layer makes the method more generalizable across different model architectures.

Let's implement the GraNd method using PyTorch:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Example model definition (e.g., a simple CNN)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.fc1 = nn.Linear(32 * 26 * 26, 10)  # Example dimensions
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return self.softmax(x)

# Model and optimizer setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleCNN().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# Sample data
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28).to(device)  # Dummy MNIST-like data
targets = torch.randint(0, 10, (batch_size,)).to(device)

# GraNd score computation function
def compute_grand_scores(model, data, targets):
    model.eval()
    scores = []
    
    for i in range(data.size(0)):
        sample = data[i].unsqueeze(0)  # Process as individual sample
        target = targets[i].unsqueeze(0)
        
        # Forward pass and loss computation
        prediction = model(sample)
        loss = loss_fn(prediction, target)
        
        # Compute gradients (only for the penultimate layer)
        grads = torch.autograd.grad(loss, model.fc1.weight, retain_graph=True)[0]
        
        # Calculate gradient norm (e.g., L2 norm)
        grad_norm = grads.norm(2).item()
        scores.append(grad_norm)
    
    return scores

# Compute GraNd scores
grand_scores = compute_grand_scores(model, data, targets)

# Sort to select samples that could provide the most information
top_k_indices = sorted(range(len(grand_scores)), key=lambda i: grand_scores[i], reverse=True)[:10]
print("Indices of the most informative samples:", top_k_indices)


## Explanation of the Code

1. We define a simple CNN model with a convolutional layer and a fully connected layer.
2. The `compute_grand_scores` function calculates the GraNd score for each sample:
   - It computes the loss for each individual sample.
   - It calculates the gradients with respect to the weights of the penultimate layer (fc1).
   - It computes the L2 norm of these gradients as the GraNd score.
3. We then sort the samples based on their GraNd scores and select the top k most informative samples.

# Applying GraNd Sampling to nnU-Net

Below is a basic implementation of GraNd for nnU-Net:

In [None]:
import torch
import torch.nn.functional as F
import nibabel as nib
import numpy as np
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

# Load the nnU-Net model
predictor = nnUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=True,
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    allow_tqdm=True
)

predictor.initialize_from_trained_model_folder(
    'path/to/your/model/folder',
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth'
)

# Hook to capture penultimate layer activations
activations = {}
penultimate_layer = predictor.network.encoder[-1]  # Last encoder block

def get_activation(name):
    """Stores the output of the specified layer."""
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

penultimate_layer.register_forward_hook(get_activation("penultimate"))

# Function to load and preprocess 3D medical images
def load_nifti_as_tensor(file_path):
    """Loads a NIfTI image, normalizes it, and converts it to a PyTorch tensor."""
    nifti_img = nib.load(file_path)
    image = nifti_img.get_fdata()
    
    # Normalize (zero mean, unit variance)
    image = (image - np.mean(image)) / (np.std(image) + 1e-8)

    # Convert to PyTorch tensor (B, C, H, W, D)
    return torch.tensor(image, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Load sample images and labels 
data_files = ["path/to/image1_0000.nii.gz", "path/to/image2_0000.nii.gz"]
data = torch.cat([load_nifti_as_tensor(f) for f in data_files])

target_files = ["path/to/image1.nii.gz", "path/to/image1.nii.gz"]
targets = torch.cat([load_nifti_as_tensor(f) for f in target_files]).long()  # Convert to integer labels

# GraNd Score Computation
def compute_grand_scores(model, data, targets):
    """Computes GraNd scores by measuring gradient norms of the loss w.r.t. the penultimate layer."""
    model.eval()
    scores = []

    for i in range(data.size(0)):
        sample = data[i].unsqueeze(0)  # (1, C, H, W, D)
        target = targets[i].unsqueeze(0)  # (1, H, W, D)
        
        model(sample)  # Forward pass (stores activation via hook)

        # Compute loss using stored penultimate activations
        feature_map = activations["penultimate"]
        loss = F.cross_entropy(feature_map, target)

        # Compute gradient norm
        grads = torch.autograd.grad(loss, feature_map, retain_graph=True)[0]
        grad_norm = grads.norm(2).item()

        scores.append(grad_norm)
    
    return scores

# Compute GraNd scores
grand_scores = compute_grand_scores(predictor.network, data, targets)

# Select the top-k most informative samples
top_k_indices = sorted(range(len(grand_scores)), key=lambda i: grand_scores[i], reverse=True)[:10]
print("Indices of the most informative samples:", top_k_indices)


## Conclusion

The GraNd method provides a way to select the most informative samples for active learning. By focusing on the gradients of the penultimate layer, it captures the model's uncertainty and learning potential for each sample. This approach can significantly reduce the amount of labeled data needed for training while maintaining or even improving model performance.