In [1]:
from collections import defaultdict

import enum
import pickle
import traceback

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from mishax import ast_patcher
from mishax import safe_greenlet

import odeformer
from odeformer.model import SymbolicTransformerRegressor


In [2]:
class Site(enum.StrEnum):
    """Instrumentation sites within an ODEFormer forward pass."""
    # Attention sites
    QUERY, KEY, VALUE, ATTN_SCORES, ATTN_PROBS, ATTN_OUTPUT, ATTN_MLP_OUTPUT, POST_ATTN_RESIDUAL = (
        enum.auto(), enum.auto(), enum.auto(), enum.auto(), enum.auto(), enum.auto(), enum.auto(), enum.auto()
    )

    # Layer norm sites
    PRE_ATTN_LAYERNORM, PRE_MLP_LAYERNORM = enum.auto(), enum.auto()

    # MLP sites
    MLP_INPUT, MLP_HIDDEN, MLP_OUTPUT, POST_MLP_RESIDUAL = enum.auto(), enum.auto(), enum.auto(), enum.auto()

    # Cross attention (decoder only)
    CROSS_ATTN_SCORES, CROSS_ATTN_PROBS, CROSS_ATTN_OUTPUT = enum.auto(), enum.auto(), enum.auto()

class ModulePathMapper:
    """Maps modules to their hierarchical paths within the model."""
    def __init__(self, model):
        self.path_map = {}
        self.model = model
        self._build_path_map()

    def _build_path_map(self):
        """Constructs the module-to-path mapping."""
        model = getattr(self.model, 'model', self.model)
        
        for section in ['encoder', 'decoder']:
            module = getattr(model, section, None)
            if module:
                for name, sub_module in module.named_modules():
                    self.path_map[id(sub_module)] = f"{section}.{name if name else 'outer'}"

    def get_layer_path(self, module: nn.Module, accessing_component: str = None) -> str:
        """Returns the full hierarchical path including accessed component if provided."""
        base_path = self.path_map.get(id(module))
        return f"{base_path}.{accessing_component}" if base_path and accessing_component else base_path

_path_mapper = None

def _tag(module: nn.Module, site: Site, value: torch.Tensor, accessing: str = None) -> torch.Tensor:
    """Tags a value at a particular site for instrumentation."""
    try:
        parent = safe_greenlet.getparent()
        if parent is None:
            return value

        # Get full path including component
        path = None
        if _path_mapper is not None:
            path = _path_mapper.get_layer_path(module, accessing)

        ret = parent.switch((site, value, path))
        return ret if ret is not None else value
    except Exception as e:
        print(f"Error in tag at {site}: {e}")
        return value

def collect_activations_during_fit(model, times, trajectories):
    """Collects activations during model training."""
    global _path_mapper
    _path_mapper = ModulePathMapper(model)
    return collect_activations(lambda: model.fit(times, trajectories))

def install():
    """Installs patches for instrumentation."""
    print("Installing patches...", end=' ', flush=True)
    
    PREFIX = f"from {__name__} import Site, _tag as tag"
    
    patcher = ast_patcher.ModuleASTPatcher(
        odeformer.model.transformer,
        ast_patcher.PatchSettings(prefix=PREFIX),
        MultiHeadAttention=[
            "scores = torch.matmul(q, k.transpose(2, 3))",
            "scores = tag(self, Site.ATTN_SCORES, torch.matmul(q, k.transpose(2, 3)), accessing='scores')",
            
            "weights = F.softmax(scores.float(), dim=-1).type_as(scores)",
            "weights = tag(self, Site.ATTN_PROBS, F.softmax(scores.float(), dim=-1).type_as(scores), accessing='weights')",

            "context = torch.matmul(weights, v)",
            "context = tag(self, Site.ATTN_OUTPUT, torch.matmul(weights, v), accessing='context')",
        ],
        TransformerModel=[
            """
            attn = self.encoder_attn[i](
                    tensor, src_mask, kv=src_enc, use_cache=use_cache
                )
            """,
            """
            attn = tag(
                        self, Site.ATTN_MLP_OUTPUT, 
                        self.encoder_attn[i](
                            tensor, src_mask, kv=src_enc, use_cache=use_cache
                        ),
                        accessing=f'cross_attention{i}'
                    )
            """,
            
            "attn = self.attentions[i](tensor, attn_mask, use_cache=use_cache)",
            "attn = tag(self, Site.ATTN_MLP_OUTPUT, self.attentions[i](tensor, attn_mask, use_cache=use_cache), accessing=f'attention_layer{i}')"
        ],
        TransformerFFN=[
            "x = self.lin1(input)",
            "x = self.lin1(tag(self, Site.MLP_INPUT, input, accessing='input'))",
            
            "x = self.lin2(x)",
            "x = tag(self, Site.MLP_OUTPUT, self.lin2(x), accessing='output')",
        ]
    )

    try:
        patcher.install()
        print("Patches installed successfully")
    except Exception as e:
        print(f"Error installing patches: {e}")
        import traceback
        traceback.print_exc()
    
    return patcher

