# Stage 1: Infrastructure Setup

In [1]:
import torch
import torch.nn as nn
import math

### Stage 1: Infrastructure Setup


In [2]:
class DampedKFAC(nn.Module):
    # Assuming 'd,d' were placeholders for feature dimensions
    def __init__(self, in_features: int, out_features: int, ema_decay: float = 0.95, damping: tuple[float, float] = (1e-7, 1e-5)):
        super().__init__()
        self.ema_decay = ema_decay
        self.damping = damping # Tuple (damping_A, damping_B)
        
        # Using register_buffer for A_accum and B_accum
        self.register_buffer('A_accum', torch.zeros(in_features, in_features))
        self.register_buffer('B_accum', torch.zeros(out_features, out_features))
        self.A_initialized = False
        self.B_initialized = False

    def update(self, current_A_cov: torch.Tensor, current_B_cov: torch.Tensor = None):
        """
        Update A_accum or B_accum with new covariance.
        If current_B_cov is None, only A_accum is updated.
        If current_A_cov is None, only B_accum is updated.
        """
        with torch.no_grad():
            if current_A_cov is not None:
                if not self.A_initialized:
                    self.A_accum.copy_(current_A_cov)
                    self.A_initialized = True
                else:
                    self.A_accum.mul_(self.ema_decay).add_(current_A_cov, alpha=(1 - self.ema_decay))
            
            if current_B_cov is not None:
                if not self.B_initialized:
                    self.B_accum.copy_(current_B_cov)
                    self.B_initialized = True
                else:
                    self.B_accum.mul_(self.ema_decay).add_(current_B_cov, alpha=(1 - self.ema_decay))

    def get_factors(self) -> tuple[torch.Tensor, torch.Tensor]:
        """Returns the current A_accum and B_accum factors."""
        return self.A_accum, self.B_accum

    def get_inverse_factors(self) -> tuple[torch.Tensor, torch.Tensor]:
        """Computes and returns damped inverses (A_inv, B_inv)."""
        with torch.no_grad():
            device = self.A_accum.device
            
            A_damped = self.A_accum + self.damping[0] * torch.eye(self.A_accum.size(0), device=device)
            B_damped = self.B_accum + self.damping[1] * torch.eye(self.B_accum.size(0), device=device)
            
            try:
                A_inv = torch.linalg.pinv(A_damped)
            except Exception as e:
                # print(f"Error inverting A_damped: {e}. Returning identity.")
                A_inv = torch.eye(self.A_accum.size(0), device=device)
            
            try:
                B_inv = torch.linalg.pinv(B_damped)
            except Exception as e:
                # print(f"Error inverting B_damped: {e}. Returning identity.")
                B_inv = torch.eye(self.B_accum.size(0), device=device)
                
            return A_inv, B_inv

    def __repr__(self):
        return (f"DampedKFAC(in={self.A_accum.shape[0]}, out={self.B_accum.shape[0]}, "
                f"ema_decay={self.ema_decay}, damping={self.damping}, "
                f"A_init={self.A_initialized}, B_init={self.B_initialized})")

In [3]:
# 1. Curvature Estimation Framework Implementation
# 1.1 K-FAC Layer Specialization

# Stage 1: Infrastructure Setup

class KFACLayer(nn.Module):
    """
    Handles K-FAC's activation and gradient covariance estimation for a given module.
    Uses an instance of DampedKFAC to manage EMA updates and factor inversions.
    """
    def __init__(self, module: nn.Linear, layer_type: str, ema_decay: float = 0.95, kfac_damping: tuple[float, float] = (1e-7, 1e-5)):
        super().__init__()
        self.layer_type = layer_type
        # ema_decay and kfac_damping are now for the DampedKFAC instance
        
        try:
            device = next(module.parameters()).device
        except StopIteration: 
            device = torch.device("cpu")

        in_features = module.in_features
        out_features = module.out_features

        # Instantiate DampedKFAC
        # Make sure DampedKFAC is defined before this class is used.
        self.kfac_manager = DampedKFAC(in_features, out_features, ema_decay, kfac_damping)
        self.kfac_manager.to(device)


    def forward_hook(self, module: nn.Module, input_data: tuple[torch.Tensor], output_data: torch.Tensor):
        """
        Hook to compute activation covariance (A_factor) using E[aa^T] estimator.
        input_data is a tuple, input_data[0] is the actual input tensor.
        Uses DampedKFAC to update A_accum.
        Incorporates suggested reshaping for 3D/4D inputs.
        """
        x = input_data[0].detach()
        
        # Suggested reshaping
        if x.ndim == 3:  # (batch_size, seq_len, dim)
            act_reshaped = x.reshape(-1, x.shape[-1])  # (batch_size*seq_len, dim)
        elif x.ndim == 4:  # Conv2d inputs (batch_size, channels, H, W)
            # Reshape to (batch_size*H*W, channels) if module is Conv2d-like
            # For nn.Linear, this path shouldn't be hit if input is already (B, Cin)
            # This part needs care depending on whether KFACLayer is attached to Conv2d
            # Original KFACLayer was for nn.Linear. If KFACLayer is generalized,
            # input features for Conv2d (e.g. unfolded patches) would be different.
            # For now, assuming if it's Linear, ndim won't be 4 unless an error.
            # If it IS a Conv2D's activations flattened for a subsequent Linear, this is complex.
            act_reshaped = x.reshape(-1, x.shape[1]) # (batch_size*H*W, channels) 
                                                   # This assumes x.shape[1] is 'in_features'
        elif x.ndim > 2: # General case for >2D not explicitly handled by 3D/4D
            act_reshaped = x.reshape(-1, x.shape[-1])
        else: # ndim <= 2
            act_reshaped = x 
        
        if act_reshaped.shape[0] == 0: return
        
        current_A_factor_dim = self.kfac_manager.A_accum.shape[0]
        if act_reshaped.shape[1] != current_A_factor_dim:
            print(f"Warning: Activation dimension mismatch for KFACLayer {self.layer_type}. Expected {current_A_factor_dim}, got {act_reshaped.shape[1]}. Skipping A_factor update.")
            return

        # Covariance of activations: E[aa^T] estimated by (act.T @ act) / N
        # Snippet: cov = (x.T @ x) / x.size(0)
        # Using act_reshaped for x
        cov_a = (act_reshaped.T @ act_reshaped) / act_reshaped.shape[0]
        
        # Update using DampedKFAC. Current_B_cov is None as this is forward pass.
        self.kfac_manager.update(current_A_cov=cov_a, current_B_cov=None)

    def backward_hook(self, module: nn.Module, grad_input: tuple[torch.Tensor], grad_output_data: tuple[torch.Tensor]):
        """
        Hook to compute gradient covariance (B_factor) using E[gg^T] estimator.
        grad_output_data is a tuple, grad_output_data[0] is the gradient w.r.t. module's output.
        Uses DampedKFAC to update B_accum.
        """
        g = grad_output_data[0].detach()
        
        # Reshaping for gradients (similar to activations)
        if g.ndim > 2:
            grad_reshaped = g.reshape(-1, g.shape[-1])
        else:
            grad_reshaped = g

        if grad_reshaped.shape[0] == 0: return
        
        current_B_factor_dim = self.kfac_manager.B_accum.shape[0]
        if grad_reshaped.shape[1] != current_B_factor_dim:
            print(f"Warning: Gradient dimension mismatch for KFACLayer {self.layer_type}. Expected {current_B_factor_dim}, got {grad_reshaped.shape[1]}. Skipping B_factor update.")
            return

        # Covariance of gradients: E[gg^T] estimated by (grad.T @ grad) / N
        cov_g = (grad_reshaped.T @ grad_reshaped) / grad_reshaped.shape[0]

        # Update using DampedKFAC. Current_A_cov is None as this is backward pass.
        self.kfac_manager.update(current_A_cov=None, current_B_cov=cov_g)

    # Expose factors for external use if needed (e.g. by FisherEigen or other parts)
    @property
    def A_factor(self):
        return self.kfac_manager.A_accum

    @property
    def B_factor(self):
        return self.kfac_manager.B_accum

    def get_inverse_factors(self):
        return self.kfac_manager.get_inverse_factors()

In [4]:
# 1.2 Transformer Layer Instrumentation

def instrument_layer(transformer_block: nn.Module) -> list[KFACLayer]:
    """
    Instruments specific layers in a transformer_block with KFACLayer handlers.
    Assumes transformer_block has 'attn.q_proj', 'attn.k_proj', 'attn.v_proj', and 'mlp' attributes
    which are nn.Linear layers.
    Returns a list of created KFACLayer handlers.
    """
    kfac_handlers = []
    
    # Helper to instrument a single linear layer
    def _instrument_linear(module: nn.Linear, name: str, handlers_list: list):
        if isinstance(module, nn.Linear):
            if not hasattr(module, 'kfac_handler'): # Avoid double instrumentation
                # print(f"Instrumenting layer {name} for KFAC.")
                # Pass ema_decay and potentially kfac_damping if KFACLayer expects it
                # For this example, assume KFACLayer default ema_decay (0.95) is fine,
                # and default kfac_damping in KFACLayer is also fine.
                # If specific values are needed here, they should be passed.
                kfac_layer_instance = KFACLayer(module, name) # Uses KFACLayer defaults
                module.register_forward_hook(kfac_layer_instance.forward_hook)
                module.register_full_backward_hook(kfac_layer_instance.backward_hook)
                module.kfac_handler = kfac_layer_instance 
                handlers_list.append(kfac_layer_instance)
        else:
            print(f"Warning: Expected nn.Linear for {name}, got {type(module)}. Skipping KFAC instrumentation.")

    # Self-attention projections
    if hasattr(transformer_block, 'attn'):
        attn_module = transformer_block.attn
        if hasattr(attn_module, 'q_proj'):
            _instrument_linear(attn_module.q_proj, f"{transformer_block.__class__.__name__}.attn.q_proj", kfac_handlers)
        if hasattr(attn_module, 'k_proj'):
            _instrument_linear(attn_module.k_proj, f"{transformer_block.__class__.__name__}.attn.k_proj", kfac_handlers)
        if hasattr(attn_module, 'v_proj'):
            _instrument_linear(attn_module.v_proj, f"{transformer_block.__class__.__name__}.attn.v_proj", kfac_handlers)
        # Optional: Output projection (often included in KFAC for attention)
        # if hasattr(attn_module, 'o_proj'):
        #     _instrument_linear(attn_module.o_proj, f"{transformer_block.__class__.__name__}.attn.o_proj", kfac_handlers)
    else:
        print(f"Warning: Transformer block {transformer_block.__class__.__name__} has no 'attn' attribute.")
        
    # FFN matrix (assuming transformer_block.mlp is an nn.Linear layer)
    if hasattr(transformer_block, 'mlp'):
        _instrument_linear(transformer_block.mlp, f"{transformer_block.__class__.__name__}.mlp", kfac_handlers)
    else:
         print(f"Warning: Transformer block {transformer_block.__class__.__name__} has no 'mlp' attribute for FFN KFAC.")
         
    return kfac_handlers

### 2. Low-Rank Adaptation Core


In [5]:
# 2.1 Rank-Flexible LoRA Module 

class FlexLoRA(nn.Module):
    """
    Implements Low-Rank Adaptation (LoRA) for an nn.Linear layer.
    The base layer's weights are frozen, and LoRA matrices (A and B) are trained.
    Includes attributes for dynamic rank adjustment.
    """
    def __init__(self, base_layer: nn.Linear, initial_rank: int = 8, min_rank: int = 4, max_rank: int = 16):
        super().__init__()
        if not isinstance(base_layer, nn.Linear):
            raise ValueError(f"FlexLoRA currently only supports nn.Linear, got {type(base_layer)}")

        self.base_layer = base_layer
        self.initial_rank = initial_rank
        self.current_rank = initial_rank # Initialize current_rank with initial_rank
        self.min_rank = min_rank
        self.max_rank = max_rank
        
        self.base_layer.weight.requires_grad = False
        if self.base_layer.bias is not None:
            self.base_layer.bias.requires_grad = False

        device = base_layer.weight.device
        # LoRA parameters dimensions are based on current_rank (which is initial_rank at creation)
        self.lora_A = nn.Parameter(torch.randn(base_layer.in_features, self.current_rank, device=device))
        self.lora_B = nn.Parameter(torch.zeros(self.current_rank, base_layer.out_features, device=device))
        
        self.bypass = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_out = self.base_layer(x)
        
        if self.current_rank > 0: # Check current_rank for LoRA path
            lora_adapt = (x @ self.lora_A) @ self.lora_B
            return self.bypass(base_out) + lora_adapt
        return base_out
    
    def __repr__(self):
        return (f"{self.__class__.__name__}(initial_rank={self.initial_rank}, current_rank={self.current_rank}, "
                f"min_rank={self.min_rank}, max_rank={self.max_rank}, "
                f"base_layer={self.base_layer.__class__.__name__}({self.base_layer.in_features}x{self.base_layer.out_features}))")

In [6]:
# 2.2 Layer Modification Protocol

def lora_inject(model: nn.Module, layers_to_adapt_substrings: list[str], 
                rank_range: tuple[int, int] = (8, 64)) -> nn.Module:
    """
    Replaces specified nn.Linear layers in the model with FlexLoRA wrappers.
    Args:
        model: The model to modify.
        layers_to_adapt_substrings: List of substrings. If a module's name contains
                                    any of these, it will be adapted.
        rank_range: Tuple of (min_rank, max_rank). min_rank is used for initial_rank.
    Returns:
        The modified model.
    """
    min_r, max_r = rank_range
    # Use the lower bound of the rank range as the initial rank for LoRA
    # and pass min_r, max_r to FlexLoRA for its own min_rank, max_rank settings.
    initial_r = min_r 

    for name, module in model.named_modules():
        parent_name_parts = name.split('.')[:-1]
        child_name = name.split('.')[-1]
        
        parent_module = model
        if parent_name_parts:
            try:
                parent_module = model.get_submodule(".".join(parent_name_parts))
            except AttributeError:
                continue

        actual_module_to_check = getattr(parent_module, child_name, None)

        if isinstance(actual_module_to_check, nn.Linear):
            if any(sub in name for sub in layers_to_adapt_substrings):
                if isinstance(actual_module_to_check, FlexLoRA):
                    continue

                original_layer = actual_module_to_check
                
                lora_layer = FlexLoRA(original_layer, 
                                        initial_rank=initial_r, 
                                        min_rank=min_r, 
                                        max_rank=max_r)
                setattr(parent_module, child_name, lora_layer)
    return model

### 3. Neural Reprojection Engine


In [7]:
# 3.1 Eigen Decomposition Module

class FisherEigen(nn.Module):
    """
    Handles Fisher block construction (via Kronecker product of KFAC's A and B factors)
    and its eigendecomposition using LOBPCG.
    """
    def __init__(self):
        super().__init__()
        # The Fisher block buffer will be created dynamically in update_fisher
        self.register_buffer('fisher_block', torch.empty(0), persistent=False)

    def update_fisher(self, A: torch.Tensor, B: torch.Tensor):
        """
        Updates the Fisher block using A and B factors.
        Fisher ~ kron(B, A)
        Assumes A and B are symmetric, or uses (X + X.T)/2 for safety.
        """
        # Ensure factors are symmetric (KFAC factors should be, but enforce for safety)
        A_sym = (A + A.T) / 2 if A.numel() > 0 else A
        B_sym = (B + B.T) / 2 if B.numel() > 0 else B
        
        if A_sym.numel() == 0 or B_sym.numel() == 0:
            # print("Warning: A or B factor is empty. Fisher block will be empty.")
            self.fisher_block = torch.empty(0, device=A.device if A.numel() > 0 else (B.device if B.numel() > 0 else 'cpu'))
            return

        # Explicit Kronecker product. This can be very memory intensive.
        self.fisher_block = torch.kron(B_sym, A_sym)

    def decompose(self, k: float = 0.2) -> tuple[torch.Tensor | None, torch.Tensor | None]:
        """
        Decomposes the computed Fisher block to find top-k eigenvectors using LOBPCG.
        Args:
            k: Fraction of total eigenvectors to return (corresponding to largest eigenvalues).
        Returns:
            Tuple of (eigenvalues, eigenvectors). Eigenvectors are column vectors.
            Returns (None, None) if decomposition is not possible or fails.
        """
        if self.fisher_block is None or self.fisher_block.numel() == 0:
            # print("Fisher block not computed or is empty. Cannot decompose.")
            return None, None

        dim = self.fisher_block.shape[0]
        if dim == 0:
            return None, None

        # k for lobpcg is the number of eigenpairs to find
        num_eigenvectors_to_find = int(dim * k)
        # Ensure k_lobpcg is valid: 0 < k_lobpcg < dim for non-trivial cases, k_lobpcg <= dim
        num_eigenvectors_to_find = max(1, min(num_eigenvectors_to_find, dim -1 if dim > 1 else 1))
        if dim == 1 and num_eigenvectors_to_find > 1 : num_eigenvectors_to_find = 1 # Max 1 for 1x1 matrix

        if num_eigenvectors_to_find <= 0:
             return None, None
        
        try:
            # LOBPCG for largest eigenvalues.
            # It requires the matrix to be symmetric (which self.fisher_block should be by construction).
            eigenvalues, eigenvectors = torch.lobpcg(
                A=self.fisher_block, # The matrix A for which to solve A*X = lambda*X*B (B is identity here)
                k=num_eigenvectors_to_find,
                largest=True, # Get largest eigenvalues
                method="ortho" # Ensures eigenvectors are orthogonalized
            )
            # lobpcg returns eigenvalues in ascending order if largest=False, 
            # and typically in no specific order or descending for largest=True.
            # We want largest, so let's sort to be sure.
            sorted_indices = torch.argsort(eigenvalues, descending=True)
            eigenvalues = eigenvalues[sorted_indices]
            eigenvectors = eigenvectors[:, sorted_indices]
            
            return eigenvalues, eigenvectors

        except Exception as e:
            # Fallback to eigh if LOBPCG fails (e.g. due to matrix properties or k)
            # print(f"LOBPCG decomposition failed for a {dim}x{dim} matrix (k={num_eigenvectors_to_find}): {e}. Trying eigh.")
            try:
                eigenvalues_all, eigenvectors_all = torch.linalg.eigh(self.fisher_block)
                top_eigenvalues = eigenvalues_all[-num_eigenvectors_to_find:]
                top_eigenvectors = eigenvectors_all[:, -num_eigenvectors_to_find:]
                # Return in descending order of eigenvalue magnitude
                return top_eigenvalues.flip(dims=[0]), top_eigenvectors.fliplr()
            except Exception as e_eigh:
                print(f"LOBPCG and eigh decomposition failed for a {dim}x{dim} matrix (k={num_eigenvectors_to_find}): {e_eigh}")
                return None, None

