# CONSCIOUSNESS-AWARE PROSODY–EMOTION SNN SYSTEM

This notebook targets Google Colab with TPU v6 (Trillium). It implements a 12-cell pipeline: env setup, imports, mapping, prosody/emotion modules, dataset, integrated model, training, and inference.


In [None]:
"""
CONSCIOUSNESS-AWARE PROSODY–EMOTION SNN SYSTEM
Google Colab TPU v6 Implementation
"""

# TPU detection and env setup (Colab)
import os
import jax

print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print(f"TPU cores: {jax.device_count()}")

if jax.devices():
    print(f"First device: {jax.devices()[0]}")

# In Colab, uncomment to install packages
# !pip -q install flax optax datasets transformers spacy wandb
# !python -m spacy download en_core_web_sm



In [None]:
import jax
import jax.numpy as jnp
from jax import random, jit
from flax import linen as nn
from flax.training import train_state, checkpoints
import optax
from typing import Dict, Optional

print("✓ Imports ready")



In [2]:
# GoEmotions → Plutchik mapping
GOEMOTION_LABELS = [
    'admiration','amusement','anger','annoyance','approval','caring','confusion','curiosity','desire','disappointment',
    'disapproval','disgust','embarrassment','excitement','fear','gratitude','grief','joy','love','nervousness','optimism',
    'pride','realization','relief','remorse','sadness','surprise','neutral'
]
PLUTCHIK_LABELS = ['joy','trust','fear','surprise','sadness','disgust','anger','anticipation']
GOEMOTION_TO_PLUTCHIK = jnp.array([
  [0,1,0,0,0,0,0,0], [1,0,0,0,0,0,0,0], [0,0,0,0,0,0,1,0], [0,0,0,0,0,0,0.7,0],
  [0,1,0,0,0,0,0,0], [0,0.8,0,0,0,0,0,0], [0,0,0.5,0.3,0,0,0,0], [0,0,0,0.5,0,0,0,0.5],
  [0,0,0,0,0,0,0,1], [0,0,0,0,1,0,0,0], [0,0,0,0,0,0.8,0,0], [0,0,0,0,0,1,0,0],
  [0,0,0,0,0.3,0.7,0,0], [1,0,0,0,0,0,0,0], [0,0,1,0,0,0,0,0], [0.8,0.2,0,0,0,0,0,0],
  [0,0,0,0,1,0,0,0], [1,0,0,0,0,0,0,0], [0.6,0.4,0,0,0,0,0,0], [0,0,0.8,0,0,0,0,0],
  [0.5,0,0,0,0,0,0,0.5], [1,0,0,0,0,0,0,0], [0,0,0,1,0,0,0,0], [0.8,0,0,0,0,0,0,0],
  [0,0,0,0,1,0,0,0], [0,0,0,0,1,0,0,0], [0,0,0,1,0,0,0,0], [0,0,0,0,0,0,0,0]
], dtype=jnp.float32)

@jit
def map_to_plutchik(goemotion_probs: jnp.ndarray) -> jnp.ndarray:
    p = jnp.matmul(goemotion_probs, GOEMOTION_TO_PLUTCHIK)
    return p / jnp.maximum(jnp.sum(p, axis=-1, keepdims=True), 1e-6)

print("✓ Mapping ready")



NameError: name 'jnp' is not defined

In [None]:
class ProsodyExtractorJAX(nn.Module):
    """Text-only prosody proxies: pitch, energy, duration, rhythm, pauses."""
    hidden_dim: int = 64

    @nn.compact
    def __call__(self,
                 token_embeddings: jnp.ndarray,  # [b, t, d]
                 pos_tags: jnp.ndarray,          # [b, t, p]
                 syntax_features: jnp.ndarray    # [b, t, s]
                 ) -> Dict[str, jnp.ndarray]:
        word_lengths = jnp.sum(jnp.abs(token_embeddings), axis=-1)  # [b,t]
        pauses = nn.sigmoid(nn.Dense(1)(syntax_features)).squeeze(-1)  # [b,t]
        stress = nn.sigmoid(nn.Dense(1)(pos_tags)).squeeze(-1)        # [b,t]
        duration = nn.Dense(1)(jnp.stack([word_lengths, stress, pauses], axis=-1))  # [b,t,1]
        pitch = nn.gelu(nn.Dense(self.hidden_dim)(jnp.stack([
            jnp.mean(stress, axis=1), jnp.std(word_lengths, axis=1)
        ], axis=-1)))
        energy = nn.gelu(nn.Dense(self.hidden_dim)(jnp.stack([
            jnp.sum(stress, axis=1), jnp.sum(pauses, axis=1)
        ], axis=-1)))
        rhythm_phase = 2 * jnp.pi * (jnp.mean(word_lengths, axis=1) / (jnp.max(word_lengths, axis=1) + 1e-6))
        rhythm = jnp.stack([jnp.cos(rhythm_phase), jnp.sin(rhythm_phase)], axis=-1)  # [b,2]
        return { 'pitch': pitch, 'energy': energy, 'duration': duration, 'rhythm': rhythm, 'pauses': pauses }

