# Semi-Supervised Learning with PyMC

This notebook demonstrates a Bayesian semi-supervised learning approach using PyMC with entropy minimization on unlabeled data.

In [None]:
import pymc as pm
import numpy as np
import arviz as az
import pytensor.tensor as pt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

## Step 1: Generate Data

In [None]:
X, y = make_classification(n_samples=2000, n_features=2, n_informative=2, 
                           n_redundant=0, weights=[0.8, 0.2], flip_y=0.1, random_state=42)

# Split: Labeled (small) vs Unlabeled (large)
X_labeled, X_unlabeled, y_labeled, _ = train_test_split(X, y, train_size=100, random_state=42)

print(f"Labeled: {len(X_labeled)} | Unlabeled: {len(X_unlabeled)}")

## Step 2: Scalable PyMC Model

In [None]:
with pm.Model() as model:
    # --- Input Data Containers (Mutable for future mini-batching) ---
    X_l = pm.Data("X_labeled", X_labeled)
    y_l = pm.Data("y_labeled", y_labeled)
    X_u = pm.Data("X_unlabeled", X_unlabeled)

    # --- Parameters ---
    # Global calibration (applied to the weighted ensemble)
    # We simplified: Weight -> Ensemble -> Calibrate
    w = pm.Dirichlet("weights", a=np.ones(2))
    a = pm.HalfNormal("slope", sigma=1.0)
    b = pm.Normal("bias", mu=0.0, sigma=1.0)

    # --- Forward Pass ---
    # 1. Ensemble Combination (Linear)
    # Note: We combine logits first, then calibrate. 
    # This is numerically more stable than combining probabilities.
    ens_logit_l = pm.math.dot(X_l, w)
    ens_logit_u = pm.math.dot(X_u, w)
    
    # 2. Calibration (Platt Scaling)
    p_labeled = pm.math.sigmoid(a * ens_logit_l + b)
    p_unlabeled = pm.math.sigmoid(a * ens_logit_u + b)

    # --- Likelihood (Labeled Data) ---
    obs = pm.Bernoulli("y_obs", p=p_labeled, observed=y_l)

    # --- SSL: Entropy Minimization (The Critical Fix) ---
    # We penalize high uncertainty (p near 0.5) on unlabeled data.
    # This forces the decision boundary into low-density regions.
    entropy = -(p_unlabeled * pt.log(p_unlabeled + 1e-6) + 
               (1 - p_unlabeled) * pt.log(1 - p_unlabeled + 1e-6))
    
    # "lambda_ssl" controls how much we trust unlabeled structure vs labeled data
    lambda_ssl = 0.5 
    pm.Potential("ssl_regularization", -lambda_ssl * entropy.sum())

    # --- Inference: ADVI for Scalability ---
    print("Fitting with ADVI (Variational Inference)...")
    mean_field = pm.fit(method='advi', n=20000)

## Step 3: Sampling & Prediction

In [None]:
# Draw samples from the approximated posterior
trace = mean_field.sample(1000)

# Extract learned parameters directly from trace (no need for sample_posterior_predictive)
print("\nLearned Weights (Mean):")
print(trace.posterior["weights"].mean(dim=["draw", "chain"]).values)

print("\nLearned Slope (Mean):")
print(trace.posterior["slope"].mean(dim=["draw", "chain"]).values)

print("\nLearned Bias (Mean):")
print(trace.posterior["bias"].mean(dim=["draw", "chain"]).values)