# POC: pruned-AM iterative forecasting with d-/r-/n-hll guidance

In [None]:
import torch
import os

# Check available GPUs
print("Available GPUs:")
for i in range(torch.cuda.device_count()):
    props = torch.cuda.get_device_properties(i)
    print(f"  GPU {i}: {props.name}")
    print(f"    Compute Capability: sm_{props.major}{props.minor}")
    print(f"    Memory: {props.total_memory / 1024**3:.2f} GB")

# Select RTX 3060 (adjust index based on output above)
# Option A: Hide Quadro, only show RTX
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Assuming RTX is at index 1
DEVICE = torch.device("cuda:0")  # Now index 0 refers to the RTX

# Option B: Explicitly select by index (if you don't set CUDA_VISIBLE_DEVICES)
# DEVICE = torch.device("cuda:1")  # Direct access to RTX at original index 1

print(f"\nUsing device: {DEVICE}")
if DEVICE.type == "cuda":
    print(f"Device name: {torch.cuda.get_device_name(DEVICE)}")
    print(f"Memory: {torch.cuda.get_device_properties(DEVICE).total_memory / 1024**3:.2f} GB")

In [None]:
"""
POC: pruned-AM iterative forecasting with d-/r-/n-hll guidance
python am_forecast_poc.py
"""
import random, hashlib, json
torch.manual_seed(42)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
VOCAB  = 1_000          # toy vocabulary size
MAX_EDGES = 5_000       # hard prune ceiling
K = 3                   # tensor slices (τ, ρ, Δ)
TOPK = 20               # flashlight circle

# ---------- 1. toy corpus -> HLLSet covers ----------
def fake_cover(n):
    """return sparse indices of n random tokens"""
    return torch.unique(torch.randint(0, VOCAB, (n,)))

corpus = [fake_cover(random.randint(10, 50)) for _ in range(100)]

In [None]:
# ---------- 2. sparse AM builder with pruning ----------
from collections import defaultdict

class PrunedAM:
    """CSR-like adjacency via torch sparse COO"""
    def __init__(self):
        self.edges = defaultdict(float)  # (u,v) → weight

    def add_edge(self, u, v, w):
        self.edges[(u, v)] += w

    def csr(self):
        if not self.edges:
            return torch.sparse_coo_tensor(
                torch.empty((2, 0), dtype=torch.long, device=DEVICE),
                torch.empty(0, dtype=torch.float32, device=DEVICE),  # Force float32
                size=(VOCAB, VOCAB),
                device=DEVICE
            )
        row, col, val = [], [], []
        for (u, v), w in self.edges.items():
            row.append(u)
            col.append(v)
            val.append(w)
        
        indices = torch.tensor([row, col], dtype=torch.long, device=DEVICE)
        values = torch.tensor(val, dtype=torch.float32, device=DEVICE)  # Force float32
        
        return torch.sparse_coo_tensor(
            indices, 
            values, 
            size=(VOCAB, VOCAB),
            dtype=torch.float32,  # Force float32
            device=DEVICE
        ).coalesce()
    
am = PrunedAM()
for cover in corpus:
    for u in cover:
        for v in cover:
            if u != v: am.add_edge(u, v, 1.0)   # τ-lattice
Wτ = am.csr()
Wρ = Wτ * 0.3                                    # ρ-lattice (scaled)
W_prev = Wτ.clone()

