# Consciousness-Aware SNN Training on Colab A100

Use this notebook to fine-tune SBERT with SentencePiece-driven prosody features on a single A100 GPU. Cells install dependencies, stage data from GCS, build the PyTorch model, and run training with bfloat16 autocast plus gradient accumulation. Follow the configuration placeholders before launching a run.


In [None]:
# Verify GPU availability
!nvidia-smi


In [None]:
# Install runtime dependencies (PyTorch GPU build, Transformers, SentencePiece, etc.)
%pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
%pip install --upgrade transformers accelerate sentencepiece datasets scikit-learn pandas tiktoken


## Authenticate and Configure GCS Access
Run the next cell to authenticate with Google Cloud (required for `gsutil`). Ensure the Colab project has access to the `aura_tpu_data` bucket or update the placeholders to your bucket.


In [None]:
from google.colab import auth
auth.authenticate_user()


In [None]:
import os
import pathlib

# TODO: Update these values for your project/bucket layout
PROJECT_ID = "auragcloudtpu"
BUCKET_NAME = "aura_tpu_data"
DATA_OBJECT = "data/json/emotions.jsonl"
SP_MODEL_OBJECT = "models/spm/spiece.model"
LOCAL_WORKDIR = pathlib.Path("/content/aura_tpu")
LOCAL_WORKDIR.mkdir(parents=True, exist_ok=True)

os.environ["PROJECT_ID"] = PROJECT_ID
os.environ["BUCKET_NAME"] = BUCKET_NAME
os.environ["DATA_OBJECT"] = DATA_OBJECT
os.environ["SP_MODEL_OBJECT"] = SP_MODEL_OBJECT
print(f"Configured workdir: {LOCAL_WORKDIR}")


In [None]:
# Copy data and SentencePiece model from GCS
!mkdir -p $LOCAL_WORKDIR/models/spm
!gsutil -m cp gs://$BUCKET_NAME/$DATA_OBJECT $LOCAL_WORKDIR/
!gsutil -m cp gs://$BUCKET_NAME/$SP_MODEL_OBJECT $LOCAL_WORKDIR/models/spm/spiece.model


In [None]:
import json
from dataclasses import dataclass
from typing import Dict, List

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import train_test_split
import sentencepiece as spm

PLUTCHIK_LABELS = ['joy','trust','fear','surprise','sadness','disgust','anger','anticipation']
COMPASS_INTENTS = ['inform','negotiate','question','clarify','social','express','command','request']
INTENT_MAPPING = {
    'share_news': 'inform','ask_help':'request','clarify':'clarify','complain':'express','thank':'social','propose':'negotiate'
}
PUNCTUATION_TOKENS = {'.','!','? ',',',';',';',':','...','!!','??'}

@dataclass
class TrainConfig:
    data_path: str
    sp_model_path: str
    sbert_model_name: str = 'roberta-base'
    max_length: int = 128
    batch_size: int = 16
    epochs: int = 5
    learning_rate: float = 3e-5
    weight_decay: float = 0.02
    label_smoothing: float = 0.05
    diversity_coef: float = 0.05
    grad_accum_steps: int = 2
    num_workers: int = 2
    checkpoint_dir: str = '/content/aura_tpu/checkpoints'

class ProductionSentencePieceLoader:
    def __init__(self, model_path: str):
        self.model_path = model_path
        self.processor = spm.SentencePieceProcessor()
        self.processor.load(model_path)
        self.vocab_size = self.processor.get_piece_size()
        print(f"Loaded SentencePiece model: {model_path} (vocab={self.vocab_size})")

    def encode(self, text: str, max_len: int = 128):
        ids = self.processor.encode(text, out_type=int)
        pieces = self.processor.encode(text, out_type=str)
        ids = ids[:max_len]
        pieces = pieces[:max_len]
        if len(ids) < max_len:
            pad = max_len - len(ids)
            ids += [self.processor.pad_id()] * pad
            pieces += ['<pad>'] * pad
        return ids, pieces

