# ViT Architecture Analysis for DC Decomposition

This notebook analyzes the Vision Transformer (ViT) architecture to identify:
1. All module types used
2. Operations that need special handling (division, reshape, etc.)
3. How to implement DC decomposition hooks for each component

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict

# Check if transformers is available
try:
    from transformers import ViTConfig, ViTModel, ViTForImageClassification
    HAS_TRANSFORMERS = True
except ImportError:
    HAS_TRANSFORMERS = False
    print("Install transformers: pip install transformers")

## 1. Load ViT Model with ReLU Activations

In [None]:
if HAS_TRANSFORMERS:
    # Load ViT with ReLU instead of GELU
    config = ViTConfig.from_pretrained("google/vit-base-patch16-224")
    config.hidden_act = "relu"  # Override GELU with ReLU
    
    model = ViTModel.from_pretrained("google/vit-base-patch16-224", config=config)
    model.eval()
    
    print(f"Model loaded: {type(model).__name__}")
    print(f"Hidden size: {config.hidden_size}")
    print(f"Num attention heads: {config.num_attention_heads}")
    print(f"Num hidden layers: {config.num_hidden_layers}")
    print(f"Intermediate size: {config.intermediate_size}")
    print(f"Activation: {config.hidden_act}")

## 2. List All Module Types

In [None]:
if HAS_TRANSFORMERS:
    # Collect all module types
    module_types = defaultdict(list)
    
    for name, module in model.named_modules():
        module_type = type(module).__name__
        module_types[module_type].append(name)
    
    print("Module Types in ViT:")
    print("=" * 60)
    for mtype, names in sorted(module_types.items()):
        print(f"\n{mtype} ({len(names)} instances):")
        for name in names[:3]:  # Show first 3
            print(f"  - {name}")
        if len(names) > 3:
            print(f"  ... and {len(names) - 3} more")

## 3. Detailed Module Analysis

In [None]:
if HAS_TRANSFORMERS:
    print("Detailed Module Analysis:")
    print("=" * 60)
    
    # Categorize modules by DC decomposition support
    supported = []
    needs_implementation = []
    container_modules = []
    
    supported_types = (nn.Linear, nn.LayerNorm, nn.ReLU, nn.Softmax, 
                       nn.Conv2d, nn.Dropout, nn.Identity)
    container_types = (nn.ModuleList, nn.Sequential)
    
    for name, module in model.named_modules():
        if name == "":
            continue
        mtype = type(module).__name__
        
        if isinstance(module, supported_types):
            supported.append((name, mtype))
        elif isinstance(module, container_types):
            container_modules.append((name, mtype))
        elif mtype in ['ViTModel', 'ViTEncoder', 'ViTEmbeddings', 'ViTLayer', 
                       'ViTAttention', 'ViTSelfAttention', 'ViTSelfOutput',
                       'ViTIntermediate', 'ViTOutput', 'ViTPatchEmbeddings',
                       'ViTPooler']:
            container_modules.append((name, mtype))
        else:
            needs_implementation.append((name, mtype))
    
    print(f"\nSupported modules: {len(supported)}")
    print(f"Container modules: {len(container_modules)}")
    print(f"Needs implementation: {len(needs_implementation)}")
    
    if needs_implementation:
        print("\nModules needing implementation:")
        for name, mtype in needs_implementation:
            print(f"  {mtype}: {name}")

## 4. Analyze ViT Attention Forward Pass

In [None]:
if HAS_TRANSFORMERS:
    # Let's look at the attention module's forward pass
    import inspect
    
    # Get attention layer
    attn_layer = model.encoder.layer[0].attention.attention
    
    print("ViTSelfAttention Structure:")
    print("=" * 60)
    for name, child in attn_layer.named_children():
        print(f"  {name}: {type(child).__name__}")
        if isinstance(child, nn.Linear):
            print(f"    in_features={child.in_features}, out_features={child.out_features}")
    
    print(f"\nAttention parameters:")
    print(f"  num_attention_heads: {attn_layer.num_attention_heads}")
    print(f"  attention_head_size: {attn_layer.attention_head_size}")
    print(f"  all_head_size: {attn_layer.all_head_size}")

