# Fluid Weights: Perpetual Plasticity for Transformers

**A Novel Learning System for Continuous Weight Adaptation**

This notebook demonstrates a system where transformer weights adapt continuously during inference, without explicit loss functions or training phases.

## Key Innovations

1. **Attention-Guided Plasticity (AGP)**: Uses attention patterns to guide weight updates
2. **Temporal Surprise Minimization (TSM)**: Reduces prediction "surprise" without explicit targets
3. **Contextual Homeostasis (CH)**: Maintains stable activation statistics
4. **Hybrid Update Rules**: Combines Oja, BCM, and energy-based learning

---

## 1. Setup

In [None]:
# Install dependencies
!pip install -q torch transformers accelerate bitsandbytes sentencepiece
!pip install -q matplotlib pandas seaborn tqdm

In [None]:
# Clone the fluid_weights package (or upload it)
# For Colab, we'll define it inline

import os
import sys

# Create package directory
os.makedirs('fluid_weights', exist_ok=True)

In [None]:
%%writefile fluid_weights/__init__.py
"""
Fluid Weights: Perpetual Plasticity for Transformer Models
"""

from .core import FluidLoRA, FluidTransformer, FluidConfig, PlasticityMode
from .update_rules import (
    UpdateRule, HebbianUpdate, OjaUpdate, BCMUpdate,
    PredictiveCodingUpdate, EnergyBasedUpdate, HybridUpdate
)
from .stability import (
    StabilityMechanism, ElasticWeightConsolidation,
    SpectralNormConstraint, GradientClipping, AdaptiveRateControl
)

__version__ = "1.0.0"

In [None]:
%%writefile fluid_weights/update_rules.py
"""
Update Rules for Fluid Weight Learning
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, Tuple
from dataclasses import dataclass
import math


@dataclass
class UpdateContext:
    """Context for update rules."""
    x: torch.Tensor
    h: torch.Tensor
    y: torch.Tensor
    A: torch.Tensor
    B: torch.Tensor
    attention_weights: Optional[torch.Tensor] = None
    layer_idx: int = 0
    step: int = 0


class UpdateRule(ABC):
    def __init__(self, learning_rate: float = 1e-5, **kwargs):
        self.learning_rate = learning_rate
        self.config = kwargs

    @abstractmethod
    def compute_update(self, ctx: UpdateContext) -> Tuple[torch.Tensor, torch.Tensor]:
        pass


class HebbianUpdate(UpdateRule):
    """Classical Hebbian: neurons that fire together, wire together."""
    
    def __init__(self, learning_rate: float = 1e-5, normalize: bool = True, **kwargs):
        super().__init__(learning_rate, **kwargs)
        self.normalize = normalize

    def compute_update(self, ctx: UpdateContext) -> Tuple[torch.Tensor, torch.Tensor]:
        x_flat = ctx.x.reshape(-1, ctx.x.shape[-1])
        h_flat = ctx.h.reshape(-1, ctx.h.shape[-1])
        y_flat = ctx.y.reshape(-1, ctx.y.shape[-1])

        delta_A = x_flat.T @ h_flat
        delta_B = h_flat.T @ y_flat

        if self.normalize:
            x_norm = torch.norm(x_flat) + 1e-8
            h_norm = torch.norm(h_flat) + 1e-8
            y_norm = torch.norm(y_flat) + 1e-8
            delta_A = delta_A / (x_norm * h_norm)
            delta_B = delta_B / (h_norm * y_norm)

        return self.learning_rate * delta_A, self.learning_rate * delta_B


class OjaUpdate(UpdateRule):
    """Oja's rule: self-normalizing Hebbian learning."""
    
    def __init__(self, learning_rate: float = 1e-5, stabilization_strength: float = 1.0, **kwargs):
        super().__init__(learning_rate, **kwargs)
        self.stabilization_strength = stabilization_strength

    def compute_update(self, ctx: UpdateContext) -> Tuple[torch.Tensor, torch.Tensor]:
        x_flat = ctx.x.reshape(-1, ctx.x.shape[-1])
        h_flat = ctx.h.reshape(-1, ctx.h.shape[-1])
        y_flat = ctx.y.reshape(-1, ctx.y.shape[-1])

        hebbian_A = x_flat.T @ h_flat
        hebbian_B = h_flat.T @ y_flat

        h_norm_sq = torch.mean(torch.sum(h_flat ** 2, dim=-1))
        y_norm_sq = torch.mean(torch.sum(y_flat ** 2, dim=-1))

        stabilize_A = h_norm_sq * ctx.A
        stabilize_B = y_norm_sq * ctx.B

        delta_A = hebbian_A / (x_flat.shape[0] + 1e-8) - self.stabilization_strength * stabilize_A
        delta_B = hebbian_B / (h_flat.shape[0] + 1e-8) - self.stabilization_strength * stabilize_B

        return self.learning_rate * delta_A, self.learning_rate * delta_B


