In [2]:
"""
GE-FL: Generalized Expectation with Feature Labels
Reference: "Learning from Labeled Features using Generalized Expectation Criteria"
           Gregory Druck, Gideon Mann, Andrew McCallum. SIGIR 2008.

This script:
 - loads a small subset of 20newsgroups (ibm vs mac)
 - extracts a bag-of-words feature matrix
 - constructs candidate features and a simulated 'oracle' feature labeler
 - trains a multinomial logistic regression model by minimizing:
       Loss = sum_k KL(p_hat_k || p_tilde_k(theta))  +  (1/(2*sigma^2)) * ||theta||^2
   where p_tilde_k(theta) is the model predicted class distribution over all
   unlabeled instances that contain feature k, and p_hat_k is a reference distribution
   for feature k (from labeled features).
 - evaluates the trained model on a held-out labeled test set (for demonstration).
"""

import numpy as np
from scipy.optimize import minimize
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_selection import mutual_info_classif
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


# ------------------------
# Utilities
# ------------------------
def softmax(logits):
    z = logits - logits.max(axis=1, keepdims=True)
    expz = np.exp(z)
    return expz / expz.sum(axis=1, keepdims=True)


def build_schapire_distributions(feature_labels, n_classes, qmaj=0.9):
    """
    Create reference distributions for each labeled feature using Schapire heuristic.
    """
    p_hat_dict = {}
    for k, assoc in feature_labels.items():
        assoc = list(assoc)
        n_assoc = len(assoc)
        p = np.ones(n_classes) * ((1.0 - qmaj) / max(1, (n_classes - n_assoc)))
        if n_assoc > 0:
            p[assoc] = qmaj / n_assoc
        p_hat_dict[k] = p
    return p_hat_dict


# ------------------------
# GE-FL objective + gradient
# ------------------------
class GEFLTrainer:
    def __init__(self, X_unlabeled, n_classes, feature_p_hats, sigma=1.0, add_bias=True):
        self.X = X_unlabeled.tocsr(copy=False)
        self.n_samples, self.n_features = self.X.shape
        self.n_classes = n_classes
        self.feature_p_hats = feature_p_hats
        self.sigma = float(sigma)
        self.add_bias = add_bias

        # Precompute rows where each labeled feature occurs
        self.feat_rows = {k: self.X[:, k].nonzero()[0] for k in feature_p_hats.keys()}

    def pack_params(self, W, b):
        return np.concatenate([W.ravel(), b.ravel()]) if self.add_bias else W.ravel()

    def unpack_params(self, params):
        if self.add_bias:
            size_w = self.n_classes * self.n_features
            W = params[:size_w].reshape(self.n_classes, self.n_features)
            b = params[size_w:size_w + self.n_classes]
            return W, b
        else:
            W = params.reshape(self.n_classes, self.n_features)
            b = np.zeros(self.n_classes)
            return W, b

    def predict_proba_unlabeled(self, W, b):
        logits = self.X.dot(W.T)
        if self.add_bias:
            logits = logits + b.reshape(1, -1)
        return softmax(logits)

    def objective_and_grad(self, params):
        W, b = self.unpack_params(params)
        P = self.predict_proba_unlabeled(W, b)

        loss = 0.0
        grad_W = np.zeros_like(W)
        grad_b = np.zeros_like(b)

        for k, p_hat in self.feature_p_hats.items():
            rows = self.feat_rows.get(k, np.array([], dtype=int))
            Ck = len(rows)
            if Ck == 0:
                continue

            Q = P[rows]                      # (Ck, n_classes)
            p_tilde = Q.mean(axis=0)         # model expectation for this feature
            eps = 1e-12
            p_tilde = np.clip(p_tilde, eps, 1.0)
            p_tilde = p_tilde / p_tilde.sum()

            p_hat = np.asarray(p_hat, dtype=float)
            p_hat = np.clip(p_hat, eps, 1.0)
            p_hat = p_hat / p_hat.sum()

            # KL divergence
            kl = np.sum(p_hat * (np.log(p_hat) - np.log(p_tilde)))
            loss += kl

            # Gradient
            alpha = p_hat / p_tilde
            s = Q.dot(alpha)                       # (Ck,)
            A = Q * alpha[np.newaxis, :]           # (Ck, n_classes)
            A_final = - (1.0 / float(Ck)) * (A - (Q * s[:, np.newaxis]))

            X_rows = self.X[rows].toarray()        # force dense (Ck, n_features)
            grad_W += (A_final.T @ X_rows).astype(np.float64)
            grad_b += A_final.sum(axis=0)
        # Gaussian prior
        prior_term = 0.5 * (np.sum(W * W) + np.sum(b * b)) / (self.sigma ** 2)
        loss += prior_term
        grad_W += W / (self.sigma ** 2)
        grad_b += b / (self.sigma ** 2)

        grad = np.concatenate([grad_W.ravel(), grad_b.ravel()]) if self.add_bias else grad_W.ravel()
        return loss, grad