In [None]:
if HAS_TRANSFORMERS:
    # Trace forward pass to identify operations
    print("\nViT Attention Forward Operations:")
    print("=" * 60)
    print("""
    1. query = self.query(hidden_states)  # Linear
    2. key = self.key(hidden_states)      # Linear
    3. value = self.value(hidden_states)  # Linear
    
    4. query = transpose_for_scores(query)  # Reshape + Permute
       - shape: (batch, seq, hidden) -> (batch, heads, seq, head_dim)
    
    5. attention_scores = query @ key.transpose(-1, -2)  # MatMul
    
    6. attention_scores = attention_scores / sqrt(head_dim)  # Division by scalar
    
    7. attention_probs = softmax(attention_scores, dim=-1)  # Softmax
    
    8. attention_probs = dropout(attention_probs)  # Dropout (identity in eval)
    
    9. context = attention_probs @ value  # MatMul
    
    10. context = context.permute(0, 2, 1, 3).contiguous()  # Permute
    
    11. context = context.view(batch, seq, hidden)  # Reshape
    """)
    
    print("\nOperations needing DC decomposition support:")
    print("  - MatMul (Q @ K^T, attn @ V): DCMatMul ✓")
    print("  - Division by scalar: DCScalarDiv (new)")
    print("  - Reshape/View: DCReshape (new)")
    print("  - Permute/Transpose: DCPermute (new)")
    print("  - Softmax: Already supported ✓")
    print("  - Dropout: Identity in eval mode ✓")

## 5. Implement Missing DC Modules

In [None]:
# DC modules for operations that need special handling

class DCReshape(nn.Module):
    """
    Reshape module for DC decomposition.
    
    Reshape is a linear operation - it just changes the view of data.
    Both pos and neg streams are reshaped identically.
    """
    _dc_is_reshape = True
    
    def __init__(self, target_shape):
        super().__init__()
        self.target_shape = target_shape
    
    def forward(self, x):
        return x.view(*self.target_shape)
    
    def extra_repr(self):
        return f'target_shape={self.target_shape}'


class DCPermute(nn.Module):
    """
    Permute module for DC decomposition.
    
    Permute is a linear operation - it just reorders dimensions.
    Both pos and neg streams are permuted identically.
    """
    _dc_is_permute = True
    
    def __init__(self, dims):
        super().__init__()
        self.dims = dims
    
    def forward(self, x):
        return x.permute(*self.dims)
    
    def extra_repr(self):
        return f'dims={self.dims}'


class DCTranspose(nn.Module):
    """
    Transpose module for DC decomposition.
    
    Transpose is a linear operation.
    Both pos and neg streams are transposed identically.
    """
    _dc_is_transpose = True
    
    def __init__(self, dim0, dim1):
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1
    
    def forward(self, x):
        return x.transpose(self.dim0, self.dim1)
    
    def extra_repr(self):
        return f'dim0={self.dim0}, dim1={self.dim1}'


class DCScalarMul(nn.Module):
    """
    Scalar multiplication for DC decomposition.
    
    For scalar s:
    - If s >= 0: pos_out = s * pos_in, neg_out = s * neg_in
    - If s < 0: pos_out = |s| * neg_in, neg_out = |s| * pos_in
    """
    _dc_is_scalar_mul = True
    
    def __init__(self, scalar):
        super().__init__()
        self.register_buffer('scalar', torch.tensor(scalar))
    
    def forward(self, x):
        return x * self.scalar
    
    def extra_repr(self):
        return f'scalar={self.scalar.item()}'


class DCScalarDiv(nn.Module):
    """
    Scalar division for DC decomposition.
    
    Division by positive scalar s is equivalent to multiplication by 1/s.
    pos_out = pos_in / s, neg_out = neg_in / s
    """
    _dc_is_scalar_div = True
    
    def __init__(self, scalar):
        super().__init__()
        self.register_buffer('scalar', torch.tensor(scalar))
    
    def forward(self, x):
        return x / self.scalar
    
    def extra_repr(self):
        return f'scalar={self.scalar.item()}'