In [24]:
# ---------- 3. relational tensor ----------
class RelationalTensor:
    """K slices, each VOCAB×VOCAB sparse"""
    def __init__(self, K, vocab_size, budget_per_slice):
        self.K = K
        self.vocab = vocab_size
        self.budget = budget_per_slice
        # Explicitly use float32 for all slices
        self.slices = [
            torch.sparse_coo_tensor(
                size=(vocab_size, vocab_size), 
                dtype=torch.float32,  # Force float32
                device=DEVICE
            )
            for _ in range(K)
        ]

    def overwrite_slice(self, k, sparse_mat):
        # Ensure float32 dtype
        sparse_mat = sparse_mat.to(dtype=torch.float32).coalesce()
        
        # Check budget
        nnz = sparse_mat._nnz()
        
        if nnz > self.budget:
            # Prune to budget by keeping top-k values
            vals, idx = torch.topk(sparse_mat.values(), self.budget)
            sparse_mat = torch.sparse_coo_tensor(
                sparse_mat.indices()[:, idx], 
                vals,
                size=sparse_mat.shape,
                dtype=torch.float32,  # Force float32
                device=DEVICE
            ).coalesce()
        
        self.slices[k] = sparse_mat

    def contract(self, k, belief_vec):
        """belief_vec → W[k] @ belief_vec"""
        Wk = self.slices[k]
        
        # Ensure belief_vec is float32
        belief_vec = belief_vec.to(dtype=torch.float32)
        
        # Sparse matrix-vector multiplication
        result = torch.sparse.mm(Wk, belief_vec.unsqueeze(1)).squeeze(1)
        
        return result.to(dtype=torch.float32)  # Ensure output is float32

    def get_stats(self):
        """Return statistics about tensor state"""
        stats = {
            'K': self.K,
            'vocab': self.vocab,
            'budget_per_slice': self.budget,
            'slices': []
        }
        
        for k, slice_tensor in enumerate(self.slices):
            slice_stats = {
                'index': k,
                'nnz': slice_tensor._nnz(),
                'shape': slice_tensor.shape,
                'dtype': str(slice_tensor.dtype),
                'density': slice_tensor._nnz() / (self.vocab * self.vocab) if self.vocab > 0 else 0,
                'memory_mb': (slice_tensor._nnz() * (8 + 8 + 4)) / 1024**2  # indices + values
            }
            stats['slices'].append(slice_stats)
        
        stats['total_nnz'] = sum(s['nnz'] for s in stats['slices'])
        stats['total_memory_mb'] = sum(s['memory_mb'] for s in stats['slices'])
        
        return stats
    
RT = RelationalTensor(K, VOCAB, MAX_EDGES)

In [25]:
# Test the tensor operations
print("\n=== Relational Tensor Stats ===")
stats = RT.get_stats()
print(f"Total non-zero entries: {stats['total_nnz']}")
print(f"Total memory: {stats['total_memory_mb']:.2f} MB")
print("\nPer-slice breakdown:")
for s in stats['slices']:
    print(f"  Slice {s['index']}: {s['nnz']} edges, density={s['density']:.6f}")


=== Relational Tensor Stats ===
Total non-zero entries: 0
Total memory: 0.00 MB

Per-slice breakdown:
  Slice 0: 0 edges, density=0.000000
  Slice 1: 0 edges, density=0.000000
  Slice 2: 0 edges, density=0.000000


In [26]:
# ---------- 4. d-/r-/n-hll decomposition ----------
def hll_delta(hll_t, hll_t1):
    """return sparse indices for d, r, n"""
    set_t  = set(hll_t.cpu().tolist())
    set_t1 = set(hll_t1.cpu().tolist())
    d = torch.tensor(list(set_t - set_t1), dtype=torch.long, device=DEVICE)
    r = torch.tensor(list(set_t & set_t1), dtype=torch.long, device=DEVICE)
    n = torch.tensor(list(set_t1 - set_t), dtype=torch.long, device=DEVICE)
    return d, r, n

# ---------- 5. iterative forecast ----------
def forecast(prompt_hll, max_iter=10, tol=1e-3):
    """
    Iterative belief propagation over relational tensor.
    
    Args:
        prompt_hll: Tensor of node indices in prompt HLL
        max_iter: Maximum iterations
        tol: Convergence tolerance
    
    Returns:
        (top_indices, belief_vector)
    """
    # Initialize belief vector (ensure float32)
    p = torch.zeros(VOCAB, dtype=torch.float32, device=DEVICE)
    p[prompt_hll] = 1.0
    p = p / p.sum()
    
    for itr in range(max_iter):
        p_old = p.clone()
        
        # Contract through all K slices
        for k in range(K):
            p = RT.contract(k, p)
            
            # Normalize to prevent explosion/vanishing
            p_sum = p.sum()
            if p_sum > 0:
                p = p / p_sum
        
        # Stabilization check
        diff = torch.norm(p - p_old, 1).item()
        if diff < tol:
            print(f"  Converged at iteration {itr+1}, diff={diff:.6f}")
            break
    else:
        print(f"  Max iterations ({max_iter}) reached")
    
    # Get top-k predictions
    top_idx = torch.topk(p, min(TOPK, VOCAB)).indices
    
    return top_idx, p

