In [1]:
"""
3-way PLSI (User-Document-Word) — EM implementation in NumPy
Author: ChatGPT

This file implements:
- A triadic PLSI model where P(u,d,w) = sum_z P(z) P(u|z) P(d|z) P(w|z)
- EM updates with optional Laplace smoothing
- A vanilla (document-word) PLSI for comparison
- Synthetic data generator, training, held-out perplexity evaluation
- Simple user-personalized document retrieval scoring

Usage: run the file. All dependencies are standard Python + numpy.
"""
import numpy as np
from collections import defaultdict
import math

np.random.seed(0)

# ---------------------------
# Utilities
# ---------------------------

def normalize_rows(mat):
    # normalize rows to sum to 1, avoid divide-by-zero
    s = mat.sum(axis=1, keepdims=True)
    s[s == 0] = 1.0
    return mat / s

# ---------------------------
# Triadic PLSI model
# ---------------------------
class TriadicPLSI:
    def __init__(self, n_users, n_docs, n_words, n_topics, laplace=1e-6):
        self.U = n_users
        self.D = n_docs
        self.W = n_words
        self.Z = n_topics
        self.laplace = laplace
        # initialize parameters randomly
        self.Pz = np.random.rand(self.Z)
        self.Pz /= self.Pz.sum()
        self.Pu_z = np.random.rand(self.Z, self.U)
        self.Pd_z = np.random.rand(self.Z, self.D)
        self.Pw_z = np.random.rand(self.Z, self.W)
        # normalize conditionals: P(u|z), P(d|z), P(w|z)
        self.Pu_z = normalize_rows(self.Pu_z)
        self.Pd_z = normalize_rows(self.Pd_z)
        self.Pw_z = normalize_rows(self.Pw_z)

    def fit(self, triplets, counts=None, n_iters=50, verbose=True, tol=1e-5):
        """
        triplets: list of (u,d,w) tuples (0-based indices)
        counts: optional list/array of same length with counts (defaults to 1 each)
        """
        N = len(triplets)
        if counts is None:
            counts = np.ones(N)
        counts = np.asarray(counts, dtype=float)
        total_counts = counts.sum()

        # Pre-allocate posterior p(z | u,d,w) for each triplet
        post = np.zeros((N, self.Z))

        prev_ll = -np.inf
        for it in range(n_iters):
            # E-step: compute posterior p(z | u,d,w)
            for i, (u, d, w) in enumerate(triplets):
                # unnormalized: P(z) P(u|z) P(d|z) P(w|z)
                pz_unnorm = self.Pz * self.Pu_z[:, u] * self.Pd_z[:, d] * self.Pw_z[:, w]
                s = pz_unnorm.sum()
                if s == 0:
                    # assign uniform to avoid NaNs
                    post[i, :] = 1.0 / self.Z
                else:
                    post[i, :] = pz_unnorm / s

            # M-step: update Pz, Pu_z, Pd_z, Pw_z using expected counts
            # expected count for topic z: sum_i counts[i] * post[i,z]
            ez = (counts[:, None] * post).sum(axis=0)  # shape Z
            # update priors
            self.Pz = ez / ez.sum()

            # update conditionals: P(u|z) = sum_{d,w} n(u,d,w) p(z|u,d,w) / ez[z]
            Pu_z_new = np.zeros((self.Z, self.U))
            Pd_z_new = np.zeros((self.Z, self.D))
            Pw_z_new = np.zeros((self.Z, self.W))

            for i, (u, d, w) in enumerate(triplets):
                c = counts[i]
                for z in range(self.Z):
                    contrib = c * post[i, z]
                    Pu_z_new[z, u] += contrib
                    Pd_z_new[z, d] += contrib
                    Pw_z_new[z, w] += contrib

            # apply Laplace smoothing to avoid zeros
            Pu_z_new += self.laplace
            Pd_z_new += self.laplace
            Pw_z_new += self.laplace

            # normalize conditionals
            self.Pu_z = normalize_rows(Pu_z_new)
            self.Pd_z = normalize_rows(Pd_z_new)
            self.Pw_z = normalize_rows(Pw_z_new)

            # compute log-likelihood for monitoring
            ll = 0.0
            for i, (u, d, w) in enumerate(triplets):
                prob = (self.Pz * self.Pu_z[:, u] * self.Pd_z[:, d] * self.Pw_z[:, w]).sum()
                # numerical stability
                if prob <= 0:
                    prob = 1e-300
                ll += counts[i] * np.log(prob)

            if verbose:
                print(f"Iter {it+1:3d} | Log-likelihood: {ll:.4f}")

            # check convergence on log-likelihood
            if abs(ll - prev_ll) < tol:
                if verbose:
                    print("Converged.")
                break
            prev_ll = ll

    def log_likelihood(self, triplets, counts=None):
        if counts is None:
            counts = np.ones(len(triplets))
        counts = np.asarray(counts, dtype=float)
        ll = 0.0
        for i, (u, d, w) in enumerate(triplets):
            prob = (self.Pz * self.Pu_z[:, u] * self.Pd_z[:, d] * self.Pw_z[:, w]).sum()
            if prob <= 0:
                prob = 1e-300
            ll += counts[i] * np.log(prob)
        return ll

    def perplexity(self, triplets, counts=None):
        if counts is None:
            counts = np.ones(len(triplets))
        total = counts.sum()
        ll = self.log_likelihood(triplets, counts)
        return math.exp(-ll / total)

    def user_topic_distribution(self, u):
        # P(z|u) proportional to P(z) P(u|z)
        unnorm = self.Pz * self.Pu_z[:, u]
        s = unnorm.sum()
        if s == 0:
            return np.ones(self.Z) / self.Z
        return unnorm / s

    def score_docs_for_user(self, u):
        # score P(d | u) ∝ sum_z P(z|u) P(d|z)
        pz_given_u = self.user_topic_distribution(u)
        # P(d|u) = sum_z P(d|z) P(z|u)
        scores = (pz_given_u[:, None] * self.Pd_z).sum(axis=0)
        return scores