# ------------------------
# Example pipeline
# ------------------------
def example_run():
    categories = ['comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware']
    data = fetch_20newsgroups(subset='all', categories=categories, remove=('headers', 'footers', 'quotes'))
    y_all = data.target
    X_text = data.data

    X_train_text, X_test_text, y_train, y_test = train_test_split(
        X_text, y_all, test_size=0.3, random_state=42, stratify=y_all
    )

    vect = CountVectorizer(min_df=5, stop_words='english', max_features=5000)
    X_train_counts = vect.fit_transform(X_train_text)
    X_test_counts = vect.transform(X_test_text)
    n_classes = len(np.unique(y_all))

    # Oracle feature selection (mutual information)
    mi = mutual_info_classif(X_train_counts, y_train, discrete_features=True)
    top_k = 100
    top_idx = np.argsort(mi)[::-1][:top_k]

    feature_labels = {}
    Xc = X_train_counts.tocsr()
    for k in top_idx:
        rows = Xc[:, k].nonzero()[0]
        if len(rows) == 0:
            continue
        labels_in_rows = y_train[rows]
        counts = np.bincount(labels_in_rows, minlength=n_classes)
        max_class = np.argmax(counts)
        assoc = [max_class]
        for c in range(n_classes):
            if c != max_class and counts[c] >= 0.5 * counts[max_class]:
                assoc.append(c)
        if counts[max_class] > 0:
            feature_labels[k] = assoc

    print(f"Selected {len(feature_labels)} labeled features (oracle-simulated).")

    # Build reference distributions
    p_hat_dict = build_schapire_distributions(feature_labels, n_classes, qmaj=0.9)

    # Train GE-FL
    trainer = GEFLTrainer(X_train_counts, n_classes, p_hat_dict, sigma=1.0, add_bias=True)

    rng = np.random.RandomState(0)
    W0 = 0.01 * rng.randn(n_classes, X_train_counts.shape[1])
    b0 = np.zeros(n_classes)
    p0 = trainer.pack_params(W0, b0)

    def fun_and_grad(p):
        loss, grad = trainer.objective_and_grad(p)
        return loss, grad

    print("Starting L-BFGS optimization...")
    res = minimize(fun_and_grad, p0, method='L-BFGS-B', jac=True,
                   options={'maxiter': 50, 'disp': True})  # reduce maxiter for demo
    W_opt, b_opt = trainer.unpack_params(res.x)
    print("Optimization done. Final objective:", res.fun)

    # Evaluate
    logits_test = X_test_counts.dot(W_opt.T) + b_opt.reshape(1, -1)
    P_test = softmax(logits_test)
    y_pred = np.argmax(P_test, axis=1)
    print("Test accuracy (GE-FL):", accuracy_score(y_test, y_pred))
    print(classification_report(y_test, y_pred, target_names=categories))


if __name__ == "__main__":
    example_run()

Selected 100 labeled features (oracle-simulated).
Starting L-BFGS optimization...


  res = minimize(fun_and_grad, p0, method='L-BFGS-B', jac=True,


Optimization done. Final objective: 5.302976412435929
Test accuracy (GE-FL): 0.791095890410959
                          precision    recall  f1-score   support

comp.sys.ibm.pc.hardware       0.73      0.94      0.82       295
   comp.sys.mac.hardware       0.91      0.64      0.75       289

                accuracy                           0.79       584
               macro avg       0.82      0.79      0.79       584
            weighted avg       0.82      0.79      0.79       584

