In [None]:
import numpy as np 
from typing import Tuple
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from typing import Tuple
import matplotlib.lines as mlines
from dataclasses import dataclass, field
from typing import Tuple
Shape = Tuple[int, ...]

In [2]:

@dataclass(frozen=True)
class ShapeBundle:
    # The fundamental truth: A bundle is just an ordered list of tensor shapes.
    # We use Tuple[Shape, ...] to enforce the "static list" requirement.
    tensors: Tuple[Shape, ...]

    def __post_init__(self):
        # normalize inputs to ensure they are tuples of tuples
        # this handles cases where user passes a list of lists by accident
        object.__setattr__(self, 'tensors', tuple(tuple(s) for s in self.tensors))

    @classmethod
    def from_single(cls, shape: Shape):
        """Helper to create a bundle representing a single tensor."""
        return cls((shape,))

    @property
    def num_tensors(self) -> int:
        return len(self.tensors)

    def __repr__(self):
        return f"{self.tensors}"


class BundleOps:
    """
    Operations are now structural on the LIST of tensors, not the dimensions inside them.
    We respect the bracket boundaries. tis is simply because different tensors may have 
    different start pointers so fusing ((a,b), (c,d)) to ((a,b,c,d)) makes no sense, as eventually
    we want to upgrade to layout + mempointer to be the types. 
    """
    
    @staticmethod
    def concat(b1: ShapeBundle, b2: ShapeBundle) -> ShapeBundle:
        """
        Concatenates two bundles.
        ((a,b), (c,d)) + ((e,f)) -> ((a,b), (c,d), (e,f))
        """
        return ShapeBundle(b1.tensors + b2.tensors)

    @staticmethod
    def split(b: ShapeBundle, split_idx: int) -> Tuple[ShapeBundle, ShapeBundle]:
        """
        Splits a bundle at a specific tensor index.
        Input: ((a,b), (c,d), (e,f)), split_idx=1
        Output: ((a,b)) and ((c,d), (e,f))
        """
        if not (0 < split_idx < b.num_tensors):
            raise ValueError(f"Split index {split_idx} out of bounds for bundle with {b.num_tensors} tensors.")
            
        left = b.tensors[:split_idx]
        right = b.tensors[split_idx:]
        
        return ShapeBundle(left), ShapeBundle(right)

# --- Test / Demonstration ---