class BCMUpdate(UpdateRule):
    """BCM rule: sliding threshold plasticity."""
    
    def __init__(self, learning_rate: float = 1e-5, threshold_decay: float = 0.99, **kwargs):
        super().__init__(learning_rate, **kwargs)
        self.threshold_decay = threshold_decay
        self.theta_h = None
        self.theta_y = None

    def compute_update(self, ctx: UpdateContext) -> Tuple[torch.Tensor, torch.Tensor]:
        x_flat = ctx.x.reshape(-1, ctx.x.shape[-1])
        h_flat = ctx.h.reshape(-1, ctx.h.shape[-1])
        y_flat = ctx.y.reshape(-1, ctx.y.shape[-1])

        h_sq = torch.mean(h_flat ** 2, dim=0)
        y_sq = torch.mean(y_flat ** 2, dim=0)

        if self.theta_h is None or self.theta_h.shape != h_sq.shape:
            self.theta_h = h_sq.detach().clone()
        else:
            self.theta_h = self.threshold_decay * self.theta_h + (1 - self.threshold_decay) * h_sq.detach()

        if self.theta_y is None or self.theta_y.shape != y_sq.shape:
            self.theta_y = y_sq.detach().clone()
        else:
            self.theta_y = self.threshold_decay * self.theta_y + (1 - self.threshold_decay) * y_sq.detach()

        h_modulated = h_flat * (h_flat - self.theta_h.unsqueeze(0))
        y_modulated = y_flat * (y_flat - self.theta_y.unsqueeze(0))

        delta_A = x_flat.T @ h_modulated / (x_flat.shape[0] + 1e-8)
        delta_B = h_flat.T @ y_modulated / (h_flat.shape[0] + 1e-8)

        return self.learning_rate * delta_A, self.learning_rate * delta_B


class PredictiveCodingUpdate(UpdateRule):
    """Predictive coding: learn by minimizing prediction errors."""
    
    def __init__(self, learning_rate: float = 1e-5, **kwargs):
        super().__init__(learning_rate, **kwargs)

    def compute_update(self, ctx: UpdateContext) -> Tuple[torch.Tensor, torch.Tensor]:
        x_flat = ctx.x.reshape(-1, ctx.x.shape[-1])
        h_flat = ctx.h.reshape(-1, ctx.h.shape[-1])
        y_flat = ctx.y.reshape(-1, ctx.y.shape[-1])

        h_pred = y_flat @ ctx.B.T
        h_pred = h_pred / (torch.norm(ctx.B, dim=0, keepdim=True).T + 1e-8)
        eps_h = h_flat - h_pred

        x_pred = h_flat @ ctx.A.T
        x_pred = x_pred / (torch.norm(ctx.A, dim=0, keepdim=True).T + 1e-8)
        eps_x = x_flat - x_pred

        delta_A = eps_x.T @ h_flat / (x_flat.shape[0] + 1e-8)
        delta_B = (eps_h.T @ y_flat / (h_flat.shape[0] + 1e-8)).T

        return self.learning_rate * delta_A, self.learning_rate * delta_B


class EnergyBasedUpdate(UpdateRule):
    """Energy-based: minimize an energy function."""
    
    def __init__(self, learning_rate: float = 1e-5, running_mean_decay: float = 0.99, **kwargs):
        super().__init__(learning_rate, **kwargs)
        self.running_mean_decay = running_mean_decay
        self.h_running_mean = None
        self.y_running_mean = None

    def compute_update(self, ctx: UpdateContext) -> Tuple[torch.Tensor, torch.Tensor]:
        x_flat = ctx.x.reshape(-1, ctx.x.shape[-1])
        h_flat = ctx.h.reshape(-1, ctx.h.shape[-1])
        y_flat = ctx.y.reshape(-1, ctx.y.shape[-1])

        h_mean = torch.mean(h_flat, dim=0)
        y_mean = torch.mean(y_flat, dim=0)

        if self.h_running_mean is None or self.h_running_mean.shape != h_mean.shape:
            self.h_running_mean = h_mean.detach().clone()
            self.y_running_mean = y_mean.detach().clone()
        else:
            self.h_running_mean = self.running_mean_decay * self.h_running_mean + (1 - self.running_mean_decay) * h_mean.detach()
            self.y_running_mean = self.running_mean_decay * self.y_running_mean + (1 - self.running_mean_decay) * y_mean.detach()

        h_error = h_flat - self.h_running_mean.unsqueeze(0)
        y_error = y_flat - self.y_running_mean.unsqueeze(0)

        delta_A = -x_flat.T @ h_error / (x_flat.shape[0] + 1e-8)
        delta_B = -h_flat.T @ y_error / (h_flat.shape[0] + 1e-8)

        return self.learning_rate * delta_A, self.learning_rate * delta_B


class HybridUpdate(UpdateRule):
    """Combine multiple update rules."""
    
    def __init__(self, learning_rate: float = 1e-5, rules: Dict = None, **kwargs):
        super().__init__(learning_rate, **kwargs)
        if rules is None:
            self.rules = {
                'oja': (OjaUpdate(learning_rate=1.0), 0.5),
                'bcm': (BCMUpdate(learning_rate=1.0), 0.3),
                'energy': (EnergyBasedUpdate(learning_rate=1.0), 0.2),
            }
        else:
            self.rules = rules

    def compute_update(self, ctx: UpdateContext) -> Tuple[torch.Tensor, torch.Tensor]:
        total_delta_A = torch.zeros_like(ctx.A)
        total_delta_B = torch.zeros_like(ctx.B)

        for name, (rule, weight) in self.rules.items():
            delta_A, delta_B = rule.compute_update(ctx)
            total_delta_A += weight * delta_A
            total_delta_B += weight * delta_B

        return self.learning_rate * total_delta_A, self.learning_rate * total_delta_B

In [None]:
%%writefile fluid_weights/stability.py
"""
Stability Mechanisms for Fluid Weight Learning
"""