class PlutchikEmotionEncoderJAX(nn.Module):
    emotion_dim: int = 8
    hidden_dim: int = 64

    @nn.compact
    def __call__(self,
                 text_embedding: jnp.ndarray,           # [b,d]
                 prosody_features: Dict[str, jnp.ndarray],
                 personality_traits: Optional[jnp.ndarray] = None  # [b,5]
                 ) -> Dict[str, jnp.ndarray]:
        semantic = nn.gelu(nn.Dense(self.hidden_dim)(text_embedding))
        pros_concat = jnp.concatenate([
            prosody_features['pitch'], prosody_features['energy'], prosody_features['rhythm']
        ], axis=-1)
        pros_h = nn.gelu(nn.Dense(self.hidden_dim)(pros_concat))
        trait_h = nn.gelu(nn.Dense(self.hidden_dim)(personality_traits)) if personality_traits is not None else jnp.zeros_like(semantic)
        fused = nn.gelu(nn.Dense(self.hidden_dim)(semantic * pros_h + trait_h))
        goemotion_logits = nn.Dense(28)(fused)
        goemotion = nn.sigmoid(goemotion_logits)
        plutchik = map_to_plutchik(goemotion)
        return { 'plutchik': plutchik, 'goemotion': goemotion, 'embeddings': fused }

print("✓ Prosody & Emotion modules ready")



In [None]:
# Minimal synthetic dataset loader (Colab: replace with HF datasets if desired)
import numpy as np

def make_synthetic_batch(batch_size: int = 8, seq_len: int = 64, embed_dim: int = 64):
    key = random.PRNGKey(0)
    token_embeddings = random.normal(key, (batch_size, seq_len, embed_dim))
    pos_tags = random.normal(key, (batch_size, seq_len, 10))
    syntax_features = random.normal(key, (batch_size, seq_len, 3))
    goemotion_labels = jnp.clip(random.uniform(key, (batch_size, 28)), 0, 1)
    # Normalize to probabilities for Plutchik labels (weak supervision)
    plutchik_labels = map_to_plutchik(goemotion_labels)
    return {
        'token_embeddings': token_embeddings,
        'pos_tags': pos_tags,
        'syntax_features': syntax_features,
        'goemotion_labels': goemotion_labels,
        'plutchik_labels': plutchik_labels
    }

print("✓ Synthetic data function ready")



In [None]:
class ConsciousnessAwareSNNModel(nn.Module):
    embed_dim: int = 64
    num_experts: int = 4
    hidden_dim: int = 128

    @nn.compact
    def __call__(self,
                 token_embeddings: jnp.ndarray,
                 pos_tags: jnp.ndarray,
                 syntax_features: jnp.ndarray,
                 personality_traits: Optional[jnp.ndarray] = None,
                 training: bool = True):
        b, t, d = token_embeddings.shape
        pooled = jnp.mean(token_embeddings, axis=1)
        pros = ProsodyExtractorJAX(hidden_dim=64)(token_embeddings, pos_tags, syntax_features)
        emo = PlutchikEmotionEncoderJAX(hidden_dim=64)(pooled, pros, personality_traits)
        composite = jnp.concatenate([
            jnp.mean(pros['pitch'], axis=-1, keepdims=True),
            jnp.mean(pros['energy'], axis=-1, keepdims=True),
            emo['plutchik']
        ], axis=-1)  # [b, 10]
        gate = nn.softmax(nn.Dense(self.num_experts)(composite), axis=-1)
        out = nn.gelu(nn.Dense(self.hidden_dim)(pooled))
        return {'output': out, 'emotions': emo, 'prosody': pros, 'gate_weights': gate}

print("✓ Integrated model ready")



In [None]:
# SBERT setup (Colab: uncomment installs)
# !pip -q install sentence-transformers
from sentence_transformers import SentenceTransformer

SBERT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"  # fast 384-dim
sbert = SentenceTransformer(SBERT_MODEL_NAME)
print(f"Loaded SBERT: {SBERT_MODEL_NAME}")

def sbert_encode(texts):
    """Return SBERT embeddings as jnp array [batch, dim]."""
    embs = sbert.encode(texts, convert_to_numpy=True, normalize_embeddings=False)
    return jnp.asarray(embs)



