**Adaptive pruning optimization and performance evaluation**

1. Introduce necessary dependencies

In [None]:
!pip install fvcore


Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6 (from fvcore)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting iopath>=0.1.7 (from fvcore)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting portalocker (from iopath>=0.1.7->fvcore)
  Downloading portalocker-3.0.0-py3-none-any.whl.metadata (8.5 kB)
Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Downloading portalocker-3.0.0-py3-none-any.whl (19 kB)
Building wheels for collected packages: fvcore, iopath
  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Created wheel for fvcore: filename=fvcore-0.1.5.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from fvcore.nn import FlopCountAnalysis, flop_count_table
import torch.profiler as profiler
from torch.utils.data import DataLoader, TensorDataset
from copy import deepcopy


2. Token Saliency computing module

In [None]:
class TokenSaliency(nn.Module):
    """
    Compute saliency scores for visual tokens based on their contribution.
    """
    def __init__(self, method="norm"):
        super(TokenSaliency, self).__init__()
        self.method = method

    def forward(self, tokens):
        """
        Args:
            tokens: Tensor of shape (B, N, D), where
                    B = Batch size,
                    N = Number of tokens,
                    D = Dimension of each token.
        Returns:
            saliency_scores: Tensor of shape (B, N), saliency scores for each token.
        """
        if self.method == "norm":
            saliency_scores = tokens.norm(dim=-1)  # Use L2 norm
        else:
            raise ValueError(f"Unsupported method: {self.method}")

        return saliency_scores


3. Adaptive Token pruning module

In [None]:
class AdaptiveTokenPruning(nn.Module):
    def __init__(self, saliency_threshold=0.5):
        super(AdaptiveTokenPruning, self).__init__()
        self.saliency_threshold = saliency_threshold

    def forward(self, x):
        """
        Compute token saliency and generate a pruning mask.
        """
        saliency_scores = self.compute_saliency(x)
        keep_tokens = saliency_scores > self.saliency_threshold
        return keep_tokens

    def compute_saliency(self, x):
        """
        Compute saliency scores (e.g., L2 norm across embedding dimensions).
        """
        saliency_scores = x.norm(p=2, dim=-1)  # Shape: (batch_size, seq_len)
        return saliency_scores



4. Pruned Transformer Encoder

In [None]:
class PrunedTransformerEncoder(nn.Module):
    """
    Transformer encoder with token pruning capability.
    """
    def __init__(self, encoder_layer, num_layers, saliency_threshold=0.5):
        super().__init__()
        self.layers = nn.ModuleList([deepcopy(encoder_layer) for _ in range(num_layers)])
        self.token_pruning = AdaptiveTokenPruning(saliency_threshold=saliency_threshold)

    def forward(self, src):
        """
        Forward pass with token pruning.
        Args:
            src: Input tensor of shape (batch_size, seq_len, d_model).
        Returns:
            Output tensor after pruning.
        """
        batch_size, seq_len, d_model = src.shape
        keep_tokens = torch.ones((batch_size, seq_len), device=src.device).bool()  # Initialize with all True

        for i, layer in enumerate(self.layers):
            # Calculate saliency scores
            saliency = self.token_pruning.compute_saliency(src)

            # Update keep_tokens
            new_keep_tokens = (saliency > self.token_pruning.saliency_threshold)
            keep_tokens = keep_tokens & new_keep_tokens  # Retain the accumulated crop state

            # Dynamically crop the input tensor
            pruned_src = []
            pruned_keep_tokens = []

            for batch_idx in range(batch_size):
                active_token_indices = keep_tokens[batch_idx].nonzero(as_tuple=True)[0]
                pruned_src.append(src[batch_idx, active_token_indices])
                pruned_keep_tokens.append(keep_tokens[batch_idx, active_token_indices])

            # Update src and keep_tokens with the trimmed tensor
            src = torch.nn.utils.rnn.pad_sequence(pruned_src, batch_first=True)
            keep_tokens = torch.nn.utils.rnn.pad_sequence(pruned_keep_tokens, batch_first=True)

            # Print debugging information
            print(f"Layer {i}: Active tokens per batch = {[len(t) for t in pruned_src]}")

            # Pass the clipped tensor to the next layer
            src = layer(src)

        return src



5. FLOPs evaluation tool

In [None]:
from fvcore.nn import FlopCountAnalysis, flop_count_table