class ProductionSentencePieceProsodyExtractor:
    def __init__(self, punct_tokens: set = None):
        self.punct_tokens = punct_tokens or PUNCTUATION_TOKENS

    def prosody_features(self, pieces: List[str]):
        max_len = len(pieces)
        word_boundary = np.zeros(max_len, dtype=np.float32)
        punctuation = np.zeros(max_len, dtype=np.float32)
        subword_len = np.zeros(max_len, dtype=np.float32)
        for i, token in enumerate(pieces):
            normalized = token.replace('▁',' ')
            word_boundary[i] = 1.0 if token.startswith('▁') else 0.0
            punctuation[i] = 1.0 if normalized.strip() in self.punct_tokens else 0.0
            subword_len[i] = float(len(normalized.strip()))
        if subword_len.max() > 0:
            subword_len = subword_len / subword_len.max()
        return word_boundary, punctuation, subword_len

class EmotionIntentDataset(Dataset):
    def __init__(self, records: List[Dict], tokenizer: AutoTokenizer, sp_loader: ProductionSentencePieceLoader,
                 prosody: ProductionSentencePieceProsodyExtractor, max_length: int = 128):
        self.records = records
        self.tokenizer = tokenizer
        self.sp_loader = sp_loader
        self.prosody = prosody
        self.max_length = max_length

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx: int):
        rec = self.records[idx]
        text = (rec.get('text') or '').replace('\u00A0', ' ')
        tok = self.tokenizer(text, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt')
        input_ids = tok['input_ids'].squeeze(0)
        attention_mask = tok['attention_mask'].squeeze(0)

        sp_ids, sp_pieces = self.sp_loader.encode(text, max_len=128)
        wb, pn, sl = self.prosody.prosody_features(sp_pieces)

        plutchik = np.zeros(len(PLUTCHIK_LABELS), dtype=np.float32)
        prim = rec.get('plutchik',{}).get('primary','joy')
        intensity = float(rec.get('plutchik',{}).get('intensity',0.5))
        if prim in PLUTCHIK_LABELS:
            plutchik[PLUTCHIK_LABELS.index(prim)] = intensity
        secondary = rec.get('plutchik',{}).get('secondary')
        secondary_map = {
            'optimism':'anticipation','admiration':'trust','anxiety':'fear','hope':'anticipation',
            'excitement':'joy','contentment':'joy','grief':'sadness','despair':'sadness',
            'contempt':'disgust','outrage':'anger','fury':'anger','resentment':'anger'
        }
        if secondary in secondary_map:
            plutchik[PLUTCHIK_LABELS.index(secondary_map[secondary])] += 0.25
        plutchik = plutchik / (plutchik.sum() + 1e-6)

        mapped_intent = INTENT_MAPPING.get(rec.get('intent','inform'), rec.get('intent','inform'))
        if mapped_intent not in COMPASS_INTENTS:
            mapped_intent = 'inform'
        intent = np.zeros(len(COMPASS_INTENTS), dtype=np.float32)
        intent[COMPASS_INTENTS.index(mapped_intent)] = 1.0

        style = rec.get('style', {})
        beta = float(style.get('beta', 0.5))
        phi = float(style.get('phi', 0.5))
        urgency = float(intensity if intensity > 0.6 else intensity * 0.7)
        certainty = float(phi if phi > 0 else 0.5)

        sample = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'sp_token_ids': torch.tensor(sp_ids, dtype=torch.long),
            'sp_wb': torch.tensor(wb, dtype=torch.float32),
            'sp_punct': torch.tensor(pn, dtype=torch.float32),
            'sp_sublen': torch.tensor(sl, dtype=torch.float32),
            'plutchik': torch.tensor(plutchik, dtype=torch.float32),
            'intent': torch.tensor(intent, dtype=torch.float32),
            'urgency': torch.tensor([urgency], dtype=torch.float32),
            'certainty': torch.tensor([certainty], dtype=torch.float32),
            'formality': torch.tensor([beta], dtype=torch.float32),
            'politeness': torch.tensor([phi], dtype=torch.float32)
        }
        return sample

def load_emotion_dataset(jsonl_path: str) -> List[Dict]:
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f if line.strip()]