In [8]:
# 3.2 Subspace Preservation System

class SubspaceBuffer(nn.Module):
    """
    Stores and manages projection masks and eigenvectors for multiple blocks/layers.
    Uses a fixed buffer_size for eigenvector storage dimensions.
    """
    def __init__(self, num_blocks: int, buffer_size: int = 1024, device: str ='cpu'):
        super().__init__()
        self.num_blocks = num_blocks
        self.buffer_size = buffer_size # Acts as max dimension and max number of vectors
        self.device = torch.device(device)

        self.subspace_module_list = nn.ModuleList() # MODIFIED HERE
        for _ in range(num_blocks):
            # Each buffer entry is a BufferDict
            buffer_entry = nn.Module() # Using nn.Module to host buffers for easy device moving
            buffer_entry.register_buffer(
                'eigen_vectors', 
                torch.zeros(buffer_size, buffer_size, device=self.device) # (max_feature_dim, max_num_eigenvectors)
            )
            buffer_entry.register_buffer(
                'projection_mask', # Indicates how many of the 'buffer_size' eigenvector slots are active
                torch.zeros(buffer_size, device=self.device, dtype=torch.float32) # Using float for 0/1 values
            )
            self.subspace_module_list.append(buffer_entry) # MODIFIED HERE
        
    def update_buffer(self, block_idx: int, new_eigenvectors: torch.Tensor):
        """
        Updates the buffer for a specific block with new eigenvectors.
        Args:
            block_idx: Index of the layer/block.
            new_eigenvectors: Tensor of shape (full_vector_dim, num_new_vecs).
                              Assumes full_vector_dim <= buffer_size and num_new_vecs <= buffer_size.
        """
        if not (0 <= block_idx < self.num_blocks):
            raise IndexError(f"block_idx {block_idx} out of range for {self.num_blocks} buffers.")

        current_buffer_entry = self.subspace_module_list[block_idx] # MODIFIED HERE
        
        if new_eigenvectors is None or new_eigenvectors.numel() == 0:
            # print(f"Warning: No new eigenvectors provided for block {block_idx}. Buffer not updated.")
            current_buffer_entry.projection_mask.fill_(0.) # Clear mask if no new vectors
            return

        actual_feature_dim, num_new_vecs = new_eigenvectors.shape

        if actual_feature_dim > self.buffer_size:
            # print(f"Warning: Eigenvector feature dimension ({actual_feature_dim}) for block {block_idx} "
            #       f"exceeds buffer_size ({self.buffer_size}). Truncating features.")
            new_eigenvectors = new_eigenvectors[:self.buffer_size, :]
            actual_feature_dim = self.buffer_size
        
        slots_to_fill = min(num_new_vecs, self.buffer_size)
        
        # Clear old values and set new ones
        current_buffer_entry.eigen_vectors.fill_(0.)
        current_buffer_entry.projection_mask.fill_(0.)

        if slots_to_fill > 0:
            current_buffer_entry.eigen_vectors[:actual_feature_dim, :slots_to_fill] = new_eigenvectors[:, :slots_to_fill]
            current_buffer_entry.projection_mask[:slots_to_fill] = 1.0 # Mark as active

    def get_subspace(self, block_idx: int) -> torch.Tensor | None:
        """
        Retrieves the active eigenvectors for a given block.
        Returns:
            Tensor of shape (feature_dim_of_stored_vecs, num_active_vecs), or None.
            Feature_dim_of_stored_vecs could be up to buffer_size.
        """
        if not (0 <= block_idx < self.num_blocks):
            raise IndexError(f"block_idx {block_idx} out of range for {self.num_blocks} buffers.")
        
        current_buffer_entry = self.subspace_module_list[block_idx] # MODIFIED HERE
        active_indices = current_buffer_entry.projection_mask.bool() # Convert 0/1 to boolean mask
        
        if not active_indices.any():
            return None # No active eigenvectors

        # Retrieve only up to the point where features might have been stored.
        # This assumes eigenvectors were stored contiguously from feature dim 0.
        # A more robust way would be to also store the actual_feature_dim, but this matches the buffer_size logic.
        active_vectors = current_buffer_entry.eigen_vectors[:, active_indices]
        return active_vectors

In [9]:
# --- Dummy Model for Illustration (adapted for instrument_layer) ---
class AttentionModule(nn.Module): # Helper module for attention projections
    def __init__(self, dim):
        super().__init__()
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.o_proj = nn.Linear(dim, dim) # Output projection

    def forward(self, q, k, v): # Simplified forward
        return self.o_proj(self.q_proj(q) + self.k_proj(k) + self.v_proj(v))


class SimpleTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4): # n_heads not used in this simplified version
        super().__init__()
        self.attn = AttentionModule(dim) # Contains q_proj, k_proj, v_proj
        
        # FFN layers
        self.mlp_fc1 = nn.Linear(dim, dim * mlp_ratio)
        self.mlp_gelu = nn.GELU()
        self.mlp_fc2 = nn.Linear(dim * mlp_ratio, dim)
        
        # Expose one FFN layer as 'self.mlp' for instrument_layer as per its current design
        # This assumes KFAC is applied to the first linear layer of the MLP.
        self.mlp = self.mlp_fc1 
    
    def forward(self, x):
        # Simplified forward pass
        # Dummy attention: uses x for q, k, v.
        attn_out = self.attn(x, x, x) 
        
        # MLP path
        mlp_hidden = self.mlp_gelu(self.mlp(attn_out)) # self.mlp is mlp_fc1
        mlp_out = self.mlp_fc2(mlp_hidden)
        return attn_out + mlp_out # Residual connection

class DummyModel(nn.Module):
    def __init__(self, num_blocks=2, dim=64, n_heads=4):
        super().__init__()
        self.blocks = nn.ModuleList([SimpleTransformerBlock(dim, n_heads) for _ in range(num_blocks)])
        self.out_head = nn.Linear(dim, 10) # Example output

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.out_head(x)

In [10]:
# --- KFAC Example (using new KFACLayer and instrument_layer) ---
print("\\n--- KFAC Example ---")
model_kfac = DummyModel(num_blocks=1, dim=32) # Using 1 block for simplicity

all_kfac_handlers = []
for i, block in enumerate(model_kfac.blocks):
    print(f"Instrumenting KFAC for block {i}...")
    block_kfac_handlers = instrument_layer(block) # instrument_layer targets one block
    all_kfac_handlers.extend(block_kfac_handlers)

print(f"Instrumented {len(all_kfac_handlers)} layers in total for KFAC.")

# Simulate a forward/backward pass
dummy_data = torch.randn(2, 5, 32) # Batch, Seq, Dim
optimizer_kfac = torch.optim.Adam(model_kfac.parameters(), lr=1e-3)

# First pass (initializes factors if using mean_vector outer product)
output_kfac = model_kfac(dummy_data)
loss_kfac = output_kfac.mean()
loss_kfac.backward()
optimizer_kfac.step()
optimizer_kfac.zero_grad()

# Second pass to see updated factors
output_kfac_2 = model_kfac(dummy_data)
loss_kfac_2 = output_kfac_2.mean()
loss_kfac_2.backward()
optimizer_kfac.step()


if all_kfac_handlers:
    first_handler = all_kfac_handlers[0]
    print(f"KFAC A_factor for '{first_handler.layer_type}' (sample):\\n", first_handler.A_factor[:min(5,first_handler.A_factor.shape[0]),:min(5,first_handler.A_factor.shape[1])])
    print(f"KFAC B_factor for '{first_handler.layer_type}' (sample):\\n", first_handler.B_factor[:min(5,first_handler.B_factor.shape[0]),:min(5,first_handler.B_factor.shape[1])])
else:
    print("No KFAC handlers created.")

\n--- KFAC Example ---
Instrumenting KFAC for block 0...
Instrumented 4 layers in total for KFAC.
KFAC A_factor for 'SimpleTransformerBlock.attn.q_proj' (sample):\n tensor([[ 0.4425, -0.1333,  0.1536, -0.0331,  0.4627],
        [-0.1333,  1.3624, -0.3087, -0.0966, -0.4073],
        [ 0.1536, -0.3087,  1.2780, -0.1461, -0.0248],
        [-0.0331, -0.0966, -0.1461,  0.6619, -0.2659],
        [ 0.4627, -0.4073, -0.0248, -0.2659,  1.8390]])
KFAC B_factor for 'SimpleTransformerBlock.attn.q_proj' (sample):\n tensor([[ 4.8930e-07,  5.4784e-07, -1.1847e-06,  1.3906e-06, -4.3782e-07],
        [ 5.4784e-07,  6.2973e-07, -1.3418e-06,  1.5643e-06, -4.9425e-07],
        [-1.1847e-06, -1.3418e-06,  2.9071e-06, -3.3941e-06,  1.0758e-06],
        [ 1.3906e-06,  1.5643e-06, -3.3941e-06,  3.9753e-06, -1.2569e-06],
        [-4.3782e-07, -4.9425e-07,  1.0758e-06, -1.2569e-06,  4.0789e-07]])


In [11]:
# --- LoRA Example (using new FlexLoRA and lora_inject) ---
print("\\n--- LoRA Example ---")
model_lora = DummyModel(num_blocks=1, dim=32)
# Adapt layers containing 'attn' or 'mlp' in their name.
# Note: 'mlp' in SimpleTransformerBlock is self.mlp_fc1. If we want to adapt both mlp_fc1 and mlp_fc2,
# the substring matching needs to be more specific or lora_inject needs to traverse deeper.
# For this example, 'blocks.0.mlp' will match the first FFN layer if 'mlp' is in substrings.
# 'blocks.0.attn.q_proj' will match if 'attn' or 'q_proj' are in substrings.
lora_inject(model_lora, layers_to_adapt_substrings=['attn.q_proj', 'attn.k_proj', 'attn.v_proj', 'mlp'], rank_range=(4, 16))
print("Model with LoRA injected:")
# print(model_lora) # To see FlexLoRA modules, can be verbose

trainable_params = sum(p.numel() for p in model_lora.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model_lora.parameters())
print(f"Trainable params with LoRA: {trainable_params}, Total params: {total_params}")

output_lora = model_lora(dummy_data) # dummy_data from KFAC example
loss_lora = output_lora.mean()
optimizer_lora = torch.optim.Adam(filter(lambda p: p.requires_grad, model_lora.parameters()), lr=1e-3)
optimizer_lora.zero_grad()
loss_lora.backward()
optimizer_lora.step()
print("LoRA model training step completed.")

\n--- LoRA Example ---
Model with LoRA injected:
Trainable params with LoRA: 3434, Total params: 14954
LoRA model training step completed.


In [12]:
# --- FisherEigen & SubspaceBuffer Example ---
print("\\n--- FisherEigen & SubspaceBuffer Example ---")
if all_kfac_handlers: # Requires KFAC to have run and populated factors
    # Example: use the first KFAC handler's factors
    kfac_A = all_kfac_handlers[0].A_factor
    kfac_B = all_kfac_handlers[0].B_factor
    
    if torch.all(kfac_A == 0) or torch.all(kfac_B == 0):
        print("Warning: KFAC factors from the first handler are zero. Synthesizing dummy factors for FisherEigen test.")
        # Synthesize dummy factors if they are zero (e.g., if only one pass done or specific KFAC math leads to it)
        dim_A = kfac_A.shape[0] if kfac_A.numel() > 0 else 32
        dim_B = kfac_B.shape[0] if kfac_B.numel() > 0 else 32
        kfac_A = torch.eye(dim_A, device=kfac_A.device) * 0.1 + torch.rand(dim_A,dim_A, device=kfac_A.device)*0.01
        kfac_B = torch.eye(dim_B, device=kfac_B.device) * 0.1 + torch.rand(dim_B,dim_B, device=kfac_B.device)*0.01
        kfac_A = (kfac_A + kfac_A.T)/2 # ensure symmetric
        kfac_B = (kfac_B + kfac_B.T)/2


    fisher_engine = FisherEigen()
    fisher_engine.update_fisher(A=kfac_A, B=kfac_B)
    print(f"Fisher block computed, shape: {fisher_engine.fisher_block.shape if fisher_engine.fisher_block is not None else 'None'}")

    if fisher_engine.fisher_block is not None and fisher_engine.fisher_block.numel() > 0:
        eig_vals, eig_vecs = fisher_engine.decompose(k=0.1) # Request top 10%
        
        if eig_vecs is not None:
            print(f"Decomposed Fisher: Got {eig_vecs.shape[1]} eigenvectors with shape {eig_vecs.shape}")
            print("Top eigenvalues (sample):\\n", eig_vals[:min(5, len(eig_vals) if eig_vals is not None else 0)])

            # Subspace Buffer example
            # Suppose we track subspaces for a few "blocks" (e.g. layers from KFAC)
            num_tracked_fisher_blocks = 1 # For this example, just one from the first KFAC handler
            subspace_buffer_size = 1024 # Max dim / max num vectors for this buffer example
            
            # Ensure eigenvector dimensions match subspace buffer expectations
            # Fisher block dim = dim_A * dim_B. Eigenvector dim is also this.
            # For SubspaceBuffer, this full_vector_dim must be <= buffer_size.
            
            current_fisher_block_dim = fisher_engine.fisher_block.shape[0]
            if current_fisher_block_dim > subspace_buffer_size:
                 print(f"Warning: Fisher block dim ({current_fisher_block_dim}) > SubspaceBuffer buffer_size ({subspace_buffer_size}). "
                       "Eigenvectors will be truncated by SubspaceBuffer.")

            subspace_manager = SubspaceBuffer(num_blocks=num_tracked_fisher_blocks, 
                                              buffer_size=subspace_buffer_size,
                                              device=eig_vecs.device)
            subspace_manager.update_buffer(block_idx=0, new_eigenvectors=eig_vecs)
            retrieved_vecs = subspace_manager.get_subspace(block_idx=0)
            if retrieved_vecs is not None:
                print(f"Retrieved {retrieved_vecs.shape[1]} vectors (shape {retrieved_vecs.shape}) from SubspaceBuffer for block 0.")
            else:
                print("No vectors retrieved from SubspaceBuffer for block 0 (likely empty or issue).")
        else:
            print("FisherEigen decomposition failed to produce eigenvectors.")
    else:
        print("Skipping FisherEigen decomposition as Fisher block is empty or not computed.")
else:
    print("Skipping FisherEigen & SubspaceBuffer example as KFAC handlers are not available.")

print("\\n--- Infrastructure Script End ---")

\n--- FisherEigen & SubspaceBuffer Example ---
Fisher block computed, shape: torch.Size([1024, 1024])
Decomposed Fisher: Got 102 eigenvectors with shape torch.Size([1024, 102])
Top eigenvalues (sample):\n tensor([0.0005, 0.0004, 0.0004, 0.0003, 0.0003])
Retrieved 102 vectors (shape torch.Size([1024, 102])) from SubspaceBuffer for block 0.
\n--- Infrastructure Script End ---


# Stage 2: Training Pipeline

In [13]:
import torch
import torch.nn as nn

### 1. Forward Pass Implementation


In [14]:
def forward_hook(layer, inputs, outputs):
    """
    Hook to track activation statistics for K-FAC and record covariances.
    Attaches A_factor directly to the layer.
    """
    # Track activation statistics for K-FAC
    x = inputs[0].detach()
    # Ensure x is 2D: (batch_size_effective, in_features)
    if x.ndim > 2:
        x_reshaped = x.reshape(-1, x.shape[-1])
    else:
        x_reshaped = x
    
    if x_reshaped.shape[0] == 0: return

    # Initialize A_factor if not present or different device/dtype
    if not hasattr(layer, 'A_factor') or layer.A_factor is None or \
       layer.A_factor.shape[0] != x_reshaped.shape[1] or \
       layer.A_factor.device != x_reshaped.device or layer.A_factor.dtype != x_reshaped.dtype:
        layer.A_factor = torch.zeros(x_reshaped.shape[1], x_reshaped.shape[1], device=x_reshaped.device, dtype=x_reshaped.dtype)

    # KFAC A_factor update: EMA of (x.T @ x) / N
    # This is M2 (second moment matrix), not covariance if mean is non-zero.
    # KFAC uses E[aa^T]
    current_A_factor_val = (x_reshaped.T @ x_reshaped) / x_reshaped.size(0)
    layer.A_factor = 0.95 * layer.A_factor + 0.05 * current_A_factor_val
    
    # Record input/output covariances (though not directly used in subsequent GRIT steps provided)
    # Ensure layer has these attributes initialized correctly.
    if not hasattr(layer, 'input_cov') or layer.input_cov is None:
        layer.input_cov = torch.zeros_like(layer.A_factor)
    if not hasattr(layer, 'output_cov') or layer.output_cov is None:
        out_features = outputs.shape[-1]
        layer.output_cov = torch.zeros(out_features, out_features, device=outputs.device, dtype=outputs.dtype)

    if x_reshaped.size(0) > 1:
        layer.input_cov = (x_reshaped.T @ x_reshaped) / (x_reshaped.size(0) -1) # Sample covariance
    
    # Process outputs for output_cov
    if outputs.ndim > 2:
        outputs_reshaped = outputs.reshape(-1, outputs.shape[-1])
    else:
        outputs_reshaped = outputs
    
    if outputs_reshaped.size(0) > 1:
        layer.output_cov = (outputs_reshaped.T @ outputs_reshaped) / (outputs_reshaped.size(0) -1)

def instrument_forward(model):
    """Registers forward_hook on Linear and Conv2d layers."""
    for layer in model.modules():
        if isinstance(layer, (nn.Linear, nn.Conv2d)): 
            # Initialize dummy factors if they don't exist, to avoid issues on first pass
            # Hooks will overwrite these. This is mainly for layers not hit by backward pass early.
            if not hasattr(layer, 'A_factor'):
                # Determine in_features for Linear vs Conv2d
                if isinstance(layer, nn.Linear):
                    in_dim = layer.in_features
                elif isinstance(layer, nn.Conv2d):
                    in_dim = layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1] # Simplified view
                else:
                    in_dim = 0 # Should not happen due to outer check

                if in_dim > 0:
                    try:
                        device = next(layer.parameters()).device
                    except StopIteration:
                        device = torch.device("cpu")
                    layer.A_factor = torch.zeros(in_dim, in_dim, device=device)

            layer.register_forward_hook(forward_hook)

### 2. Backward Pass Implementation