In [None]:
class SBERTConsciousnessAwareSNN(nn.Module):
    """Consciousness-aware model using SBERT sentence embeddings."""
    sbert_dim: int = 384
    num_experts: int = 4
    hidden_dim: int = 128

    @nn.compact
    def __call__(self,
                 sbert_embeddings: jnp.ndarray,                 # [b, sbert_dim]
                 personality_traits: Optional[jnp.ndarray] = None,
                 training: bool = True):
        # Project SBERT to model space
        rich = nn.gelu(nn.Dense(self.hidden_dim)(sbert_embeddings))  # [b, hidden]

        # Lightweight prosody proxies derived from SBERT
        pitch = nn.gelu(nn.Dense(64)(rich))
        energy = nn.gelu(nn.Dense(64)(rich))
        rhythm = sbert_embeddings[:, :2] if sbert_embeddings.shape[1] >= 2 else jnp.zeros((sbert_embeddings.shape[0], 2))
        pros = { 'pitch': pitch, 'energy': energy, 'rhythm': rhythm }

        # Emotion via PlutchikEmotionEncoderJAX using SBERT as text embedding
        emo = PlutchikEmotionEncoderJAX(hidden_dim=64)(sbert_embeddings, pros, personality_traits)

        # Composite gating
        comp = jnp.concatenate([
            jnp.mean(pitch, axis=-1, keepdims=True),
            jnp.mean(energy, axis=-1, keepdims=True),
            emo['plutchik']
        ], axis=-1)  # [b, 10]
        gate = nn.softmax(nn.Dense(self.num_experts)(comp), axis=-1)

        # Output head
        out = nn.gelu(nn.Dense(self.hidden_dim)(rich))
        return { 'output': out, 'emotions': emo, 'prosody': pros, 'gate_weights': gate }

print("✓ SBERT-integrated model ready")



In [None]:
# SBERT inference example
texts = [
    "I am so excited and grateful for this amazing opportunity!",
    "This is disappointing and makes me a bit sad."
]
embs = sbert_encode(texts)

sbert_model = SBERTConsciousnessAwareSNN(sbert_dim=embs.shape[1])
params = sbert_model.init({'params': random.PRNGKey(0)}, embs, None, False)
outputs = sbert_model.apply(params, embs, None, False)

for i, txt in enumerate(texts):
    pl = outputs['emotions']['plutchik'][i]
    print(f"\nText: {txt}")
    print("Plutchik:")
    for j, name in enumerate(['joy','trust','fear','surprise','sadness','disgust','anger','anticipation']):
        print(f"  {name:12s}: {float(pl[j]):.3f}")



In [None]:
# Fine-tune SBERT on GoEmotions → Plutchik (argmax pseudo-label)
# !pip -q install datasets sentence-transformers
from datasets import load_dataset
from sentence_transformers import InputExample, losses
from sentence_transformers import SentencesDataset
from torch.utils.data import DataLoader
import torch

# Load dataset (small subset for demo)
geo = load_dataset("go_emotions", "simplified")
train_split = geo["train"].select(range(5000))

# Build InputExamples with dominant Plutchik index as softmax label
examples = []
for ex in train_split:
    text = ex["text"]
    # build multi-hot GoEmotions vector
    ge = jnp.zeros((28,), dtype=jnp.float32).at[jnp.array(ex["labels"])] .set(1.0)
    # map to Plutchik distribution
    pl = map_to_plutchik(ge[None, :])[0]
    label_id = int(jnp.argmax(pl))
    examples.append(InputExample(texts=[text], label=label_id))

# SentenceTransformer expects a dataset and DataLoader
train_dataset = SentencesDataset(examples, sbert)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=64)

# Softmax loss over 8 Plutchik classes
train_loss = losses.SoftmaxLoss(model=sbert, sentence_embedding_dimension=sbert.get_sentence_embedding_dimension(), num_labels=8)

# Training
warmup_steps = int(len(train_dataloader) * 1 * 0.1)
print(f"Training SBERT head with SoftmaxLoss over 8 classes, steps/epoch={len(train_dataloader)}")
sbert.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=1,
    warmup_steps=warmup_steps,
    show_progress_bar=True
)
print("✓ SBERT fine-tuning complete")



In [None]:
# Use fine-tuned SBERT in the pipeline
texts_ft = [
    "I feel grateful and joyful today.",
    "I am worried and a bit afraid of the outcome."
]
embs_ft = sbert_encode(texts_ft)
outputs_ft = sbert_model.apply(params, embs_ft, None, False)

for i, txt in enumerate(texts_ft):
    pl = outputs_ft['emotions']['plutchik'][i]
    print(f"\n[FT] Text: {txt}")
    for j, name in enumerate(['joy','trust','fear','surprise','sadness','disgust','anger','anticipation']):
        print(f"  {name:12s}: {float(pl[j]):.3f}")



In [None]:
# Bio-inspired multi-heads: STDP and NLMS
class STDPHead(nn.Module):
    """Surrogate STDP head producing spike rates from a hidden state.
    Loss encourages correlation between pre/post spikes given target sign.
    """
    spike_dim: int = 64
    v_th: float = 0.0

    @nn.compact
    def __call__(self, hidden: jnp.ndarray) -> jnp.ndarray:
        v = nn.Dense(self.spike_dim)(hidden)
        # Surrogate spike rate via fast sigmoid; center around threshold
        spikes = jax.nn.sigmoid(v - self.v_th)
        return spikes  # [b, spike_dim]

    @staticmethod
    def stdp_loss(pre: jnp.ndarray, post: jnp.ndarray, target_sign: jnp.ndarray) -> jnp.ndarray:
        # Hebbian-like: encourage pre*post when target_sign>0, anti-Hebbian when <0
        corr = jnp.mean(pre * post, axis=-1)
        return jnp.mean(-target_sign * corr)