def calculate_dynamic_flops(model, x, keep_tokens):
    """
    Calculate FLOPs dynamically based on active tokens.
    Args:
        model: The pruned Transformer model.
        x: Input tensor of shape (batch_size, seq_len, d_model).
        keep_tokens: Boolean tensor indicating active tokens for the pruned model.
    """
    #  Get the maximum number of active tokens
    active_tokens = keep_tokens.sum(dim=1).max().item()
    x = x[:, :active_tokens, :]  # Crop to active Token
    flops = FlopCountAnalysis(model, x)
    print(flop_count_table(flops))




6. Memory usage evaluation tool

In [None]:
def profile_memory_and_time_safe(model, input_tensor):
    """
    Profile memory and time for the given model and input.
    Args:
        model: PyTorch model to profile.
        input_tensor: Tensor input to pass through the model.
    """
    try:
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ],
            record_shapes=True,
            profile_memory=True,
            with_stack=False,  # Disable stack tracing to reduce possible conflicts
        ) as prof:
            model(input_tensor)  # Perform model forward propagation
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    except RuntimeError as e:
        print(f"Profiler failed: {e}")


7. Model accuracy evaluation

In [None]:
def evaluate_model_accuracy(model, train_data, train_labels, test_data, test_labels):
    """
    Train and evaluate model accuracy on a toy dataset.
    Args:
        model: PyTorch model to evaluate.
        train_data, train_labels, test_data, test_labels: Dataset tensors.
    """
    # Make sure the shape of the label is 1D
    train_labels = train_labels.squeeze()
    test_labels = test_labels.squeeze()

    model.train()

    # Dataset and DataLoader
    train_dataset = TensorDataset(train_data, train_labels)
    test_dataset = TensorDataset(test_data, test_labels)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=16)

    # Optimizer and Loss
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(5):
        for inputs, labels in train_loader:
            inputs, labels = inputs.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.mean(dim=1)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

    # Evaluation loop
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            outputs = outputs.mean(dim=1)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy:.2f}%")


In [None]:
from fvcore.nn import FlopCountAnalysis, flop_count_table

def evaluate_pruned_model(baseline_model, pruned_model, test_data):
    """
    Compare baseline and pruned models in terms of FLOPs and active token efficiency.
    Args:
        baseline_model: The baseline Transformer model.
        pruned_model: The pruned Transformer model.
        test_data: Sample input tensor for efficiency evaluation.
    """
    print("=== Baseline Model Efficiency ===")
    flops_baseline = FlopCountAnalysis(baseline_model, test_data)
    print(flop_count_table(flops_baseline))

    print("\n=== Pruned Model Efficiency ===")
    # Assuming the PrunedTransformerEncoder dynamically prunes tokens
    with torch.no_grad():
        pruned_outputs = pruned_model[0](test_data)  # Get the intermediate result of PrunedTransformer
        active_tokens = pruned_outputs.shape[1]  # The number of valid tokens remaining
        flops_pruned = FlopCountAnalysis(pruned_model, test_data[:, :active_tokens, :])
        print(flop_count_table(flops_pruned))


8. Prepare the data set

In [None]:
def prepare_data():
    """
    Prepare simulated toy dataset for training and testing.
    Returns:
        train_data, train_labels, test_data, test_labels
    """
    train_data = torch.rand(1000, 128, 512).cuda()  # 1000 samples, 128 tokens, 512 dimensions
    train_labels = torch.randint(0, 2, (1000,), dtype=torch.long).cuda()  # Make sure it's a 1D long integral tensor
    test_labels = torch.randint(0, 2, (200,), dtype=torch.long).cuda()
    test_data = torch.rand(200, 128, 512).cuda()  # 200 samples for testing
    return train_data, train_labels, test_data, test_labels


In [None]:
train_data, train_labels, test_data, test_labels = prepare_data()
print(train_data.shape, train_labels.shape)


torch.Size([1000, 128, 512]) torch.Size([1000])


9. Define the model

In [None]:
# Keep the SimpleClassifierHead class
class SimpleClassifierHead(nn.Module):
    """
    A simple classification head for transformer output.
    """
    def __init__(self, input_dim, num_classes):
        super(SimpleClassifierHead, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)