import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Optional, Tuple, List
from dataclasses import dataclass


@dataclass
class StabilityMetrics:
    weight_norm: float = 0.0
    spectral_norm: float = 0.0
    drift_from_origin: float = 0.0
    oscillation_score: float = 0.0


class StabilityMechanism(ABC):
    @abstractmethod
    def constrain_update(self, delta_A, delta_B, A, B) -> Tuple[torch.Tensor, torch.Tensor]:
        pass

    def check_stability(self, A, B) -> StabilityMetrics:
        return StabilityMetrics(weight_norm=torch.norm(A).item() + torch.norm(B).item())


class ElasticWeightConsolidation(StabilityMechanism):
    """Protect important weights from changing too much."""
    
    def __init__(self, consolidation_strength: float = 0.1, importance_decay: float = 0.999):
        self.consolidation_strength = consolidation_strength
        self.importance_decay = importance_decay
        self.A_original = None
        self.B_original = None
        self.F_A = None
        self.F_B = None

    def initialize(self, A, B):
        self.A_original = A.detach().clone()
        self.B_original = B.detach().clone()
        self.F_A = torch.zeros_like(A)
        self.F_B = torch.zeros_like(B)

    def update_fisher(self, A, B, h):
        if self.F_A is None:
            self.initialize(A, B)
        h_importance = torch.mean(h.reshape(-1, h.shape[-1]) ** 2, dim=0)
        new_F_A = torch.outer(torch.ones(A.shape[0], device=A.device), h_importance)
        new_F_B = torch.outer(h_importance, torch.ones(B.shape[1], device=B.device))
        self.F_A = self.importance_decay * self.F_A + (1 - self.importance_decay) * new_F_A
        self.F_B = self.importance_decay * self.F_B + (1 - self.importance_decay) * new_F_B

    def constrain_update(self, delta_A, delta_B, A, B):
        if self.A_original is None:
            self.initialize(A, B)
        ewc_A = self.consolidation_strength * self.F_A * (A - self.A_original)
        ewc_B = self.consolidation_strength * self.F_B * (B - self.B_original)
        return delta_A - ewc_A, delta_B - ewc_B


class SpectralNormConstraint(StabilityMechanism):
    """Keep spectral norm bounded."""
    
    def __init__(self, max_spectral_norm: float = 2.0):
        self.max_spectral_norm = max_spectral_norm

    def _spectral_norm(self, W):
        u = torch.randn(W.shape[1], device=W.device)
        u = u / torch.norm(u)
        for _ in range(3):
            v = W @ u
            v = v / (torch.norm(v) + 1e-8)
            u = W.T @ v
            u = u / (torch.norm(u) + 1e-8)
        return torch.norm(W @ u).item()

    def constrain_update(self, delta_A, delta_B, A, B):
        A_new = A + delta_A
        B_new = B + delta_B
        sigma_A = self._spectral_norm(A_new)
        sigma_B = self._spectral_norm(B_new)
        if sigma_A > self.max_spectral_norm:
            delta_A = delta_A * self.max_spectral_norm / sigma_A
        if sigma_B > self.max_spectral_norm:
            delta_B = delta_B * self.max_spectral_norm / sigma_B
        return delta_A, delta_B


class GradientClipping(StabilityMechanism):
    """Limit update magnitudes."""
    
    def __init__(self, max_norm: float = 0.1):
        self.max_norm = max_norm

    def constrain_update(self, delta_A, delta_B, A, B):
        total_norm = (torch.norm(delta_A) ** 2 + torch.norm(delta_B) ** 2) ** 0.5
        if total_norm > self.max_norm:
            scale = self.max_norm / (total_norm + 1e-8)
            delta_A = delta_A * scale
            delta_B = delta_B * scale
        return delta_A, delta_B


class AdaptiveRateControl(StabilityMechanism):
    """Dynamically adjust learning rate based on stability."""
    
    def __init__(self, target_stability: float = 0.1, decay_factor: float = 0.9):
        self.target_stability = target_stability
        self.decay_factor = decay_factor
        self.effective_rate = 1.0
        self.change_history = []

    def constrain_update(self, delta_A, delta_B, A, B):
        change_mag = torch.norm(delta_A).item() + torch.norm(delta_B).item()
        self.change_history.append(change_mag)
        if len(self.change_history) > 100:
            self.change_history.pop(0)
        stability_score = sum(self.change_history) / len(self.change_history)
        if stability_score > self.target_stability:
            self.effective_rate *= self.decay_factor
        else:
            self.effective_rate = min(1.0, self.effective_rate * 1.01)
        return delta_A * self.effective_rate, delta_B * self.effective_rate


class CompositeStability(StabilityMechanism):
    """Combine multiple stability mechanisms."""
    
    def __init__(self, mechanisms: List[StabilityMechanism]):
        self.mechanisms = mechanisms

    def constrain_update(self, delta_A, delta_B, A, B):
        for mech in self.mechanisms:
            delta_A, delta_B = mech.constrain_update(delta_A, delta_B, A, B)
        return delta_A, delta_B

    def check_stability(self, A, B):
        metrics = StabilityMetrics()
        for mech in self.mechanisms:
            m = mech.check_stability(A, B)
            metrics.weight_norm = max(metrics.weight_norm, m.weight_norm)
        return metrics