class NLMSHead(nn.Module):
    """Normalized LMS predictor head.
    Returns prediction y and computes NLMS residual-based loss.
    """
    out_dim: int = 8  # align with Plutchik classes by default

    @nn.compact
    def __call__(self, hidden: jnp.ndarray) -> jnp.ndarray:
        y = nn.Dense(self.out_dim)(hidden)
        return y

    @staticmethod
    def nlms_loss(pred: jnp.ndarray, target: jnp.ndarray, x: jnp.ndarray, eps: float = 1e-6) -> jnp.ndarray:
        # NLMS-inspired: minimize normalized squared error with input power
        err = pred - target
        norm = jnp.sum(x * x, axis=-1, keepdims=True) + eps
        se = jnp.sum(err * err, axis=-1, keepdims=True) / norm
        return jnp.mean(se)

print("✓ STDP and NLMS heads ready")



In [None]:
# SentencePiece adapter for SBERT (upload your .model to Colab first)
# !pip -q install sentencepiece
import sentencepiece as spm
import torch

class SentencePieceTokenizer:
    def __init__(self, sp_model_path: str, max_seq_length: int = 128):
        self.sp = spm.SentencePieceProcessor()
        self.sp.load(sp_model_path)
        self.pad_token_id = self.sp.pad_id()
        self.cls_token_id = self.sp.bos_id()
        self.sep_token_id = self.sp.eos_id()
        self.unk_token_id = self.sp.unk_id()
        self.max_seq_length = max_seq_length
        self.vocab_size = self.sp.get_piece_size()

    def __call__(self, texts, padding=True, truncation=True, max_length=None, return_tensors='pt'):
        if isinstance(texts, str):
            texts = [texts]
        max_len = max_length or self.max_seq_length
        ids_list = []
        for txt in texts:
            ids = self.sp.encode(txt, out_type=int)
            if truncation:
                ids = ids[:max_len - 2]
            ids = [self.cls_token_id] + ids + [self.sep_token_id]
            ids_list.append(ids)
        if padding:
            pad_len = max(len(ids) for ids in ids_list)
            ids_list = [ids + [self.pad_token_id] * (pad_len - len(ids)) for ids in ids_list]
        attn = [[1 if tok != self.pad_token_id else 0 for tok in ids] for ids in ids_list]
        if return_tensors == 'pt':
            return {
                'input_ids': torch.tensor(ids_list, dtype=torch.long),
                'attention_mask': torch.tensor(attn, dtype=torch.long)
            }
        return {'input_ids': ids_list, 'attention_mask': attn}

print("✓ SentencePiece tokenizer wrapper ready")



In [None]:
# Prosody extraction from SentencePiece tokens
class SentencePieceProsodyExtractor(nn.Module):
    hidden_dim: int = 64

    @nn.compact
    def __call__(self,
                 token_ids: jnp.ndarray,           # [b, t]
                 token_embeddings: jnp.ndarray     # [b, t, d]
                 ) -> Dict[str, jnp.ndarray]:
        b, t = token_ids.shape
        # Simple heuristics: whitespace boundary ≈ small IDs; punct set example
        word_boundary = (token_ids < 256).astype(jnp.float32)
        punct_ids = jnp.array([33, 34, 35, 36, 37, 38, 39])  # placeholder ids
        is_punct = jnp.isin(token_ids, punct_ids).astype(jnp.float32)
        token_len = jnp.sum(jnp.abs(token_embeddings), axis=-1)

        pause_feats = jnp.stack([
            word_boundary,
            is_punct,
            jnp.roll(is_punct, 1, axis=1)
        ], axis=-1)
        pause_logits = nn.Dense(1)(pause_feats)
        pause_probs = jax.nn.sigmoid(pause_logits).squeeze(-1)

        stress_feats = jnp.stack([
            word_boundary,
            token_len,
            jnp.abs(jnp.mean(token_embeddings, axis=-1))
        ], axis=-1)
        stress_logits = nn.Dense(1)(stress_feats)
        stress_probs = jax.nn.sigmoid(stress_logits).squeeze(-1)

        avg_stress = jnp.mean(stress_probs, axis=1)
        stress_var = jnp.std(stress_probs, axis=1)
        boundary_var = jnp.std(word_boundary, axis=1)

        pitch = nn.gelu(nn.Dense(self.hidden_dim)(jnp.stack([avg_stress, stress_var, boundary_var], axis=-1)))
        energy = nn.gelu(nn.Dense(self.hidden_dim)(jnp.stack([
            jnp.sum(stress_probs, axis=1), jnp.sum(pause_probs, axis=1), jnp.sum(word_boundary, axis=1)
        ], axis=-1)))

        duration = nn.Dense(1)(jnp.stack([token_len, stress_probs, pause_probs], axis=-1))

        return {
            'pitch': pitch,
            'energy': energy,
            'duration': duration,
            'rhythm': jnp.stack([boundary_var, stress_var], axis=-1),
            'pauses': pause_probs,
            'stress': stress_probs,
            'word_boundaries': word_boundary,
        }