# TransformerEncoderLayerWithPruning class
class TransformerEncoderLayerWithPruning(nn.TransformerEncoderLayer):
    """
    A customized TransformerEncoderLayer that supports dynamic token skipping.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, src, src_mask=None, src_key_padding_mask=None, keep_tokens=None):
        """
        Args:
            src: Input tensor of shape (batch_size, seq_len, d_model).
            keep_tokens: Boolean tensor of shape (batch_size, seq_len).
        """
        if keep_tokens is not None:
            # Dynamically crop the tensor shape, keeping only tokens marked True
            batch_size, seq_len, d_model = src.shape
            active_indices = keep_tokens.nonzero(as_tuple=True)  # Gets the index of active tokens
            max_active_tokens = keep_tokens.sum(dim=1).max().item()  # Maximum number of active tokens
            pruned_src = torch.zeros(batch_size, max_active_tokens, d_model, device=src.device)

            for batch_idx in range(batch_size):
                active_token_indices = keep_tokens[batch_idx].nonzero(as_tuple=True)[0]
                pruned_src[batch_idx, :len(active_token_indices)] = src[batch_idx, active_token_indices]

            src = pruned_src  # Update to the clipped tensor

        # A forward method that passes the trimmed tensor to the parent class
        return super().forward(src, src_mask, src_key_padding_mask)





# create_models function
def create_models():
    """
    Create baseline and pruned Transformer models, each with a classification head.
    Returns:
        baseline_model, pruned_model
    """
    num_classes = 2  # dichotomy

    # Baseline model
    baseline_encoder = nn.TransformerEncoderLayer(d_model=512, nhead=8)
    baseline_transformer = nn.TransformerEncoder(baseline_encoder, num_layers=2).cuda()
    baseline_model = nn.Sequential(
        baseline_transformer,
        SimpleClassifierHead(input_dim=512, num_classes=num_classes).cuda()
    )

    # Pruned model
    pruned_encoder = TransformerEncoderLayerWithPruning(d_model=512, nhead=8)
    pruned_transformer = PrunedTransformerEncoder(pruned_encoder, num_layers=2, saliency_threshold=13.0).cuda()
    pruned_model = nn.Sequential(
        pruned_transformer,
        SimpleClassifierHead(input_dim=512, num_classes=num_classes).cuda()
    )

    return baseline_model, pruned_model



In [None]:
baseline_model, pruned_model = create_models()
print(baseline_model)
print(pruned_model)


Sequential(
  (0): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (1): SimpleClassifierHead(
    (fc): Linear(in_features=512, out_features=2, bias=True)
  )
)
Sequential(
  (0): PrunedTransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayerWithPruning(
        (self_attn): MultiheadAttention(
          (out_proj): NonD

In [None]:
input_tensor = torch.rand(16, 128, 512).cuda()


In [None]:
pruned_outputs = pruned_model[0](input_tensor)
print(f"Output shape after pruning: {pruned_outputs.shape}")


Layer 0: Active tokens per batch = [79, 80, 65, 79, 86, 76, 80, 70, 82, 79, 68, 71, 73, 78, 83, 69]
Layer 1: Active tokens per batch = [79, 80, 65, 79, 86, 76, 80, 70, 82, 79, 68, 71, 73, 78, 83, 69]
Output shape after pruning: torch.Size([16, 86, 512])


In [None]:
saliency_scores = pruned_model[0].token_pruning.compute_saliency(input_tensor)
print(f"Saliency scores range: {saliency_scores.min().item()} - {saliency_scores.max().item()}")


Saliency scores range: 12.195259094238281 - 13.956788063049316


In [None]:
keep_tokens = pruned_model[0].token_pruning.compute_saliency(input_tensor) > pruned_model[0].token_pruning.saliency_threshold
print(f"Keep tokens mask (sample batch): {keep_tokens[0].cpu().numpy()}")


Keep tokens mask (sample batch): [False  True False  True False False  True  True  True  True  True False
  True  True  True  True False  True  True  True False False  True  True
 False False False  True  True  True  True False  True  True False  True
 False  True False  True  True  True False False  True False  True  True
 False False  True  True  True False  True False  True False False False
  True  True  True False  True  True  True False  True  True  True  True
  True  True  True  True False  True  True False  True False  True False
 False False  True False False  True False  True False  True  True False
  True  True  True  True False  True  True  True False False False  True
  True  True  True False  True  True  True  True  True False False False
  True  True False  True False False  True  True]


10. Evaluate model accuracy

In [None]:
def compare_models_accuracy(baseline_model, pruned_model, train_data, train_labels, test_data, test_labels):
    """
    Compare accuracy of baseline and pruned models.
    """
    print("\n=== Baseline Model Accuracy ===")
    evaluate_model_accuracy(baseline_model, train_data, train_labels, test_data, test_labels)

    print("\n=== Pruned Model Accuracy ===")
    evaluate_model_accuracy(pruned_model, train_data, train_labels, test_data, test_labels)


In [None]:
compare_models_accuracy(baseline_model, pruned_model, train_data, train_labels, test_data, test_labels)



=== Baseline Model Accuracy ===
Accuracy: 47.50%

=== Pruned Model Accuracy ===
Layer 0: Active tokens per batch = [71, 84, 77, 70, 71, 82, 64, 80, 75, 74, 71, 76, 80, 77, 75, 82]
Layer 1: Active tokens per batch = [71, 84, 77, 70, 71, 82, 64, 80, 75, 74, 71, 76, 80, 77, 75, 82]
Layer 0: Active tokens per batch = [82, 75, 76, 71, 71, 75, 79, 76, 76, 77, 80, 79, 78, 71, 91, 80]
Layer 1: Active tokens per batch = [82, 75, 76, 71, 71, 75, 79, 76, 76, 77, 80, 79, 78, 71, 91, 80]
Layer 0: Active tokens per batch = [78, 77, 69, 72, 71, 70, 83, 74, 81, 77, 67, 73, 75, 71, 78, 73]
Layer 1: Active tokens per batch = [78, 77, 69, 72, 71, 70, 83, 74, 81, 77, 67, 73, 75, 71, 78, 73]
Layer 0: Active tokens per batch = [70, 68, 78, 78, 82, 88, 75, 69, 79, 74, 79, 70, 71, 84, 79, 67]
Layer 1: Active tokens per batch = [70, 68, 78, 78, 82, 88, 75, 69, 79, 74, 79, 70, 71, 84, 79, 67]
Layer 0: Active tokens per batch = [65, 84, 70, 79, 75, 76, 79, 72, 75, 74, 78, 81, 62, 82, 74, 75]
Layer 1: Active tok

11.FLOPs versus memory performance

In [None]:
from torch.profiler import profile, ProfilerActivity
from fvcore.nn import FlopCountAnalysis, flop_count_table

def calculate_dynamic_flops_and_profile(pruned_model, input_tensor):
    """
    Calculate dynamic FLOPs and memory usage for the pruned model.
    Args:
        pruned_model: Model with dynamic token pruning.
        input_tensor: Input tensor.
    """
    # Dynamic computing FLOPs
    print("\n=== Pruned Model ===")
    flops_pruned = FlopCountAnalysis(pruned_model, input_tensor)
    print(flop_count_table(flops_pruned))

    # Dynamic profile memory and time
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
        _ = pruned_model(input_tensor)
    print(prof.key_averages().table(sort_by="cuda_time_total"))

def compare_efficiency(baseline_model, pruned_model):
    """
    Compare FLOPs and memory usage for baseline and pruned models.
    Args:
        baseline_model, pruned_model: Models to compare.
    """
    input_tensor = torch.rand(16, 128, 512).cuda()  # Simulated input: batch size=16, tokens=128, dim=512

    # FLOPs and performance evaluation of Baseline Model
    print("\n=== Baseline Model ===")
    flops_baseline = FlopCountAnalysis(baseline_model, input_tensor)
    print(flop_count_table(flops_baseline))
    profile_memory_and_time_safe(baseline_model, input_tensor)

    # Pruned Model dynamic FLOPs and performance evaluation
    calculate_dynamic_flops_and_profile(pruned_model, input_tensor)



In [None]:
compare_efficiency(baseline_model, pruned_model)



=== Baseline Model ===
| module                  | #parameters or shape   | #flops    |
|:------------------------|:-----------------------|:----------|
| model                   | 6.306M                 | 12.908G   |
|  0.layers               |  6.305M                |  12.906G  |
|   0.layers.0            |   3.152M               |   6.453G  |
|    0.layers.0.self_attn |    1.051M              |    2.147G |
|    0.layers.0.linear1   |    1.051M              |    2.147G |
|    0.layers.0.linear2   |    1.049M              |    2.147G |
|    0.layers.0.norm1     |    1.024K              |    5.243M |
|    0.layers.0.norm2     |    1.024K              |    5.243M |
|   0.layers.1            |   3.152M               |   6.453G  |
|    0.layers.1.self_attn |    1.051M              |    2.147G |
|    0.layers.1.linear1   |    1.051M              |    2.147G |
|    0.layers.1.linear2   |    1.049M              |    2.147G |
|    0.layers.1.norm1     |    1.024K              |    5.243M |
|