class DCAdd(nn.Module):
    """
    Element-wise addition for DC decomposition.
    
    (a_pos - a_neg) + (b_pos - b_neg) = (a_pos + b_pos) - (a_neg + b_neg)
    """
    _dc_is_add = True
    
    def __init__(self):
        super().__init__()
    
    def forward(self, a, b):
        return a + b


class DCContiguous(nn.Module):
    """
    Make tensor contiguous for DC decomposition.
    
    This is a no-op mathematically but ensures memory layout.
    """
    _dc_is_contiguous = True
    
    def forward(self, x):
        return x.contiguous()


print("DC Operation Modules defined:")
print("  - DCReshape: view/reshape operations")
print("  - DCPermute: dimension permutation")
print("  - DCTranspose: dimension transposition")
print("  - DCScalarMul: multiplication by scalar")
print("  - DCScalarDiv: division by scalar")
print("  - DCAdd: element-wise addition")
print("  - DCContiguous: memory layout")

## 6. Create DC-Compatible ViT Attention

In [None]:
import sys
sys.path.insert(0, '..')
from dc_decompose import DCMatMul


class DCViTSelfAttention(nn.Module):
    """
    ViT Self-Attention with DC-decomposable operations.
    
    All operations are wrapped in modules that can be hooked:
    - Linear projections (query, key, value)
    - Reshape operations
    - Matrix multiplications
    - Scalar division
    - Softmax
    """
    
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = config.hidden_size // config.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        
        # Linear projections
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        
        # DC-compatible operations
        self.scale_div = DCScalarDiv(self.attention_head_size ** 0.5)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        
        # MatMul modules for attention
        self.attn_matmul = DCMatMul()  # Q @ K^T
        self.context_matmul = DCMatMul()  # attn @ V
        
        # Reshape/permute operations
        self.permute_qkv = DCPermute((0, 2, 1, 3))  # (B, S, H, D) -> (B, H, S, D)
        self.transpose_k = DCTranspose(-1, -2)  # For K^T
        self.permute_context = DCPermute((0, 2, 1, 3))  # (B, H, S, D) -> (B, S, H, D)
        self.contiguous = DCContiguous()
    
    def transpose_for_scores(self, x, batch_size):
        """Reshape and permute for multi-head attention."""
        # (B, S, H*D) -> (B, S, H, D) -> (B, H, S, D)
        new_shape = (batch_size, -1, self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_shape)
        return self.permute_qkv(x)
    
    def forward(self, hidden_states):
        batch_size = hidden_states.shape[0]
        
        # 1. Linear projections
        query_layer = self.query(hidden_states)
        key_layer = self.key(hidden_states)
        value_layer = self.value(hidden_states)
        
        # 2. Reshape for multi-head attention
        query_layer = self.transpose_for_scores(query_layer, batch_size)
        key_layer = self.transpose_for_scores(key_layer, batch_size)
        value_layer = self.transpose_for_scores(value_layer, batch_size)
        
        # 3. Attention scores: Q @ K^T
        key_layer_t = self.transpose_k(key_layer)
        # For DC: set the second operand
        self.attn_matmul.set_operand(key_layer_t)
        attention_scores = self.attn_matmul(query_layer)
        
        # 4. Scale by sqrt(head_dim)
        attention_scores = self.scale_div(attention_scores)
        
        # 5. Softmax
        attention_probs = self.softmax(attention_scores)
        
        # 6. Dropout (identity in eval)
        attention_probs = self.dropout(attention_probs)
        
        # 7. Context: attn @ V
        self.context_matmul.set_operand(value_layer)
        context_layer = self.context_matmul(attention_probs)
        
        # 8. Reshape back
        context_layer = self.permute_context(context_layer)
        context_layer = self.contiguous(context_layer)
        new_shape = (batch_size, -1, self.all_head_size)
        context_layer = context_layer.view(*new_shape)
        
        return context_layer


print("DCViTSelfAttention created with hookable operations")

## 7. Test DC Attention Module