def _process_activations(activations):
    """Processes collected activations into a structured format."""
    processed = {}
    for site, name_data in activations.items():
        processed[site] = {}
        for name, tensors in name_data.items():
            grouped = defaultdict(list)
            for tensor in tensors:
                grouped[tuple(tensor.shape)].append(tensor)
            processed[site][name] = {
                shape: torch.stack(tensors) 
                for shape, tensors in grouped.items()
            }
    return processed

def collect_activations(model_fn):
    """Collects activations during a model function execution."""
    print("\nStarting activation collection")
    activations = defaultdict(lambda: defaultdict(list))
    
    patcher = install()
    with patcher():
        def run_in_greenlet():
            try:
                print("Starting model execution in greenlet...")
                return model_fn()
            except Exception as e:
                print(f"Error in greenlet execution: {e}")
                traceback.print_exc()
                raise

        glet = safe_greenlet.SafeGreenlet(run_in_greenlet)
        with glet:
            result = glet.switch()
            while glet:
                try:
                    site, value, name = result
                    if torch.is_tensor(value):
                        activations[site][name].append(value.detach().cpu())
                    result = glet.switch(value)
                except StopIteration:
                    break
                except Exception as e:
                    print(f"Error during activation collection: {e}")
                    traceback.print_exc()
                    break
    
    print(f"Collection complete. Found sites: {list(activations.keys())}")
    return _process_activations(activations), result

In [3]:
def print_activations(activations, with_shape=False):
    for site, tensors in activations.items():
        print(f"{site}:")
        for name, tensor_dict in tensors.items():
            print(f"\tName: {name}")
            if with_shape:
                for shape, tensor in tensor_dict.items():
                    print(f"\t\tShape {shape}: {tensor.shape[0]} tensor{'s' if tensor.shape[0] > 1 else ''}")
        print()

def view_layer_paths(activations):
    """Displays activations organized by full layer paths."""
    prefix = "Activations by Layer Path:"
    print(f"\n{prefix}\n{'-' * len(prefix)}")
    
    sections = {
        'encoder': {'attention': set(), 'ffn': set()},
        'decoder': {'attention': set(), 'ffn': set(), 'encoder_attn': set()}
    }
    
    def categorize_path(path):
            """Categorizes paths into appropriate sections and components."""
            if 'encoder.' in path:
                if any(x in path for x in ['q_lin', 'k_lin', 'v_lin', 'out_lin']):
                    sections['encoder']['attention'].add(path)
                elif any(x in path for x in ['lin1', 'lin2']):
                    sections['encoder']['ffn'].add(path)
            elif 'decoder.' in path:
                if 'encoder_attn' in path:
                    sections['decoder']['encoder_attn'].add(path)
                elif any(x in path for x in ['q_lin', 'k_lin', 'v_lin', 'out_lin']):
                    sections['decoder']['attention'].add(path)
                elif any(x in path for x in ['lin1', 'lin2']):
                    sections['decoder']['ffn'].add(path)
    
    for site_data in activations.values():
        for path in filter(None, site_data.keys()):
            categorize_path(path)
    
    for section, components in sections.items():
        print(f"{section.upper()}:")
        for component_type, paths in components.items():
            if not paths:
                continue

            print(f"\t{component_type}:")
            for path in sorted(set(paths)):
                print(f"\t\t{path}:")
                for _, site_data in activations.items():
                    if path in site_data:
                        for shape, tensor in site_data[path].items():
                            print(f"\t\t\tShape {shape}: {tensor.shape[0]} activations")
            print()
        print()