In [15]:
def backward_hook(layer, grad_input, grad_output):
    """
    Hook to compute gradient covariances (B_factor) and update Kronecker factor.
    Attaches B_factor and kron_factor directly to the layer.
    """
    grad = grad_output[0].detach()
    # Ensure grad is 2D: (batch_size_effective, out_features)
    if grad.ndim > 2:
        grad_reshaped = grad.reshape(-1, grad.shape[-1])
    else:
        grad_reshaped = grad

    if grad_reshaped.shape[0] == 0: return

    # Initialize B_factor if not present or different device/dtype
    if not hasattr(layer, 'B_factor') or layer.B_factor is None or \
       layer.B_factor.shape[0] != grad_reshaped.shape[1] or \
       layer.B_factor.device != grad_reshaped.device or layer.B_factor.dtype != grad_reshaped.dtype:
        layer.B_factor = torch.zeros(grad_reshaped.shape[1], grad_reshaped.shape[1], device=grad_reshaped.device, dtype=grad_reshaped.dtype)

    # KFAC B_factor update: EMA of (grad.T @ grad) / N
    # This is M2 (second moment matrix) for gradients. KFAC uses E[gg^T]
    current_B_factor_val = (grad_reshaped.T @ grad_reshaped) / grad_reshaped.size(0)
    layer.B_factor = 0.95 * layer.B_factor + 0.05 * current_B_factor_val
    
    # Update Kronecker factors
    if hasattr(layer, 'A_factor') and layer.A_factor is not None and \
       layer.A_factor.numel() > 0 and layer.B_factor.numel() > 0:
        try:
            # Ensure A_factor and B_factor are 2D square matrices before kron
            if layer.A_factor.ndim == 2 and layer.B_factor.ndim == 2 and \
               layer.A_factor.shape[0] == layer.A_factor.shape[1] and \
               layer.B_factor.shape[0] == layer.B_factor.shape[1]:
                layer.kron_factor = torch.kron(layer.A_factor, layer.B_factor)
            else:
                # print(f"Warning: A_factor or B_factor for layer is not a 2D square matrix. Skipping kron_factor update.")
                # print(f"A_factor shape: {layer.A_factor.shape}, B_factor shape: {layer.B_factor.shape}")
                if not hasattr(layer, 'kron_factor'): layer.kron_factor = None
        except Exception as e:
            # print(f"Error computing Kronecker product for layer: {e}. kron_factor not updated.")
            if not hasattr(layer, 'kron_factor'): layer.kron_factor = None
    else:
        if not hasattr(layer, 'kron_factor'): layer.kron_factor = None


def instrument_backward(model):
    """Registers backward_hook on Linear and Conv2d layers."""
    for layer in model.modules():
        if isinstance(layer, (nn.Linear, nn.Conv2d)): 
            if not hasattr(layer, 'B_factor'): # Initialize if not present
                 if isinstance(layer, nn.Linear):
                    out_dim = layer.out_features
                 elif isinstance(layer, nn.Conv2d):
                    out_dim = layer.out_channels
                 else:
                    out_dim = 0
                 if out_dim > 0:
                    try:
                        device = next(layer.parameters()).device
                    except StopIteration:
                        device = torch.device("cpu")
                    layer.B_factor = torch.zeros(out_dim, out_dim, device=device)
            
            # Using register_full_backward_hook for broader compatibility
            layer.register_full_backward_hook(backward_hook)

### 3. Natural Gradient Calculation


In [16]:
def compute_lora_fisher(lora_layer: FlexLoRA, damping: float = 1e-6) -> torch.Tensor | None:
    """
    Computes the Fisher Information Matrix for concatenated LoRA parameters [vec(A); vec(B)].
    Returns F + damping * I.
    Returns None if gradients are not available or shapes are problematic.
    """
    if not hasattr(lora_layer, 'lora_A') or not hasattr(lora_layer, 'lora_B') or \
       lora_layer.lora_A is None or lora_layer.lora_B is None:
        # print(f"compute_lora_fisher: lora_A or lora_B not found in {type(lora_layer).__name__}")
        return None
        
    A = lora_layer.lora_A
    B = lora_layer.lora_B

    if A.grad is None or B.grad is None:
        # print(f"compute_lora_fisher: Gradients for lora_A or lora_B are None in {type(lora_layer).__name__}")
        return None
    
    # Ensure A and B have expected dimensions (rank > 0)
    if A.ndim < 2 or B.ndim < 2 or A.shape[1] == 0 or B.shape[0] == 0: # rank is A.shape[1] or B.shape[0]
        # print(f"compute_lora_fisher: lora_A or lora_B has zero rank or unexpected ndim. A_shape: {A.shape}, B_shape: {B.shape}")
        return None

    # grad_A = A.grad.reshape(-1,1) # Not used in F computation directly
    # grad_B = B.grad.reshape(-1,1) # Not used in F computation directly
    
    # Fisher for concatenated [vec(A); vec(B)]
    # F = torch.kron(B @ B.T, torch.eye(A.shape[0])) + \
    #     torch.kron(torch.eye(B.shape[1]), A.T @ A)
    # This Fisher is for FIM of W_eff = W_base + AB if A,B are small perturbations.
    # Or it is an approximation of the FIM for parameters A and B.
    # Let's assume A.shape = (in_features, rank), B.shape = (rank, out_features)
    
    in_features_A = A.shape[0]
    rank_A = A.shape[1]
    rank_B = B.shape[0]
    out_features_B = B.shape[1]

    if rank_A != rank_B:
        # print(f"compute_lora_fisher: Rank mismatch between lora_A ({rank_A}) and lora_B ({rank_B}). Cannot compute Fisher.")
        return None
    if rank_A == 0 : # current_rank can be 0
        # print(f"compute_lora_fisher: Rank is 0. Cannot compute Fisher.")
        return None

    try:
        # Term 1: kron(B @ B.T, I_in)  where B is (rank, out_features)
        # B @ B.T gives (rank, rank)
        # I_in is (in_features, in_features)
        term1_kron1 = B @ B.T
        term1_kron2 = torch.eye(in_features_A, device=A.device, dtype=A.dtype)
        F_term1 = torch.kron(term1_kron1, term1_kron2) # Shape: (rank*in, rank*in) - This seems for vec(A)

        # Term 2: kron(I_out, A.T @ A) where A is (in_features, rank)
        # A.T @ A gives (rank, rank)
        # I_out is (out_features, out_features)
        term2_kron1 = torch.eye(out_features_B, device=B.device, dtype=B.dtype)
        term2_kron2 = A.T @ A 
        F_term2 = torch.kron(term2_kron1, term2_kron2) # Shape: (out*rank, out*rank) - This seems for vec(B)
        
        # The paper's structure for F_LoRA = [F_A, F_AB; F_BA, F_B] might be more appropriate.
        # The formula sums two Kronecker products directly.
        # This implies the FIM is block diagonal for vec(A) and vec(B), or this is a simplification.
        # If F is for [vec(A); vec(B)], its size should be (in*rank + rank*out, in*rank + rank*out).
        # The provided formula F = kron(B B^T, I_A_in) + kron(I_B_out, A^T A)
        # assumes that the Fisher is block-diagonal and these are the blocks for A and B.
        # Or, that vec(A) and vec(B) are independent.
        # Let's assume this is an approximation of F_A and F_B that are then used separately.
        # OR, F_A = B B.T kronecker I_in_A  and F_B = I_out_B kronecker A.T A
        # This means F_A is (in_features * rank, in_features * rank)
        # And F_B is (out_features * rank, out_features * rank)
        # If so, the Fisher matrix F for the concatenated parameters [vec(A); vec(B)]
        # would be a block diagonal matrix: diag(F_A, F_B)
        # F_A_dim = in_features_A * rank_A
        # F_B_dim = out_features_B * rank_A # rank_A == rank_B
        
        # If the formula is for the full FIM of [vec(A); vec(B)] and it's a sum,
        # then F_term1 and F_term2 must have the same dimensions. This is not generally true.
        # (rank*in_features_A) vs (rank*out_features_B)
        # This strongly suggests the formula is for a simplified or block-diagonal Fisher.

        # Perhaps F_A = torch.kron(B @ B.T, torch.eye(A.shape[0]))  (for vec(A))
        # And F_B_approx = torch.kron(torch.eye(B.shape[1]), A.T @ A) (for vec(B))
        # And the overall update is done separately for A and B using these.
        # F is intended as a single matrix of a particular structure.
        
        # Given the direct sum, they MUST be the same size. This only happens if in_features_A == out_features_B.
        # This is very restrictive.
        
        # Block diagonal Fisher:
        # F_A_block = torch.kron(B.data @ B.data.T, torch.eye(in_features_A, device=A.device, dtype=A.dtype))
        # F_B_block = torch.kron(torch.eye(out_features_B, device=B.device, dtype=B.dtype), A.data.T @ A.data)

        # Total dimension for F for [vec(A); vec(B)]
        total_param_dim = (in_features_A * rank_A) + (rank_A * out_features_B)
        F_matrix = torch.zeros(total_param_dim, total_param_dim, device=A.device, dtype=A.dtype)

        # Populate F_A_block for vec(A) parameters
        dim_A_flat = in_features_A * rank_A
        # Using .detach() for safer access to tensor data without autograd history
        F_A_block_calc = torch.kron(B.detach() @ B.detach().T, torch.eye(in_features_A, device=A.device, dtype=A.dtype))
        F_matrix[:dim_A_flat, :dim_A_flat] = F_A_block_calc

        # Populate F_B_block for vec(B) parameters
        dim_B_flat = rank_A * out_features_B
        # Using .detach() for safer access to tensor data without autograd history
        F_B_block_calc = torch.kron(torch.eye(out_features_B, device=B.device, dtype=B.dtype), A.detach().T @ A.detach())
        F_matrix[dim_A_flat:, dim_A_flat:] = F_B_block_calc
        
        # Add damping to the combined block-diagonal Fisher
        return F_matrix + damping * torch.eye(total_param_dim, device=A.device, dtype=A.dtype)

    except Exception as e:
        print(f"Error in compute_lora_fisher for {type(lora_layer).__name__}: {e}")
        return None

In [17]:
# 3.2 LoRA-Constrained Update (adapted for FlexLoRA)

def natural_gradient_step(model: nn.Module, lr: float = 0.001, damping_lora_fisher: float = 1e-6): # Added damping_lora_fisher
    """
    Performs a natural gradient update step for LoRA parameters using LoRA-specific Fisher.
    """
    for module_name, module in model.named_modules():
        if isinstance(module, FlexLoRA):
            if module.current_rank == 0: # Skip if LoRA rank is zero
                # print(f"Skipping NG for {module_name}: LoRA rank is 0.")
                continue
            if not hasattr(module, 'lora_A') or not hasattr(module, 'lora_B') or \
               module.lora_A.grad is None or module.lora_B.grad is None:
                # print(f"Skipping NG for {module_name}: missing LoRA params or grads.")
                continue

            # Compute LoRA-specific Fisher
            F_lora = compute_lora_fisher(module, damping=damping_lora_fisher)

            if F_lora is None or F_lora.numel() == 0:
                # print(f"Skipping NG for {module_name}: LoRA Fisher computation failed or resulted in empty tensor. Falling back to SGD.")
                with torch.no_grad():
                    if module.lora_A.grad is not None:
                        module.lora_A.data -= lr * module.lora_A.grad
                        module.lora_A.grad.zero_()
                    if module.lora_B.grad is not None:
                        module.lora_B.data -= lr * module.lora_B.grad
                        module.lora_B.grad.zero_()
                continue
            
            grad_A_flat = module.lora_A.grad.flatten()
            grad_B_flat = module.lora_B.grad.flatten()
            grad_cat = torch.cat([grad_A_flat, grad_B_flat]) # Shape: (in*rank + rank*out)
            
            if F_lora.shape[0] != grad_cat.shape[0] or F_lora.shape[1] != grad_cat.shape[0]:
                print(f"Warning: Shape mismatch for F_lora_inv @ grad_cat in LoRA layer {module_name}. "
                      f"F_lora shape: {F_lora.shape}, grad_cat shape: {grad_cat.shape}. "
                      f"Falling back to SGD for this LoRA layer.")
                with torch.no_grad():
                    if module.lora_A.grad is not None:
                        module.lora_A.data -= lr * module.lora_A.grad
                        module.lora_A.grad.zero_()
                    if module.lora_B.grad is not None:
                        module.lora_B.data -= lr * module.lora_B.grad
                        module.lora_B.grad.zero_()
                continue

            try:
                # Invert Fisher: F_lora_inv = F_lora.inverse() or torch.linalg.pinv(F_lora)
                # Using pinv for stability
                F_lora_inv = torch.linalg.pinv(F_lora)
                
                with torch.no_grad():
                    ng_update = F_lora_inv @ grad_cat
                
                    num_elements_A = module.lora_A.numel()
                    module.lora_A.data -= lr * ng_update[:num_elements_A].reshape_as(module.lora_A)
                    module.lora_B.data -= lr * ng_update[num_elements_A:].reshape_as(module.lora_B)
                
                    if module.lora_A.grad is not None: module.lora_A.grad.zero_()
                    if module.lora_B.grad is not None: module.lora_B.grad.zero_()
            except RuntimeError as e:
                print(f"RuntimeError during natural gradient update for {module_name}: {e}. Falling back to SGD.")
                with torch.no_grad(): # Fallback to SGD
                    if module.lora_A.grad is not None:
                        module.lora_A.data -= lr * module.lora_A.grad
                        module.lora_A.grad.zero_()
                    if module.lora_B.grad is not None:
                        module.lora_B.data -= lr * module.lora_B.grad
                        module.lora_B.grad.zero_()

### 4. Neural Reprojection (adapted for FlexLoRA)


In [18]:
def neural_reprojection(
    lora_layer: FlexLoRA, 
    rho_eigen_sum: float = 0.9, 
    damping_lora_fisher: float = 1e-6 # Damping for LoRA Fisher used here
    ):
    """
    Performs neural reprojection on LoRA parameters.
    Uses a LoRA-specific Fisher matrix to find the principal subspace for reprojection.
    WARNING: The projection math (delta_A = lora_A @ U_k @ U_k.T) is dimensionally problematic
    and is SKIPPED. This function currently only calculates k and U_k.
    """
    if not isinstance(lora_layer, FlexLoRA):
        return
    if lora_layer.current_rank == 0:
        return
    
    # Compute LoRA-specific Fisher matrix (F_lora for [vec(A); vec(B)])
    F_lora = compute_lora_fisher(lora_layer, damping=damping_lora_fisher)

    if F_lora is None or F_lora.numel() == 0 or F_lora.shape[0] == 0:
        # print(f"Skipping reprojection for {type(lora_layer).__name__}: LoRA Fisher not available or empty.")
        return

    try:
        # SVD of F_lora. S contains singular values (eigenvalues if F_lora is PSD)
        # U contains left singular vectors (eigenvectors if F_lora is symmetric)
        # F_lora should be symmetric by construction (if A, B data used, or if it's from KFAC factors that are symm)
        # The compute_lora_fisher builds it block-diagonally from symmetric blocks.
        U_f, S_eigenvalues_f, Vh_f = torch.linalg.svd(F_lora)
    except Exception as e:
        # print(f"SVD failed for F_lora in neural_reprojection for {type(lora_layer).__name__}: {e}. Skipping.")
        return
    
    if S_eigenvalues_f.numel() == 0:
        # print(f"Warning: No singular values from SVD of F_lora in {type(lora_layer).__name__}. Skipping reprojection.")
        return

    total_eigenvalue_sum = torch.sum(S_eigenvalues_f)
    if total_eigenvalue_sum <= 1e-9: # Avoid division by zero or issues with non-positive sums
        # print(f"Warning: Total eigenvalue sum of F_lora is not positive ({total_eigenvalue_sum}) for {type(lora_layer).__name__}. Cannot determine k. Skipping.")
        return
        
    cumulative_eigenvalue_sum = torch.cumsum(S_eigenvalues_f, dim=0)
    k_candidates = torch.where(cumulative_eigenvalue_sum >= rho_eigen_sum * total_eigenvalue_sum)[0]
    
    if len(k_candidates) > 0:
        k = k_candidates[0].item() + 1 
    else:
        k = S_eigenvalues_f.shape[0] 

    k = max(1, min(k, S_eigenvalues_f.shape[0]))

    # print(f"Calculated k for reprojection in {type(lora_layer).__name__}: {k} (from {S_eigenvalues_f.shape[0]} total F_lora eigenvalues, rho={rho_eigen_sum})")
    
    # U_k are the top k eigenvectors of F_lora.
    # These eigenvectors correspond to the concatenated [vec(A); vec(B)] space.
    U_k_lora_fisher = U_f[:, :k] # Shape: ( (in*rank + rank*out), k_eigen_reproj )
    
    # Original projection logic was commented out due to dimensional issues:
    # delta_A = layer.lora_A @ U_k @ U_k.T # This U_k would need to be for A's space
    # delta_B = U_k.T @ layer.lora_B # This U_k would need to be for B's space
    
    # If U_k_lora_fisher is for the concatenated space, applying it directly to lora_A and lora_B separately is not straightforward.
    # The reprojection should be: new_params_flat = U_k @ U_k.T @ current_params_flat
    # Then unflatten and update lora_A and lora_B.
    
    # Flatten current LoRA parameters
    lora_A_flat = lora_layer.lora_A.data.flatten()
    lora_B_flat = lora_layer.lora_B.data.flatten()
    lora_params_flat = torch.cat([lora_A_flat, lora_B_flat])

    if U_k_lora_fisher.shape[0] != lora_params_flat.shape[0]:
        # print(f"Warning: Dimension mismatch in neural_reprojection for {type(lora_layer).__name__}. "
        #       f"U_k_lora_fisher dimension {U_k_lora_fisher.shape[0]} vs LoRA params dimension {lora_params_flat.shape[0]}. Skipping update.")
        return

    # Project onto the subspace spanned by U_k_lora_fisher
    projected_lora_params_flat = U_k_lora_fisher @ (U_k_lora_fisher.T @ lora_params_flat)
    
    # Reshape and update LoRA parameters
    with torch.no_grad():
        num_elements_A = lora_layer.lora_A.numel()
        new_lora_A_flat = projected_lora_params_flat[:num_elements_A]
        new_lora_B_flat = projected_lora_params_flat[num_elements_A:]
        
        lora_layer.lora_A.data.copy_(new_lora_A_flat.reshape_as(lora_layer.lora_A.data))
        lora_layer.lora_B.data.copy_(new_lora_B_flat.reshape_as(lora_layer.lora_B.data))
        
    # print(f"Neural reprojection applied to {type(lora_layer).__name__} using U_k from LoRA Fisher (k={k}).")
    pass # Actual update logic is now included above.

