In [1]:
import torch
import numpy as np
from scipy.stats import spearmanr
import torch.nn as nn
import torch.optim as optim

In [2]:
import math

class GradientMaskOptimizer:
    def __init__(self, dimension, lambda_reg=0.001, lr=0.01, min_active_params=10, max_active_params=None,
                 device='cuda:3' if torch.cuda.is_available() else 'cpu'):
        """
        Initialize the gradient mask optimizer.

        Args:
            dimension: Dimensionality of the model parameters (d)
            lambda_reg: Regularization parameter for sparsity
            lr: Learning rate for optimizer
            min_active_params: Minimum number of parameters that must remain active
            max_active_params: Maximum number of parameters allowed (default: 2*min_active_params)
            device: Device to run computations on
        """
        self.device = device
        self.lambda_reg = lambda_reg
        self.min_active_params = min_active_params
        self.max_active_params = max_active_params if max_active_params is not None else 2 * min_active_params
        self.dimension = dimension

        # Initialize S with small random values for better optimization
        self.S = nn.Parameter(torch.randn(dimension, device=device) * 0.01)

        # Use Adam optimizer
        self.optimizer = optim.Adam([self.S], lr=lr)

    def sigmoid(self):
        """Return sigmoid(S) as the mask"""
        return torch.sigmoid(self.S)

    def compute_inner_products(self, test_grads, train_grads, apply_mask=False):
        """
        Compute inner products between test and training gradients, optionally with mask.

        Args:
            test_grads: Test gradients tensor of shape [n_test, dimension]
            train_grads: Training gradients tensor of shape [n_train, dimension]
            apply_mask: Whether to apply mask to gradients before computing inner products

        Returns:
            Tensor of inner products of shape [n_test, n_train]
        """
        if apply_mask:
            mask = self.sigmoid()
            # Apply mask to both test and train gradients
            test_grads = test_grads * mask
            train_grads = train_grads * mask

        # Compute all inner products at once using matrix multiplication
        # (n_test, dimension) × (dimension, n_train) -> (n_test, n_train)
        return torch.matmul(test_grads, train_grads.T)

    def correlation_loss(self, original_ips, masked_ips):
        """
        Compute average correlation across all test gradients.
        We maximize correlation by minimizing the negative correlation.

        Args:
            original_ips: Tensor of original inner products of shape [n_test, n_train]
            masked_ips: Tensor of masked inner products of shape [n_test, n_train]

        Returns:
            Average negative correlation across all test gradients
        """
        # Mean center both sets of inner products along train dimension
        orig_centered = original_ips - original_ips.mean(dim=1, keepdim=True)
        masked_centered = masked_ips - masked_ips.mean(dim=1, keepdim=True)

        # Compute correlation for each test gradient
        numerator = torch.sum(orig_centered * masked_centered, dim=1)
        denominator = torch.sqrt(torch.sum(orig_centered**2, dim=1) * torch.sum(masked_centered**2, dim=1) + 1e-8)
        correlations = numerator / denominator

        # Average correlation across all test gradients
        avg_correlation = correlations.mean()

        # Return negative correlation (to minimize)
        return -avg_correlation

    def sparsity_loss(self):
        """
        L1 regularization to encourage sparsity in the mask,
        with a safeguard to ensure a minimum number of active parameters
        """
        mask = self.sigmoid()

        # Count active parameters (above threshold)
        active_params = torch.sum(mask > 0.5).item()

        # If we're approaching minimum parameters, reduce regularization
        if active_params <= self.min_active_params * 2:
            # Dynamically reduce regularization when approaching minimum
            reduction_factor = max(0.1, active_params / (self.min_active_params * 2))
            effective_lambda = self.lambda_reg * reduction_factor
        else:
            effective_lambda = self.lambda_reg

        return effective_lambda * torch.sum(mask)

    def train_step(self, test_grads, train_grads):
        """
        Perform one optimization step

        Args:
            test_grads: Test gradients tensor of shape [n_test, dimension]
            train_grads: Training gradients tensor of shape [n_train, dimension]
        """
        self.optimizer.zero_grad()

        # Compute original and masked inner products
        original_ips = self.compute_inner_products(test_grads, train_grads, apply_mask=False)
        masked_ips = self.compute_inner_products(test_grads, train_grads, apply_mask=True)

        # Compute correlation loss (negative correlation to maximize it)
        corr_loss = self.correlation_loss(original_ips, masked_ips)

        # Compute sparsity loss
        sparse_loss = self.sparsity_loss()

        # Total loss: minimize -correlation + sparsity
        total_loss = corr_loss + sparse_loss

        # Compute gradients and update parameters
        total_loss.backward()
        self.optimizer.step()

        return {
            'total_loss': total_loss.item(),
            'correlation_loss': corr_loss.item(),
            'sparsity_loss': sparse_loss.item(),
            'mask_sparsity': (self.sigmoid() < 0.5).float().mean().item()
        }

    def train(self, train_grads, test_grads, num_epochs=500, log_every=50, correlation_threshold=0.7):
        """
        Train the mask optimizer for multiple epochs

        Args:
            train_grads: Training gradients tensor of shape [n_train, dimension] or list of gradients
            test_grads: Test gradients tensor of shape [n_test, dimension] or list of gradients
            num_epochs: Maximum number of training epochs
            log_every: Log progress every N epochs
            correlation_threshold: Stop training when correlation drops below this value
        """
        # Convert lists to tensors if necessary
        if isinstance(train_grads, list):
            train_grads = torch.stack(train_grads).to(self.device)
        elif isinstance(train_grads, torch.Tensor) and train_grads.ndim == 1:
            train_grads = train_grads.unsqueeze(0).to(self.device)
        else:
            train_grads = train_grads.to(self.device)

        if isinstance(test_grads, list):
            test_grads = torch.stack(test_grads).to(self.device)
        elif isinstance(test_grads, torch.Tensor) and test_grads.ndim == 1:
            test_grads = test_grads.unsqueeze(0).to(self.device)
        else:
            test_grads = test_grads.to(self.device)

        best_correlation = -float('inf')
        best_mask = None

        # Track masks that meet sparsity constraints
        candidate_masks = []

        for epoch in range(num_epochs):
            metrics = self.train_step(test_grads, train_grads)

            # Check for empty mask (safety check)
            with torch.no_grad():
                mask = self.sigmoid()
                active_count = torch.sum(mask > 0.5).item()

                # Force some parameters to be active if needed
                if active_count < self.min_active_params:
                    # Get the current values and find top-k
                    top_k_values, top_k_indices = torch.topk(mask, k=self.min_active_params)

                    # Create a new optimizer with re-initialized parameters
                    new_S_data = self.S.data.clone()
                    boost_amount = 5.0  # Large enough to ensure sigmoid is close to 1
                    new_S_data[top_k_indices] = boost_amount

                    # Recreate parameter and optimizer
                    self.S = nn.Parameter(new_S_data)
                    self.optimizer = optim.Adam([self.S], lr=self.optimizer.param_groups[0]['lr'])

                    print(f"  Warning: Forced {self.min_active_params} parameters to be active")

            # Evaluate and log progress
            if epoch % log_every == 0 or epoch == num_epochs - 1:
                print(f"Epoch {epoch}:")
                print(f"  Total Loss: {metrics['total_loss']:.4f}")
                print(f"  Correlation Loss: {metrics['correlation_loss']:.4f} (negative, lower is better)")
                print(f"  Sparsity Loss: {metrics['sparsity_loss']:.4f}")
                print(f"  Mask Sparsity: {metrics['mask_sparsity']:.4f}")

                # Evaluate current mask performance
                eval_metrics = self.evaluate_mask(test_grads, train_grads)
                print()

                # Check if this mask satisfies our constraints and is better
                correlation_value = -metrics['correlation_loss']

                # Add to candidates if within the min and max parameter count constraints
                if self.min_active_params <= active_count <= self.max_active_params:
                    candidate_masks.append({
                        'correlation': correlation_value,
                        'active_count': active_count,
                        'mask': self.S.data.clone(),
                        'epoch': epoch,
                        'avg_rank_correlation': eval_metrics.get('avg_rank_correlation', float('nan'))
                    })
                    print(f"  Added to candidate masks (active params: {active_count}, correlation: {correlation_value:.4f})")

                # Store best mask - we want highest correlation (most negative correlation_loss)
                # but now with sparsity constraint
                if (correlation_value > best_correlation and
                    self.min_active_params <= active_count <= self.max_active_params):
                    best_correlation = correlation_value
                    best_mask = self.S.data.clone()
                    print(f"  New best mask! (active params: {active_count}, correlation: {correlation_value:.4f})")

                # Early stopping based on correlation threshold
                avg_rank_correlation = eval_metrics.get('avg_rank_correlation', float('nan'))
                if not math.isnan(avg_rank_correlation) and avg_rank_correlation < correlation_threshold:
                    print(f"Early stopping at epoch {epoch} - correlation {avg_rank_correlation:.4f} below threshold {correlation_threshold:.4f}")
                    break

        # Select the best mask based on correlation if we found one that meets constraints
        if best_mask is not None:
            print("Using best mask found during training")
            self.S = nn.Parameter(best_mask)
        elif len(candidate_masks) > 0:
            # Find mask with highest correlation among candidates
            best_candidate = max(candidate_masks, key=lambda x: x['correlation'])
            print(f"Using best candidate mask from epoch {best_candidate['epoch']} "
                  f"(active params: {best_candidate['active_count']}, "
                  f"correlation: {best_candidate['correlation']:.4f})")
            self.S = nn.Parameter(best_candidate['mask'])
        else:
            print("Warning: No mask met the sparsity constraints. Using final mask.")

        # Final evaluation
        print("\nFinal mask evaluation:")
        eval_metrics = self.evaluate_mask(test_grads, train_grads)

        # Report if we met the correlation threshold
        if 'avg_rank_correlation' in eval_metrics and not math.isnan(eval_metrics['avg_rank_correlation']):
            print(f"Final correlation: {eval_metrics['avg_rank_correlation']:.4f}")
            if eval_metrics['avg_rank_correlation'] >= correlation_threshold:
                print(f"✓ Successfully maintained correlation above threshold ({correlation_threshold:.4f})")
            else:
                print(f"✗ Final correlation below threshold ({correlation_threshold:.4f})")

        return eval_metrics

    def evaluate_mask(self, test_grads, train_grads):
        """
        Evaluate the current mask's performance in preserving rankings

        Args:
            test_grads: Test gradients tensor of shape [n_test, dimension]
            train_grads: Training gradients tensor of shape [n_train, dimension]
        """
        with torch.no_grad():
            # Get mask stats
            mask = self.sigmoid()
            active_mask = (mask > 0.5).float()
            active_count = active_mask.sum().int().item()
            percent_active = active_mask.mean().item() * 100

            # Compute original and masked inner products
            original_ips = self.compute_inner_products(test_grads, train_grads, apply_mask=False)
            masked_ips = self.compute_inner_products(test_grads, train_grads, apply_mask=True)

            # Convert to CPU for Spearman calculation
            original_ips_np = original_ips.cpu().numpy()
            masked_ips_np = masked_ips.cpu().numpy()

            # Compute Spearman rank correlation for each test gradient
            all_correlations = []
            n_test = test_grads.shape[0]

            if active_count > 0:
                for i in range(n_test):
                    rank_corr, _ = spearmanr(original_ips_np[i], masked_ips_np[i])
                    all_correlations.append(rank_corr)

            # Compute average correlation
            avg_correlation = float('nan')
            if all_correlations:
                avg_correlation = sum(all_correlations) / len(all_correlations)
                print(f"  Average Spearman Rank Correlation: {avg_correlation:.4f}")

            print(f"  Mask: {active_count}/{mask.numel()} parameters active ({percent_active:.2f}%)")

            return {
                'avg_rank_correlation': avg_correlation,
                'active_count': active_count,
                'percent_active': percent_active
            }

    def get_important_indices(self, threshold=0.5, min_count=None):
        """
        Get indices of important parameters based on mask value

        Args:
            threshold: Value threshold for selecting parameters
            min_count: Minimum number of parameters to select (uses top-k if needed)
        """
        mask = self.sigmoid()
        indices = torch.where(mask > threshold)[0]

        # If we got fewer than min_count parameters, take the top-k instead
        if min_count is not None and len(indices) < min_count:
            print(f"Warning: Only {len(indices)} parameters above threshold. Selecting top-{min_count} instead.")
            values, top_indices = torch.topk(mask, k=min_count)
            indices = top_indices

        return indices