In [None]:
%%writefile fluid_weights/core.py
"""
Core Fluid Weight Learning System

NOVEL CONTRIBUTIONS:
1. Attention-Guided Plasticity (AGP)
2. Temporal Surprise Minimization (TSM)
3. Contextual Homeostasis (CH)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, Optional, Tuple, List
from dataclasses import dataclass, field
from enum import Enum
import math
import re

from .update_rules import UpdateRule, UpdateContext, HybridUpdate, OjaUpdate, BCMUpdate, EnergyBasedUpdate, PredictiveCodingUpdate
from .stability import StabilityMechanism, CompositeStability, ElasticWeightConsolidation, SpectralNormConstraint, GradientClipping, AdaptiveRateControl


class PlasticityMode(Enum):
    FROZEN = "frozen"
    FLUID = "fluid"


@dataclass
class FluidConfig:
    rank: int = 16
    alpha: float = 32.0
    dropout: float = 0.0
    mode: PlasticityMode = PlasticityMode.FLUID
    learning_rate: float = 1e-5
    update_every_n_tokens: int = 1
    use_ewc: bool = True
    ewc_strength: float = 0.05
    use_spectral_norm: bool = True
    max_spectral_norm: float = 2.0
    use_gradient_clipping: bool = True
    max_gradient_norm: float = 0.1
    use_adaptive_rate: bool = True
    weight_decay_to_origin: float = 0.001
    use_attention_guided_plasticity: bool = True
    attention_plasticity_strength: float = 0.3
    use_temporal_surprise: bool = True
    surprise_window: int = 32
    use_contextual_homeostasis: bool = True
    homeostasis_strength: float = 0.1
    homeostasis_burnin: int = 100
    track_metrics: bool = True
    log_every_n_steps: int = 100


@dataclass
class FluidState:
    step: int = 0
    total_tokens: int = 0
    is_burned_in: bool = False
    activation_mean: Optional[torch.Tensor] = None
    activation_var: Optional[torch.Tensor] = None
    target_mean: Optional[torch.Tensor] = None
    target_var: Optional[torch.Tensor] = None
    temporal_buffer: Optional[torch.Tensor] = None
    metrics_history: List[Dict] = field(default_factory=list)


class FluidLoRA(nn.Module):
    """Fluid LoRA with perpetual plasticity."""

    def __init__(self, in_features: int, out_features: int, config: FluidConfig, layer_type: str = "linear", layer_idx: int = 0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.config = config
        self.layer_type = layer_type
        self.layer_idx = layer_idx

        self.A = nn.Parameter(torch.zeros(in_features, config.rank))
        self.B = nn.Parameter(torch.zeros(config.rank, out_features))
        self.scaling = config.alpha / config.rank

        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.zeros_(self.B)

        self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity()
        self._init_update_rules()
        self._init_stability()
        self.state = FluidState()

    def _init_update_rules(self):
        rules = {
            'oja': (OjaUpdate(learning_rate=1.0), 0.4),
            'bcm': (BCMUpdate(learning_rate=1.0), 0.2),
            'energy': (EnergyBasedUpdate(learning_rate=1.0), 0.2),
            'predictive': (PredictiveCodingUpdate(learning_rate=1.0), 0.2),
        }
        self.base_update = HybridUpdate(learning_rate=self.config.learning_rate, rules=rules)

    def _init_stability(self):
        mechanisms = []
        if self.config.use_ewc:
            mechanisms.append(ElasticWeightConsolidation(consolidation_strength=self.config.ewc_strength))
        if self.config.use_spectral_norm:
            mechanisms.append(SpectralNormConstraint(max_spectral_norm=self.config.max_spectral_norm))
        if self.config.use_gradient_clipping:
            mechanisms.append(GradientClipping(max_norm=self.config.max_gradient_norm))
        if self.config.use_adaptive_rate:
            mechanisms.append(AdaptiveRateControl())
        self.stability = CompositeStability(mechanisms)

    def forward(self, x: torch.Tensor, attention_weights: Optional[torch.Tensor] = None) -> torch.Tensor:
        h = x @ self.A
        h = self.dropout(h)
        y = h @ self.B
        y = y * self.scaling

        if self.config.mode == PlasticityMode.FLUID:
            self._apply_fluid_update(x.detach(), h.detach(), y.detach(), attention_weights)

        self.state.step += 1
        self.state.total_tokens += x.shape[0] * x.shape[1]
        return y

    def _apply_fluid_update(self, x, h, y, attention_weights=None):
        if self.state.step % self.config.update_every_n_tokens != 0:
            return

        ctx = UpdateContext(x=x, h=h, y=y, A=self.A.data, B=self.B.data, attention_weights=attention_weights, layer_idx=self.layer_idx, step=self.state.step)
        delta_A, delta_B = self.base_update.compute_update(ctx)

        if self.config.use_attention_guided_plasticity and attention_weights is not None:
            agp_A, agp_B = self._attention_guided_update(x, h, y, attention_weights)
            delta_A = delta_A + self.config.attention_plasticity_strength * agp_A
            delta_B = delta_B + self.config.attention_plasticity_strength * agp_B

        if self.config.use_temporal_surprise:
            ts_A, ts_B = self._temporal_surprise_update(x, h, y)
            delta_A = delta_A + ts_A
            delta_B = delta_B + ts_B

        if self.config.use_contextual_homeostasis:
            ch_A, ch_B = self._homeostasis_update(x, h, y)
            delta_A = delta_A + self.config.homeostasis_strength * ch_A
            delta_B = delta_B + self.config.homeostasis_strength * ch_B

        delta_A, delta_B = self.stability.constrain_update(delta_A, delta_B, self.A.data, self.B.data)

        with torch.no_grad():
            self.A.data.add_(delta_A)
            self.B.data.add_(delta_B)

        if self.config.use_ewc:
            for mech in self.stability.mechanisms:
                if isinstance(mech, ElasticWeightConsolidation):
                    mech.update_fisher(self.A.data, self.B.data, h)

        if self.config.track_metrics and self.state.step % self.config.log_every_n_steps == 0:
            self.state.metrics_history.append({
                'step': self.state.step,
                'A_norm': torch.norm(self.A.data).item(),
                'B_norm': torch.norm(self.B.data).item(),
                'delta_norm': torch.norm(delta_A).item() + torch.norm(delta_B).item(),
            })

    def _attention_guided_update(self, x, h, y, attn):
        attn_avg = attn.mean(dim=(0, 1))
        x_flat = x.reshape(-1, x.shape[-1])
        h_flat = h.reshape(-1, h.shape[-1])
        y_flat = y.reshape(-1, y.shape[-1])
        batch_size, seq_len = x.shape[:2]

        if self.layer_type in ['query', 'key']:
            h_seq = h.reshape(batch_size, seq_len, -1)
            h_attended = torch.einsum('ij,bjr->bir', attn_avg, h_seq).reshape(-1, h.shape[-1])
            delta_A = x_flat.T @ (h_attended - h_flat) / (x_flat.shape[0] + 1e-8) * self.config.learning_rate
            attn_entropy = -torch.sum(attn_avg * torch.log(attn_avg + 1e-8), dim=-1)
            confidence = 1.0 / (1.0 + attn_entropy.mean())
            delta_B = confidence * h_flat.T @ y_flat / (h_flat.shape[0] + 1e-8) * self.config.learning_rate
        elif self.layer_type == 'value':
            y_seq = y.reshape(batch_size, seq_len, -1)
            y_attended = torch.einsum('ij,bjd->bid', attn_avg, y_seq).reshape(-1, y.shape[-1])
            delta_B = h_flat.T @ y_attended / (h_flat.shape[0] + 1e-8) * self.config.learning_rate
            delta_A = x_flat.T @ h_flat / (x_flat.shape[0] + 1e-8) * self.config.learning_rate * 0.5
        else:
            delta_A = x_flat.T @ h_flat / (x_flat.shape[0] + 1e-8) * self.config.learning_rate
            delta_B = h_flat.T @ y_flat / (h_flat.shape[0] + 1e-8) * self.config.learning_rate

        return delta_A, delta_B

    def _temporal_surprise_update(self, x, h, y):
        h_flat = h.reshape(-1, h.shape[-1])
        x_flat = x.reshape(-1, x.shape[-1])
        y_flat = y.reshape(-1, y.shape[-1])

        if self.state.temporal_buffer is None:
            buffer_size = min(self.config.surprise_window, h_flat.shape[0])
            self.state.temporal_buffer = torch.zeros(buffer_size, h_flat.shape[-1], device=h.device)

        buffer_size = self.state.temporal_buffer.shape[0]
        n_samples = min(h_flat.shape[0], buffer_size)
        if n_samples < buffer_size:
            self.state.temporal_buffer = torch.roll(self.state.temporal_buffer, -n_samples, dims=0)
        self.state.temporal_buffer[-n_samples:] = h_flat[:n_samples].detach()

        h_expected = self.state.temporal_buffer.mean(dim=0)
        h_std = self.state.temporal_buffer.std(dim=0) + 1e-8
        h_current = h_flat.mean(dim=0)
        surprise = (h_current - h_expected) / h_std
        surprise_weight = torch.sigmoid(torch.norm(surprise) - 1.0)

        h_error = surprise.unsqueeze(0)
        delta_A = -x_flat.T @ h_error.expand(x_flat.shape[0], -1) / (x_flat.shape[0] + 1e-8) * self.config.learning_rate * surprise_weight

        y_expected = self.state.temporal_buffer @ self.B.data
        y_error = (y_flat.mean(dim=0) - y_expected.mean(dim=0)).unsqueeze(0)
        delta_B = -h_flat.T @ y_error.expand(h_flat.shape[0], -1) / (h_flat.shape[0] + 1e-8) * self.config.learning_rate * surprise_weight

        return delta_A, delta_B

    def _homeostasis_update(self, x, h, y):
        h_flat = h.reshape(-1, h.shape[-1])
        x_flat = x.reshape(-1, x.shape[-1])

        current_mean = h_flat.mean(dim=0)
        current_var = h_flat.var(dim=0)

        decay = 0.99
        if self.state.activation_mean is None:
            self.state.activation_mean = current_mean.detach()
            self.state.activation_var = current_var.detach()
        else:
            self.state.activation_mean = decay * self.state.activation_mean + (1 - decay) * current_mean.detach()
            self.state.activation_var = decay * self.state.activation_var + (1 - decay) * current_var.detach()

        if self.state.step < self.config.homeostasis_burnin:
            if self.state.step == self.config.homeostasis_burnin - 1:
                self.state.target_mean = self.state.activation_mean.clone()
                self.state.target_var = self.state.activation_var.clone()
                self.state.is_burned_in = True
            return torch.zeros_like(self.A.data), torch.zeros_like(self.B.data)

        mean_error = current_mean - self.state.target_mean
        var_ratio = torch.sqrt(self.state.target_var / (current_var + 1e-8))
        total_correction = -mean_error.unsqueeze(0) + 0.1 * (var_ratio - 1.0).unsqueeze(0) * h_flat.mean(dim=0, keepdim=True)

        delta_A = x_flat.T @ total_correction.expand(x_flat.shape[0], -1) / (x_flat.shape[0] + 1e-8) * self.config.learning_rate
        delta_B = (h_flat.T @ (-mean_error.unsqueeze(0)).expand(h_flat.shape[0], -1) @ self.B.data.T).T * self.config.learning_rate * 0.1

        return delta_A, delta_B

    def save_state(self):
        return {'A': self.A.data.clone(), 'B': self.B.data.clone(), 'step': self.state.step}

    def load_state(self, saved):
        self.A.data = saved['A']
        self.B.data = saved['B']
        self.state.step = saved['step']


class FluidTransformer(nn.Module):
    """Wrapper for HuggingFace transformers with fluid LoRA."""

    def __init__(self, model: nn.Module, config: FluidConfig = None, target_modules: List[str] = None):
        super().__init__()
        self.model = model
        self.config = config or FluidConfig()
        self.target_modules = target_modules or ['q_proj', 'k_proj', 'v_proj', 'o_proj']
        self.fluid_loras: Dict[str, FluidLoRA] = {}
        self._attention_weights: Dict[int, torch.Tensor] = {}
        self._patch_model()
        self._register_attention_hooks()

    def _patch_model(self):
        for name, module in self.model.named_modules():
            module_name = name.split('.')[-1]
            if module_name not in self.target_modules:
                continue
            if not isinstance(module, nn.Linear):
                continue

            layer_type = self._infer_layer_type(module_name)
            layer_idx = self._infer_layer_idx(name)

            fluid_lora = FluidLoRA(
                in_features=module.in_features,
                out_features=module.out_features,
                config=self.config,
                layer_type=layer_type,
                layer_idx=layer_idx,
            ).to(module.weight.device)

            lora_name = name.replace('.', '_')
            self.fluid_loras[lora_name] = fluid_lora
            self._patch_module_forward(name, module, fluid_lora, layer_idx)

        print(f"Patched {len(self.fluid_loras)} modules with FluidLoRA")

    def _infer_layer_type(self, module_name):
        if 'q_proj' in module_name or 'query' in module_name:
            return 'query'
        elif 'k_proj' in module_name or 'key' in module_name:
            return 'key'
        elif 'v_proj' in module_name or 'value' in module_name:
            return 'value'
        elif 'o_proj' in module_name or 'out' in module_name:
            return 'output'
        return 'linear'

    def _infer_layer_idx(self, name):
        match = re.search(r'\.(\d+)\.', name)
        return int(match.group(1)) if match else 0

    def _patch_module_forward(self, name, module, fluid_lora, layer_idx):
        original_forward = module.forward
        def new_forward(x):
            out = original_forward(x)
            attn = self._attention_weights.get(layer_idx)
            lora_out = fluid_lora(x, attention_weights=attn)
            return out + lora_out
        module.forward = new_forward

    def _register_attention_hooks(self):
        def make_hook(layer_idx):
            def hook(module, input, output):
                if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
                    self._attention_weights[layer_idx] = output[1].detach()
            return hook
        for name, module in self.model.named_modules():
            if 'attention' in name.lower() and hasattr(module, 'forward'):
                layer_idx = self._infer_layer_idx(name)
                module.register_forward_hook(make_hook(layer_idx))

    def forward(self, *args, **kwargs):
        self._attention_weights.clear()
        kwargs['output_attentions'] = True
        return self.model(*args, **kwargs)

    def generate(self, *args, **kwargs):
        return self.model.generate(*args, **kwargs)

    def set_plasticity_mode(self, mode: PlasticityMode):
        for lora in self.fluid_loras.values():
            lora.config.mode = mode

    def freeze(self):
        self.set_plasticity_mode(PlasticityMode.FROZEN)

    def unfreeze(self):
        self.set_plasticity_mode(PlasticityMode.FLUID)

    def save_fluid_state(self):
        return {name: lora.save_state() for name, lora in self.fluid_loras.items()}

    def load_fluid_state(self, saved):
        for name, state in saved.items():
            if name in self.fluid_loras:
                self.fluid_loras[name].load_state(state)

    def get_metrics(self):
        metrics = []
        for name, lora in self.fluid_loras.items():
            for m in lora.state.metrics_history:
                m['module'] = name
                metrics.append(m)
        return metrics

In [None]:
# Reload the module
import importlib
import fluid_weights
importlib.reload(fluid_weights)

from fluid_weights import FluidConfig, FluidTransformer, PlasticityMode

## 2. Load Model

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Choose model based on available GPU memory
# For A100 40GB, we can use Mistral-7B or Llama-2-7B
MODEL_NAME = "mistralai/Mistral-7B-v0.1"  # or "meta-llama/Llama-2-7b-hf"

print(f"Loading {MODEL_NAME}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

print(f"Model loaded on {next(model.parameters()).device}")

## 3. Configure Fluid Weights

In [None]:
# Configure the fluid learning system
config = FluidConfig(
    # LoRA settings
    rank=16,
    alpha=32.0,
    
    # Plasticity settings
    learning_rate=1e-5,  # Very small for stability
    update_every_n_tokens=1,
    
    # Stability
    use_ewc=True,
    ewc_strength=0.05,
    use_spectral_norm=True,
    max_spectral_norm=2.0,
    use_gradient_clipping=True,
    max_gradient_norm=0.1,
    use_adaptive_rate=True,
    
    # Novel mechanisms
    use_attention_guided_plasticity=True,
    attention_plasticity_strength=0.3,
    use_temporal_surprise=True,
    surprise_window=32,
    use_contextual_homeostasis=True,
    homeostasis_strength=0.1,
    homeostasis_burnin=100,
    
    # Monitoring
    track_metrics=True,
    log_every_n_steps=50,
)

print("Config created:")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Rank: {config.rank}")
print(f"  Novel mechanisms: AGP={config.use_attention_guided_plasticity}, TSM={config.use_temporal_surprise}, CH={config.use_contextual_homeostasis}")

In [None]:
# Wrap the model with FluidTransformer
fluid_model = FluidTransformer(
    model,
    config=config,
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
)

print(f"\nFluidLoRA modules created: {len(fluid_model.fluid_loras)}")
print("\nFirst few modules:")
for i, (name, lora) in enumerate(list(fluid_model.fluid_loras.items())[:4]):
    print(f"  {name}: {lora.in_features} -> {lora.out_features} (rank={lora.config.rank})")

## 4. Baseline Test (Frozen Weights)

In [None]:
def generate_response(model, tokenizer, prompt, max_new_tokens=100):
    """Generate a response from the model."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