In [19]:
class SVTracker(nn.Module):
    """
    Tracks the history of singular values for a FlexLoRA layer's effective matrix (lora_A @ lora_B).
    Assumes the associated FlexLoRA layer has `lora_A`, `lora_B`, and `max_rank` attributes.
    """
    def __init__(self, flex_lora_layer, window_size: int = 100):
        super().__init__()
        if not hasattr(flex_lora_layer, 'lora_A') or not hasattr(flex_lora_layer, 'lora_B'):
            raise ValueError("Provided layer must have 'lora_A' and 'lora_B' parameters.")
        if not hasattr(flex_lora_layer, 'max_rank'):
            raise ValueError("Provided FlexLoRA layer must have a 'max_rank' attribute.")
            
        self.flex_lora_layer = flex_lora_layer
        self.window_size = window_size
        
        self.register_buffer('sv_history', 
                             torch.zeros(window_size, flex_lora_layer.max_rank, 
                                         device=flex_lora_layer.lora_A.device))

    def forward(self):
        """Calculates SVD of lora_A @ lora_B and updates history."""
        with torch.no_grad():
            lora_A_data = self.flex_lora_layer.lora_A.data
            lora_B_data = self.flex_lora_layer.lora_B.data
            effective_matrix = lora_A_data @ lora_B_data
            
            try:
                # SVD: U, S, Vh = svd(A). S are singular values.
                S = torch.linalg.svdvals(effective_matrix)
            except Exception as e:
                # print(f"SVD failed in SVTracker: {e}. Skipping history update.")
                return

            # Pad or truncate S to match self.flex_lora_layer.max_rank for consistent history storage
            num_singular_values = S.shape[0]
            max_r = self.flex_lora_layer.max_rank
            
            s_padded = torch.zeros(max_r, device=S.device)
            if num_singular_values > 0:
                s_padded[:min(num_singular_values, max_r)] = S[:min(num_singular_values, max_r)]
            
            self.sv_history = torch.roll(self.sv_history, shifts=1, dims=0)
            self.sv_history[0, :] = s_padded.detach()

In [20]:
class RankScheduler:
    def __init__(self, initial_rank: int, min_rank: int, max_rank: int, # Added min/max rank from FlexLoRA
                 warmup_steps: int = 1000, ema_alpha: float = 0.1): # Added ema_alpha for sv_ema
        self.current_rank = initial_rank
        self.min_rank = min_rank # Store min_rank
        self.max_rank = max_rank # Store max_rank
        self.warmup_steps = warmup_steps
        self.step_counter = 0
        self.ema_alpha = ema_alpha # For smoothing sv_ratio
        self.sv_ema = None # EMA of singular value ratio

    def update(self, sv_ratio: float): # sv_ratio is sigma_1/sigma_r from rank_adjustment's perspective
        self.step_counter += 1
        
        if self.sv_ema is None:
            self.sv_ema = sv_ratio
        else:
            self.sv_ema = self.ema_alpha * sv_ratio + (1 - self.ema_alpha) * self.sv_ema

        if self.step_counter < self.warmup_steps:
            return self.current_rank # Keep current rank during warmup

        # Logic:
        # if self.step_counter > 0.8: # Typo: self.step_counter > 0.8 ? Should be sv_ema
        # Assuming the condition is based on self.sv_ema
        if self.sv_ema is not None: # Check if sv_ema is initialized
            if self.sv_ema < 0.5: # If ratio is small, indicates potential for rank reduction
                # Decrease rank, but not below min_rank
                self.current_rank = max(self.min_rank, self.current_rank - 2) 
            elif self.sv_ema > 0.8: # If ratio is large, potential for rank increase
                # Increase rank, but not above max_rank
                self.current_rank = min(self.max_rank, self.current_rank + 2)
        
        return self.current_rank
    
    def get_rank(self):
        return self.current_rank

In [21]:
def rank_adjustment(
    layer: FlexLoRA, 
    sv_tracker: SVTracker, 
    rank_scheduler: RankScheduler, # Pass the scheduler instance for this layer
    # rank_eta, rank_tau are removed as scheduler now handles logic
    sv_history_min_samples: int = 5 
    ):
    """
    Adjusts the rank of a FlexLoRA layer using an external RankScheduler
    based on its singular value history (sigma_1/sigma_r ratio from mean_sv).
    """
    if not all(hasattr(layer, attr) for attr in 
               ['current_rank', 'min_rank', 'max_rank', 'lora_A', 'lora_B']):
        raise ValueError("FlexLoRA layer is missing required attributes for rank adjustment.")

    if sv_tracker.sv_history[:, 0].count_nonzero() < sv_history_min_samples :
        # print(f"Warning: Not enough singular value history for layer. Skipping rank adjustment. Need {sv_history_min_samples} samples.")
        return

    with torch.no_grad():
        mean_sv = sv_tracker.sv_history.mean(dim=0) 
        
        current_rank_val = layer.current_rank # Use a different name to avoid conflict with scheduler's current_rank
        if current_rank_val == 0 or current_rank_val > layer.max_rank: # current_rank_val == 0 means LoRA is inactive or uninitialized for adjustment
            return
        
        if mean_sv.numel() == 0 or current_rank_val > mean_sv.shape[0]:
            # print("Warning: Mean singular values not available or current rank exceeds available SVs. Skipping rank adjustment.")
            return

        sigma_1 = mean_sv[0].item()
        sigma_r = mean_sv[current_rank_val - 1].item() if current_rank_val > 0 else 0.0

        if sigma_1 < 1e-9 or sigma_r < 1e-9: # Avoid division by zero or unstable ratios
            # print(f"Warning: sigma_1 ({sigma_1}) or sigma_r ({sigma_r}) is near zero. Skipping rank update based on sv_ratio.")
            return 
        
        sv_ratio_metric = sigma_1 / sigma_r
        
        new_rank = rank_scheduler.update(sv_ratio_metric) 
        new_rank = int(max(layer.min_rank, min(new_rank, layer.max_rank)))


        if new_rank == current_rank_val:
            return
        
        old_lora_A = layer.lora_A.data
        old_lora_B = layer.lora_B.data
        device = old_lora_A.device

        # Preserve weights using SVD of W_lora = old_lora_A @ old_lora_B
        # This is generally more robust than padding/truncating A and B directly.
        W_lora = old_lora_A @ old_lora_B
        try:
            U, S, Vh = torch.linalg.svd(W_lora)
        except Exception as e:
            print(f"SVD failed during rank adjustment for layer {type(layer).__name__}: {e}. Rank not changed.")
            return

        # Determine the effective rank for SVD reconstruction.
        # This rank cannot exceed the number of singular values, S.shape[0] (min of W_lora dimensions).
        # layer.max_rank should ideally be configured to be <= S.shape[0].
        effective_rank_for_svd = new_rank
        if new_rank > S.shape[0]:
            print(f"Warning: Target new_rank {new_rank} exceeds available singular values {S.shape[0]} for layer {type(layer).__name__}. "
                  f"Capping rank to {S.shape[0]}. Consider adjusting layer.max_rank.")
            effective_rank_for_svd = S.shape[0]
            new_rank = effective_rank_for_svd # Update new_rank to the capped value for consistency

        if new_rank == current_rank_val and effective_rank_for_svd == new_rank : # Check again if capping resulted in no change
             return


        # Reconstruct LoRA matrices A' and B' such that A'B' approximates W_lora.
        # A_new = U[:, :k] @ diag(sqrt(S[:k]))
        # B_new = diag(sqrt(S[:k])) @ Vh[:k, :]
        # This distributes singular values (sqrt) to both A and B.
        
        # Handle case where effective_rank_for_svd is 0 (e.g. if min_rank is 0 and it's chosen)
        if effective_rank_for_svd == 0:
            new_lora_A_data = torch.empty((old_lora_A.shape[0], 0), device=device, dtype=old_lora_A.dtype)
            new_lora_B_data = torch.empty((0, old_lora_B.shape[1]), device=device, dtype=old_lora_B.dtype)
        else:
            U_k = U[:, :effective_rank_for_svd]
            S_k = S[:effective_rank_for_svd]
            Vh_k = Vh[:effective_rank_for_svd, :]

            # Singular values (S_k) are non-negative. torch.sqrt(0) is 0.
            sqrt_S_k = torch.sqrt(S_k)
            diag_sqrt_S_k = torch.diag(sqrt_S_k)

            new_lora_A_data = U_k @ diag_sqrt_S_k
            new_lora_B_data = diag_sqrt_S_k @ Vh_k
        
        layer.lora_A = nn.Parameter(new_lora_A_data.to(device))
        layer.lora_B = nn.Parameter(new_lora_B_data.to(device))
        layer.current_rank = new_rank # new_rank might have been capped by S.shape[0]
        rank_scheduler.current_rank = new_rank # Also update scheduler's internal current_rank

### 5. Parameter Fusion

In [22]:
# 5.1 Momentum Smoothing (Using existing robust version from Training.py)
class MomentumFusion:
    def __init__(self, beta: float = 0.9):
        self.beta = beta
        self.momentum_buffer = {} # Using dict for FlexLoRA parameters by name

    def __call__(self, model: nn.Module):
        with torch.no_grad():
            for name, module in model.named_modules():
                if isinstance(module, FlexLoRA):
                    if module.current_rank == 0: continue # Skip if no LoRA adaptation

                    params_to_fuse = {}
                    if hasattr(module, 'lora_A') and module.lora_A is not None:
                        params_to_fuse["lora_A"] = module.lora_A
                    if hasattr(module, 'lora_B') and module.lora_B is not None:
                        params_to_fuse["lora_B"] = module.lora_B
                    
                    for p_name, param in params_to_fuse.items():
                        if param.requires_grad: # Only fuse trainable LoRA params
                            full_param_name = f"{name}.{p_name}"
                            if full_param_name not in self.momentum_buffer:
                                self.momentum_buffer[full_param_name] = torch.zeros_like(param.data)
                            
                            buf = self.momentum_buffer[full_param_name]
                            buf.mul_(self.beta).add_(param.data, alpha=1 - self.beta)
                            param.data.copy_(buf)

### Dynamic Rank Adjustment System

In [23]:
# 5.2 Spectral Norm Enforcement (Using existing robust version from Training.py)
def enforce_spectral_constraints(model: nn.Module, max_singular_value: float = 1.0):
    """
    Enforces spectral norm constraints on the base weights of FlexLoRA layers.
    """
    with torch.no_grad():
        for module in model.modules():
            if isinstance(module, FlexLoRA):
                # Apply to the original base layer's weight
                if hasattr(module.base_layer, 'weight') and module.base_layer.weight is not None:
                    base_weight = module.base_layer.weight 
                    
                    U, S, V_T = torch.linalg.svd(base_weight.data, full_matrices=False)
                    S_clamped = torch.clamp(S, max=max_singular_value)
                    
                    # Reconstruct the weight matrix
                    new_weight = U @ torch.diag(S_clamped) @ V_T
                    base_weight.data.copy_(new_weight)

In [24]:
# Training Loop Integration
def grit_train_epoch(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer, # Optimizer for non-LoRA params
    fusion_module: MomentumFusion,
    sv_trackers: dict,                 # NEW: Dict mapping FlexLoRA layer names to SVTracker instances
    rank_schedulers: dict,             # NEW: Dict mapping FlexLoRA layer names to RankScheduler instances
    device: torch.device,
    current_epoch: int,
    lr_natural_gradient: float = 0.001, # LR for natural_gradient_step
    damping_lora_fisher: float = 1e-6, # Damping for LoRA Fisher in NG and Reprojection
    reprojection_rho_eigen_sum: float = 0.9, # RENAMED from reprojection_keep_ratio, new default example
    spectral_constraint_sv_max: float = 1.0,
    rank_adjust_sv_history_min_samples: int = 5 # NEW: Min samples for rank adjustment decisions
    ):
    
    model.train()
    
    # Instrument model hooks (called once per epoch)
    # Note: This means factors A and B are EMA over the entire epoch.
    # KFAC typically updates factors more frequently or uses per-batch stats.
    instrument_forward(model)  # Attaches A_factor
    instrument_backward(model) # Attaches B_factor and kron_factor
    
    total_loss = 0.0
    for batch_idx, batch in enumerate(dataloader):
        # Adapt this to your specific batch structure.
        input_ids = batch[0].to(device)
        targets = batch[1].to(device)
        # attention_mask = batch.get('attention_mask', None) # If used
        # if attention_mask is not None: attention_mask = attention_mask.to(device)

        optimizer.zero_grad() # Zero gradients for all parameters

        # Standard forward pass
        # KFAC forward hooks run here (from instrument_forward), updating layer.A_factor
        # outputs = model(input_ids, attention_mask=attention_mask)
        outputs = model(input_ids) 
        
        loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        
        # Standard backward pass
        # KFAC backward hooks run here (from instrument_backward), updating layer.B_factor and layer.kron_factor
        loss.backward()       # Compute gradients for all trainable parameters

        # --- GRIT Specific Steps ---
        # 1. Natural Gradient Step (for LoRA parameters)
        # This function internally handles LoRA grads (updates and zeros them if successful or falls back to SGD).
        natural_gradient_step(model, lr=lr_natural_gradient, damping_lora_fisher=damping_lora_fisher)
        
        # Standard optimizer step for any other parameters (e.g., embeddings, non-LoRA layers, classification head)
        # LoRA parameters whose grads were handled by natural_gradient_step (and zeroed) won't be affected.
        optimizer.step() # Optimizer step AFTER NG step.

        # 2. Dynamic Rank Allocation for FlexLoRA layers (NEW)
        # This step adjusts ranks and potentially re-initializes LoRA parameters for the *next* iteration.
        # It occurs after gradient updates for the current iteration.
        # Assumes SVTracker and RankScheduler instances are provided for relevant layers.
        # Assumes rank_adjustment and SVTracker.forward are defined elsewhere.
        with torch.no_grad(): # Rank adjustment should not affect gradient computation
            for name, module_candidate in model.named_modules():
                if isinstance(module_candidate, FlexLoRA):
                    # Ensure the layer is active and has corresponding tracker/scheduler
                    if module_candidate.current_rank > 0 and \
                       name in sv_trackers and \
                       name in rank_schedulers:
                        
                        current_sv_tracker = sv_trackers[name]
                        current_rank_scheduler = rank_schedulers[name]
                        
                        # Call the SVTracker's forward method.
                        # This method is assumed to update the tracker's internal state,
                        # possibly using data collected by hooks during the model's forward/backward pass,
                        # or by analyzing current weights/gradients.
                        current_sv_tracker.forward() 
                        
                        # Perform rank adjustment
                        rank_adjustment(
                            layer=module_candidate, 
                            sv_tracker=current_sv_tracker, 
                            rank_scheduler=current_rank_scheduler, 
                            sv_history_min_samples=rank_adjust_sv_history_min_samples
                        )
                
        # 3. Neural Reprojection (applied to FlexLoRA layers) - (was 2)
        # This happens every step.
        for module_candidate in model.modules():
            if isinstance(module_candidate, FlexLoRA):
                # The neural_reprojection function expects a FlexLoRA layer
                # and uses its base_layer.kron_factor.
                neural_reprojection(module_candidate, rho_eigen_sum=reprojection_rho_eigen_sum, damping_lora_fisher=damping_lora_fisher)
                
        # 4. Parameter Fusion (LoRA momentum) - (was 3)
        fusion_module(model)
        
        # 5. Spectral Constraints (on base weights of LoRA layers) - (was 4)
        enforce_spectral_constraints(model, max_singular_value=spectral_constraint_sv_max)

        total_loss += loss.item()
        
        if batch_idx % 50 == 0 : # Simple logging
            print(f"Epoch {current_epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {current_epoch} Average Loss: {avg_loss:.4f}")
    return avg_loss

In [25]:
# --- Dummy Model and Data for Illustration ---
class SimpleModelWithLoRA(nn.Module):
    def __init__(self, vocab_size=100, dim=32, lora_rank_val=4): # Renamed lora_rank to lora_rank_val
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.linear1 = nn.Linear(dim, dim * 2) # Non-LoRA
        self.relu = nn.ReLU()
        # This nn.Linear will be replaced by FlexLoRA
        self.lora_target_linear = nn.Linear(dim * 2, dim) 
        self.output_head = nn.Linear(dim, vocab_size) # Non-LoRA
        
        # Apply LoRA to self.lora_target_linear
        # The lora_inject_model function is from Grit.Infrastructure
        # It replaces the nn.Linear module named 'lora_target_linear' with a FlexLoRA module.
        # The FlexLoRA module will then contain the original nn.Linear as its `base_layer`.
        lora_config = {
            # Substring matching for names like "lora_target_linear"
            'layers_to_adapt_substrings': ['lora_target_linear'], 
            'rank_range': (lora_rank_val, lora_rank_val), # Fixed rank for simplicity
            # 'lora_alpha' is not explicitly in FlexLoRA init in Infrastructure.py, but can be added if needed by FlexLoRA
        }
        # lora_inject_model is expected to modify the model in-place.
        # It searches for modules whose names contain substrings from 'layers_to_adapt_substrings'
        # and replaces them with FlexLoRA instances.
        lora_inject(self, **lora_config)


    def forward(self, x):
        x = self.embedding(x)
        x = x.mean(dim=1) 
        x = self.relu(self.linear1(x))
        x = self.lora_target_linear(x) # This is now the FlexLoRA wrapped layer
        x = self.output_head(x)
        return x

In [26]:
# --- Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

vocab_size_const = 100
dim_const = 32
lora_rank_const = 4 # For FlexLoRA
batch_size_const = 8
seq_len_const = 10

# Instantiate model
model_instance = SimpleModelWithLoRA(vocab_size_const, dim_const, lora_rank_const).to(device)
print("Model structure after LoRA injection:")
print(model_instance)

Using device: cuda
Model structure after LoRA injection:
SimpleModelWithLoRA(
  (embedding): Embedding(100, 32)
  (linear1): Linear(in_features=32, out_features=64, bias=True)
  (relu): ReLU()
  (lora_target_linear): FlexLoRA(initial_rank=4, current_rank=4, min_rank=4, max_rank=4, base_layer=Linear(64x32))
  (output_head): Linear(in_features=32, out_features=100, bias=True)
)


In [27]:
# --- Verify LoRA injection and find FlexLoRA layers for later checks ---
found_flexlora = False
for name, module in model_instance.named_modules():
    if isinstance(module, FlexLoRA):
        print(f"Found FlexLoRA layer: {name} with rank {module.current_rank}")
        found_flexlora = True
        # Expected: `lora_target_linear` should be FlexLoRA
        # Its `base_layer` should be the original nn.Linear(dim * 2, dim)
        if hasattr(module, 'base_layer') and isinstance(module.base_layer, nn.Linear):
            print(f"  Base layer of {name}: {module.base_layer}")
        else:
            print(f"  Error: FlexLoRA layer {name} does not have a valid nn.Linear base_layer.")