print("✓ SentencePiece prosody extractor ready")



In [None]:
# Demo: SentencePiece → SBERT → Prosody → Emotion → Summary
import os, glob
from jax import random as jrandom

# Locate .model under models/spm
sp_search_dir = "models/spm"
sp_model_candidates = glob.glob(os.path.join(sp_search_dir, "**", "*.model"), recursive=True)
assert len(sp_model_candidates) > 0, f"No .model found under {sp_search_dir}"
sp_model_path = sp_model_candidates[0]
print(f"Using SentencePiece model: {sp_model_path}")

# Build tokenizer
sp_tokenizer = SentencePieceTokenizer(sp_model_path, max_seq_length=128)

# Sample texts
texts = [
    "I'm so happy! This is amazing.",
    "I feel anxious, but trying to stay calm."
]

# Tokenize (PyTorch tensors from adapter; convert to numpy)
sp_batch = sp_tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors='pt')
input_ids = jnp.asarray(sp_batch['input_ids'].numpy())      # [b, t]
attention_mask = jnp.asarray(sp_batch['attention_mask'].numpy())

# SBERT sentence embeddings
sent_embs = sbert_encode(texts)  # [b, dim]

# Token-level embeddings (broadcast sentence embedding across sequence)
b, t = input_ids.shape
emb_dim = int(sent_embs.shape[1])
token_embs = jnp.repeat(sent_embs[:, None, :], t, axis=1)   # [b, t, dim]

# Prosody from SentencePiece tokens
pros_extractor = SentencePieceProsodyExtractor(hidden_dim=64)
pros_params = pros_extractor.init({'params': jrandom.PRNGKey(0)}, input_ids, token_embs)
prosody = pros_extractor.apply(pros_params, input_ids, token_embs)

# Emotions from SBERT embeddings + prosody
emo_encoder = PlutchikEmotionEncoderJAX(hidden_dim=64)
emo_params = emo_encoder.init({'params': jrandom.PRNGKey(1)}, sent_embs, prosody, None)
emo = emo_encoder.apply(emo_params, sent_embs, prosody, None)

# Print summary
pl_labels = ['joy','trust','fear','surprise','sadness','disgust','anger','anticipation']
for i, text in enumerate(texts):
    print("\nText:", text)
    # Show first 10 token pieces via underlying SP model
    pieces = [sp_tokenizer.sp.id_to_piece(int(x)) for x in sp_batch['input_ids'][i][:10].tolist()]
    print("Pieces:", pieces)
    print("Plutchik:")
    for j, name in enumerate(pl_labels):
        print(f"  {name:12s}: {float(emo['plutchik'][i, j]):.3f}")
    print("Prosody: pitch[0:5]=", prosody['pitch'][i, :5])



In [None]:
# Intent Compass module (primary intents + modifiers)
INTENT_LABELS = ['inform','negotiate','question','clarify','social','express','command','request']
INTENT_MODIFIERS = ['urgency','certainty','formality','politeness']

class IntentCompassJAX(nn.Module):
    hidden_dim: int = 128

    @nn.compact
    def __call__(self,
                 sbert_embedding: jnp.ndarray,        # [b, d]
                 prosody_features: Dict[str, jnp.ndarray],
                 emotion_probs: jnp.ndarray,          # [b, 8]
                 personality_traits: jnp.ndarray | None = None
                 ) -> Dict[str, jnp.ndarray]:
        # semantic path
        sem_h = nn.gelu(nn.Dense(self.hidden_dim)(sbert_embedding))
        # prosody path
        pros_concat = jnp.concatenate([
            prosody_features['pitch'], prosody_features['energy'],
            prosody_features['rhythm'] if prosody_features['rhythm'].ndim == 2 else prosody_features['rhythm'][:, :2],
            jnp.mean(prosody_features.get('stress', prosody_features['pitch']), axis=-1, keepdims=True),
            jnp.mean(prosody_features.get('pauses', prosody_features['energy']), axis=-1, keepdims=True)
        ], axis=-1)
        pros_h = nn.gelu(nn.Dense(self.hidden_dim)(pros_concat))
        # emotion path
        emo_h = nn.gelu(nn.Dense(self.hidden_dim)(emotion_probs))
        # personality path
        if personality_traits is not None:
            pers_h = nn.gelu(nn.Dense(self.hidden_dim)(personality_traits))
        else:
            pers_h = jnp.zeros_like(sem_h)
        # fuse
        stack = jnp.stack([sem_h, pros_h, emo_h, pers_h], axis=1)  # [b,4,h]
        w = nn.softmax(nn.Dense(1)(stack), axis=1)                  # [b,4,1]
        fused = jnp.sum(stack * w, axis=1)
        fused = nn.gelu(nn.Dense(self.hidden_dim)(fused))
        # heads
        primary_logits = nn.Dense(8)(fused)
        primary = nn.softmax(primary_logits, axis=-1)
        urgency = nn.sigmoid(nn.Dense(1)(jnp.concatenate([
            jnp.mean(prosody_features['energy'], axis=-1, keepdims=True), emotion_probs[:, 6:7]
        ], axis=-1)))
        certainty = nn.sigmoid(nn.Dense(1)(jnp.concatenate([
            1.0 - jnp.mean(prosody_features['pauses'], axis=-1, keepdims=True) if 'pauses' in prosody_features else jnp.ones((primary.shape[0],1)),
            jnp.max(primary, axis=-1, keepdims=True)
        ], axis=-1)))
        formality = nn.sigmoid(nn.Dense(1)(fused[:, :2]))
        politeness = nn.sigmoid(nn.Dense(1)(jnp.concatenate([emotion_probs[:, 1:2], fused[:, 2:3]], axis=-1)))
        # compass position
        angles = jnp.array([0, jnp.pi/4, jnp.pi/2, 3*jnp.pi/4, jnp.pi, 5*jnp.pi/4, 3*jnp.pi/2, 7*jnp.pi/4])
        pos = jnp.stack([
            jnp.sum(primary * jnp.cos(angles)[None, :], axis=-1),
            jnp.sum(primary * jnp.sin(angles)[None, :], axis=-1)
        ], axis=-1)
        return {
            'primary_intent': primary,
            'modifiers': {
                'urgency': urgency,
                'certainty': certainty,
                'formality': formality,
                'politeness': politeness,
            },
            'compass_position': pos,
        }