In [None]:
class TorchSNN(nn.Module):
    def __init__(self, sp_vocab: int, sbert_dim: int, num_experts: int = 4):
        super().__init__()
        self.sp_embed = nn.Embedding(sp_vocab, 128)
        self.pause_dense = nn.Sequential(nn.Linear(128 + 7, 64), nn.GELU(), nn.Linear(64, 1))
        self.stress_dense = nn.Sequential(nn.Linear(128 + 7, 64), nn.GELU(), nn.Linear(64, 1))
        self.pitch_mlp = nn.Sequential(nn.Linear(3, 64), nn.GELU())
        self.energy_mlp = nn.Sequential(nn.Linear(3, 64), nn.GELU())
        self.emotion_head = nn.Sequential(nn.Linear(sbert_dim + 64 + 64, 128), nn.ReLU(), nn.Linear(128, 8))
        self.intent_hidden = nn.Sequential(nn.Linear(sbert_dim + 128 + 64, 128), nn.ReLU())
        self.intent_head = nn.Linear(128, 8)
        self.modifier_head = nn.Linear(128, 4)
        self.gate_head = nn.Linear(sbert_dim + 128 + 128 + 64 + 64, num_experts)
        self.output_head = nn.Linear(sbert_dim + 128 + 128 + 64 + 64, 256)

    def forward(self, sbert_emb: torch.Tensor, batch: Dict[str, torch.Tensor]):
        sp_ids = batch['sp_token_ids']
        sp_embed = self.sp_embed(sp_ids)  # [B,128,128]
        wb = batch['sp_wb']
        pn = batch['sp_punct']
        sl = batch['sp_sublen']
        sl_norm = sl / sl.amax(dim=1, keepdim=True).clamp(min=1.0)

        # Compose per-token linguistic features
        ling = torch.stack([
            wb,
            torch.roll(wb, shifts=1, dims=1),
            torch.roll(wb, shifts=-1, dims=1),
            pn,
            torch.roll(pn, shifts=1, dims=1),
            torch.roll(pn, shifts=-1, dims=1),
            sl_norm
        ], dim=-1)  # [B,128,7]
        pause_in = torch.cat([sp_embed, ling], dim=-1)
        pause_probs = torch.sigmoid(self.pause_dense(pause_in)).squeeze(-1)
        stress_in = torch.cat([sp_embed, ling], dim=-1)
        stress_probs = torch.sigmoid(self.stress_dense(stress_in)).squeeze(-1)

        # Aggregate sentence-level prosody features
        pitch = self.pitch_mlp(torch.stack([
            stress_probs.std(dim=1),
            stress_probs.mean(dim=1),
            wb.mean(dim=1)
        ], dim=-1))
        energy = self.energy_mlp(torch.stack([
            stress_probs.sum(dim=1),
            pause_probs.sum(dim=1),
            pn.sum(dim=1)
        ], dim=-1))

        emotion_h = torch.cat([sbert_emb, pitch, energy], dim=-1)
        emo_logits = self.emotion_head(emotion_h)

        intent_h_input = torch.cat([sbert_emb, F.relu(self.emotion_head[0](emotion_h)), pitch], dim=-1)
        intent_hidden = self.intent_hidden(intent_h_input)
        intent_logits = self.intent_head(intent_hidden)
        modifiers = torch.sigmoid(self.modifier_head(intent_hidden))

        composite = torch.cat([sbert_emb, intent_hidden, F.relu(self.emotion_head[0](emotion_h)), pitch, energy], dim=-1)
        gate_weights = torch.softmax(self.gate_head(composite), dim=-1)
        routed = self.output_head(composite)
        return {
            'emo_logits': emo_logits,
            'intent_logits': intent_logits,
            'modifiers': modifiers,
            'gate': gate_weights,
            'prosody': {
                'pause': pause_probs,
                'stress': stress_probs,
                'pitch': pitch,
                'energy': energy
            },
            'output': routed
        }


In [None]:
def smooth_labels(labels: torch.Tensor, smoothing: float) -> torch.Tensor:
    if smoothing <= 0:
        return labels
    num_classes = labels.size(-1)
    return (1.0 - smoothing) * labels + smoothing / num_classes