In [3]:
import os
import sys
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from GradComp.GPT2LMHeadModel import GCGPT2LMHeadModel
from _GradComp.layers.linear import GCLinear

checkpoint = "./checkpoints/wd=0.0_lr=5e-5/0"
model = GCGPT2LMHeadModel.from_pretrained(checkpoint).cuda("cuda:0")
layers = []
for module_id, (module_name, module) in enumerate(model.named_modules()):
    if isinstance(module, GCLinear):
        layers.append(module_name)

layers = layers[:-1]

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
LAMBDA_REG = 5
LEARNING_RATE = 1e-3
MIN_PARAMS = 4096
MAX_PARAMS = 5000
CORRELATION_THRESHOLD = 0.6
MAX_EPOCHS = 2000
LOG_INTERVAL = 100

# Define output directory
output_dir = f'./localize/mask_{MIN_PARAMS}'
os.makedirs(output_dir, exist_ok=True)

for i, layer in enumerate(layers):
    print(f"\n=== Processing Layer {layer} ({i + 1}/{len(layers)}) ===")

    # Load gradients
    train_grads_layer = torch.load(f'/scratch/pbb/Project/Sparse-Influence/train_grad_{layer}.pt')
    test_grads_layer = torch.load(f'/scratch/pbb/Project/Sparse-Influence/test_grad_{layer}.pt')

    # Get dimensions
    dimension = train_grads_layer[0].shape[0]
    num_train = len(train_grads_layer)
    num_test = len(test_grads_layer)

    print(f"Dimension: {dimension}, Training samples: {num_train}, Test samples: {num_test}")

    # Initialize optimizer
    optimizer = GradientMaskOptimizer(
        dimension=dimension,
        lambda_reg=LAMBDA_REG,
        lr=LEARNING_RATE,
        min_active_params=MIN_PARAMS,
        max_active_params=MAX_PARAMS
    )

    # Train the optimizer
    eval_metrics = optimizer.train(
        train_grads=train_grads_layer,
        test_grads=test_grads_layer,
        num_epochs=MAX_EPOCHS,
        log_every=LOG_INTERVAL,
        correlation_threshold=CORRELATION_THRESHOLD
    )

    # Get and save important indices
    important_indices = optimizer.get_important_indices(threshold=0.5)
    output_file = os.path.join(output_dir, f'{layer}.pt')
    torch.save(important_indices, output_file)

    sparsity = len(important_indices) / dimension * 100
    print(f"Found {len(important_indices)} important parameters ({sparsity:.2f}% of {dimension})")

    # Vectorized evaluation of reduced parameters on all test gradients
    if isinstance(test_grads_layer, list):
        test_grads_tensor = torch.stack(test_grads_layer)
    else:
        test_grads_tensor = test_grads_layer

    if isinstance(train_grads_layer, list):
        train_grads_tensor = torch.stack(train_grads_layer)
    else:
        train_grads_tensor = train_grads_layer

    # Create masks for the full gradients
    full_mask = torch.zeros(dimension, device=test_grads_tensor.device)
    full_mask[important_indices] = 1.0

    # Create masked versions (alternative to slicing, preserves the original dimension)
    masked_test_grads = test_grads_tensor * full_mask
    masked_train_grads = train_grads_tensor * full_mask

    # Compute correlations for each test gradient
    for i in range(len(test_grads_layer)):
        test_grad = test_grads_tensor[i]
        masked_test = masked_test_grads[i]

        # Compute original inner products using vectorized operations
        original_ips = torch.matmul(test_grad, train_grads_tensor.T).cpu().numpy()
        reduced_ips = torch.matmul(masked_test, train_grads_tensor.T).cpu().numpy()

        # Compute rank correlation
        try:
            rank_corr, _ = spearmanr(original_ips, reduced_ips)
            print(f"Test Gradient {i+1}: Spearman Rank Correlation with {len(important_indices)}/{dimension} parameters: {rank_corr:.4f}")
        except Exception as e:
            print(f"Test Gradient {i+1}: Could not compute correlation. Error: {str(e)}")

    print(f"=== Completed Layer {layer} ===\n")


=== Processing Layer transformer.h.0.attn.c_attn (1/48) ===
Dimension: 1771776, Training samples: 200, Test samples: 20
Epoch 0:
  Total Loss: 4429452.0000
  Correlation Loss: -0.9969 (negative, lower is better)
  Sparsity Loss: 4429453.0000
  Mask Sparsity: 0.5396
  Average Spearman Rank Correlation: 1.0000
  Mask: 815677/1771776 parameters active (46.04%)

Epoch 100:
  Total Loss: 2109419.0000
  Correlation Loss: -0.9962 (negative, lower is better)
  Sparsity Loss: 2109420.0000
  Mask Sparsity: 0.9977
  Average Spearman Rank Correlation: 0.9998
  Mask: 4096/1771776 parameters active (0.23%)

  Added to candidate masks (active params: 4096, correlation: 0.9962)
  New best mask! (active params: 4096, correlation: 0.9962)
Epoch 200:
  Total Loss: 1999872.7500
  Correlation Loss: -0.9952 (negative, lower is better)
  Sparsity Loss: 1999873.7500
  Mask Sparsity: 0.9977
  Average Spearman Rank Correlation: 0.9997
  Mask: 4096/1771776 parameters active (0.23%)

  Added to candidate masks (