print("✓ Intent Compass ready")



In [None]:
# Intent Compass demo using previous SP+SBERT demo values
ic = IntentCompassJAX(hidden_dim=128)
ic_params = ic.init({'params': jrandom.PRNGKey(2)}, sent_embs, prosody, emo['plutchik'])
ic_out = ic.apply(ic_params, sent_embs, prosody, emo['plutchik'])

for i, text in enumerate(texts):
    print(f"\nText: {text}")
    print("Primary intent distribution:")
    for j, name in enumerate(INTENT_LABELS):
        print(f"  {name:10s}: {float(ic_out['primary_intent'][i, j]):.3f}")
    mods = ic_out['modifiers']
    print("Modifiers:")
    for m in INTENT_MODIFIERS:
        print(f"  {m:10s}: {float(mods[m][i, 0]):.3f}")
    cp = ic_out['compass_position'][i]
    print(f"Compass pos: ({float(cp[0]):.3f}, {float(cp[1]):.3f})")



In [None]:
"""
CELL 0: Mount/Upload emotions.jsonl (Colab)
"""
from google.colab import drive, files
import os

drive.mount('/content/drive')
print("Uploading emotions.jsonl...")
uploaded = files.upload()
dataset_path = list(uploaded.keys())[0]
print(f"✓ File uploaded: {dataset_path} | size={os.path.getsize(dataset_path)/1024/1024:.2f} MB")



In [None]:
"""
CELL 1: Setup & Dependencies (TPU v6)
"""
import jax, jax.numpy as jnp
from jax import random, jit
from flax import linen as nn
from flax.training import train_state
import optax, numpy as np, json, os
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from datetime import datetime
import wandb

print("="*60)
print("CONSCIOUSNESS-AWARE SNN: EMOTION-INTENT TRAINING")
print("="*60)
print(f"TPU Cores: {jax.device_count()} | Device: {jax.devices()[0].device_kind} | JAX: {jax.__version__}")

run_name = f"emotion_snn_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
wandb.init(project="consciousness-snn", name=run_name,
           config={"batch_size_per_core":32, "epochs":10, "learning_rate":3e-5})
print(f"✓ W&B initialized: {run_name}")



In [None]:
"""
CELL 2: Load emotions.jsonl and split
"""

def load_emotion_dataset(jsonl_path):
    recs = []
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                try:
                    recs.append(json.loads(line))
                except json.JSONDecodeError:
                    pass
    return recs

records = load_emotion_dataset(dataset_path)
print(f"✓ Loaded {len(records)} records")

from sklearn.model_selection import train_test_split
train_records, temp_records = train_test_split(records, test_size=0.2, random_state=42)
val_records, test_records = train_test_split(temp_records, test_size=0.5, random_state=42)
print(f"Split → Train {len(train_records)} | Val {len(val_records)} | Test {len(test_records)}")



In [None]:
"""
CELL 3: Label mappings & constants
"""
PLUTCHIK_LABELS = ['joy','trust','fear','surprise','sadness','disgust','anger','anticipation']
COMPASS_INTENTS = ['inform','negotiate','question','clarify','social','express','command','request']

# Example mapping from custom intents to compass intents (extend as needed)
INTENT_MAPPING = {
    'share_news': 'inform', 'ask_help': 'request', 'clarify': 'clarify',
    'complain': 'express', 'thank': 'social', 'propose': 'negotiate'
}

TONE_TO_PROSODY = {
    'ecstatic': {'energy': 0.95, 'pitch_var': 0.9, 'tempo': 1.3},
    'urgent': {'energy': 0.9, 'pitch_var': 0.8, 'tempo': 1.4},
    'neutral': {'energy': 0.5, 'pitch_var': 0.4, 'tempo': 1.0},
}
print("✓ Mappings ready")