# Test with frozen weights (baseline)
fluid_model.freeze()

print("=" * 60)
print("BASELINE TEST (Frozen Weights)")
print("=" * 60)

test_prompts = [
    "The capital of France is",
    "Write a haiku about programming:",
    "Explain quantum computing in simple terms:",
]

baseline_responses = []
for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    response = generate_response(fluid_model, tokenizer, prompt, max_new_tokens=50)
    print(f"Response: {response[len(prompt):]}")
    baseline_responses.append(response)

## 5. Fluid Learning Experiment

In [None]:
# Enable fluid plasticity
fluid_model.unfreeze()

print("Plasticity ENABLED - weights will now adapt during inference")

In [None]:
# Training data: Expose the model to a specific topic/style
# We'll use a series of prompts about a fictional topic to see if the model adapts

adaptation_texts = [
    "The kingdom of Zephyria is located in the northern mountains.",
    "In Zephyria, the people celebrate the Festival of Winds every spring.",
    "The capital of Zephyria is called Aeropolis.",
    "Queen Ventara has ruled Zephyria for 40 years.",
    "Zephyrian currency is called the Gust.",
    "The national animal of Zephyria is the Silver Eagle.",
    "Zephyrians are known for their skill in wind magic.",
    "The Great Library of Aeropolis contains ancient scrolls of air wisdom.",
    "Traditional Zephyrian food includes cloud bread and sky berries.",
    "The Zephyrian language has over 100 words for different types of wind.",
]

