In [2]:
import math
import numpy as np
import json

# --- Your Unchanged Core Logic ---
class TrackedTensor:
    def __init__(self, name, shape, init="random"):
        self.name = name
        self.shape = shape
        if init == "random":
            self.data = np.random.randn(*shape)
        elif init == "zeros":
            self.data = np.zeros(shape)
        else:
            raise ValueError("init must be 'random' or 'zeros'")
        self.hit = np.zeros(shape, dtype=bool)

    def __getitem__(self, idx):
        self.hit[idx] = True
        return self.data[idx]
    
    def __setitem__(self, idx, value):
        self.hit[idx] = True
        self.data[idx] = value
        
    def clear(self):
        self.hit[:] = False

class attention_tracker: 
    def __init__ (self, B, L_q, L_k, D):
        self.Q = TrackedTensor("Q", (B, L_q, D))
        self.K = TrackedTensor("K", (B, L_k, D))
        self.S = TrackedTensor("S", (B, L_q, L_k))
        self.P = TrackedTensor("P", (B, L_q, L_k))
        self.V = TrackedTensor("V", (B, L_k, D))
        self.O = TrackedTensor("O", (B, L_q, D))
        self.B, self.L_q, self.L_k, self.D = B, L_q, L_k, D
    
    def run(self, b, lq, d):
        for lk in range(self.L_k):
            s = 0.0
            for k in range(self.D):
                s += self.Q[b, lq, k] * self.K[b, lk, k]
            self.S[b, lq, lk] = s
            
        m = max(self.S[b, lq, lk] for lk in range(self.L_k))
        denom = 0.0
        for lk in range(self.L_k):
            denom += math.exp(self.S[b, lq, lk] - m)
        for lk in range(self.L_k):
            self.P[b, lq, lk] = math.exp(self.S[b, lq, lk] - m) / denom
            
        self.O[b, lq, d] = 0.0
        for lk in range(self.L_k):
            self.O[b, lq, d] += self.P[b, lq, lk] * self.V[b, lk, d]

    def run_slice(self, b_start, b_end, lq_start, lq_end, d_start, d_end):
        for b in range(b_start, b_end):
            for lq in range(lq_start, lq_end):
                for d in range(d_start, d_end):
                    self.run(b, lq, d)