# ---------------------------
# Vanilla PLSI (Document-Word)
# ---------------------------
class VanillaPLSI:
    def __init__(self, n_docs, n_words, n_topics, laplace=1e-6):
        self.D = n_docs
        self.W = n_words
        self.Z = n_topics
        self.laplace = laplace
        self.Pz = np.random.rand(self.Z)
        self.Pz /= self.Pz.sum()
        self.Pd_z = np.random.rand(self.Z, self.D)
        self.Pw_z = np.random.rand(self.Z, self.W)
        self.Pd_z = normalize_rows(self.Pd_z)
        self.Pw_z = normalize_rows(self.Pw_z)

    def fit(self, pairs, counts=None, n_iters=50, verbose=True, tol=1e-5):
        if counts is None:
            counts = np.ones(len(pairs))
        counts = np.asarray(counts, dtype=float)
        N = len(pairs)
        post = np.zeros((N, self.Z))
        prev_ll = -np.inf
        for it in range(n_iters):
            # E-step
            for i, (d, w) in enumerate(pairs):
                pz_unnorm = self.Pz * self.Pd_z[:, d] * self.Pw_z[:, w]
                s = pz_unnorm.sum()
                if s == 0:
                    post[i, :] = 1.0 / self.Z
                else:
                    post[i, :] = pz_unnorm / s

            ez = (counts[:, None] * post).sum(axis=0)
            self.Pz = ez / ez.sum()

            Pd_z_new = np.zeros((self.Z, self.D))
            Pw_z_new = np.zeros((self.Z, self.W))
            for i, (d, w) in enumerate(pairs):
                c = counts[i]
                for z in range(self.Z):
                    contrib = c * post[i, z]
                    Pd_z_new[z, d] += contrib
                    Pw_z_new[z, w] += contrib
            Pd_z_new += self.laplace
            Pw_z_new += self.laplace
            self.Pd_z = normalize_rows(Pd_z_new)
            self.Pw_z = normalize_rows(Pw_z_new)

            ll = 0.0
            for i, (d, w) in enumerate(pairs):
                prob = (self.Pz * self.Pd_z[:, d] * self.Pw_z[:, w]).sum()
                if prob <= 0:
                    prob = 1e-300
                ll += counts[i] * np.log(prob)
            if verbose:
                print(f"Vanilla Iter {it+1:3d} | LL: {ll:.4f}")
            if abs(ll - prev_ll) < tol:
                if verbose:
                    print("Vanilla converged")
                break
            prev_ll = ll

    def log_likelihood(self, pairs, counts=None):
        if counts is None:
            counts = np.ones(len(pairs))
        counts = np.asarray(counts, dtype=float)
        ll = 0.0
        for i, (d, w) in enumerate(pairs):
            prob = (self.Pz * self.Pd_z[:, d] * self.Pw_z[:, w]).sum()
            if prob <= 0:
                prob = 1e-300
            ll += counts[i] * np.log(prob)
        return ll

    def perplexity(self, pairs, counts=None):
        if counts is None:
            counts = np.ones(len(pairs))
        total = counts.sum()
        ll = self.log_likelihood(pairs, counts)
        return math.exp(-ll / total)

