In [10]:
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import InputLayer, Conv2D, MaxPooling2D, Flatten, Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils import shuffle
from tqdm.notebook import tqdm

class DSAL:
    """
    Dual-Stream Analytic Learning (DS-AL)
      • Main stream: recursive least squares (RLS) linear mapping.
      • Compensation stream: RLS on the null-space residuals.
    """
    def __init__(self, n_features: int, n_classes: int, lambda_: float = 1e-3):
        self.n_features = n_features
        self.n_classes  = n_classes
        self.lambda_    = lambda_

        # main-stream RLS state
        self.W_main = np.zeros((n_classes, n_features))
        self.P_main = (1.0 / lambda_) * np.eye(n_features)

        # compensation-stream RLS state
        self.W_c = np.zeros((n_classes, n_features))
        self.P_c = (1.0 / lambda_) * np.eye(n_features)

    def update(self, X: np.ndarray, Y: np.ndarray):
        """
        Per-sample RLS update on both streams.
        X: (n_samples, n_features)
        Y: (n_samples, n_classes) one-hot
        """
        for x_i, y_i in tqdm(zip(X, Y), total=len(X), desc="DS-AL update"):
            x = x_i[:, None]  # (d,1)
            y = y_i[:, None]  # (C,1)

            # ——— 1) main-stream RLS ———
            P_x    = self.P_main @ x
            denom  = 1.0 + (x.T @ P_x).item()
            k_main = P_x / denom
            e_main = y - (self.W_main @ x)

            self.W_main += e_main @ k_main.T
            self.P_main  -= k_main @ (x.T @ self.P_main)

            # compute null-space projector of W_main
            W_pinv = np.linalg.pinv(self.W_main)
            P_ker  = np.eye(self.n_features) - W_pinv @ self.W_main
            x_proj = P_ker @ x

            # ——— 2) compensation-stream RLS ———
            res    = y - (self.W_main @ x)
            Pcx    = self.P_c @ x_proj
            denom_c = 1.0 + (x_proj.T @ Pcx).item()
            k_c    = Pcx / denom_c

            self.W_c += res @ k_c.T
            self.P_c  -= k_c @ (x_proj.T @ self.P_c)

    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        f(x) = W_main x + W_c (P_ker x)
        Returns logits of shape (n_samples, n_classes).
        """
        W_pinv = np.linalg.pinv(self.W_main)
        P_ker  = np.eye(self.n_features) - W_pinv @ self.W_main

        Y_main = X @ self.W_main.T
        Y_comp = (X @ P_ker) @ self.W_c.T
        return Y_main + Y_comp

if __name__ == "__main__":
    # 1) Load & normalize MNIST
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train[..., None] / 255.0
    x_test  = x_test[..., None]  / 255.0

    # 2) One-hot over all 10 digits
    encoder     = OneHotEncoder(sparse_output=False)
    Y_train_all = encoder.fit_transform(y_train[:, None])
    Y_test_all  = encoder.transform(y_test[:, None])

    # 3) Define phases: base = [0,1], then [2,3], [4,5], [6,7], [8,9]
    phases = [[0,1], [2,3], [4,5], [6,7], [8,9]]

    # 4) BP-train CNN on the base phase
    base_classes = phases[0]
    mask0 = np.isin(y_train, base_classes)
    X0, y0 = x_train[mask0], y_train[mask0]
    Y0_small = np.stack([(y0 == c).astype(float) for c in base_classes], axis=1)

    model_cnn = Sequential([
        InputLayer(input_shape=(28,28,1)),
        Conv2D(32, 3, activation='relu'),
        MaxPooling2D(),
        Conv2D(64, 3, activation='relu'),
        MaxPooling2D(),
        Flatten(name="flatten"),
        Dense(128, activation="relu", name="fc1"),
        Dense(len(base_classes), activation="softmax", name="clf")
    ])
    model_cnn.compile(
        optimizer=Adam(),
        loss=CategoricalCrossentropy(),
        metrics=["accuracy"]
    )
    model_cnn.fit(
        X0, Y0_small,
        epochs=5,
        batch_size=64,
        verbose=1
    )

    # 5) Freeze backbone up to the Flatten layer by rebuilding functionally
    inp = Input(shape=(28,28,1))
    x   = inp
    for layer in model_cnn.layers:
        x = layer(x)
        if layer.name == "flatten":
            break
    backbone = Model(inputs=inp, outputs=x)
    d_cnn = backbone.output_shape[1]  # e.g. 7*7*64 = 3136

    # 6) Build random-projection buffer B: ℝ^{d_cnn→d_B}
    d_B = 512
    B   = np.random.randn(d_cnn, d_B)

    # 7) Extract & buffer base features
    X0_feat = backbone.predict(X0, batch_size=256)
    X0_buf  = np.maximum(0, X0_feat @ B)
    Y0_full = Y_train_all[mask0]

    # 8) Closed-form ridge solution for W_main^(0)
    λ  = 1e-3
    R0 = np.linalg.inv(X0_buf.T @ X0_buf + λ * np.eye(d_B))
    W0 = (R0 @ X0_buf.T @ Y0_full).T

    # 9) Instantiate DS-AL and inject initial state
    model = DSAL(n_features=d_B, n_classes=10, lambda_=λ)
    model.W_main = W0.copy()
    model.P_main = R0.copy()
    # W_c and P_c remain at default

    # 10) Prepare buffered feature tasks for phases 1–4
    tasks = []
    for cls_grp in phases[1:]:
        mask = np.isin(y_train, cls_grp)
        Xf   = backbone.predict(x_train[mask], batch_size=256)
        Xb   = np.maximum(0, Xf @ B)
        Yb   = Y_train_all[mask]
        tasks.append((cls_grp, Xb, Yb))

    # 11) Incremental RLS on phases 1–4
    for cls_grp, Xb, Yb in tasks:
        Xb, Yb = shuffle(Xb, Yb, random_state=42)
        print(f"Phase classes {cls_grp}: {Xb.shape[0]} samples")
        model.update(Xb, Yb)

    # 12) Final evaluation on full test set
    X_test_feat = backbone.predict(x_test, batch_size=256)
    X_test_buf  = np.maximum(0, X_test_feat @ B)
    logits      = model.predict(X_test_buf)
    y_pred      = np.argmax(logits, axis=1)
    acc         = (y_pred == y_test).mean()

    print(f"\nFinal MNIST accuracy: {acc*100:.2f}%")


Epoch 1/5




[1m198/198[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step - accuracy: 0.9799 - loss: 0.0762
Epoch 2/5
[1m198/198[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - accuracy: 0.9994 - loss: 0.0020
Epoch 3/5
[1m198/198[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - accuracy: 0.9998 - loss: 9.2832e-04
Epoch 4/5
[1m198/198[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - accuracy: 1.0000 - loss: 8.4517e-04
Epoch 5/5
[1m198/198[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - accuracy: 1.0000 - loss: 1.9795e-04
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step
[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step
Phase classes [2, 3]: 12089 samples


DS-AL update:   0%|          | 0/12089 [00:00<?, ?it/s]

Phase classes [4, 5]: 11263 samples


DS-AL update:   0%|          | 0/11263 [00:00<?, ?it/s]

Phase classes [6, 7]: 12183 samples


DS-AL update:   0%|          | 0/12183 [00:00<?, ?it/s]

Phase classes [8, 9]: 11800 samples


DS-AL update:   0%|          | 0/11800 [00:00<?, ?it/s]

[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step

Final MNIST accuracy: 9.87%