In [None]:
"""
CELL 4: Preprocess all records with SBERT and lightweight linguistic features
"""
sbert_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
print(f"✓ SBERT loaded: {sbert_model.get_sentence_embedding_dimension()}-dim")

import spacy
nlp = spacy.load("en_core_web_sm")

cache_dir = "/tmp/sbert_cache"; os.makedirs(cache_dir, exist_ok=True)

def preprocess_record(record, idx):
    text = record.get('text', '')
    rid = record.get('id', f'rec_{idx}')
    cf = f"{cache_dir}/{rid}.npy"
    if os.path.exists(cf):
        emb = np.load(cf)
    else:
        emb = sbert_model.encode(text, convert_to_tensor=False)
        np.save(cf, emb)
    doc = nlp(text)
    token_ids = np.array([hash(t.text) % 32000 for t in doc[:128]], dtype=np.int32)
    if len(token_ids) < 128:
        token_ids = np.pad(token_ids, (0, 128 - len(token_ids)))
    pos = np.zeros((128, 10), dtype=np.float32)
    syn = np.zeros((128, 3), dtype=np.float32)
    pm = {'NOUN':0,'VERB':1,'ADJ':2,'ADV':3,'PRON':4,'DET':5,'ADP':6,'CONJ':7,'NUM':8,'PUNCT':9}
    for i, tok in enumerate(list(doc)[:128]):
        if tok.pos_ in pm: pos[i, pm[tok.pos_]] = 1.0
        syn[i,0] = min(abs(tok.head.i - tok.i),10)/10.0
        syn[i,1] = 1.0 if tok.is_punct else 0.0
        syn[i,2] = 1.0 if tok.is_stop else 0.0
    # plutchik
    p = np.zeros(8, dtype=np.float32)
    prim = record.get('plutchik',{}).get('primary','joy')
    inten = float(record.get('plutchik',{}).get('intensity',0.5))
    if prim in PLUTCHIK_LABELS: p[PLUTCHIK_LABELS.index(prim)] = inten
    sec = record.get('plutchik',{}).get('secondary')
    sec_map = {'optimism':'anticipation','admiration':'trust','anxiety':'fear','hope':'anticipation','excitement':'joy','contentment':'joy','grief':'sadness','despair':'sadness','contempt':'disgust','outrage':'anger','fury':'anger','resentment':'anger'}
    if sec in sec_map: p[PLUTCHIK_LABELS.index(sec_map[sec])] += 0.25
    p = p / (np.sum(p)+1e-6)
    # intent
    mapped = INTENT_MAPPING.get(record.get('intent','inform'),'inform')
    intent_idx = COMPASS_INTENTS.index(mapped)
    intent_oh = np.zeros(8, dtype=np.float32); intent_oh[intent_idx]=1.0
    # tone
    tone = record.get('tone','neutral'); pros = TONE_TO_PROSODY.get(tone, TONE_TO_PROSODY['neutral'])
    style = record.get('style',{}); beta = float(style.get('beta',0.5)); phi=float(style.get('phi',0.5))
    urgency = inten if inten>0.6 else inten*0.7; certainty = phi if phi>0 else 0.5
    return {
        'sbert_embedding': emb.astype(np.float32),
        'token_ids': token_ids,
        'pos_tags': pos,
        'syntax_features': syn,
        'plutchik_probs': p,
        'intent_label': intent_oh,
        'urgency': urgency,
        'certainty': certainty,
        'formality': beta,
        'politeness': phi,
    }

print("Preprocessing train/val/test...")
train_processed = [preprocess_record(r,i) for i,r in enumerate(tqdm(train_records))]
val_processed   = [preprocess_record(r,i) for i,r in enumerate(tqdm(val_records))]
test_processed  = [preprocess_record(r,i) for i,r in enumerate(tqdm(test_records))]
print("✓ Preprocessing done")