print(f"Adaptation data: {len(adaptation_texts)} passages about Zephyria")

In [None]:
from tqdm import tqdm

# Process adaptation texts multiple times to allow learning
num_epochs = 5

print(f"\nRunning {num_epochs} epochs of fluid adaptation...")
print("(Weights are updating in real-time during forward passes)\n")

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    for text in tqdm(adaptation_texts):
        # Tokenize and process
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        
        # Forward pass triggers fluid updates
        with torch.no_grad():
            outputs = fluid_model(**inputs)

print("\nAdaptation complete!")

In [None]:
# Check metrics
metrics = fluid_model.get_metrics()
print(f"\nCollected {len(metrics)} metric records")

if metrics:
    import pandas as pd
    df = pd.DataFrame(metrics)
    print("\nMetrics summary:")
    print(df.groupby('module').agg({
        'A_norm': ['mean', 'std'],
        'B_norm': ['mean', 'std'],
        'delta_norm': ['mean', 'max'],
    }).head(10))

## 6. Test Adaptation

In [None]:
# Now test if the model has adapted
# Freeze weights for fair comparison
fluid_model.freeze()

print("=" * 60)
print("POST-ADAPTATION TEST")
print("=" * 60)

test_prompts_zephyria = [
    "The capital of Zephyria is",
    "Tell me about the Kingdom of Zephyria:",
    "What do Zephyrians eat?",
    "Who rules Zephyria?",
]