if not found_flexlora:
    print("Error: No FlexLoRA layers were injected into the model. Check lora_inject_model and config.")
    exit() # Stop if LoRA injection failed, as GRIT steps depend on it.


# Dummy Dataloader
dummy_input_data = torch.randint(0, vocab_size_const, (batch_size_const * 5, seq_len_const))
dummy_labels_data = torch.randint(0, vocab_size_const, (batch_size_const * 5,)) # For CrossEntropyLoss
dummy_dataset_instance = torch.utils.data.TensorDataset(dummy_input_data, dummy_labels_data)
dataloader_instance = torch.utils.data.DataLoader(dummy_dataset_instance, batch_size=batch_size_const)

# Loss and Optimizer
# Optimizer for non-LoRA params and potentially base model if not frozen by FlexLoRA.
# FlexLoRA freezes base layer, so this optimizer trains embedding, linear1, output_head.
# LoRA params are updated by natural_gradient_step (or SGD fallback within it).
criterion_instance = nn.CrossEntropyLoss()
optimizer_instance = torch.optim.AdamW(
    [p for p in model_instance.parameters() if p.requires_grad], # Should pick up non-LoRA params
    lr=1e-3
)

# GRIT components
momentum_fuser_instance = MomentumFusion(beta=0.9)

Found FlexLoRA layer: lora_target_linear with rank 4
  Base layer of lora_target_linear: Linear(in_features=64, out_features=32, bias=True)


In [28]:
# --- Setup for SVTrackers and RankSchedulers ---
# and model_instance is already created and populated with FlexLoRA layers.
# FlexLoRA modules are expected to have 'lora_A', 'lora_B', 'current_rank', and 'max_rank' attributes.
# The 'min_rank' attribute on FlexLoRA modules is optional; if not present, a default is used for RankScheduler.

sv_trackers = {}
rank_schedulers = {}

# Configuration constants for SVTracker and RankScheduler
# These values are illustrative and can be adjusted or moved to a central configuration.
SV_TRACKER_WINDOW_SIZE_CONST = 100
RANK_SCHEDULER_WARMUP_STEPS_CONST = 1000 # As per RankScheduler example
RANK_SCHEDULER_EMA_ALPHA_CONST = 0.1    # As per RankScheduler example
DEFAULT_MIN_RANK_FOR_SCHEDULER = 1      # Default min_rank if not on FlexLoRA module
# Parameter for rank_adjustment logic (called inside grit_train_epoch)
RANK_ADJUST_SV_HISTORY_MIN_SAMPLES_CONST = 5 

print("--- Initializing SVTrackers and RankSchedulers for FlexLoRA layers ---")
flexlora_layers_configured_count = 0
for name, module in model_instance.named_modules():
    if isinstance(module, FlexLoRA): # FlexLoRA class must be in scope
        # Validate that the module has the necessary attributes for SVTracker and RankScheduler
        if not all(hasattr(module, attr) for attr in ['lora_A', 'lora_B']):
            print(f"  Skipping {name}: FlexLoRA layer is missing 'lora_A' or 'lora_B' attributes required by SVTracker.")
            continue
        if not hasattr(module, 'max_rank'):
            print(f"  Skipping {name}: FlexLoRA layer is missing 'max_rank' attribute required by SVTracker and RankScheduler.")
            continue
        if not hasattr(module, 'current_rank'): # Used as initial_rank for RankScheduler
            print(f"  Skipping {name}: FlexLoRA layer is missing 'current_rank' attribute for RankScheduler.")
            continue

        # Determine min_rank for RankScheduler
        # Uses module.min_rank if available, otherwise defaults to DEFAULT_MIN_RANK_FOR_SCHEDULER
        min_rank_val = getattr(module, 'min_rank', DEFAULT_MIN_RANK_FOR_SCHEDULER)

        # Ensure rank configuration is valid: min_rank <= current_rank (initial) <= max_rank
        if not (min_rank_val <= module.current_rank <= module.max_rank):
            print(f"  Skipping {name}: Invalid rank configuration. "
                  f"Condition min_rank ({min_rank_val}) <= initial_rank ({module.current_rank}) "
                  f"<= max_rank ({module.max_rank}) is not met.")
            continue
        
        try:
            sv_trackers[name] = SVTracker(flex_lora_layer=module, window_size=SV_TRACKER_WINDOW_SIZE_CONST)
            rank_schedulers[name] = RankScheduler(
                initial_rank=module.current_rank,
                min_rank=min_rank_val,
                max_rank=module.max_rank,
                warmup_steps=RANK_SCHEDULER_WARMUP_STEPS_CONST,
                ema_alpha=RANK_SCHEDULER_EMA_ALPHA_CONST
            )
            print(f"  Initialized SVTracker and RankScheduler for FlexLoRA layer: {name} "
                  f"(initial_rank={module.current_rank}, min_rank={min_rank_val}, max_rank={module.max_rank})")
            flexlora_layers_configured_count += 1
        except ValueError as e:
            print(f"  Error initializing SVTracker or RankScheduler for {name}: {e}")
        except Exception as e: # Catch any other unexpected errors during init
            print(f"  Unexpected error initializing for {name}: {e}")


if flexlora_layers_configured_count == 0:
    print("Warning: No FlexLoRA layers were successfully configured with SVTrackers and RankSchedulers. "
          "Dynamic rank adjustment might not function as expected.")
else:
    print(f"Successfully initialized SVTrackers and RankSchedulers for {flexlora_layers_configured_count} FlexLoRA layer(s).")


# --- Training Loop ---
num_epochs_const = 2
print(f"--- Starting GRIT Training for {num_epochs_const} epochs ---")
for epoch_num in range(num_epochs_const):
    print(f"--- Epoch {epoch_num + 1} ---")
    avg_epoch_loss_val = grit_train_epoch(
        model=model_instance,
        dataloader=dataloader_instance,
        criterion=criterion_instance,
        optimizer=optimizer_instance,
        fusion_module=momentum_fuser_instance,
        device=device,
        current_epoch=epoch_num + 1,
        lr_natural_gradient=5e-4, 
        damping_lora_fisher=1e-5, # Or your desired value for LoRA Fisher damping
        reprojection_rho_eigen_sum=0.9, 
        spectral_constraint_sv_max=1.5,
        # --- New parameters for dynamic rank adjustment ---
        sv_trackers=sv_trackers,
        rank_schedulers=rank_schedulers,
        rank_adjust_sv_history_min_samples=RANK_ADJUST_SV_HISTORY_MIN_SAMPLES_CONST
    )
    print(f"--- Epoch {epoch_num + 1} Completed. Avg Loss: {avg_epoch_loss_val:.4f} ---")

print("Illustrative GRIT training based on pipeline finished.")
# Further steps: Save model, evaluate, etc.

--- Initializing SVTrackers and RankSchedulers for FlexLoRA layers ---
  Initialized SVTracker and RankScheduler for FlexLoRA layer: lora_target_linear (initial_rank=4, min_rank=4, max_rank=4)
Successfully initialized SVTrackers and RankSchedulers for 1 FlexLoRA layer(s).
--- Starting GRIT Training for 2 epochs ---
--- Epoch 1 ---
Epoch 1, Batch 0/5, Loss: 4.6314
Epoch 1 Average Loss: 4.6271
--- Epoch 1 Completed. Avg Loss: 4.6271 ---
--- Epoch 2 ---
Epoch 2, Batch 0/5, Loss: 4.5951
Epoch 2 Average Loss: 4.5996
--- Epoch 2 Completed. Avg Loss: 4.5996 ---
Illustrative GRIT training based on pipeline finished.


# GRIT Stage 3: Specialized Components

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Assume FlexLoRA is available from Grit.Infrastructure
# from Grit.Infrastructure import FlexLoRA 

### 1. Dynamic Rank Adjustment System

In [30]:
# class SVTracker(nn.Module):
#     """
#     Tracks the history of singular values for a FlexLoRA layer's effective matrix (lora_A @ lora_B).
#     Assumes the associated FlexLoRA layer has `lora_A`, `lora_B`, and `max_rank` attributes.
#     """
#     def __init__(self, flex_lora_layer, window_size: int = 100):
#         super().__init__()
#         if not hasattr(flex_lora_layer, 'lora_A') or not hasattr(flex_lora_layer, 'lora_B'):
#             raise ValueError("Provided layer must have 'lora_A' and 'lora_B' parameters.")
#         if not hasattr(flex_lora_layer, 'max_rank'):
#             raise ValueError("Provided FlexLoRA layer must have a 'max_rank' attribute.")
            
#         self.flex_lora_layer = flex_lora_layer
#         self.window_size = window_size
        
#         self.register_buffer('sv_history', 
#                              torch.zeros(window_size, flex_lora_layer.max_rank, 
#                                          device=flex_lora_layer.lora_A.device))

#     def forward(self):
#         """Calculates SVD of lora_A @ lora_B and updates history."""
#         with torch.no_grad():
#             lora_A_data = self.flex_lora_layer.lora_A.data
#             lora_B_data = self.flex_lora_layer.lora_B.data
#             effective_matrix = lora_A_data @ lora_B_data
            
#             try:
#                 # SVD: U, S, Vh = svd(A). S are singular values.
#                 S = torch.linalg.svdvals(effective_matrix)
#             except Exception as e:
#                 # print(f"SVD failed in SVTracker: {e}. Skipping history update.")
#                 return

#             # Pad or truncate S to match self.flex_lora_layer.max_rank for consistent history storage
#             num_singular_values = S.shape[0]
#             max_r = self.flex_lora_layer.max_rank
            
#             s_padded = torch.zeros(max_r, device=S.device)
#             if num_singular_values > 0:
#                 s_padded[:min(num_singular_values, max_r)] = S[:min(num_singular_values, max_r)]
            
#             self.sv_history = torch.roll(self.sv_history, shifts=1, dims=0)
#             self.sv_history[0, :] = s_padded.detach()

In [31]:
# class RankScheduler:
#     def __init__(self, initial_rank: int, min_rank: int, max_rank: int, # Added min/max rank from FlexLoRA
#                  warmup_steps: int = 1000, ema_alpha: float = 0.1): # Added ema_alpha for sv_ema
#         self.current_rank = initial_rank
#         self.min_rank = min_rank # Store min_rank
#         self.max_rank = max_rank # Store max_rank
#         self.warmup_steps = warmup_steps
#         self.step_counter = 0
#         self.ema_alpha = ema_alpha # For smoothing sv_ratio
#         self.sv_ema = None # EMA of singular value ratio

#     def update(self, sv_ratio: float): # sv_ratio is sigma_1/sigma_r from rank_adjustment's perspective
#         self.step_counter += 1
        
#         if self.sv_ema is None:
#             self.sv_ema = sv_ratio
#         else:
#             self.sv_ema = self.ema_alpha * sv_ratio + (1 - self.ema_alpha) * self.sv_ema

#         if self.step_counter < self.warmup_steps:
#             return self.current_rank # Keep current rank during warmup

#         # Logic:
#         # if self.step_counter > 0.8: # Typo: self.step_counter > 0.8 ? Should be sv_ema
#         # Assuming the condition is based on self.sv_ema
#         if self.sv_ema is not None: # Check if sv_ema is initialized
#             if self.sv_ema < 0.5: # If ratio is small, indicates potential for rank reduction
#                 # Decrease rank, but not below min_rank
#                 self.current_rank = max(self.min_rank, self.current_rank - 2) 
#             elif self.sv_ema > 0.8: # If ratio is large, potential for rank increase
#                 # Increase rank, but not above max_rank
#                 self.current_rank = min(self.max_rank, self.current_rank + 2)
        
#         return self.current_rank
    
#     def get_rank(self):
#         return self.current_rank

In [32]:
# def rank_adjustment(
#     layer: FlexLoRA, 
#     sv_tracker: SVTracker, 
#     rank_scheduler: RankScheduler, # Pass the scheduler instance for this layer
#     # rank_eta, rank_tau are removed as scheduler now handles logic
#     sv_history_min_samples: int = 5 
#     ):
#     """
#     Adjusts the rank of a FlexLoRA layer using an external RankScheduler
#     based on its singular value history (sigma_1/sigma_r ratio from mean_sv).
#     """
#     if not all(hasattr(layer, attr) for attr in 
#                ['current_rank', 'min_rank', 'max_rank', 'lora_A', 'lora_B']):
#         raise ValueError("FlexLoRA layer is missing required attributes for rank adjustment.")

#     if sv_tracker.sv_history[:, 0].count_nonzero() < sv_history_min_samples :
#         # print(f"Warning: Not enough singular value history for layer. Skipping rank adjustment. Need {sv_history_min_samples} samples.")
#         return

#     with torch.no_grad():
#         mean_sv = sv_tracker.sv_history.mean(dim=0) 
        
#         current_rank_val = layer.current_rank # Use a different name to avoid conflict with scheduler's current_rank
#         if current_rank_val == 0 or current_rank_val > layer.max_rank: # current_rank_val == 0 means LoRA is inactive or uninitialized for adjustment
#             return
        
#         if mean_sv.numel() == 0 or current_rank_val > mean_sv.shape[0]:
#             # print("Warning: Mean singular values not available or current rank exceeds available SVs. Skipping rank adjustment.")
#             return

#         sigma_1 = mean_sv[0].item()
#         sigma_r = mean_sv[current_rank_val - 1].item() if current_rank_val > 0 else 0.0

#         if sigma_1 < 1e-9 or sigma_r < 1e-9: # Avoid division by zero or unstable ratios
#             # print(f"Warning: sigma_1 ({sigma_1}) or sigma_r ({sigma_r}) is near zero. Skipping rank update based on sv_ratio.")
#             return 
        
#         sv_ratio_metric = sigma_1 / sigma_r
        
#         new_rank = rank_scheduler.update(sv_ratio_metric) 
#         new_rank = int(max(layer.min_rank, min(new_rank, layer.max_rank)))


#         if new_rank == current_rank_val:
#             return
        
#         old_lora_A = layer.lora_A.data
#         old_lora_B = layer.lora_B.data
#         device = old_lora_A.device

#         # Preserve weights using SVD of W_lora = old_lora_A @ old_lora_B
#         # This is generally more robust than padding/truncating A and B directly.
#         W_lora = old_lora_A @ old_lora_B
#         try:
#             U, S, Vh = torch.linalg.svd(W_lora)
#         except Exception as e:
#             print(f"SVD failed during rank adjustment for layer {type(layer).__name__}: {e}. Rank not changed.")
#             return

#         # Determine the effective rank for SVD reconstruction.
#         # This rank cannot exceed the number of singular values, S.shape[0] (min of W_lora dimensions).
#         # layer.max_rank should ideally be configured to be <= S.shape[0].
#         effective_rank_for_svd = new_rank
#         if new_rank > S.shape[0]:
#             print(f"Warning: Target new_rank {new_rank} exceeds available singular values {S.shape[0]} for layer {type(layer).__name__}. "
#                   f"Capping rank to {S.shape[0]}. Consider adjusting layer.max_rank.")
#             effective_rank_for_svd = S.shape[0]
#             new_rank = effective_rank_for_svd # Update new_rank to the capped value for consistency

#         if new_rank == current_rank_val and effective_rank_for_svd == new_rank : # Check again if capping resulted in no change
#              return


#         # Reconstruct LoRA matrices A' and B' such that A'B' approximates W_lora.
#         # A_new = U[:, :k] @ diag(sqrt(S[:k]))
#         # B_new = diag(sqrt(S[:k])) @ Vh[:k, :]
#         # This distributes singular values (sqrt) to both A and B.
        
#         # Handle case where effective_rank_for_svd is 0 (e.g. if min_rank is 0 and it's chosen)
#         if effective_rank_for_svd == 0:
#             new_lora_A_data = torch.empty((old_lora_A.shape[0], 0), device=device, dtype=old_lora_A.dtype)
#             new_lora_B_data = torch.empty((0, old_lora_B.shape[1]), device=device, dtype=old_lora_B.dtype)
#         else:
#             U_k = U[:, :effective_rank_for_svd]
#             S_k = S[:effective_rank_for_svd]
#             Vh_k = Vh[:effective_rank_for_svd, :]

#             # Singular values (S_k) are non-negative. torch.sqrt(0) is 0.
#             sqrt_S_k = torch.sqrt(S_k)
#             diag_sqrt_S_k = torch.diag(sqrt_S_k)

#             new_lora_A_data = U_k @ diag_sqrt_S_k
#             new_lora_B_data = diag_sqrt_S_k @ Vh_k
        
#         layer.lora_A = nn.Parameter(new_lora_A_data.to(device))
#         layer.lora_B = nn.Parameter(new_lora_B_data.to(device))
#         layer.current_rank = new_rank # new_rank might have been capped by S.shape[0]
#         rank_scheduler.current_rank = new_rank # Also update scheduler's internal current_rank

### 2. Stability Assurance Mechanisms


In [33]:
def fisher_clip(
    gradient: torch.Tensor, 
    inv_fisher_factor_A: torch.Tensor, # Inverse of activation covariance (A_inv)
    inv_fisher_factor_B: torch.Tensor, # Inverse of output gradient covariance (B_inv)
    clip_norm: float = 1.0
    ) -> torch.Tensor:
    """
    Clips a gradient in the Fisher-approximated natural gradient space.
    Conceptually, natural_grad = F_inv @ grad.flatten(), where F_inv = kron(inv_B, inv_A).
    This function computes it efficiently as vec(inv_B @ grad @ inv_A.T).
    """
    if gradient is None:
        return None
        
    with torch.no_grad():
        natural_grad_matrix = inv_fisher_factor_B @ gradient @ inv_fisher_factor_A
        natural_grad_flat = natural_grad_matrix.flatten()
        
        grad_norm_fisher_space = torch.norm(natural_grad_flat)
        
        if grad_norm_fisher_space > clip_norm:
            scale = clip_norm / (grad_norm_fisher_space + 1e-7) # Add epsilon for stability
            return gradient * scale
        return gradient

In [34]:
# class DampedKFAC(nn.Module):
#     # Assuming 'd,d' were placeholders for feature dimensions
#     def __init__(self, in_features: int, out_features: int, ema_decay: float = 0.95, damping: tuple[float, float] = (1e-7, 1e-5)):
#         super().__init__()
#         self.ema_decay = ema_decay
#         self.damping = damping # Tuple (damping_A, damping_B)
        
#         # Using register_buffer for A_accum and B_accum
#         self.register_buffer('A_accum', torch.zeros(in_features, in_features))
#         self.register_buffer('B_accum', torch.zeros(out_features, out_features))
#         self.A_initialized = False
#         self.B_initialized = False