def export_threejs_visualizer(tracker, filename="tensor_engine_all.html"):
    tensors = [tracker.Q, tracker.K, tracker.S, tracker.P, tracker.V, tracker.O]
    colors = {'Q': 0x00f0ff, 'K': 0xff0055, 'S': 0xbbff00, 'P': 0xbbff00, 'V': 0xffaa00, 'O': 0xb000ff}
    
    dim_names = {
        'Q': ['B', 'L_q', 'D'],
        'K': ['B', 'L_k', 'D'],
        'S': ['B', 'L_q', 'L_k'],
        'P': ['B', 'L_q', 'L_k'],
        'V': ['B', 'L_k', 'D'],
        'O': ['B', 'L_q', 'D']
    }
    
    data_payload = {}
    for t in tensors:
        if t.name == 'K':
            # Intercept K and plot as K^T by swapping the last two dimensions
            hit_data = np.transpose(t.hit, (0, 2, 1)).tolist()
            shape_data = (t.shape[0], t.shape[2], t.shape[1])
            dims_data = ['B', 'D', 'L_k']
        else:
            hit_data = t.hit.tolist()
            shape_data = t.shape
            dims_data = dim_names[t.name]
            
        data_payload[t.name] = {
            "shape": shape_data,
            "hits": hit_data,
            "color": colors[t.name],
            "dims": dims_data
        }

    html_template = f"""<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Multi-Tensor Voxel Engine</title>
    <style>
        body {{ margin: 0; overflow: hidden; background-color: #0a0a0a; color: white; font-family: monospace; }}
        #ui {{ position: absolute; top: 20px; left: 20px; z-index: 100; background: rgba(0,0,0,0.8); padding: 15px; border: 1px solid #333; border-radius: 8px; pointer-events: none; }}
        .label {{ font-size: 14px; color: #aaa; margin-bottom: 5px; }}
        
        #labels-container {{ position: absolute; top: 0; left: 0; width: 100%; height: 100%; pointer-events: none; z-index: 50; overflow: hidden; }}
        
        /* Main Tensor Titles */
        .tensor-label {{ position: absolute; color: white; background: rgba(0,0,0,0.6); padding: 4px 8px; border: 1px solid #555; border-radius: 4px; font-weight: bold; font-size: 14px; transform: translate(-50%, -50%); transition: opacity 0.1s; }}
        
        /* Canonical Axis Labels */
        .axis-label {{ position: absolute; font-size: 14px; font-weight: bold; transform: translate(-50%, -50%); transition: opacity 0.1s; text-shadow: 1px 1px 0 #000, -1px -1px 0 #000, 1px -1px 0 #000, -1px 1px 0 #000; }}
    </style>
</head>
<body>
    <div id="ui">
        <h2>Attention Slice Footprint</h2>
        <div class="label" style="color:#ff8888;">Red Arrow: Width (X)</div>
        <div class="label" style="color:#88ff88;">Green Arrow: Height (Y)</div>
        <div class="label" style="color:#8888ff;">Blue Arrow: Depth (Z)</div>
    </div>
    
    <div id="labels-container"></div>

    <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>

    <script>
        const tensorData = {json.dumps(data_payload)};
        
        const scene = new THREE.Scene();
        const camera = new THREE.PerspectiveCamera(50, window.innerWidth / window.innerHeight, 0.1, 1000);
        const renderer = new THREE.WebGLRenderer({{ antialias: true }});
        renderer.setSize(window.innerWidth, window.innerHeight);
        document.body.appendChild(renderer.domElement);

        const controls = new THREE.OrbitControls(camera, renderer.domElement);
        controls.enableDamping = true;
        controls.dampingFactor = 0.05;

        let maxDim = 0;
        for(const name in tensorData) {{
            const shape = tensorData[name].shape;
            maxDim = Math.max(maxDim, shape[0], shape[1], shape[2]);
        }}
        
        const spacingX = maxDim * 1.1; 
        const spacingZ = maxDim * 1.8;

        // Linear layout on the same plane
        const layout = {{
            'Q': {{ x: 0, z: 0 }}, 
            'K': {{ x: 1, z: 0 }}, 
            'S': {{ x: 2, z: 0 }},
            'P': {{ x: 3, z: 0 }}, 
            'V': {{ x: 4, z: 0 }}, 
            'O': {{ x: 5, z: 0 }}
        }};

        const labelMarkers = [];
        const labelsContainer = document.getElementById('labels-container');

        function addTrackingLabel(text, color, localPos, parentGroup, className) {{
            const el = document.createElement('div');
            el.className = className;
            el.style.color = color;
            el.innerText = text;
            labelsContainer.appendChild(el);

            const markerObj = new THREE.Object3D();
            markerObj.position.copy(localPos);
            parentGroup.add(markerObj);

            labelMarkers.push({{ element: el, obj: markerObj }});
        }}

        const boxGeo = new THREE.BoxGeometry(0.8, 0.8, 0.8);
        const edgeGeo = new THREE.EdgesGeometry(boxGeo);

        const wireMat = new THREE.LineBasicMaterial({{ color: 0x666666, transparent: true, opacity: 0.6 }});
        const outlineMat = new THREE.LineBasicMaterial({{ color: 0xffffff, transparent: true, opacity: 0.9 }});

        for(const [name, pos] of Object.entries(layout)) {{
            const data = tensorData[name];
            const [A, B, C] = data.shape;
            const hitMat = new THREE.MeshBasicMaterial({{ color: data.color, transparent: true, opacity: 0.9 }});
            
            const group = new THREE.Group();

            for(let a=0; a<A; a++) {{
                for(let b=0; b<B; b++) {{
                    for(let c=0; c<C; c++) {{
                        const isHit = data.hits[a][b][c];
                        const posX = c;             
                        const posY = -b;            
                        const posZ = a;             

                        if(isHit) {{
                            const mesh = new THREE.Mesh(boxGeo, hitMat);
                            mesh.position.set(posX, posY, posZ);
                            group.add(mesh);
                            
                            const edges = new THREE.LineSegments(edgeGeo, outlineMat);
                            edges.position.set(posX, posY, posZ);
                            group.add(edges);
                        }} else {{
                            const wire = new THREE.LineSegments(edgeGeo, wireMat);
                            wire.position.set(posX, posY, posZ);
                            group.add(wire);
                        }}
                    }}
                }}
            }}

            const axisOrigin = new THREE.Vector3(-0.6, 0.6, -0.6);
            
            const dirX = new THREE.Vector3(1, 0, 0);
            const lenX = C + 0.5;
            group.add(new THREE.ArrowHelper(dirX, axisOrigin, lenX, 0xff5555, 0.4, 0.2));
            addTrackingLabel(data.dims[2], '#ff8888', new THREE.Vector3(lenX, 0, 0).add(axisOrigin), group, 'axis-label');

            const dirY = new THREE.Vector3(0, -1, 0);
            const lenY = B + 0.5;
            group.add(new THREE.ArrowHelper(dirY, axisOrigin, lenY, 0x55ff55, 0.4, 0.2));
            addTrackingLabel(data.dims[1], '#88ff88', new THREE.Vector3(0, -lenY, 0).add(axisOrigin), group, 'axis-label');

            const dirZ = new THREE.Vector3(0, 0, 1);
            const lenZ = A + 0.5;
            group.add(new THREE.ArrowHelper(dirZ, axisOrigin, lenZ, 0x5555ff, 0.4, 0.2));
            addTrackingLabel(data.dims[0], '#8888ff', new THREE.Vector3(0, 0, lenZ).add(axisOrigin), group, 'axis-label');

            const centerOffsetX = (C - 1) / 2;
            const centerOffsetY = -(B - 1) / 2;
            const centerOffsetZ = (A - 1) / 2;
            group.position.set(-centerOffsetX, -centerOffsetY, -centerOffsetZ);

            const wrapper = new THREE.Group();
            
            // Shift x by 2.5 so the 6 items (0 to 5) are centered around 0
            const worldPosX = (pos.x - 2.5) * spacingX; 
            const worldPosZ = 0; 
            
            wrapper.position.set(worldPosX, 0, worldPosZ);
            wrapper.add(group);
            scene.add(wrapper);

            const titleName = name === 'K' ? 'K^T' : name;
            addTrackingLabel(`${{titleName}} ${{JSON.stringify(data.shape)}}`, '#' + data.color.toString(16).padStart(6, '0'), new THREE.Vector3(0, B/2 + 1.5, 0), wrapper, 'tensor-label');
        }}

        // Pull camera back to see the wider linear layout
        camera.position.set(0, spacingZ * 1.5, spacingX * 4.0);
        controls.target.set(0, 0, 0);

        function animate() {{
            requestAnimationFrame(animate);
            controls.update();
            
            const vector = new THREE.Vector3();
            for(const marker of labelMarkers) {{
                marker.obj.getWorldPosition(vector);
                vector.project(camera);
                
                if(vector.z < 1) {{
                    const x = (vector.x * 0.5 + 0.5) * window.innerWidth;
                    const y = (-(vector.y * 0.5) + 0.5) * window.innerHeight;
                    marker.element.style.transform = `translate(-50%, -50%) translate(${{x}}px,${{y}}px)`;
                    marker.element.style.opacity = "1";
                }} else {{
                    marker.element.style.opacity = "0";
                }}
            }}

            renderer.render(scene, camera);
        }}
        animate();

        window.addEventListener('resize', () => {{
            camera.aspect = window.innerWidth / window.innerHeight;
            camera.updateProjectionMatrix();
            renderer.setSize(window.innerWidth, window.innerHeight);
        }});
    </script>
</body>
</html>"""

    with open(filename, 'w') as f:
        f.write(html_template)
    print(f"Generated linear multi-tensor panorama. K is transposed. Written to {filename}")



In [3]:
B_dim, L_q, L_k, D = 2, 16, 16, 4
tracker = attention_tracker(B_dim, L_q, L_k, D)

tracker.run_slice(
    b_start=0, b_end=1, 
    lq_start=0, lq_end=1, 
    d_start=0, d_end=1
)

export_threejs_visualizer(tracker)

Generated linear multi-tensor panorama. K is transposed. Written to tensor_engine_all.html