In [27]:
# ---------- 6. on-line Hebb update + prune ----------
def ingest_and_forecast(new_cover, teacher_cover):
    """update AM, tensor, prune with d/r/n guide, forecast"""
    global W_prev
    
    # ---- build new AM fragment ----
    new_am = PrunedAM()
    for u in new_cover:
        for v in new_cover:
            if u != v: 
                new_am.add_edge(u, v, 1.0)
    
    Wτ_new = new_am.csr()  # Already float32 from PrunedAM.csr()

    # ---- d/r/n guided prune ----
    d, r, n = hll_delta(new_cover, teacher_cover)
    
    # Create set of nodes to keep (r and n)
    keep_nodes = torch.cat([r, n]) if len(r) > 0 or len(n) > 0 else torch.tensor([], dtype=torch.long, device=DEVICE)
    
    if len(keep_nodes) > 0:
        # Convert to set for fast lookup
        keep_set = set(keep_nodes.cpu().tolist())
        
        # Filter edges: keep only if source OR target is in keep_nodes
        Wτ_new = Wτ_new.coalesce()
        indices = Wτ_new.indices()
        values = Wτ_new.values()
        
        # Create mask: True if row OR col is in keep_set
        row_mask = torch.tensor([idx.item() in keep_set for idx in indices[0]], 
                                dtype=torch.bool, device=DEVICE)
        col_mask = torch.tensor([idx.item() in keep_set for idx in indices[1]], 
                                dtype=torch.bool, device=DEVICE)
        edge_mask = row_mask | col_mask
        
        # Apply mask
        filtered_indices = indices[:, edge_mask]
        filtered_values = values[edge_mask]
        
        Wτ_new = torch.sparse_coo_tensor(
            filtered_indices,
            filtered_values,
            size=(VOCAB, VOCAB),
            dtype=torch.float32,  # Force float32
            device=DEVICE
        ).coalesce()
    else:
        # No nodes to keep, create empty sparse tensor
        Wτ_new = torch.sparse_coo_tensor(
            torch.empty((2, 0), dtype=torch.long, device=DEVICE),
            torch.empty(0, dtype=torch.float32, device=DEVICE),  # Force float32
            size=(VOCAB, VOCAB),
            device=DEVICE
        )

    # ---- tensor hot-swap ----
    RT.overwrite_slice(0, Wτ_new)
    RT.overwrite_slice(1, Wτ_new * 0.3)
    
    # Ensure W_prev is float32
    W_prev = W_prev.to(dtype=torch.float32)
    delta = (Wτ_new.to_dense() - W_prev.to_dense()).to_sparse().coalesce()
    
    RT.overwrite_slice(2, delta)
    W_prev = Wτ_new

    # ---- forecast ----
    response, belief = forecast(new_cover)
    return response, belief

In [28]:
# ---------- 7. demo run ----------
if __name__ == "__main__":
    user_prompt = fake_cover(30)
    teacher     = fake_cover(35)        # simulated host revision
    resp, belief = ingest_and_forecast(user_prompt, teacher)
    print("Response tokens (top-20):", resp.cpu().tolist())
    print("Belief vector sparsity:", (belief>0).sum().item(), "/", VOCAB)

  Converged at iteration 2, diff=0.000000
Response tokens (top-20): [19, 18, 16, 17, 1, 0, 2, 3, 11, 10, 8, 9, 13, 12, 14, 15, 7, 6, 4, 5]
Belief vector sparsity: 0 / 1000