#     def update(self, current_A_cov: torch.Tensor, current_B_cov: torch.Tensor = None):
#         """
#         Update A_accum or B_accum with new covariance.
#         If current_B_cov is None, only A_accum is updated.
#         If current_A_cov is None, only B_accum is updated.
#         """
#         with torch.no_grad():
#             if current_A_cov is not None:
#                 if not self.A_initialized:
#                     self.A_accum.copy_(current_A_cov)
#                     self.A_initialized = True
#                 else:
#                     self.A_accum.mul_(self.ema_decay).add_(current_A_cov, alpha=(1 - self.ema_decay))
            
#             if current_B_cov is not None:
#                 if not self.B_initialized:
#                     self.B_accum.copy_(current_B_cov)
#                     self.B_initialized = True
#                 else:
#                     self.B_accum.mul_(self.ema_decay).add_(current_B_cov, alpha=(1 - self.ema_decay))

#     def get_factors(self) -> tuple[torch.Tensor, torch.Tensor]:
#         """Returns the current A_accum and B_accum factors."""
#         return self.A_accum, self.B_accum

#     def get_inverse_factors(self) -> tuple[torch.Tensor, torch.Tensor]:
#         """Computes and returns damped inverses (A_inv, B_inv)."""
#         with torch.no_grad():
#             device = self.A_accum.device
            
#             A_damped = self.A_accum + self.damping[0] * torch.eye(self.A_accum.size(0), device=device)
#             B_damped = self.B_accum + self.damping[1] * torch.eye(self.B_accum.size(0), device=device)
            
#             try:
#                 A_inv = torch.linalg.pinv(A_damped)
#             except Exception as e:
#                 # print(f"Error inverting A_damped: {e}. Returning identity.")
#                 A_inv = torch.eye(self.A_accum.size(0), device=device)
            
#             try:
#                 B_inv = torch.linalg.pinv(B_damped)
#             except Exception as e:
#                 # print(f"Error inverting B_damped: {e}. Returning identity.")
#                 B_inv = torch.eye(self.B_accum.size(0), device=device)
                
#             return A_inv, B_inv

#     def __repr__(self):
#         return (f"DampedKFAC(in={self.A_accum.shape[0]}, out={self.B_accum.shape[0]}, "
#                 f"ema_decay={self.ema_decay}, damping={self.damping}, "
#                 f"A_init={self.A_initialized}, B_init={self.B_initialized})")

In [35]:
# NaturalGradientClipper
class NaturalGradientClipper:
    def __init__(self, max_norm: float = 1.0):
        self.max_norm = max_norm

    def __call__(self, natural_grad: torch.Tensor) -> torch.Tensor:
        norm = torch.norm(natural_grad)
        if norm > self.max_norm:
            return natural_grad * (self.max_norm / (norm + 1e-7)) # Added epsilon for stability
        return natural_grad

### 3. Memory Optimization Framework


In [36]:
def block_svd(matrix: torch.Tensor, block_size: int = 256) -> tuple[list, list, list]:
    """
    Performs SVD on blocks of a matrix. Assumes square blocks.
    Returns lists of U_blocks, S_value_blocks, Vh_blocks (transposed V).
    """
    U_blocks, S_blocks, Vh_blocks = [], [], []
    block_size_rows, block_size_cols = block_size, block_size # Assuming square blocks from signature
    
    for i in range(0, matrix.size(0), block_size_rows):
        for j in range(0, matrix.size(1), block_size_cols):
            block = matrix[i:i+block_size_rows, j:j+block_size_cols].clone()
            if block.numel() == 0: continue
            try:
                U, S, Vh = torch.linalg.svd(block) # linalg.svd returns Vh (V.T)
                U_blocks.append(U)
                S_blocks.append(S)
                Vh_blocks.append(Vh) # Storing Vh as is consistent with linalg.svd
            except Exception as e:
                # print(f"SVD failed for block ({i},{j}) of size {block.shape}: {e}")
                U_blocks.append(torch.empty(0,0, device=matrix.device))
                S_blocks.append(torch.empty(0, device=matrix.device))
                Vh_blocks.append(torch.empty(0,0, device=matrix.device))
    return U_blocks, S_blocks, Vh_blocks

class HalfPrecisionFisher(nn.Module):
    """
    Manages KFAC A and B factors stored in half-precision (float16).
    EMA updates are performed with care for precision.
    """
    def __init__(self, in_features: int, out_features: int, ema_decay: float = 0.95):
        super().__init__()
        self.ema_decay = ema_decay # 0.95 decay
        
        self.register_buffer('A_half', torch.zeros(in_features, in_features, dtype=torch.float16))
        self.register_buffer('B_half', torch.zeros(out_features, out_features, dtype=torch.float16))

    def update(self, current_A_cov: torch.Tensor, current_B_cov: torch.Tensor):
        """Update with raw (float32) covariances from current batch."""
        with torch.no_grad():
            new_A_f32 = self.A_half.float() * self.ema_decay + \
                        current_A_cov.half().float() * (1 - self.ema_decay)
            self.A_half = new_A_f32.half()

            new_B_f32 = self.B_half.float() * self.ema_decay + \
                        current_B_cov.half().float() * (1 - self.ema_decay)
            self.B_half = new_B_f32.half()
    
    def get_factors_f32(self) -> tuple[torch.Tensor, torch.Tensor]:
        """Returns factors cast to float32 for use."""
        return self.A_half.float(), self.B_half.float()

In [37]:
print("--- GRIT Stage 3: Specialized Components --- (Illustrative Examples with Updated Names)")

# Placeholder FlexLoRA for SVTracker and rank_adjustment demonstration
class MockFlexLoRA(nn.Module):
    def __init__(self, in_f, out_f, initial_rank=8, min_r=4, max_r=16, device='cpu'):
        super().__init__()
        self.in_features = in_f
        self.out_features = out_f
        # For FlexLoRA as updated, it would be:
        self.initial_rank = initial_rank
        self.current_rank = initial_rank
        self.min_rank = min_r
        self.max_rank = max_r
        self.lora_A = nn.Parameter(torch.randn(in_f, self.current_rank, device=device) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(self.current_rank, out_f, device=device))
        self.device = device
    def __repr__(self):
        return f"MockFlexLoRA(current_rank={self.current_rank}, min_rank={self.min_rank}, max_rank={self.max_rank})"

# 1. Dynamic Rank Adjustment Example
print("\n1. Dynamic Rank Adjustment Example")
# Ensure MockFlexLoRA matches the attributes FlexLoRA now has (initial_rank, current_rank, min_rank, max_rank)
mock_layer_for_rank_adj = MockFlexLoRA(in_f=64, out_f=32, initial_rank=8, min_r=4, max_r=12, device='cpu')
sv_tracker_instance = SVTracker(mock_layer_for_rank_adj, window_size=20)

# Simulate some SV history updates by calling SVTracker's forward method
for i in range(20):
    # Ensure lora_A's rank matches current_rank if it were to change dynamically in a real scenario
    if mock_layer_for_rank_adj.lora_A.shape[1] == mock_layer_for_rank_adj.current_rank:
         mock_layer_for_rank_adj.lora_A.data += torch.randn_like(mock_layer_for_rank_adj.lora_A.data) * 0.001
    # (Skipping lora_B changes for simplicity in this mock update)
    sv_tracker_instance.forward() 
    if i % 5 == 0:
        print(f"  SV History updated (step {i}). Current rank: {mock_layer_for_rank_adj.current_rank}")

print(f"  Before rank adjustment: {mock_layer_for_rank_adj}")
if sv_tracker_instance.sv_history.numel() > 0 and mock_layer_for_rank_adj.max_rank > 5:
    # Forcing some singular values to be small in history for testing decay
    # Ensure we don't index out of bounds if max_rank is small
    if sv_tracker_instance.sv_history.shape[1] > 5 :
        sv_tracker_instance.sv_history[:, 5:] *= 0.01 

# (mock_layer_for_rank_adj and sv_tracker_instance are defined above in this cell)

# Create a RankScheduler for the mock layer
mock_rank_scheduler = RankScheduler(
    initial_rank=mock_layer_for_rank_adj.current_rank,
    min_rank=mock_layer_for_rank_adj.min_rank,
    max_rank=mock_layer_for_rank_adj.max_rank,
    warmup_steps=10, # Example value, adjust as needed
    ema_alpha=0.1    # Example value, adjust as needed
)
print(f"  Created MockRankScheduler: initial_rank={mock_rank_scheduler.current_rank}, min_rank={mock_rank_scheduler.min_rank}, max_rank={mock_rank_scheduler.max_rank}")

# Call rank_adjustment with new parameters
rank_adjustment(
    mock_layer_for_rank_adj, 
    sv_tracker_instance, 
    mock_rank_scheduler, # Pass the RankScheduler instance
    sv_history_min_samples=5 # Example value
)
print(f"  After rank adjustment: {mock_layer_for_rank_adj}, new lora_A shape: {mock_layer_for_rank_adj.lora_A.shape}")

--- GRIT Stage 3: Specialized Components --- (Illustrative Examples with Updated Names)

1. Dynamic Rank Adjustment Example
  SV History updated (step 0). Current rank: 8
  SV History updated (step 5). Current rank: 8
  SV History updated (step 10). Current rank: 8
  SV History updated (step 15). Current rank: 8
  Before rank adjustment: MockFlexLoRA(current_rank=8, min_rank=4, max_rank=12)
  Created MockRankScheduler: initial_rank=8, min_rank=4, max_rank=12
  After rank adjustment: MockFlexLoRA(current_rank=8, min_rank=4, max_rank=12), new lora_A shape: torch.Size([64, 8])


In [38]:
# 2. Stability Assurance Example
print("\n2. Stability Assurance Example")
# DampedKFAC (formerly DampedKroneckerFactors)
damped_k_factors = DampedKFAC(in_features=64, out_features=32, damping=(1e-4, 1e-4)) # Pass damping as tuple
dummy_A_cov = torch.rand(64, 64) * 0.1 + torch.eye(64) 
dummy_B_cov = torch.rand(32, 32) * 0.1 + torch.eye(32)
dummy_A_cov = (dummy_A_cov + dummy_A_cov.T)/2
dummy_B_cov = (dummy_B_cov + dummy_B_cov.T)/2
damped_k_factors.update(dummy_A_cov, dummy_B_cov) # Call renamed update method
inv_A, inv_B = damped_k_factors.get_inverse_factors() # Call renamed get_inverse_factors method
print(f"  DampedKFAC: A_inv shape {inv_A.shape}, B_inv shape {inv_B.shape}")

# fisher_clip (formerly fisher_clip_gradient)
dummy_grad = torch.randn(32, 64) 
# fisher_clip expects inv_fisher_factor_A, inv_fisher_factor_B
clipped_grad = fisher_clip(dummy_grad, inv_A, inv_B, clip_norm=0.5)
print(f"  Fisher Clip: Original grad norm {torch.norm(dummy_grad)}, Clipped grad norm {torch.norm(clipped_grad)}")
original_ng_norm = torch.norm(inv_B @ dummy_grad @ inv_A) # Natural grad norm
clipped_ng_norm = torch.norm(inv_B @ clipped_grad @ inv_A)
print(f"  Fisher Clip: Original natural grad norm {original_ng_norm:.4f}, Clipped natural grad norm {clipped_ng_norm:.4f}")


2. Stability Assurance Example
  DampedKFAC: A_inv shape torch.Size([64, 64]), B_inv shape torch.Size([32, 32])
  Fisher Clip: Original grad norm 45.446922302246094, Clipped grad norm 0.48223498463630676
  Fisher Clip: Original natural grad norm 47.1211, Clipped natural grad norm 0.5000


In [39]:
# 3. Memory Optimization Example
print("\n3. Memory Optimization Example")
# HalfPrecisionFisher (formerly HalfPrecisionKroneckerFactors)
hp_fisher_factors = HalfPrecisionFisher(in_features=64, out_features=32, ema_decay=0.95)
hp_fisher_factors.update(dummy_A_cov, dummy_B_cov) # Call renamed update method
hp_A_f32, hp_B_f32 = hp_fisher_factors.get_factors_f32()
print(f"  HalfPrecisionFisher: A_half dtype {hp_fisher_factors.A_half.dtype}, B_half dtype {hp_fisher_factors.B_half.dtype}")
print(f"  Retrieved factors (f32): A shape {hp_A_f32.shape}, B shape {hp_B_f32.shape}")

# Block SVD (updated signature)
large_matrix = torch.randn(512, 256)
# block_svd now takes a single block_size argument, assuming square blocks or applying it to both dimensions
# For non-square, one might pass different sizes, but current takes one.
# Let's test with a size that divides one dim and not other to see behavior with current block_svd.
# Or better, use a size that fits well like 128 for both if matrix is 512x256.
# The implementation uses block_size for both block_size_rows and block_size_cols.
U_blocks, S_val_blocks, Vh_blocks = block_svd(large_matrix, block_size=128) 
print(f"  Block SVD: Performed on {large_matrix.shape} matrix with block_size=128.")
print(f"    Got {len(U_blocks)} U_blocks, {len(S_val_blocks)} S_value_blocks, {len(Vh_blocks)} Vh_blocks.")
if U_blocks and U_blocks[0].numel() > 0 : # Check if first block is not empty
    print(f"    Example: First U_block shape {U_blocks[0].shape}, First S_values_block shape {S_val_blocks[0].shape if S_val_blocks[0].numel() > 0 else 'empty'}")

print("\n--- End of Illustrative Examples (Updated Names) ---")


3. Memory Optimization Example
  HalfPrecisionFisher: A_half dtype torch.float16, B_half dtype torch.float16
  Retrieved factors (f32): A shape torch.Size([64, 64]), B shape torch.Size([32, 32])
  Block SVD: Performed on torch.Size([512, 256]) matrix with block_size=128.
    Got 8 U_blocks, 8 S_value_blocks, 8 Vh_blocks.
    Example: First U_block shape torch.Size([128, 128]), First S_values_block shape torch.Size([128])

--- End of Illustrative Examples (Updated Names) ---


# GRIT Stage 4: Evaluation

In [40]:
import torch
import torch.nn as nn
import re
import numpy as np
import time # For ThroughputBenchmark

# Placeholder for an evaluation function that returns a single performance score (higher is better)
# This would need to be properly defined or imported based on the specific task and model.
def evaluate_model(model: nn.Module, data_loader: torch.utils.data.DataLoader, device: torch.device) -> float:
    """
    Placeholder for model evaluation.
    Returns a single scalar performance metric (e.g., accuracy, F1-score).
    Higher values should indicate better performance.
    """
    print(f"Warning: Using placeholder 'evaluate_model' in ForgettingMonitor. Please implement for actual metrics.")
    # Example: calculate loss, and return negative loss if higher is better
    # model.eval()
    # total_loss = 0
    # num_batches = 0
    # with torch.no_grad():
    #     for batch in data_loader:
    #         # Assuming batch format and loss calculation similar to a standard loop
    #         # This part is highly dependent on your data and model specifics
    #         # inputs, targets = batch[0].to(device), batch[1].to(device) 
    #         # outputs = model(inputs)
    #         # loss = torch.nn.functional.cross_entropy(outputs, targets) # Example loss
    #         # total_loss += loss.item()
    #         num_batches += 1
    # model.train()
    # if num_batches == 0:
    #     return 0.0
    # return - (total_loss / num_batches) # Negative loss, so higher is better
    return 0.5 # Dummy performance score

### 1. Convergence Metrics Implementation

In [41]:
class CurvatureAnalyzer:
    def __init__(self, model, damping=1e-3): # Model is the full model
        self.model = model
        self.damping = damping
        # self.hessian_buffers = {} 

    def compute_layer_curvature(self, layer_with_factors): # layer should have A_factor, B_factor
        # Assumes layer has A_factor and B_factor (e.g., from KFACLayer.kfac_manager.A_accum)
        A_factor, B_factor = None, None
        if hasattr(layer_with_factors, 'kfac_handler') and hasattr(layer_with_factors.kfac_handler, 'A_factor') and hasattr(layer_with_factors.kfac_handler, 'B_factor'):
            A_factor = layer_with_factors.kfac_handler.A_factor
            B_factor = layer_with_factors.kfac_handler.B_factor
        elif hasattr(layer_with_factors, 'A_factor') and hasattr(layer_with_factors, 'B_factor'): # Direct attributes
            A_factor = layer_with_factors.A_factor
            B_factor = layer_with_factors.B_factor
        
        if A_factor is None or B_factor is None or A_factor.numel() == 0 or B_factor.numel() == 0:
            # print(f"compute_layer_curvature: Factors not found or empty for layer {type(layer_with_factors).__name__}")
            return None

        A = A_factor + self.damping * torch.eye(A_factor.size(0), device=A_factor.device, dtype=A_factor.dtype)
        B = B_factor + self.damping * torch.eye(B_factor.size(0), device=B_factor.device, dtype=B_factor.dtype)
        
        if A.shape[0] == 0 or B.shape[0] == 0 : return None # Cannot compute kron with empty dimension
        return torch.kron(A, B) # In KFAC, this usually kron(B,A) for Fisher if A=act, B=grad_out
    
    def track_landscape(self, dataloader=None): # dataloader is not used in this version
        eigenvalues = {}
        for name, module_in_model in self.model.named_modules():
            # Check if the module itself is what we are interested in (e.g. has kfac_handler)
            # This logic depends on how KFAC factors are stored/accessed.
            # Assuming KFAC factors are on the nn.Linear layer via 'kfac_handler' from instrument_layer
            # or directly as A_factor, B_factor if KFACLayer itself is the module.
            
            # Try to get factors:
            # This part is tricky because KFAC factors might be on module.kfac_handler (from KFACLayer)
            # or if module_in_model *is* a KFACLayer instance (less likely from instrument_layer design).
            # Let's assume we check for 'kfac_handler' which holds KFACLayer.
            # The kfac_handler would then have A_factor and B_factor properties.
            
            layer_to_check_factors_on = module_in_model 
            # If FlexLoRA, we might want curvature of base_layer if it has KFAC
            if isinstance(module_in_model, FlexLoRA) and hasattr(module_in_model.base_layer, 'kfac_handler'):
                 layer_to_check_factors_on = module_in_model.base_layer


            # Check if this layer_to_check_factors_on (e.g., an nn.Linear) has a KFAC handler
            # and that handler has valid factors.
            has_valid_factors = False
            if hasattr(layer_to_check_factors_on, 'kfac_handler'):
                handler = layer_to_check_factors_on.kfac_handler
                if hasattr(handler, 'A_factor') and hasattr(handler, 'B_factor') and \
                   handler.A_factor is not None and handler.B_factor is not None and \
                   handler.A_factor.numel() > 0 and handler.B_factor.numel() > 0:
                    has_valid_factors = True
            
            if has_valid_factors:
                try:
                    # Pass the module that has the kfac_handler (e.g. the nn.Linear)
                    F = self.compute_layer_curvature(layer_to_check_factors_on) 
                    if F is not None and F.numel() > 0:
                        # eigvalsh requires symmetric matrix. KFAC A,B should be, so F should be.
                        eigvals = torch.linalg.eigvalsh(F) 
                        eigenvalues[name] = {
                            'max': eigvals[-1].item(), # eigvalsh sorts ascending
                            'min': eigvals[0].item(),
                            'trace': eigvals.sum().item()
                        }
                    else:
                        eigenvalues[name] = {'max': float('nan'), 'min': float('nan'), 'trace': float('nan')}
                except Exception as e:
                    # print(f"Could not compute eigenvalues for layer {name}: {e}")
                    eigenvalues[name] = {'max': float('nan'), 'min': float('nan'), 'trace': float('nan')}
        return eigenvalues