def cross_entropy_with_probs(logits: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
    log_probs = F.log_softmax(logits, dim=-1)
    return F.kl_div(log_probs, probs, reduction='batchmean')


def collate_batch(batch_list: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    out = {}
    for key in batch_list[0].keys():
        out[key] = torch.stack([sample[key] for sample in batch_list], dim=0)
    return out


def prepare_dataloaders(cfg: TrainConfig, tokenizer: AutoTokenizer, sp_loader: ProductionSentencePieceLoader):
    records = load_emotion_dataset(cfg.data_path)
    train_records, temp_records = train_test_split(records, test_size=0.2, random_state=42)
    val_records, test_records = train_test_split(temp_records, test_size=0.5, random_state=42)

    prosody = ProductionSentencePieceProsodyExtractor()
    train_dataset = EmotionIntentDataset(train_records, tokenizer, sp_loader, prosody, max_length=cfg.max_length)
    val_dataset = EmotionIntentDataset(val_records, tokenizer, sp_loader, prosody, max_length=cfg.max_length)
    test_dataset = EmotionIntentDataset(test_records, tokenizer, sp_loader, prosody, max_length=cfg.max_length)

    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
                            num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate_batch)
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
                             num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate_batch)
    return train_loader, val_loader, test_loader


def train_epoch(model: TorchSNN, sbert: AutoModel, loader: DataLoader, optim: torch.optim.Optimizer,
                cfg: TrainConfig, device: torch.device) -> float:
    model.train(); sbert.train()
    total_loss = 0.0
    step = 0
    optim.zero_grad(set_to_none=True)
    accum_steps = max(1, cfg.grad_accum_steps)
    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=device.type == 'cuda'):
            sbert_outputs = sbert(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
            hidden = sbert_outputs.last_hidden_state
            mask = batch['attention_mask'].float()
            denom = mask.sum(dim=1, keepdim=True).clamp(min=1.0)
            pooled = (hidden * mask.unsqueeze(-1)).sum(dim=1) / denom
            outputs = model(pooled, batch)
            emo_target = smooth_labels(batch['plutchik'], cfg.label_smoothing)
            intent_target = smooth_labels(batch['intent'], cfg.label_smoothing)
            emo_loss = cross_entropy_with_probs(outputs['emo_logits'], emo_target)
            intent_loss = cross_entropy_with_probs(outputs['intent_logits'], intent_target)
            mods = outputs['modifiers']
            mod_loss = F.mse_loss(mods[:, 0:1], batch['urgency']) + \
                       F.mse_loss(mods[:, 1:2], batch['certainty']) + \
                       F.mse_loss(mods[:, 2:3], batch['formality']) + \
                       F.mse_loss(mods[:, 3:4], batch['politeness'])
            gate_div = -(outputs['gate'] * outputs['gate'].clamp(min=1e-8).log()).sum(dim=-1).mean()
            loss = emo_loss + intent_loss + 0.5 * mod_loss + cfg.diversity_coef * gate_div
        loss.backward()
        step += 1
        if step % accum_steps == 0:
            optim.step()
            optim.zero_grad(set_to_none=True)
        total_loss += float(loss.detach().cpu())
    if step % accum_steps != 0:
        optim.step()
        optim.zero_grad(set_to_none=True)
    return total_loss / len(loader)


def evaluate(model: TorchSNN, sbert: AutoModel, loader: DataLoader, cfg: TrainConfig, device: torch.device) -> float:
    model.eval(); sbert.eval()
    losses = []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=device.type == 'cuda'):
                sbert_outputs = sbert(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
                hidden = sbert_outputs.last_hidden_state
                mask = batch['attention_mask'].float()
                denom = mask.sum(dim=1, keepdim=True).clamp(min=1.0)
                pooled = (hidden * mask.unsqueeze(-1)).sum(dim=1) / denom
                outputs = model(pooled, batch)
                emo_target = smooth_labels(batch['plutchik'], cfg.label_smoothing)
                intent_target = smooth_labels(batch['intent'], cfg.label_smoothing)
                emo_loss = cross_entropy_with_probs(outputs['emo_logits'], emo_target)
                intent_loss = cross_entropy_with_probs(outputs['intent_logits'], intent_target)
                mods = outputs['modifiers']
                mod_loss = F.mse_loss(mods[:, 0:1], batch['urgency']) + \
                           F.mse_loss(mods[:, 1:2], batch['certainty']) + \
                           F.mse_loss(mods[:, 2:3], batch['formality']) + \
                           F.mse_loss(mods[:, 3:4], batch['politeness'])
                gate_div = -(outputs['gate'] * outputs['gate'].clamp(min=1e-8).log()).sum(dim=-1).mean()
                loss = emo_loss + intent_loss + 0.5 * mod_loss + cfg.diversity_coef * gate_div
            losses.append(float(loss.detach().cpu()))
    return float(np.mean(losses)) if losses else 0.0


In [None]:
import pathlib
from pathlib import Path

def run_training():
    cfg = TrainConfig(
        data_path=str(LOCAL_WORKDIR / pathlib.Path(DATA_OBJECT).name),
        sp_model_path=str(LOCAL_WORKDIR / "models/spm/spiece.model"),
        sbert_model_name='roberta-large',
        max_length=128,
        batch_size=12,
        epochs=8,
        learning_rate=2e-5,
        weight_decay=0.02,
        label_smoothing=0.05,
        diversity_coef=0.05,
        grad_accum_steps=4,
        num_workers=2,
        checkpoint_dir=str(LOCAL_WORKDIR / "checkpoints")
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(cfg.sbert_model_name, use_fast=True)
    sbert = AutoModel.from_pretrained(cfg.sbert_model_name)
    if hasattr(sbert, 'gradient_checkpointing_enable'):
        sbert.gradient_checkpointing_enable()
    sbert.to(device)

    sp_loader = ProductionSentencePieceLoader(cfg.sp_model_path)
    train_loader, val_loader, test_loader = prepare_dataloaders(cfg, tokenizer, sp_loader)
    model = TorchSNN(sp_vocab=sp_loader.vocab_size, sbert_dim=sbert.config.hidden_size, num_experts=4).to(device)

    optimizer = torch.optim.AdamW(list(model.parameters()) + list(sbert.parameters()),
                                  lr=cfg.learning_rate, weight_decay=cfg.weight_decay)

    Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
    best_val = float('inf')
    best_ckpt = None

    for epoch in range(cfg.epochs):
        print(f"Epoch {epoch+1}/{cfg.epochs}")
        train_loss = train_epoch(model, sbert, train_loader, optimizer, cfg, device)
        val_loss = evaluate(model, sbert, val_loader, cfg, device)
        print(f"  train_loss={train_loss:.4f} val_loss={val_loss:.4f}")
        if val_loss < best_val:
            best_val = val_loss
            best_ckpt = Path(cfg.checkpoint_dir) / f"ckpt_epoch_{epoch+1:04d}.pt"
            torch.save({
                'model_state': model.state_dict(),
                'sbert_state': sbert.state_dict(),
                'config': cfg.__dict__,
                'val_loss': val_loss
            }, best_ckpt)
            print(f"  Saved checkpoint: {best_ckpt}")

    test_loss = evaluate(model, sbert, test_loader, cfg, device)
    print(f"Test loss: {test_loss:.4f}")
    return cfg, best_ckpt

cfg, best_checkpoint = run_training()


In [None]:
# Upload best checkpoint back to GCS (optional)
import subprocess

if best_checkpoint is not None:
    target_object = f"checkpoints/{Path(best_checkpoint).name}"
    destination = f"gs://{BUCKET_NAME}/{target_object}"
    cmd = ["gsutil", "cp", str(best_checkpoint), destination]
    print(" ".join(cmd))
    subprocess.run(cmd, check=True)


### Next Steps
- Adjust `TrainConfig` hyperparameters (epochs, accumulation, batch) to match A100 budget.
- Swap `sbert_model_name` to another HuggingFace checkpoint if needed.
- Monitor GPU memory with `torch.cuda.max_memory_allocated()` if you push to larger batch sizes.
- Extend the notebook with evaluation or export logic specific to your consciousness-aware routing pipeline.