for prompt in test_prompts_zephyria:
    print(f"\nPrompt: {prompt}")
    response = generate_response(fluid_model, tokenizer, prompt, max_new_tokens=50)
    print(f"Response: {response[len(prompt):]}")

In [None]:
# Test that base capabilities are preserved
print("\n" + "=" * 60)
print("BASE CAPABILITY PRESERVATION TEST")
print("=" * 60)

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    response = generate_response(fluid_model, tokenizer, prompt, max_new_tokens=50)
    print(f"Response: {response[len(prompt):]}")

## 7. Visualize Adaptation

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

if metrics:
    df = pd.DataFrame(metrics)
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Weight norms over time
    ax = axes[0, 0]
    for module in df['module'].unique()[:5]:  # First 5 modules
        module_df = df[df['module'] == module]
        ax.plot(module_df['step'], module_df['A_norm'], label=f'{module[:30]}...')
    ax.set_xlabel('Step')
    ax.set_ylabel('A Norm')
    ax.set_title('LoRA A Matrix Norms Over Time')
    ax.legend(fontsize=8)
    
    ax = axes[0, 1]
    for module in df['module'].unique()[:5]:
        module_df = df[df['module'] == module]
        ax.plot(module_df['step'], module_df['B_norm'], label=f'{module[:30]}...')
    ax.set_xlabel('Step')
    ax.set_ylabel('B Norm')
    ax.set_title('LoRA B Matrix Norms Over Time')
    ax.legend(fontsize=8)
    
    # Update magnitudes
    ax = axes[1, 0]
    ax.hist(df['delta_norm'], bins=50, alpha=0.7)
    ax.set_xlabel('Update Magnitude')
    ax.set_ylabel('Frequency')
    ax.set_title('Distribution of Update Magnitudes')
    ax.set_yscale('log')
    
    # Updates over time
    ax = axes[1, 1]
    ax.plot(df.groupby('step')['delta_norm'].mean())
    ax.set_xlabel('Step')
    ax.set_ylabel('Mean Update Magnitude')
    ax.set_title('Average Update Magnitude Over Time')
    
    plt.tight_layout()
    plt.savefig('fluid_weights_metrics.png', dpi=150)
    plt.show()
    
    print("\nMetrics visualization saved to 'fluid_weights_metrics.png'")
