In [None]:
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.kernel_approximation import RBFSampler
from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
X = X.to_numpy().astype(np.float32) / 255.0
y = y.astype(int)

X_train, _, y_train, _ = train_test_split(X, y, train_size=10000, stratify=y, random_state=42)

gamma = 0.01
latent_dim = 200

rff = RBFSampler(gamma=gamma, n_components=latent_dim, random_state=42)
X_train_latent = rff.fit_transform(X_train)

inverse_regressor = Ridge(alpha=1.0)
inverse_regressor.fit(X_train_latent, X_train)

noise = np.random.normal(scale=0.1, size=X_train_latent.shape)
Z_noisy = X_train_latent + noise

denoiser = make_pipeline(
    PolynomialFeatures(degree=2, include_bias=False),
    Ridge(alpha=1.0)
)
denoiser.fit(Z_noisy, noise)

T = 100
eta = 0.1
samples_per_digit = 100
num_digits = 10

Z_gen = []
Y_gen = []

for digit in range(num_digits):
    for _ in range(samples_per_digit):
        z = np.random.normal(size=(1, latent_dim))
        for t in range(T):
            predicted_noise = denoiser.predict(z)
            z = z - eta * predicted_noise
            z = np.clip(z, -10, 10)
        Z_gen.append(z)
        Y_gen.append(digit)

Z_gen = np.vstack(Z_gen)
X_gen = inverse_regressor.predict(Z_gen)
Y_gen = np.array(Y_gen)

print(f"Generated dataset: X_gen shape: {X_gen.shape}, Y_gen shape: {Y_gen.shape}")

clf = LogisticRegression(max_iter=200, solver="saga")
clf.fit(X_train, y_train)
y_pred = clf.predict(X_gen)

accuracy = accuracy_score(Y_gen, y_pred)
print(f"\nClassifier accuracy on generated data: {accuracy:.3f}")

fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i, ax in enumerate(axes):
    ax.imshow(X_gen[i].reshape(28, 28), cmap="gray")
    ax.set_title(f"Label: {Y_gen[i]}")
    ax.axis("off")
plt.suptitle(f"Example Generated Digit")
plt.show()
