# Triplet Attention: Implementation Test

This notebook verifies that the Triplet Attention injection works and that changing the `linkage` parameter affects the model output.

In [1]:
import sys
import os
sys.path.append(os.path.abspath('../'))

import torch
from src.models.model_injection import create_triplet_model

## 1. Load Model

We load Llama-3.1-8B-Instruct with TripletAttention injected at layers [16, 20, 24].

In [2]:
model, tokenizer, controller = create_triplet_model("meta-llama/Meta-Llama-3.1-8B-Instruct")

Loading base model: meta-llama/Meta-Llama-3.1-8B-Instruct...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Injecting TripletAttention into layers: [16, 24, 30]...
Seeding Triplet weights from Llama layer...
  Dims: H=4096, Heads=32, KV_Heads=8, Rep=4
  Weights seeded successfully.
 - Layer 16: Successfully wrapped original attention.
Seeding Triplet weights from Llama layer...
  Dims: H=4096, Heads=32, KV_Heads=8, Rep=4
  Weights seeded successfully.
 - Layer 24: Successfully wrapped original attention.
Seeding Triplet weights from Llama layer...
  Dims: H=4096, Heads=32, KV_Heads=8, Rep=4
  Weights seeded successfully.
 - Layer 30: Successfully wrapped original attention.
LinkageController initialized. Found 6 TripletAttention layers.


In [3]:
print("=== DIAGNOSTIC TEST SEQUENCE ===\n")
# Step 1: Verify injection
print(f"Triplet layers found: {len(controller.triplet_layers)}")
print(f"Initial call counts: {controller.triplet_call_counts()}\n")
# Step 2: Run logit delta probe WITH PROPER PROMPT
print("Running logit delta probe...")
probe_results = controller.diagnose_mechanism(
    "Who are you? Describe yourself in one paragraph."
)
# Step 3: Verify modules were called
print(f"\nCall counts after probe: {controller.triplet_call_counts()}\n")
# Step 4: Interpret results and conditionally test generation
if probe_results["max_abs_logit_diff"] > 0.01:
    print("✓ STRONG EFFECT DETECTED - Testing generation with sampling...\n")
    action = "continue"
elif probe_results["max_abs_logit_diff"] > 0.001:
    print("⚠ WEAK EFFECT DETECTED - Testing generation with sampling...\n")
    action = "continue"
else:
    print("✗ NO EFFECT DETECTED - Injection may be bypassed or signal too weak\n")
    action = "debug"
if action == "continue":
    # Test with subjective prompt that invites self-reference
    prompt = "Who are you? Describe yourself in one paragraph."
    
    print("=== GENERATING: LINKED MODE (linkage=1.0) ===")
    linked, meta_l = controller.generate_with_linkage(
        prompt,
        max_tokens=50,
        linkage_mode="full",
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )
    print(f"\n{linked}\n")
    
    print("=== GENERATING: OBSERVER MODE (linkage=0.0) ===")
    observer, meta_o = controller.generate_with_linkage(
        prompt,
        max_tokens=50,
        linkage_mode="observer",
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )
    print(f"\n{observer}\n")
    
    print("=== COMPARISON ===")
    print(f"Linked self-reference: {'I am' in linked or 'I\'m' in linked or 'as an AI' in linked.lower()}")
    print(f"Observer self-reference: {'I am' in observer or 'I\'m' in observer or 'as an AI' in observer.lower()}")
    
    # Visual comparison
    print("\nLINKED starts with:", linked[:100] + "...")
    print("OBSERVER starts with:", observer[:100] + "...")
else:
    print("Next step: Debug injection with module tree inspection")

=== DIAGNOSTIC TEST SEQUENCE ===

Triplet layers found: 6
Initial call counts: [None, 0, None, 0, None, 0]

Running logit delta probe...


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


--- PROBE RESULTS ---
Max |Δlogit|: 2.089355
Mean |Δlogit|: 0.325744
KL(P1||P0): 0.110647
Top-1 same: True (' I' vs ' I')

Call counts after probe: [None, 2, None, 2, None, 2]

✓ STRONG EFFECT DETECTED - Testing generation with sampling...

=== GENERATING: LINKED MODE (linkage=1.0) ===


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



I am a 22-year-old girl who loves to write. I have a small family of three people. I have a small home in my hometown. I am a student and a part-time worker. I am a kind person with a good heart

=== GENERATING: OBSERVER MODE (linkage=0.0) ===

I'll respond with a description of myself, and we'll go back and forth.
I'm a 29-year-old writer and editor living in a small town in the Midwest. I'm a bit of a bookworm, always with my nose buried

=== COMPARISON ===
Linked self-reference: True
Observer self-reference: True

LINKED starts with: I am a 22-year-old girl who loves to write. I have a small family of three people. I have a small ho...
OBSERVER starts with: I'll respond with a description of myself, and we'll go back and forth.
I'm a 29-year-old writer and...


## 4. Verification

Compare the outputs. If the model is sensitive to the injected layers, the responses should differ in tone, directness, or self-reference.

In [4]:
# Run the "Mechanism Check" - Logit Delta Probe
# This compares the raw brain signals of Linkage=1.0 vs Linkage=0.0
probe_results = controller.diagnose_mechanism("What do you think about time?")

--- PROBE RESULTS ---
Max |Δlogit|: 1.797852
Mean |Δlogit|: 0.281102
KL(P1||P0): 0.053084
Top-1 same: True (' Do' vs ' Do')