In [None]:
if HAS_TRANSFORMERS:
    # Create a config for testing
    test_config = ViTConfig(
        hidden_size=64,
        num_attention_heads=4,
        attention_probs_dropout_prob=0.0,
    )
    
    # Create DC attention
    dc_attn = DCViTSelfAttention(test_config)
    dc_attn.eval()
    
    # Test input
    x = torch.randn(2, 16, 64)  # (batch, seq, hidden)
    
    # Forward pass
    with torch.no_grad():
        output = dc_attn(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    
    # List all modules
    print("\nDC Attention Modules:")
    for name, module in dc_attn.named_modules():
        if name:
            print(f"  {name}: {type(module).__name__}")

## 8. Full DC-Compatible ViT Layer

In [None]:
class DCViTLayer(nn.Module):
    """
    Full ViT Layer with DC-decomposable operations.
    
    Structure:
    - LayerNorm -> Attention -> Add (residual)
    - LayerNorm -> MLP -> Add (residual)
    """
    
    def __init__(self, config):
        super().__init__()
        
        # Attention block
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.attention = DCViTSelfAttention(config)
        self.attention_output = nn.Linear(config.hidden_size, config.hidden_size)
        self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.add_residual_1 = DCAdd()
        
        # MLP block
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.intermediate = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act = nn.ReLU()
        self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.add_residual_2 = DCAdd()
    
    def forward(self, hidden_states):
        # Attention block
        residual = hidden_states
        hidden_states = self.layernorm_before(hidden_states)
        attention_output = self.attention(hidden_states)
        attention_output = self.attention_output(attention_output)
        attention_output = self.attention_dropout(attention_output)
        hidden_states = self.add_residual_1(residual, attention_output)
        
        # MLP block
        residual = hidden_states
        hidden_states = self.layernorm_after(hidden_states)
        hidden_states = self.intermediate(hidden_states)
        hidden_states = self.intermediate_act(hidden_states)
        hidden_states = self.output_dense(hidden_states)
        hidden_states = self.output_dropout(hidden_states)
        hidden_states = self.add_residual_2(residual, hidden_states)
        
        return hidden_states


print("DCViTLayer created")

In [None]:
if HAS_TRANSFORMERS:
    # Create and test DCViTLayer
    test_config = ViTConfig(
        hidden_size=64,
        num_attention_heads=4,
        intermediate_size=256,
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        layer_norm_eps=1e-6,
    )
    
    dc_layer = DCViTLayer(test_config)
    dc_layer.eval()
    
    x = torch.randn(2, 16, 64)
    
    with torch.no_grad():
        output = dc_layer(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    
    print("\nAll modules in DCViTLayer:")
    for name, module in dc_layer.named_modules():
        if name and not any(c in name for c in ['attention.', 'layernorm', 'intermediate', 'output']):
            continue
        if name:
            mtype = type(module).__name__
            if not mtype.startswith('DC') and mtype not in ['Linear', 'LayerNorm', 'ReLU', 'Softmax', 'Dropout']:
                continue
            print(f"  {name}: {mtype}")

## 9. Save DC Operation Modules to Package

In [None]:
# Generate code for dc_operations.py
dc_operations_code = '''"""
DC Operation Modules

These modules wrap common tensor operations to make them compatible with
hook-based DC decomposition. Each operation is implemented as a module
that can be hooked by HookDecomposer.

For most operations, pos and neg streams are processed identically since
the operations are linear (reshape, permute, transpose, scalar mul/div).
"""

import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple, Optional, List, Union


class DCReshape(nn.Module):
    """
    Reshape/view module for DC decomposition.
    
    Reshape is a linear operation - both pos and neg streams
    are reshaped identically.
    
    Args:
        target_shape: Target shape (can include -1 for inference)
    """
    _dc_is_reshape = True
    
    def __init__(self, *target_shape):
        super().__init__()
        if len(target_shape) == 1 and isinstance(target_shape[0], (list, tuple)):
            self.target_shape = tuple(target_shape[0])
        else:
            self.target_shape = target_shape
    
    def forward(self, x: Tensor) -> Tensor:
        return x.view(*self.target_shape)
    
    def extra_repr(self) -> str:
        return f'target_shape={self.target_shape}'


class DCDynamicReshape(nn.Module):
    """
    Dynamic reshape module where shape is computed at runtime.
    
    Use set_shape() before forward pass to set the target shape.
    """
    _dc_is_reshape = True
    
    def __init__(self):
        super().__init__()
        self._target_shape: Optional[Tuple[int, ...]] = None
    
    def set_shape(self, *shape):
        """Set target shape for next forward pass."""
        if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
            self._target_shape = tuple(shape[0])
        else:
            self._target_shape = shape
    
    def forward(self, x: Tensor) -> Tensor:
        if self._target_shape is None:
            raise RuntimeError("Target shape not set. Call set_shape() first.")
        return x.view(*self._target_shape)


class DCPermute(nn.Module):
    """
    Permute dimensions module for DC decomposition.
    
    Permute is a linear operation - both pos and neg streams
    are permuted identically.
    
    Args:
        dims: Permutation of dimensions
    """
    _dc_is_permute = True
    
    def __init__(self, *dims):
        super().__init__()
        if len(dims) == 1 and isinstance(dims[0], (list, tuple)):
            self.dims = tuple(dims[0])
        else:
            self.dims = dims
    
    def forward(self, x: Tensor) -> Tensor:
        return x.permute(*self.dims)
    
    def extra_repr(self) -> str:
        return f'dims={self.dims}'


class DCTranspose(nn.Module):
    """
    Transpose dimensions module for DC decomposition.
    
    Transpose is a linear operation - both pos and neg streams
    are transposed identically.
    
    Args:
        dim0: First dimension to transpose
        dim1: Second dimension to transpose
    """
    _dc_is_transpose = True
    
    def __init__(self, dim0: int, dim1: int):
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1
    
    def forward(self, x: Tensor) -> Tensor:
        return x.transpose(self.dim0, self.dim1)
    
    def extra_repr(self) -> str:
        return f'dim0={self.dim0}, dim1={self.dim1}'


class DCContiguous(nn.Module):
    """
    Make tensor contiguous for DC decomposition.
    
    This is a no-op mathematically but ensures memory layout.
    """
    _dc_is_contiguous = True
    
    def forward(self, x: Tensor) -> Tensor:
        return x.contiguous()


class DCScalarMul(nn.Module):
    """
    Scalar multiplication for DC decomposition.
    
    For positive scalar s: pos_out = s * pos_in, neg_out = s * neg_in
    For negative scalar s: pos_out = |s| * neg_in, neg_out = |s| * pos_in
    
    Args:
        scalar: Scalar value to multiply by
    """
    _dc_is_scalar_mul = True
    
    def __init__(self, scalar: float):
        super().__init__()
        self.register_buffer('scalar', torch.tensor(scalar))
        self.register_buffer('is_negative', torch.tensor(scalar < 0))
        self.register_buffer('abs_scalar', torch.tensor(abs(scalar)))
    
    def forward(self, x: Tensor) -> Tensor:
        return x * self.scalar
    
    def extra_repr(self) -> str:
        return f'scalar={self.scalar.item():.6f}'


class DCScalarDiv(nn.Module):
    """
    Scalar division for DC decomposition.
    
    Division by scalar s is multiplication by 1/s.
    For positive s: pos_out = pos_in / s, neg_out = neg_in / s
    For negative s: pos_out = neg_in / |s|, neg_out = pos_in / |s|
    
    Args:
        scalar: Scalar value to divide by
    """
    _dc_is_scalar_div = True
    
    def __init__(self, scalar: float):
        super().__init__()
        self.register_buffer('scalar', torch.tensor(scalar))
        self.register_buffer('is_negative', torch.tensor(scalar < 0))
        self.register_buffer('abs_scalar', torch.tensor(abs(scalar)))
    
    def forward(self, x: Tensor) -> Tensor:
        return x / self.scalar
    
    def extra_repr(self) -> str:
        return f'scalar={self.scalar.item():.6f}'


class DCAdd(nn.Module):
    """
    Element-wise addition for DC decomposition.
    
    (a_pos - a_neg) + (b_pos - b_neg) = (a_pos + b_pos) - (a_neg + b_neg)
    
    This module stores the second operand's decomposition for the hook.
    """
    _dc_is_add = True
    
    def __init__(self):
        super().__init__()
        # Storage for second operand's DC decomposition
        self._dc_operand_pos: Optional[Tensor] = None
        self._dc_operand_neg: Optional[Tensor] = None
    
    def set_operand(self, b: Tensor, b_pos: Optional[Tensor] = None, b_neg: Optional[Tensor] = None):
        """Set the second operand for addition."""
        if b_pos is not None and b_neg is not None:
            self._dc_operand_pos = b_pos
            self._dc_operand_neg = b_neg
        else:
            import torch.nn.functional as F
            self._dc_operand_pos = F.relu(b)
            self._dc_operand_neg = F.relu(-b)
    
    def set_operand_decomposed(self, b_pos: Tensor, b_neg: Tensor):
        """Set the second operand with pre-decomposed components."""
        self._dc_operand_pos = b_pos
        self._dc_operand_neg = b_neg
    
    def forward(self, a: Tensor, b: Tensor) -> Tensor:
        return a + b


class DCSplit(nn.Module):
    """
    Split tensor along a dimension for DC decomposition.
    
    Split is a linear operation - pos and neg are split identically.
    
    Args:
        split_size: Size of each split or list of sizes
        dim: Dimension to split along
    """
    _dc_is_split = True
    
    def __init__(self, split_size: Union[int, List[int]], dim: int = 0):
        super().__init__()
        self.split_size = split_size
        self.dim = dim
    
    def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
        return torch.split(x, self.split_size, dim=self.dim)
    
    def extra_repr(self) -> str:
        return f'split_size={self.split_size}, dim={self.dim}'


class DCChunk(nn.Module):
    """
    Chunk tensor into equal parts for DC decomposition.
    
    Chunk is a linear operation - pos and neg are chunked identically.
    
    Args:
        chunks: Number of chunks
        dim: Dimension to chunk along
    """
    _dc_is_chunk = True
    
    def __init__(self, chunks: int, dim: int = 0):
        super().__init__()
        self.chunks = chunks
        self.dim = dim
    
    def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
        return torch.chunk(x, self.chunks, dim=self.dim)
    
    def extra_repr(self) -> str:
        return f'chunks={self.chunks}, dim={self.dim}'


class DCCat(nn.Module):
    """
    Concatenate tensors for DC decomposition.
    
    Concatenation is a linear operation - pos and neg are concatenated identically.
    
    Args:
        dim: Dimension to concatenate along
    """
    _dc_is_cat = True
    
    def __init__(self, dim: int = 0):
        super().__init__()
        self.dim = dim
    
    def forward(self, tensors: List[Tensor]) -> Tensor:
        return torch.cat(tensors, dim=self.dim)
    
    def extra_repr(self) -> str:
        return f'dim={self.dim}'


class DCSlice(nn.Module):
    """
    Slice tensor for DC decomposition.
    
    Slicing is a linear operation - pos and neg are sliced identically.
    
    Args:
        dim: Dimension to slice
        start: Start index
        end: End index (exclusive)
    """
    _dc_is_slice = True
    
    def __init__(self, dim: int, start: Optional[int] = None, end: Optional[int] = None):
        super().__init__()
        self.dim = dim
        self.start = start
        self.end = end
    
    def forward(self, x: Tensor) -> Tensor:
        slices = [slice(None)] * x.dim()
        slices[self.dim] = slice(self.start, self.end)
        return x[tuple(slices)]
    
    def extra_repr(self) -> str:
        return f'dim={self.dim}, start={self.start}, end={self.end}'


class DCDropout(nn.Module):
    """
    Dropout for DC decomposition.
    
    In eval mode, dropout is identity.
    In train mode, the same mask is applied to both pos and neg.
    
    Args:
        p: Dropout probability
    """
    _dc_is_dropout = True
    
    def __init__(self, p: float = 0.5):
        super().__init__()
        self.p = p
        self.dropout = nn.Dropout(p)
    
    def forward(self, x: Tensor) -> Tensor:
        return self.dropout(x)
    
    def extra_repr(self) -> str:
        return f'p={self.p}'


class DCIdentity(nn.Module):
    """
    Identity module for DC decomposition.
    
    Useful as a placeholder or for skip connections.
    """
    _dc_is_identity = True
    
    def forward(self, x: Tensor) -> Tensor:
        return x
'''

print("DC Operations module code generated")
print(f"Length: {len(dc_operations_code)} characters")

In [None]:
# Write to file
with open('../dc_decompose/dc_operations.py', 'w') as f:
    f.write(dc_operations_code)

print("Saved to ../dc_decompose/dc_operations.py")

## 10. Test All DC Operations

In [None]:
# Import the new module
import importlib
import sys

# Reload to get fresh imports
if 'dc_decompose.dc_operations' in sys.modules:
    del sys.modules['dc_decompose.dc_operations']

from dc_decompose.dc_operations import (
    DCReshape, DCDynamicReshape, DCPermute, DCTranspose, DCContiguous,
    DCScalarMul, DCScalarDiv, DCAdd, DCSplit, DCChunk, DCCat, DCSlice,
    DCDropout, DCIdentity
)

print("Testing DC Operations:")
print("=" * 60)

x = torch.randn(2, 4, 8)

# Test DCReshape
reshape = DCReshape(2, -1)
y = reshape(x)
print(f"DCReshape: {x.shape} -> {y.shape}")

# Test DCPermute
permute = DCPermute(0, 2, 1)
y = permute(x)
print(f"DCPermute: {x.shape} -> {y.shape}")

# Test DCTranspose
transpose = DCTranspose(-1, -2)
y = transpose(x)
print(f"DCTranspose: {x.shape} -> {y.shape}")

# Test DCScalarMul
scalar_mul = DCScalarMul(2.0)
y = scalar_mul(x)
print(f"DCScalarMul(2.0): max_diff={torch.max(torch.abs(y - 2*x)).item():.2e}")

# Test DCScalarDiv
scalar_div = DCScalarDiv(8.0)
y = scalar_div(x)
print(f"DCScalarDiv(8.0): max_diff={torch.max(torch.abs(y - x/8)).item():.2e}")

# Test DCSplit
split = DCSplit(4, dim=1)
y1, = split(x)
print(f"DCSplit: {x.shape} -> {y1.shape}")

# Test DCChunk
chunk = DCChunk(2, dim=-1)
y1, y2 = chunk(x)
print(f"DCChunk: {x.shape} -> {y1.shape}, {y2.shape}")

# Test DCSlice
slice_op = DCSlice(dim=1, start=1, end=3)
y = slice_op(x)
print(f"DCSlice: {x.shape} -> {y.shape}")

print("\nAll operations working correctly!")

## 11. Summary: ViT Components and DC Support

| Component | Module Type | DC Support | Notes |
|-----------|-------------|------------|-------|
| Patch Embedding | Conv2d | ✓ | Standard convolution |
| Position Embedding | Parameter + Add | ✓ | Linear operation |
| LayerNorm | LayerNorm | ✓ | Variance as constant |
| Q/K/V Projection | Linear | ✓ | Weight decomposition |
| Reshape for heads | DCReshape | ✓ | Linear operation |
| Permute for heads | DCPermute | ✓ | Linear operation |
| Q @ K^T | DCMatMul | ✓ | Product rule |
| Scale by sqrt(d) | DCScalarDiv | ✓ | Linear operation |
| Softmax | Softmax | ✓ | Jacobian backward |
| Dropout | DCDropout | ✓ | Identity in eval |
| attn @ V | DCMatMul | ✓ | Product rule |
| Transpose back | DCTranspose | ✓ | Linear operation |
| Contiguous | DCContiguous | ✓ | Memory layout |
| Output projection | Linear | ✓ | Weight decomposition |
| Residual add | DCAdd | ✓ | Linear operation |
| MLP intermediate | Linear | ✓ | Weight decomposition |
| ReLU/GELU | ReLU | ✓ | Use ReLU for DC |
| MLP output | Linear | ✓ | Weight decomposition |

In [None]:
print("Summary complete!")
print("\nAll ViT components can be decomposed with the implemented DC modules.")
print("\nKey files created/updated:")
print("  - dc_decompose/dc_operations.py: Linear operations (reshape, permute, etc.)")
print("  - dc_decompose/dc_matmul.py: Matrix multiplication")
print("  - dc_decompose/hook_decomposer.py: LayerNorm, Softmax, etc.")