In [4]:
_ = install()

Installing patches... Patches installed successfully


In [5]:
from dataclasses import dataclass, field
from typing import List, Dict, Set
import re
import torch
import pickle

# Initialize the Symbolic Transformer Regressor model
# `from_pretrained=True` loads a pre-trained version of the model
model = SymbolicTransformerRegressor(from_pretrained=True)

@dataclass
class Keys:
    """
    Stores configuration settings for filtering activations.

    Attributes:
        encoders (Set[int]): Set of encoder layer indices to keep.
        decoders (Set[int]): Set of decoder layer indices to keep.
        encoder_attn (Set[str]): Set of attention-related activations to collect. 
                                Can contain any of: ['attn_scores', 'attn_probs', 'attn_output'].
        cross_attention (bool): Whether to keep cross-attention activations.

        to_collect (Dict[str, bool]): Dictionary mapping activation types to boolean values
                                      indicating whether to collect them.
    """
    encoders: Set[int] = field(default_factory=set)
    decoders: Set[int] = field(default_factory=set)
    encoder_attn: Set[str] = field(default_factory=set)
    cross_attention: bool = False

    to_collect: Dict[str, bool] = field(default_factory=lambda: {
        'attn_scores': False,
        'attn_probs': False,
        'attn_output': False,
        'attn_mlp_output': False,
        'mlp_input': False,
        'mlp_output': False
    })

    def __post_init__(self):
        """
        Ensures that `to_collect` always contains all possible keys with default values.
        This prevents missing keys from causing errors when accessing `to_collect`.
        """
        default_flags = {
            'attn_scores': False,
            'attn_probs': False,
            'attn_output': False,
            'attn_mlp_output': False,
            'mlp_input': False,
            'mlp_output': False
        }

        # Merge user-defined `to_collect` values with default values
        self.to_collect = {**default_flags, **self.to_collect}

    def _encoders_attn_to_remove(self) -> Set[str]:
        """
        Determines which attention types should be removed from activations.

        Returns:
            Set[str]: A set of attention keys to remove.
        """
        return {'attn_scores', 'attn_probs', 'attn_output'} - self.encoder_attn


    def filter(self, activations: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, Dict[str, torch.Tensor]]:
        """
        Filters activations based on the `to_collect` settings and encoder/decoder indices.

        Args:
            activations (Dict[str, Dict[str, torch.Tensor]]): 
                Dictionary containing activation tensors, structured as:
                {
                    "attn_scores": { "encoder.attentions.0.scores": tensor, ... },
                    "attn_probs": { "encoder.attentions.1.weights": tensor, ... },
                    ...
                }

        Returns:
            Dict[str, Dict[str, torch.Tensor]]: The filtered activations dictionary.
        """

        # Keep only activation keys that are enabled in `to_collect`
        activations = {key: sub_dict for key, sub_dict in activations.items() if self.to_collect.get(key, False)}

        # Remove encoder attention entries that are NOT specified in `encoder_attn`
        if (to_remove := self._encoders_attn_to_remove()):
            for key in filter(lambda k: k in activations, to_remove):
                activations[key] = {
                    k: v for k, v in activations[key].items() 
                    if 'encoder_attn' not in k
                }

        # If `cross_attention` is disabled, remove all `cross_attention` activations from 'attn_mlp_output'
        if not self.cross_attention and (mlp_output := activations.get('attn_mlp_output')):
            activations['attn_mlp_output'] = {
                k: v for k, v in mlp_output.items() 
                if 'cross_attention' not in k
            }

        def check(x: str) -> bool:
            """
            Determines whether a given activation key corresponds to a valid encoder or decoder layer.

            Args:
                x (str): The activation key, e.g., 'encoder.attentions.3.scores'.

            Returns:
                bool: True if the activation should be kept, False otherwise.
            """
            if not (match := re.search(r"\d+", x)):  # Extracts the first number in the key
                return False
            
            layer = int(match.group())  # Convert extracted number to integer
            indices = self.encoders if x.startswith("encoder") else self.decoders

            return layer in indices  # Keep only layers that exist in `self.encoders` or `self.decoders`

        # Apply layer filtering to each activation type
        return {
            key: {k: v for k, v in values.items() if check(k)}
            for key, values in activations.items()
        }


