In [1]:
import numpy as np
import tensorflow as tf
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

# === 1. Synthetic Data ===
X, y = make_classification(n_samples=5000, n_features=20, n_informative=15, 
                           n_redundant=5, n_classes=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
encoder = OneHotEncoder(sparse_output=False)
y_train_oh = encoder.fit_transform(y_train.reshape(-1, 1))
y_test_oh = encoder.transform(y_test.reshape(-1, 1))

# === 2. Weak Base Models ===
def make_weak_model(seed):
    tf.random.set_seed(seed)
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(20,)),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

n_models = 5
base_models = [make_weak_model(seed) for seed in range(n_models)]

# Train each model for only 1 epoch to make them intentionally weak
model_preds_train = []
model_preds_test = []

for model in base_models:
    model.fit(X_train, y_train_oh, epochs=1, verbose=0)
    model_preds_train.append(model.predict(X_train))
    model_preds_test.append(model.predict(X_test))

# Stack predictions: (n_models, num_samples, n_classes)
model_preds_train = np.stack(model_preds_train)
model_preds_test = np.stack(model_preds_test)

# === 3. Reverse Oracle Detection ===
def detect_reverse_oracles(model_preds, y_true, threshold=0.48):
    reverse_flags = []
    for preds in model_preds:
        acc = np.mean(np.argmax(preds, axis=1) == y_true)
        reverse_flags.append(acc < threshold)
    return np.array(reverse_flags)

reverse_flags = detect_reverse_oracles(model_preds_test, y_test)
print("Reverse Oracle Flags:", reverse_flags)

# === 4. Reverse Oracle-Aware Ensemble ===
class ReverseOracleSoftmax(tf.keras.layers.Layer):
    def __init__(self, reverse_flags):
        super().__init__()
        self.reverse_flags = tf.constant(reverse_flags, dtype=tf.bool)

    def call(self, model_outputs):  # (batch, n_models, n_classes)
        flipped = 1.0 - model_outputs
        flipped = flipped / tf.reduce_sum(flipped, axis=-1, keepdims=True)
        return tf.where(self.reverse_flags[tf.newaxis, :, tf.newaxis], flipped, model_outputs)

class WeightedReverseEnsemble(tf.keras.Model):
    def __init__(self, reverse_flags, n_models, n_classes):
        super().__init__()
        self.reverse_layer = ReverseOracleSoftmax(reverse_flags)
        self.weights = tf.Variable(tf.ones([n_models]), trainable=True)
        self.n_classes = n_classes

    def call(self, model_outputs):  # (batch, n_models, n_classes)
        corrected = self.reverse_layer(model_outputs)
        weighted = corrected * tf.nn.softmax(self.weights)[tf.newaxis, :, tf.newaxis]
        return tf.reduce_sum(weighted, axis=1)  # (batch, n_classes)

# === 5. Train Ensemble Weights ===
def train_ensemble(model_preds, y_true, reverse_flags, n_classes=10, epochs=100, lr=0.01):
    ensemble = WeightedReverseEnsemble(reverse_flags, model_preds.shape[1], n_classes)
    optimizer = tf.keras.optimizers.Adam(lr)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

    dataset = tf.data.Dataset.from_tensor_slices((model_preds, y_true)).batch(32)
    for epoch in range(epochs):
        for x_batch, y_batch in dataset:
            with tf.GradientTape() as tape:
                y_pred = ensemble(x_batch)
                loss = loss_fn(y_batch, y_pred)
            grads = tape.gradient(loss, ensemble.trainable_variables)
            optimizer.apply_gradients(zip(grads, ensemble.trainable_variables))
    return ensemble

# Reshape model_preds_test to (num_samples, n_models, n_classes)
model_preds_test_reshaped = np.transpose(model_preds_test, (1, 0, 2))
ensemble = train_ensemble(model_preds_test_reshaped, y_test, reverse_flags)

# Evaluate Ensemble
y_pred_ensemble = ensemble(model_preds_test_reshaped).numpy()
acc = np.mean(np.argmax(y_pred_ensemble, axis=1) == y_test)
print("Reverse-Aware Ensemble Accuracy:", acc)


2025-06-27 00:07:25.484004: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750975645.505239  806947 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750975645.512037  806947 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750975645.528835  806947 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750975645.528858  806947 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750975645.528860  806947 computation_placer.cc:177] computation placer alr

[1m 62/113[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 2ms/step - accuracy: 0.1273 - loss: 4.4875

I0000 00:00:1750975651.354071  807058 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 19ms/step - accuracy: 0.1306 - loss: 4.1803 - val_accuracy: 0.1925 - val_loss: 2.6842
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 17ms/step - accuracy: 0.0878 - loss: 3.8844 - val_accuracy: 0.1425 - val_loss: 2.7512
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step
[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 16ms/step - accuracy: 0.1093 - loss: 3.6708 - val_accuracy: 0.1400 - val_loss: 2.7925
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step
[1m113/113[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 16ms/step - accuracy: 0.0998 - lo

ValueError: Arguments `target` and `output` must have the same rank (ndim). Received: target.shape=(32,), output.shape=(32, 10)