def fisher_rao_distance(P: torch.Tensor, Q: torch.Tensor, epsilon=1e-7) -> float: # Adjusted epsilon 
    # P, Q are data matrices (samples x features)
    if P.ndim != 2 or Q.ndim != 2 or P.shape[1] != Q.shape[1]:
        # print("Fisher-Rao: P and Q must be 2D tensors with the same number of features.")
        return float('nan')
    if P.shape[0] <= 1 or Q.shape[0] <= 1 or P.shape[1] == 0 : # Need >1 sample for cov, and non-zero features
        # print("Fisher-Rao: Need >1 sample per distribution and non-zero features.")
        return float('nan')

    device = P.device
    
    try:
        # Covariance matrices
        cov_p = torch.cov(P.T) + epsilon * torch.eye(P.size(1), device=device, dtype=P.dtype)
        cov_q = torch.cov(Q.T) + epsilon * torch.eye(Q.size(1), device=device, dtype=Q.dtype)
        
        # Check for PSD and invertibility if using matrix_power(0.5) for sqrt
        # Using cholesky decomposition for matrix sqrt: L L^T = M
        # Or, eigenvalue decomposition: M = V diag(lambda) V^T => M^0.5 = V diag(sqrt(lambda)) V^T
        # torch.linalg.matrix_power requires PSD for fractional powers.
        # Let's use direct formula with matrix_power and trust input epsilon handles it.

        # Ensure symmetry for matrix_power if needed (cov should be symmetric)
        # cov_p = (cov_p + cov_p.T) / 2
        # cov_q = (cov_q + cov_q.T) / 2

        sqrt_p = torch.linalg.matrix_power(cov_p, 0.5)
        
        # cross_term = sqrt_p @ cov_q @ sqrt_p
        # Corrected cross_term stabilization from original for matrix_power
        cross_term_inner = sqrt_p @ cov_q @ sqrt_p
        # Stabilize for matrix_power by ensuring it's symmetric and adding epsilon if needed
        cross_term_inner_stabilized = (cross_term_inner + cross_term_inner.T) / 2 
        cross_term_inner_stabilized = cross_term_inner_stabilized + epsilon * torch.eye(cross_term_inner_stabilized.size(0), device=device, dtype=P.dtype)
        
        sqrt_cross_term = torch.linalg.matrix_power(cross_term_inner_stabilized, 0.5)
        
        distance = torch.trace(cov_p + cov_q - 2 * sqrt_cross_term).item()
        
        # Distance should be non-negative. If it's slightly negative due to numerical precision, clamp to 0.
        return max(0.0, distance)
    except Exception as e:
        # print(f"Error in Fisher-Rao distance computation: {e}. Returning NaN.")
        return float('nan')


def subspace_alignment(source_eigenvectors: torch.Tensor, target_eigenvectors: torch.Tensor) -> dict:
    # source_eigenvectors and target_eigenvectors are matrices where columns are eigenvectors
    # Shape: (feature_dim, num_eigenvectors)
    if source_eigenvectors.numel() == 0 or target_eigenvectors.numel() == 0 or \
       source_eigenvectors.ndim != 2 or target_eigenvectors.ndim != 2 or \
       source_eigenvectors.shape[0] != target_eigenvectors.shape[0]: # Feature dim must match
        # print("Subspace Alignment: Invalid input tensors. Check dims and feature size.")
        return {'mean_angle': float('nan'), 'max_angle': float('nan'), 'min_angle': float('nan')}

    try:
        # Orthonormalize bases if they are not already (e.g. if they are just SVD U matrices)
        # Assuming they are already orthonormal bases (e.g. U from SVD)
        # U_s = torch.linalg.qr(source_eigenvectors).Q
        # U_t = torch.linalg.qr(target_eigenvectors).Q
        U_s = source_eigenvectors
        U_t = target_eigenvectors
        
        # k = min(U_s.shape[1], U_t.shape[1]) # Number of vectors to compare
        # torch.diag(U_s.T @ U_t) implies comparing corresponding vectors
        # if number of vectors (subspace dim) is the same.
        # For principal angles between subspaces of potentially different dimensions k_s, k_t:
        # We compute SVD of U_s^T @ U_t. The singular values are cosines of principal angles.
        
        # If U_s and U_t have different number of columns (vectors), U_s.T @ U_t is not square.
        # Taking diag of non-square matrix extracts main diagonal.
        # To get principal angles correctly:
        correlation_matrix = U_s.T @ U_t
        singular_values_of_corr = torch.linalg.svdvals(correlation_matrix) # These are cos(theta_i)
        
        # Ensure values are within [-1, 1] for acos
        cos_angles = torch.clamp(singular_values_of_corr, -1.0 + 1e-7, 1.0 - 1e-7)
        principal_angles_rad = torch.acos(cos_angles)
        
        if principal_angles_rad.numel() == 0: # No angles computed
             return {'mean_angle': float('nan'), 'max_angle': float('nan'), 'min_angle': float('nan')}

        return {
            'mean_angle': torch.mean(principal_angles_rad).item(),
            'max_angle': torch.max(principal_angles_rad).item(),
            'min_angle': torch.min(principal_angles_rad).item()
        }
    except Exception as e:
        # print(f"Error in subspace alignment: {e}. Returning NaNs.")
        return {'mean_angle': float('nan'), 'max_angle': float('nan'), 'min_angle': float('nan')}

### 2. Quality Assessment Framework

In [42]:
class InstructionValidator:
    def __init__(self, criteria_mapping: dict): # str_criterion -> regex_pattern
        self.criteria = criteria_mapping

    # validate_single will be used by the main validate method
    def _validate_single(self, instruction: str, response: str) -> float:
        if not self.criteria:
            return 1.0

        scores = {}
        for criterion, pattern in self.criteria.items():
            try:
                matches = re.findall(pattern, response, re.IGNORECASE)
                scores[criterion] = len(matches) > 0
            except Exception as e:
                print(f"Regex error for criterion '{criterion}' with pattern '{pattern}': {e}")
                scores[criterion] = False

        if not scores: 
            return 0.0
            
        return sum(scores.values()) / len(scores)

    # grit_evaluation_loop calls validate(test_samples). Assuming test_samples is a list.
    def validate(self, test_samples: list[tuple[str, str]]) -> float:
        """
        Validates a list of (instruction, response) pairs.
        """
        if not test_samples:
            return 0.0
        
        all_scores = []
        for instruction, response in test_samples:
            all_scores.append(self._validate_single(instruction, response))
        
        return np.mean(all_scores) if all_scores else 0.0

def domain_shift_score(source_feats: torch.Tensor, target_feats: torch.Tensor) -> float:
    if source_feats.numel() == 0 or target_feats.numel() == 0:
        return float('nan') # Keep check for empty tensors
    source_mean = source_feats.mean(dim=0)
    target_mean = target_feats.mean(dim=0)
    return torch.norm(source_mean - target_mean, p=2).item()


class ForgettingMonitor:
    # Using existing robust version that stores model instance, not state_dict
    def __init__(self, original_model_instance: nn.Module):
        self.original_model = original_model_instance
        # Ensure original_model is in eval mode for consistent evaluation
        self.original_model.eval()


    def compute_forgetting(self, tuned_model: nn.Module, val_loader: torch.utils.data.DataLoader, device: torch.device) -> float:
        # Ensure original_model is on the correct device
        self.original_model.to(device)
        
        # tuned_model should also be in eval mode and on correct device
        tuned_model.eval().to(device)

        original_perf = evaluate_model(self.original_model, val_loader, device)
        tuned_perf = evaluate_model(tuned_model, val_loader, device)
        
        # Restore tuned_model's mode if it was changed (evaluate_model should ideally handle this)
        # For simplicity, assuming evaluate_model doesn't change it, or it's fine.

        if original_perf is None or tuned_perf is None: # Check if evaluate_model failed
             print("Warning: ForgettingMonitor received None performance. Returning NaN.")
             return float('nan')

        if original_perf == 0: # Avoid division by zero
            if tuned_perf < 0: return float('inf') # Performance dropped from zero to negative
            if tuned_perf == 0: return 0.0 # No change from zero
            return -float('inf') # Performance improved from zero (negative forgetting is improvement)
        
        return (original_perf - tuned_perf) / abs(original_perf) # Use abs for safety if perf can be negative

### 3. Efficiency Benchmarking System

In [43]:
def update_density(model: nn.Module) -> float:
    total_params = 0
    updated_lora_param_sum = 0 # Sum of numel for LoRA params considered updated
    
    for name, param in model.named_parameters():
        if param.requires_grad and 'lora' in name: # Consider only trainable LoRA parameters
            if param.grad is not None:
                updates_norm = torch.norm(param.grad.detach(), p=2)
                # Density interpretation: 1.0 if param's gradient norm is significant, 0.0 otherwise
                is_updated_significantly = updates_norm > 1e-6
                if is_updated_significantly:
                    updated_lora_param_sum += param.numel()
            # else: if no grad, not updated
        
        # total_params includes ALL parameters in the model
        total_params += param.numel()

    if total_params == 0:
        return 0.0
    return updated_lora_param_sum / total_params


# MemoryTracker
class MemoryTracker:
    def __init__(self):
        self.peak_mem = 0 # Stores peak memory in bytes
        # Reset CUDA peak stats at the start of tracking for this instance for more predictable one-shot track()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats() 
            self.peak_mem = torch.cuda.max_memory_allocated() # Initial peak

    def track(self) -> float: # Returns peak memory in GB
        if torch.cuda.is_available():
            torch.cuda.synchronize() # Ensure all ops are done before reading memory
            # max_memory_allocated gives peak since last reset.
            # Logic wants to track max over calls IF instance is reused.
            # For single call MemoryTracker().track(), this is fine.
            self.peak_mem = max(self.peak_mem, torch.cuda.max_memory_allocated())
        return self.peak_mem / (1024 ** 3) # Return in GB


# ThroughputBenchmark: Existing implementation is robust and matches core logic.
class ThroughputBenchmark:
    def __init__(self, batch_size: int):
        self.batch_size = batch_size
        self.timings = [] # Stores time per step

    def log_step(self, start_time: float, end_time: float): # Parameters named start_time, end_time
        self.timings.append(end_time - start_time)

    def compute_throughput(self) -> float: # Samples/second
        if not self.timings:
            return 0.0
        avg_time_per_step = np.mean(self.timings)
        if avg_time_per_step == 0: # Avoid division by zero
            return float('inf') if self.batch_size > 0 else 0.0
        return self.batch_size / avg_time_per_step

In [44]:
### Evaluation Protocol Implementation

def grit_evaluation_loop(
    model: nn.Module,
    original_model: nn.Module, # For forgetting monitor
    train_loader: torch.utils.data.DataLoader, 
    test_loader: torch.utils.data.DataLoader,
    domains: list[tuple[torch.Tensor, torch.Tensor]], # List of (source_features, target_features)
    pretrained_subspace: torch.Tensor, # For subspace alignment
    criteria: dict, # For InstructionValidator
    test_samples: list[tuple[str, str]], # For InstructionValidator: [(instr, model_resp), ...]
    device: torch.device,
    curvature_damping: float = 1e-3, # Added for CurvatureAnalyzer
    fisher_rao_epsilon: float = 1e-6 # Added for fisher_rao_distance
    ):

    results = {}
    model.eval() # Ensure model is in evaluation mode

    # --- Convergence Metrics ---
    print("Calculating Convergence Metrics...")
    
    # 1.1 Loss Curvature Tracking
    # The dataloader arg for track_landscape is not used by the CurvatureAnalyzer logic.
    curvature_analyzer = CurvatureAnalyzer(model, damping=curvature_damping)
    # track_landscape dataloader arg is for potential future use, not current.
    curvature_data_per_layer = curvature_analyzer.track_landscape(dataloader=train_loader) 
    
    max_overall_curvature = float('-inf')
    min_overall_curvature = float('inf')
    avg_trace_curvature = 0.0 # Initialize as float
    num_layers_for_curvature = 0

    if curvature_data_per_layer:
        for layer_name, data in curvature_data_per_layer.items():
            if data and not np.isnan(data['max']): # Check for NaN
                max_overall_curvature = max(max_overall_curvature, data['max'])
            if data and not np.isnan(data['min']):
                min_overall_curvature = min(min_overall_curvature, data['min'])
            if data and not np.isnan(data['trace']):
                avg_trace_curvature += data['trace']
                num_layers_for_curvature +=1
        if num_layers_for_curvature > 0 :
             avg_trace_curvature /= num_layers_for_curvature
        else:
             avg_trace_curvature = float('nan') # No valid layers found
    
    # Handle cases where no curvature data was computed
    final_max_curvature = max_overall_curvature if max_overall_curvature != float('-inf') else float('nan')
    final_min_curvature = min_overall_curvature if min_overall_curvature != float('inf') else float('nan')


    # 1.2 Fisher-Rao Distance
    fr_distance = float('nan') # Default to NaN
    try:
        # Robustly get features from dataset object within loader
        train_features = getattr(train_loader.dataset, 'features', torch.empty(0, device=device))
        test_features = getattr(test_loader.dataset, 'features', torch.empty(0, device=device))
        if train_features.numel() > 0 and test_features.numel() > 0:
            fr_distance = fisher_rao_distance(train_features.to(device), test_features.to(device), epsilon=fisher_rao_epsilon)
        else:
            print("Warning: Could not extract features for Fisher-Rao distance. Set to NaN.")
    except AttributeError:
        print("Warning: '.features' attribute not found on datasets for Fisher-Rao. Set to NaN.")
    except Exception as e:
        print(f"Error calculating Fisher-Rao distance: {e}. Set to NaN.")

    # 1.3 Subspace Alignment
    # Assuming model has an attribute 'subspace' (e.g., extracted LoRA parameters as a matrix)
    model_current_subspace = getattr(model, 'subspace', torch.empty(0, device=device))
    sa_metrics = subspace_alignment(model_current_subspace.to(device), pretrained_subspace.to(device))

    results['convergence'] = {
        'max_curvature': final_max_curvature, # Aggregated max
        'min_curvature': final_min_curvature, # Aggregated min 
        'avg_trace_curvature': avg_trace_curvature, # Aggregated avg trace
        'fisher_rao': fr_distance,
        'subspace_alignment': sa_metrics # sa_metrics is already a dict {'mean_angle', 'max_angle', 'min_angle'}
    }

    # --- Quality Assessment ---
    print("Calculating Quality Assessment Metrics...")
    # 2.1 Instruction Following Accuracy
    instruction_validator = InstructionValidator(criteria)
    instr_acc = instruction_validator.validate(test_samples) # test_samples is list of (instr, resp)

    # 2.2 Domain Adaptation Metrics
    domain_adaptation_scores = []
    if domains: # Check if domains list is not empty
        for src_feats, tgt_feats in domains:
            if src_feats.numel() > 0 and tgt_feats.numel() > 0:
                domain_adaptation_scores.append(domain_shift_score(src_feats.to(device), tgt_feats.to(device)))
            else:
                domain_adaptation_scores.append(float('nan')) # Handle empty features per domain
    else: # No domains provided
        domain_adaptation_scores = [float('nan')]


    # 2.3 Catastrophic Forgetting Tests
    forgetting_monitor = ForgettingMonitor(original_model) # original_model passed as arg
    forgetting_rate = forgetting_monitor.compute_forgetting(model, test_loader, device)
    
    results['quality'] = {
        'instruction_acc': instr_acc,
        'domain_scores': domain_adaptation_scores, 
        'forgetting_rate': forgetting_rate
    }

    # --- Efficiency Benchmarks ---
    print("Calculating Efficiency Benchmarks...")
    # 3.1 Parameter Update Density
    # update_density assumes gradients are present (e.g., after a backward pass if called during training)
    # For a pure evaluation script, gradients might be zero or None.
    # We'll call it, but it might return 0 if no grads.
    param_density = update_density(model) 

    # 3.2 Memory Consumption
    mem_tracker = MemoryTracker() # Instantiated here
    peak_memory_gb = mem_tracker.track() 

    # 3.3 Training Throughput
    # This is tricky in a standalone eval loop unless timings were previously logged
    # or this loop itself performs timed operations.
    # ThroughputBenchmark().compute_throughput() implies an existing instance or pre-logged steps.
    # For a dummy call as in existing code:
    throughput_analyzer = ThroughputBenchmark(batch_size=getattr(test_loader, 'batch_size', 1))
    # Example: If this loop were to run some inferences, log them here.
    # For now, it will return 0 or inf if no log_step calls.
    # if hasattr(test_loader, 'batch_size'):
    #    dummy_start_time = time.time()
    #    # Simulate some work, e.g., one pass over test_loader for timing
    #    # with torch.no_grad():
    #    #     for _ in test_loader: pass 
    #    # throughput_analyzer.log_step(dummy_start_time, time.time())
    # else:
    #    print("Warning: test_loader has no batch_size, throughput may be 0 or incorrect.")
    
    # result dict asks for 'samples_per_sec', which is what compute_throughput returns.
    # 'throughput' variable name from snippet, key 'samples_per_sec' in output.
    throughput_val = throughput_analyzer.compute_throughput()


    results['efficiency'] = {
        'update_density': param_density,
        'memory_gb': peak_memory_gb,
        'samples_per_sec': throughput_val
    }
    
    model.train() # Restore model to train mode if it was changed
    print("GRIT Evaluation Loop Finished.")
    return results

In [45]:
print("Running basic example for grit_evaluation_loop (with dummy data)...")