def collect(
    input_path: str, 
    output_path: str, 
    model_args: dict = None,
    keys: Keys = Keys
):
    """
    Processes symbolic regression data, extracts activations, filters them using `Keys`,
    and saves the processed activations.

    Args:
        input_path (str): Path to the input `.pkl` file containing solutions.
        output_path (str): Path where the processed activations will be saved.
        model_args (dict, optional): Arguments for configuring the symbolic transformer model.
        keys (Keys): An instance of the `Keys` class to filter activations.
    """

    # Load symbolic regression solutions from the input file
    with open(input_path, 'rb') as file:
        solutions = pickle.load(file)

    # Set model hyperparameters if not provided
    model_args = model_args or {
        'beam_size': 20, 
        'beam_temperature': 0.1
    }
    model.set_model_args(model_args)
    
    collected_act = []

    for solution in solutions:
        trajectory = solution['solution']
        times = solution['time_points']

        fit_activations, outputs = collect_activations_during_fit(model, times, trajectory)
        fit_activations = keys.filter(fit_activations)

        collected_act.append(fit_activations)

    with open(output_path, 'wb') as file:
        pickle.dump(collected_act, file)



Found pretrained model at odeformer.pt
Loaded pretrained model


In [6]:
collect('lotka_volterra.pkl', 'act.pkl', keys=Keys(
    decoders=[0,1],
    encoder_attn={'attn_scores'},
    cross_attention=True,
    to_collect={
        'attn_scores': True,
        'attn_probs': True,
        'attn_mlp_output': True
    }
))


Starting activation collection
Installing patches... Patches installed successfully
Starting model execution in greenlet...
Collection complete. Found sites: [<Site.ATTN_SCORES: 'attn_scores'>, <Site.ATTN_PROBS: 'attn_probs'>, <Site.ATTN_OUTPUT: 'attn_output'>, <Site.ATTN_MLP_OUTPUT: 'attn_mlp_output'>, <Site.MLP_INPUT: 'mlp_input'>, <Site.MLP_OUTPUT: 'mlp_output'>]

Starting activation collection
Installing patches... Patches installed successfully
Starting model execution in greenlet...
Collection complete. Found sites: [<Site.ATTN_SCORES: 'attn_scores'>, <Site.ATTN_PROBS: 'attn_probs'>, <Site.ATTN_OUTPUT: 'attn_output'>, <Site.ATTN_MLP_OUTPUT: 'attn_mlp_output'>, <Site.MLP_INPUT: 'mlp_input'>, <Site.MLP_OUTPUT: 'mlp_output'>]

Starting activation collection
Installing patches... Patches installed successfully
Starting model execution in greenlet...
Collection complete. Found sites: [<Site.ATTN_SCORES: 'attn_scores'>, <Site.ATTN_PROBS: 'attn_probs'>, <Site.ATTN_OUTPUT: 'attn_output'

In [7]:
with open('act.pkl', 'rb') as file:
    collected_act = pickle.load(file)

for k, v in collected_act[0].items():
    print(k)
    for k1 in v:
        print(f'\t{k1}')

attn_scores
	decoder.attentions.0.scores
	decoder.encoder_attn.0.scores
	decoder.attentions.1.scores
	decoder.encoder_attn.1.scores
attn_probs
	decoder.attentions.0.weights
	decoder.attentions.1.weights
attn_mlp_output
	decoder.outer.attention_layer0
	decoder.outer.cross_attention0
	decoder.outer.attention_layer1
	decoder.outer.cross_attention1