else:
    print("No metrics available for visualization")

## 8. Save and Load Adapted State

In [None]:
# Save the adapted state
adapted_state = fluid_model.save_fluid_state()

print(f"Saved state for {len(adapted_state)} modules")

# Save to file
torch.save(adapted_state, 'fluid_weights_adapted.pt')
print("Saved to 'fluid_weights_adapted.pt'")

In [None]:
# To reload later:
# loaded_state = torch.load('fluid_weights_adapted.pt')
# fluid_model.load_fluid_state(loaded_state)
# print("State restored!")

## 9. Ablation Studies

In [None]:
# Test different configurations
ablation_configs = {
    'full': FluidConfig(
        learning_rate=1e-5,
        use_attention_guided_plasticity=True,
        use_temporal_surprise=True,
        use_contextual_homeostasis=True,
    ),
    'no_agp': FluidConfig(
        learning_rate=1e-5,
        use_attention_guided_plasticity=False,
        use_temporal_surprise=True,
        use_contextual_homeostasis=True,
    ),
    'no_tsm': FluidConfig(
        learning_rate=1e-5,
        use_attention_guided_plasticity=True,
        use_temporal_surprise=False,
        use_contextual_homeostasis=True,
    ),
    'no_ch': FluidConfig(
        learning_rate=1e-5,
        use_attention_guided_plasticity=True,
        use_temporal_surprise=True,
        use_contextual_homeostasis=False,
    ),
    'base_only': FluidConfig(
        learning_rate=1e-5,
        use_attention_guided_plasticity=False,
        use_temporal_surprise=False,
        use_contextual_homeostasis=False,
    ),
}

print("Ablation configurations defined:")
for name in ablation_configs:
    print(f"  - {name}")

print("\nTo run ablation studies, reload model and use different configs")

## 10. Stability Analysis

In [None]:
# Analyze stability metrics
print("=" * 60)
print("STABILITY ANALYSIS")
print("=" * 60)

for name, lora in list(fluid_model.fluid_loras.items())[:5]:
    stability = lora.stability.check_stability(lora.A.data, lora.B.data)
    print(f"\n{name}:")
    print(f"  Weight norm: {stability.weight_norm:.4f}")
    print(f"  Spectral norm: {stability.spectral_norm:.4f}")
    print(f"  Drift from origin: {stability.drift_from_origin:.4f}")

In [None]:
# Check homeostasis status
print("\n" + "=" * 60)
print("HOMEOSTASIS STATUS")
print("=" * 60)

burned_in = sum(1 for lora in fluid_model.fluid_loras.values() if lora.state.is_burned_in)
print(f"\nModules with completed burn-in: {burned_in}/{len(fluid_model.fluid_loras)}")

for name, lora in list(fluid_model.fluid_loras.items())[:3]:
    if lora.state.target_mean is not None:
        mean_drift = torch.norm(lora.state.activation_mean - lora.state.target_mean).item()
        var_drift = torch.norm(lora.state.activation_var - lora.state.target_var).item()
        print(f"\n{name}:")
        print(f"  Mean drift: {mean_drift:.6f}")
        print(f"  Variance drift: {var_drift:.6f}")

## Summary

This notebook demonstrates the **Fluid Weights** system for perpetual plasticity in transformers.

### Key Findings:

1. **Weights can adapt during inference** without explicit loss functions
2. **Stability mechanisms prevent divergence** (EWC, spectral norm, adaptive rate)
3. **Novel mechanisms enhance learning**:
   - Attention-Guided Plasticity uses attention patterns as learning signals
   - Temporal Surprise creates implicit prediction objectives
   - Contextual Homeostasis maintains stable activation statistics

### Limitations:

- Learning is subtle and requires multiple exposures
- Hyperparameters require tuning for different tasks
- Long-term stability over millions of tokens needs more testing

### Next Steps:

1. Test on longer conversations
2. Tune hyperparameters for faster adaptation
3. Add more sophisticated evaluation metrics
4. Test on real-world adaptation scenarios (user style, domain adaptation)

In [None]:
print("\nFluid Weights Demo Complete!")
print("\nThis is exploratory research. Results may vary.")