# Dummy model setup (ensure KFAC factors A_factor, B_factor for CurvatureAnalyzer)
class DummyLayerWithKFAC(nn.Module):
    def __init__(self, in_f, out_f, device='cpu'):
        super().__init__()
        self.linear = nn.Linear(in_f, out_f)
        # Simulate KFAC factors being present and initialized (non-empty)
        self.register_buffer('A_factor', torch.eye(in_f, device=device) * 0.1 + torch.rand(in_f, in_f, device=device)*0.01)
        self.register_buffer('B_factor', torch.eye(out_f, device=device) * 0.1 + torch.rand(out_f, out_f, device=device)*0.01)
    def forward(self, x):
        return self.linear(x)

class DummyModel(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.layer1_kfac = DummyLayerWithKFAC(10, 20, device=device)
        self.layer2_no_kfac = nn.Linear(20, 5) 
        self.layer3_kfac = DummyLayerWithKFAC(5,2, device=device)
        # Dummy model subspace for subspace_alignment (e.g., flattened params of some layers)
        # Shape: (param_dim, num_principal_components_or_rank)
        self.subspace = torch.randn(10 * 20, 4, device=device) 
    
    def forward(self, x):
        x = torch.relu(self.layer1_kfac(x))
        x = self.layer2_no_kfac(x)
        x = self.layer3_kfac(x)
        return x

current_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {current_device}")

model_to_eval = DummyModel(device=current_device).to(current_device)
original_model_for_forgetting = DummyModel(device=current_device).to(current_device) # Fresh instance

Running basic example for grit_evaluation_loop (with dummy data)...
Using device: cuda


In [46]:
# Dummy data and loaders
class DummyDatasetWithFeatures(torch.utils.data.Dataset):
    def __init__(self, num_samples, num_features, num_classes, device='cpu'):
        self.num_samples = num_samples
        # Simulate features for Fisher-Rao and domain shift
        self.features = torch.randn(num_samples, num_features, device=device)
        self.labels = torch.randint(0, num_classes, (num_samples,), device=device)
    def __len__(self):
        return self.num_samples
    def __getitem__(self, idx):
        # Dataloader will collate these. For model input, usually just features[idx]
        return self.features[idx], self.labels[idx] 

train_dataset_inst = DummyDatasetWithFeatures(100, 10, 5, device=current_device)
test_dataset_inst = DummyDatasetWithFeatures(50, 10, 5, device=current_device)
# Ensure dummy model's input layer matches feature dim (10)
train_loader_inst = torch.utils.data.DataLoader(train_dataset_inst, batch_size=10)
test_loader_inst = torch.utils.data.DataLoader(test_dataset_inst, batch_size=10)

# Dummy parameters for the grit_evaluation_loop
dummy_domain_data = [
    (torch.randn(50, 10, device=current_device), torch.randn(50, 10, device=current_device)), 
    (torch.randn(30, 10, device=current_device), torch.randn(30, 10, device=current_device))
]
# Pretrained subspace: e.g., SVD of initial LoRA weights or some reference subspace
# Shape should match model.subspace for alignment. (param_dim, num_vectors)
dummy_pretrained_model_subspace = torch.randn(10 * 20, 4, device=current_device) 

dummy_instruction_criteria = {
    "completeness": r"task is complete|done|finished",
    "politeness": r"please|thank you"
}
dummy_test_instruction_samples = [
    ("Instruction: Be polite and complete task.", "Response: The task is complete, thank you!"),
    ("Instruction: Summarize.", "Response: Summary done.")
]

In [47]:
# Simulate some gradients for update_density (e.g. from a backward pass)
# This is artificial for a standalone eval script. In real use, grads would be from training.
for name, param in model_to_eval.named_parameters():
    if 'lora' in name and param.requires_grad: # Assuming some LoRA params might exist if model was LoRA-fied
        param.grad = torch.rand_like(param) * 0.01
    # For non-LoRA params, grads could also exist. update_density only looks at 'lora'.
    # Let's make a KFAC layer's linear weight have "lora" in its name for testing update_density
    # This is a hack for the dummy example.
if hasattr(model_to_eval, 'layer1_kfac') and hasattr(model_to_eval.layer1_kfac, 'linear'):
    # Give it a grad
    model_to_eval.layer1_kfac.linear.weight.grad = torch.rand_like(model_to_eval.layer1_kfac.linear.weight) * 0.01
    # Temporarily rename for update_density to pick up (VERY HACKY FOR DEMO)
    # In a real scenario, LoRA layers would be named appropriately.
    # This part is tricky for a generic dummy example without actual LoRA layers.
    # We'll rely on 'lora' being in the name if we want update_density > 0.
    # For now, let's assume no LoRA params by default in this DummyModel to avoid complex renaming.
    # So, update_density will likely be 0 unless the model is modified to have 'lora' in names.
    pass


# Call the updated evaluation loop
eval_results_output = grit_evaluation_loop(
    model=model_to_eval,
    original_model=original_model_for_forgetting,
    train_loader=train_loader_inst,
    test_loader=test_loader_inst,
    domains=dummy_domain_data,
    pretrained_subspace=dummy_pretrained_model_subspace,
    criteria=dummy_instruction_criteria,
    test_samples=dummy_test_instruction_samples,
    device=current_device,
    curvature_damping=1e-4, # example override
    fisher_rao_epsilon=1e-5 # example override
)

Calculating Convergence Metrics...
Calculating Quality Assessment Metrics...
Calculating Efficiency Benchmarks...
GRIT Evaluation Loop Finished.


In [48]:
print("\\n--- GRIT Evaluation Results (Dummy Run) ---")
import json
# Custom handler for NaN/Inf float values for JSON serialization
def json_default_handler(x):
    if isinstance(x, float) and (np.isnan(x) or np.isinf(x)):
        return str(x)
    if isinstance(x, torch.Tensor): # Should not happen if .item() is used
         return str(x.tolist()) if x.numel() > 1 else str(x.item())
    return x

print(json.dumps(eval_results_output, indent=2, default=json_default_handler))
print("--- End of Dummy Run ---")

\n--- GRIT Evaluation Results (Dummy Run) ---
{
  "convergence": {
    "max_curvature": NaN,
    "min_curvature": NaN,
    "avg_trace_curvature": 0.0,
    "fisher_rao": NaN,
    "subspace_alignment": {
      "mean_angle": 0.00048828125,
      "max_angle": 0.00048828125,
      "min_angle": 0.00048828125
    }
  },
  "quality": {
    "instruction_acc": 0.75,
    "domain_scores": [
      0.7748472690582275,
      0.5999931693077087
    ],
    "forgetting_rate": 0.0
  },
  "efficiency": {
    "update_density": 0.0,
    "memory_gb": 0.0865321159362793,
    "samples_per_sec": 0.0
  }
}
--- End of Dummy Run ---


# GRIT Adapter

In [None]:
class GRITAdapter(nn.Module):
    def __init__(self, model: nn.Module,
                 layers_to_adapt_substrings: list[str],
                 rank_range: tuple[int, int] = (8, 64),
                 kfac_ema_decay: float = 0.95,
                 kfac_damping: tuple[float, float] = (1e-7, 1e-5),
                 fisher_k: float = 0.1, # Fraction for eigen decomposition
                 subspace_buffer_size: int = 1024,
                 # Rank scheduling params
                 rank_scheduler_warmup_steps: int = 1000,
                 rank_scheduler_ema_alpha: float = 0.1,
                 sv_tracker_window_size: int = 100,
                 sv_history_min_samples: int = 5,
                 # Neural reprojection params
                 rho_eigen_sum: float = 0.9,
                 # Parameter fusion params
                 momentum_fusion_beta: float = 0.9,
                 # Update params
                 learning_rate: float = 0.001,
                 damping_lora_fisher: float = 1e-6
                 ):
        super().__init__()
        self.model = model
        self.layers_to_adapt_substrings = layers_to_adapt_substrings
        self.rank_range = rank_range
        self.kfac_ema_decay = kfac_ema_decay
        self.kfac_damping = kfac_damping
        self.fisher_k = fisher_k
        self.subspace_buffer_size = subspace_buffer_size
        self.rank_scheduler_warmup_steps = rank_scheduler_warmup_steps
        self.rank_scheduler_ema_alpha = rank_scheduler_ema_alpha
        self.sv_tracker_window_size = sv_tracker_window_size
        self.sv_history_min_samples = sv_history_min_samples
        self.rho_eigen_sum = rho_eigen_sum
        self.momentum_fusion_beta = momentum_fusion_beta
        self.learning_rate = learning_rate
        self.damping_lora_fisher = damping_lora_fisher

        self.all_kfac_handlers = []
        self.fisher_engines = {} # Store FisherEigen instances per layer/block
        self.subspace_manager = None # Will be initialized after KFAC instrumentation
        self.sv_trackers = {} # For FlexLoRA layers
        self.rank_schedulers = {} # For FlexLoRA layers
        self.momentum_fuser = MomentumFusion(beta=self.momentum_fusion_beta)

        self._initialize_grit()

    def _initialize_grit(self):
        # 1. Inject LoRA layers
        lora_inject(self.model, self.layers_to_adapt_substrings, self.rank_range)

        # 2. Instrument KFAC
        # KFAC instrumentation might depend on the structure of the base model.
        # This part might need to be more generic or configurable.
        
        model_blocks_attr_names = ['trf_blocks', 'blocks'] # Common names for transformer blocks
        actual_blocks_attr_name = None
        for attr_name in model_blocks_attr_names:
            if hasattr(self.model, attr_name) and isinstance(getattr(self.model, attr_name), nn.ModuleList):
                actual_blocks_attr_name = attr_name
                break

        if actual_blocks_attr_name:
            model_blocks = getattr(self.model, actual_blocks_attr_name)
            for i, block in enumerate(model_blocks):
                # Ensure instrument_layer can find layers like 'attn.q_proj' or 'mlp' within the block
                # This might require block to have conventional sub-module names or instrument_layer to be more flexible
                block_kfac_handlers = instrument_layer(block) 
                self.all_kfac_handlers.extend(block_kfac_handlers)
        else:
            print(f"Warning: GRITAdapter KFAC instrumentation could not find a suitable blocks attribute (tried {model_blocks_attr_names}). KFAC might not be fully set up.")

        # Initialize FisherEigen and SubspaceBuffer based on KFAC handlers
        # This assumes KFAC handlers are for nn.Linear layers that were also LoRA adapted.
        # The number of "blocks" for SubspaceBuffer should match num KFAC-instrumented entities.
        num_kfac_instrumented_entities = len(self.all_kfac_handlers)
        if num_kfac_instrumented_entities > 0:
            # device = next(self.model.parameters()).device # Get model device
            # Assuming all KFAC handlers and the model are on the same device
            # Get device from the first KFAC handler's A_factor if available, else model param
            if self.all_kfac_handlers and self.all_kfac_handlers[0].A_factor is not None:
                 device = self.all_kfac_handlers[0].A_factor.device
            else:
                 try:
                     device = next(self.model.parameters()).device
                 except StopIteration:
                     device = torch.device("cpu") # Default if model has no parameters

            self.subspace_manager = SubspaceBuffer(num_blocks=num_kfac_instrumented_entities,
                                                   buffer_size=self.subspace_buffer_size,
                                                   device=device) # Pass device
            for i, handler in enumerate(self.all_kfac_handlers):
                self.fisher_engines[handler.layer_type] = FisherEigen().to(device) # Ensure FisherEigen is on the correct device
        else:
            print("Warning: No KFAC handlers found. FisherEigen and SubspaceBuffer not initialized.")


        # Initialize SVTrackers and RankSchedulers for FlexLoRA layers
        min_r, max_r = self.rank_range
        initial_r = min_r # As per lora_inject
        for name, module in self.model.named_modules():
            if isinstance(module, FlexLoRA):
                self.sv_trackers[name] = SVTracker(module, window_size=self.sv_tracker_window_size)
                self.rank_schedulers[name] = RankScheduler(
                    initial_rank=module.initial_rank, # Use FlexLoRA's actual initial_rank
                    min_rank=module.min_rank,
                    max_rank=module.max_rank,
                    warmup_steps=self.rank_scheduler_warmup_steps,
                    ema_alpha=self.rank_scheduler_ema_alpha
                )

    def _update_kfac_factors_and_fisher(self):
        """
        Updates KFAC factors (implicitly via hooks during forward/backward),
        then updates Fisher blocks and decomposes them.
        This method assumes a forward and backward pass has just occurred.
        """
        if not self.all_kfac_handlers or self.subspace_manager is None:
            # print("KFAC or SubspaceManager not initialized. Skipping Fisher updates.")
            return

        for i, handler in enumerate(self.all_kfac_handlers):
            kfac_A = handler.A_factor
            kfac_B = handler.B_factor
            fisher_engine = self.fisher_engines.get(handler.layer_type)

            if fisher_engine is None:
                # print(f"No FisherEngine for {handler.layer_type}. Skipping.")
                continue
            
            if kfac_A is None or kfac_B is None or kfac_A.numel() == 0 or kfac_B.numel() == 0:
                # print(f"Warning: KFAC factors for {handler.layer_type} are None or empty. Skipping Fisher update for this handler.")
                continue

            fisher_engine.update_fisher(A=kfac_A, B=kfac_B)
            if fisher_engine.fisher_block is not None and fisher_engine.fisher_block.numel() > 0:
                eig_vals, eig_vecs = fisher_engine.decompose(k=self.fisher_k)
                if eig_vecs is not None:
                    self.subspace_manager.update_buffer(block_idx=i, new_eigenvectors=eig_vecs)
                # else:
                    # print(f"Fisher decomposition failed for {handler.layer_type}. Subspace buffer not updated.")
            # else:
                # print(f"Fisher block for {handler.layer_type} is empty. Subspace buffer not updated.")


    def _apply_neural_reprojection(self):
        for name, module in self.model.named_modules():
            if isinstance(module, FlexLoRA):
                if module.current_rank > 0: # Only reproject if LoRA is active
                    # The neural_reprojection function takes the LoRA layer itself.
                    # It uses compute_lora_fisher internally.
                    neural_reprojection(
                        lora_layer=module,
                        rho_eigen_sum=self.rho_eigen_sum,
                        damping_lora_fisher=self.damping_lora_fisher # Pass damping for its internal Fisher
                    )

    def _perform_rank_adjustment(self):
        for name, module in self.model.named_modules():
            if isinstance(module, FlexLoRA):
                sv_tracker = self.sv_trackers.get(name)
                rank_sched = self.rank_schedulers.get(name)
                if sv_tracker and rank_sched:
                    sv_tracker.forward() # Update SV history
                    rank_adjustment(
                        layer=module,
                        sv_tracker=sv_tracker,
                        rank_scheduler=rank_sched,
                        sv_history_min_samples=self.sv_history_min_samples
                    )

    def step(self, closure=None):
        """
        Performs a single GRIT optimization step.
        This should be called after model.backward() and before optimizer.zero_grad()
        if an external optimizer is managing non-GRIT parameters.
        If GRIT manages all trainable parameters (LoRA params), then this
        function handles the update.

        Args:
            closure: An optional closure that re-evaluates the model and returns the loss.
                     This is useful if KFAC factors need to be updated based on a fresh pass,
                     though typically KFAC hooks handle this during the main forward/backward.
        """
        # 1. Update KFAC factors (usually done by hooks) and then Fisher information
        # This assumes forward/backward has just happened.
        self._update_kfac_factors_and_fisher()

        # 2. Natural Gradient Step for LoRA parameters
        # This replaces the standard optimizer step for LoRA parameters.
        natural_gradient_step(self.model, lr=self.learning_rate, damping_lora_fisher=self.damping_lora_fisher)

        # 3. Neural Reprojection
        self._apply_neural_reprojection()

        # 4. Rank Adjustment (includes SV tracking)
        self._perform_rank_adjustment()

        # 5. Parameter Fusion (Momentum Smoothing)
        self.momentum_fuser(self.model) # Apply momentum to LoRA parameters

    def forward(self, *args, **kwargs):
        """Forward pass through the underlying model."""
        return self.model(*args, **kwargs)

    def parameters(self, recurse: bool = True):
        """Returns parameters of the adapted model that are trainable (LoRA params)."""
        return filter(lambda p: p.requires_grad, self.model.parameters(recurse))

    def named_parameters(self, prefix: str = '', recurse: bool = True):
        """Returns named parameters of the adapted model that are trainable."""
        return filter(lambda kv: kv[1].requires_grad, self.model.named_parameters(prefix, recurse))

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        """Returns a state dict containing only the LoRA parameters."""
        # It's often better to save the state_dict of the original model,
        # and then separately save LoRA parameters if needed, or re-inject on load.
        # However, for a pure adapter, one might want to save only LoRA parts.
        lora_state_dict = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad: # Assuming only LoRA parameters are trainable
                 if any(sub in name for sub in self.layers_to_adapt_substrings): # Further check if it's a LoRA param
                    lora_state_dict[name] = param.data
        return lora_state_dict # This might need refinement based on FlexLoRA naming

    def load_state_dict(self, state_dict, strict=True):
        """Loads LoRA parameters into the model."""
        # This needs to carefully map state_dict keys to the FlexLoRA parameters.
        # It's safer to load the full model and then re-apply GRIT,
        # or to have a dedicated save/load for LoRA parameters.
        # For now, this is a placeholder. A robust implementation would iterate
        # through FlexLoRA modules and load 'lora_A' and 'lora_B'.
        # self.model.load_state_dict(state_dict, strict=False) # Be careful with strict=False

        # A more targeted way for FlexLoRA:
        current_model_state_dict = self.model.state_dict()
        for name, param_data in state_dict.items():
            if name in current_model_state_dict:
                # Check if this parameter belongs to a FlexLoRA module and should be loaded
                # This assumes keys in state_dict directly match model param names (e.g., "blocks.0.attn.q_proj.lora_A")
                module_path = name.split('.')[:-1] # e.g., ['blocks', '0', 'attn', 'q_proj']
                param_short_name = name.split('.')[-1] # e.g., 'lora_A'
                
                try:
                    parent_module_path = ".".join(module_path)
                    parent_module = self.model.get_submodule(parent_module_path)
                    if isinstance(parent_module, FlexLoRA) and (param_short_name == 'lora_A' or param_short_name == 'lora_B'):
                         current_model_state_dict[name].copy_(param_data)
                    # else if not a FlexLoRA or not lora_A/B, it might be other trainable param if any
                except AttributeError:
                    if strict:
                        raise RuntimeError(f"Error loading {name} into GRITAdapter: module path not found.")
            elif strict:
                raise RuntimeError(f"Error loading {name} into GRITAdapter: parameter not found in model.")
        
        # After potentially updating parts of the state dict, load it into the model.
        # This approach is still a bit risky if state_dict contains non-LoRA params.
        # A truly clean way is to ensure state_dict ONLY contains LoRA weights
        # and load them specifically.
        # For simplicity here, if state_dict is purely LoRA, this would be fine:
        self.model.load_state_dict(state_dict, strict=strict)