# ---------------------------
# Synthetic data generator
# ---------------------------

def generate_synthetic(U=20, D=50, W=100, Z=5, N=5000, alpha=0.5):
    """
    Generate synthetic triplets from a known triadic model
    Returns: triplets list, counts (all ones), and ground-truth parameters
    """
    # ground truth topic prior
    Pz_true = np.random.dirichlet([alpha]*Z)
    Pu_z_true = np.array([np.random.dirichlet([alpha]*U) for _ in range(Z)])
    Pd_z_true = np.array([np.random.dirichlet([alpha]*D) for _ in range(Z)])
    Pw_z_true = np.array([np.random.dirichlet([alpha]*W) for _ in range(Z)])

    triplets = []
    counts = []
    for n in range(N):
        z = np.random.choice(Z, p=Pz_true)
        u = np.random.choice(U, p=Pu_z_true[z])
        d = np.random.choice(D, p=Pd_z_true[z])
        w = np.random.choice(W, p=Pw_z_true[z])
        triplets.append((u, d, w))
        counts.append(1)

    return triplets, np.array(counts), (Pz_true, Pu_z_true, Pd_z_true, Pw_z_true)

# ---------------------------
# Demo training & evaluation
# ---------------------------
if __name__ == '__main__':
    # create synthetic data (train/test split)
    triplets, counts, gt = generate_synthetic(U=15, D=40, W=200, Z=6, N=4000)
    # simple train-test split
    split = int(0.8 * len(triplets))
    train_trip = triplets[:split]
    train_counts = counts[:split]
    test_trip = triplets[split:]
    test_counts = counts[split:]

    # train triadic model
    model = TriadicPLSI(n_users=15, n_docs=40, n_words=200, n_topics=6, laplace=1e-6)
    print('\nTraining Triadic PLSI...')
    model.fit(train_trip, counts=train_counts, n_iters=30, verbose=True)
    print('\nTriadic Perplexity on test set:', model.perplexity(test_trip, counts=test_counts))

    # prepare pairs (d,w) for vanilla model by dropping user dimension
    pairs = [(d, w) for (u, d, w) in train_trip]
    pairs_test = [(d, w) for (u, d, w) in test_trip]

    vanilla = VanillaPLSI(n_docs=40, n_words=200, n_topics=6, laplace=1e-6)
    print('\nTraining Vanilla (doc-word) PLSI...')
    vanilla.fit(pairs, counts=train_counts, n_iters=30, verbose=True)
    print('\nVanilla Perplexity on test set:', vanilla.perplexity(pairs_test, counts=test_counts))

    # simple retrieval: given user u, rank documents
    u = 3
    scores = model.score_docs_for_user(u)
    top_docs = np.argsort(-scores)[:10]
    print(f"\nTop docs for user {u} (by model):", top_docs)

    # show top words for each topic
    topk = 8
    for z in range(model.Z):
        top_words = np.argsort(-model.Pw_z[z])[:topk]
        print(f"Topic {z}: top words indices: {top_words}")

    print('\nDone.')



Training Triadic PLSI...
Iter   1 | Log-likelihood: -35212.8762
Iter   2 | Log-likelihood: -35048.6161
Iter   3 | Log-likelihood: -34883.5497
Iter   4 | Log-likelihood: -34717.0272
Iter   5 | Log-likelihood: -34563.1429
Iter   6 | Log-likelihood: -34423.8763
Iter   7 | Log-likelihood: -34301.1691
Iter   8 | Log-likelihood: -34195.2452
Iter   9 | Log-likelihood: -34104.6745
Iter  10 | Log-likelihood: -34027.2016
Iter  11 | Log-likelihood: -33964.8974
Iter  12 | Log-likelihood: -33916.7690
Iter  13 | Log-likelihood: -33878.5709
Iter  14 | Log-likelihood: -33847.1656
Iter  15 | Log-likelihood: -33821.2107
Iter  16 | Log-likelihood: -33799.5880
Iter  17 | Log-likelihood: -33780.7731
Iter  18 | Log-likelihood: -33763.4074
Iter  19 | Log-likelihood: -33746.2870
Iter  20 | Log-likelihood: -33727.6949
Iter  21 | Log-likelihood: -33706.0327
Iter  22 | Log-likelihood: -33682.9630
Iter  23 | Log-likelihood: -33661.5395
Iter  24 | Log-likelihood: -33641.1361
Iter  25 | Log-likelihood: -33620.9273