In [None]:
"""
CELL 5: Model (emotion + intent + modifiers) and training loop
"""
class ConsciousnessAwareSNN(nn.Module):
    num_experts: int = 5
    hidden_dim: int = 256
    sbert_dim: int = 384

    @nn.compact
    def __call__(self, sbert_embeddings, token_ids, pos_tags, syntax_features, training=True):
        # Prosody (lightweight)
        pauses = nn.sigmoid(nn.Dense(1)(nn.relu(nn.Dense(32)(syntax_features)))).squeeze(-1)
        stress = nn.sigmoid(nn.Dense(1)(nn.relu(nn.Dense(32)(pos_tags)))).squeeze(-1)
        pitch = nn.relu(nn.Dense(64)(jnp.concatenate([jnp.mean(stress, axis=1, keepdims=True), jnp.max(pauses, axis=1, keepdims=True)], axis=-1)))
        energy = nn.relu(nn.Dense(64)(jnp.concatenate([jnp.std(stress, axis=1, keepdims=True), jnp.sum(pauses, axis=1, keepdims=True)], axis=-1)))
        prosody = {'pitch': pitch, 'energy': energy, 'pauses': pauses, 'stress': stress}
        # Emotion
        emotion_h = nn.relu(nn.Dense(128)(jnp.concatenate([sbert_embeddings, pitch, energy], axis=-1)))
        plutchik_probs = nn.softmax(nn.Dense(8)(emotion_h))
        # Intent
        intent_h = nn.relu(nn.Dense(128)(jnp.concatenate([sbert_embeddings, emotion_h, pitch], axis=-1)))
        primary_intent = nn.softmax(nn.Dense(8)(intent_h))
        # Modifiers
        urgency = nn.sigmoid(nn.Dense(1)(intent_h)); certainty = nn.sigmoid(nn.Dense(1)(intent_h))
        formality = nn.sigmoid(nn.Dense(1)(intent_h)); politeness = nn.sigmoid(nn.Dense(1)(intent_h))
        # Experts
        composite = jnp.concatenate([sbert_embeddings, emotion_h, intent_h], axis=-1)
        gate_weights = nn.softmax(nn.Dense(self.num_experts)(composite))
        output = nn.Dense(self.hidden_dim)(composite)
        return {
            'output': output,
            'prosody': prosody,
            'emotions': {'plutchik': plutchik_probs},
            'intent': {
                'primary_intent': primary_intent,
                'modifiers': {'urgency': urgency, 'certainty': certainty, 'formality': formality, 'politeness': politeness}
            },
            'gate_weights': gate_weights
        }

# Init
rng = random.PRNGKey(42)
model = ConsciousnessAwareSNN()
params = model.init({'params': rng}, jnp.ones((2,384)), jnp.ones((2,128),dtype=jnp.int32), jnp.ones((2,128,10)), jnp.ones((2,128,3)), training=False)['params']

# Optimizer
steps = (len(train_processed)//32)*10
schedule = optax.warmup_cosine_decay_schedule(0.0, 3e-5, 50, steps, 1e-5)
tx = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=schedule, weight_decay=0.01))
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jit
def train_step(state, batch):
    def loss_fn(p):
        out = model.apply({'params': p}, batch['sbert_embedding'], batch['token_ids'], batch['pos_tags'], batch['syntax_features'], training=True)
        el = optax.softmax_cross_entropy(out['emotions']['plutchik'], batch['plutchik_probs']).mean()
        il = optax.softmax_cross_entropy(out['intent']['primary_intent'], batch['intent_label']).mean()
        m = out['intent']['modifiers']; ml = ((m['urgency']-batch['urgency'])**2 + (m['certainty']-batch['certainty'])**2 + (m['formality']-batch['formality'])**2 + (m['politeness']-batch['politeness'])**2).mean()
        gw = out['gate_weights']; div = -jnp.mean(jnp.sum(gw * jnp.log(gw + 1e-8), axis=-1))
        total = 1.0*el + 1.0*il + 0.5*ml + 0.02*div
        return total, {'loss': total, 'emotion': el, 'intent': il, 'modifiers': ml, 'diversity': -div}
    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    return state.apply_gradients(grads=grads), metrics

def batches(data, bs=32, shuffle=True):
    idx = np.arange(len(data));
    if shuffle: np.random.shuffle(idx)
    for s in range(0, len(idx), bs):
        sel = idx[s:s+bs]; d=[data[i] for i in sel]
        yield {
            'sbert_embedding': jnp.array([x['sbert_embedding'] for x in d]),
            'token_ids': jnp.array([x['token_ids'] for x in d]),
            'pos_tags': jnp.array([x['pos_tags'] for x in d]),
            'syntax_features': jnp.array([x['syntax_features'] for x in d]),
            'plutchik_probs': jnp.array([x['plutchik_probs'] for x in d]),
            'intent_label': jnp.array([x['intent_label'] for x in d]),
            'urgency': jnp.array([x['urgency'] for x in d]).reshape(-1,1),
            'certainty': jnp.array([x['certainty'] for x in d]).reshape(-1,1),
            'formality': jnp.array([x['formality'] for x in d]).reshape(-1,1),
            'politeness': jnp.array([x['politeness'] for x in d]).reshape(-1,1),
        }

print("✓ Training loop ready")



In [None]:
"""
CELL 6: Train on your dataset (10 epochs demo)
"""
print("\n"+"="*60) ; print("START TRAINING") ; print("="*60)
for epoch in range(10):
    metrics_buf = []
    for step, batch in enumerate(batches(train_processed, bs=32, shuffle=True)):
        state, metrics = train_step(state, batch)
        metrics_buf.append(metrics)
        if (step+1)%10==0:
            avg = jnp.mean(jnp.array([m['loss'] for m in metrics_buf[-10:]]))
            print(f"  epoch {epoch+1} step {step+1}: loss={float(avg):.4f}")
    avg_epoch = jnp.mean(jnp.array([m['loss'] for m in metrics_buf]))
    wandb.log({'epoch': epoch+1, 'train_loss': float(avg_epoch)})
print("✓ Training complete")

