# Multiple-channel latent Dirichlet allocation Pyro Model

In [2]:
import pandas as pd

In [3]:
# RUN THE SCRIPT Preprocess_data.py ONE TIME
# load data CaseRigshospitalet_optimized.parquet
df = pd.read_parquet('data/CaseRigshospitalet_optimized_withDistance.parquet')

## MCLDA model

In [None]:
import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam

def model(x_tokens, d_tokens, num_groups, num_diag_tokens, num_demo_tokens):
    """
    Multi-Channel LDA model:
    - x_tokens: list of diagnosis token indices for each patient [list of [N_x_p]]
    - d_tokens: list of demographic token indices for each patient [list of [N_d_p]]
    """
    num_patients = len(x_tokens)

    # Hyperparameters
    alpha = torch.ones(num_groups)
    beta_x = 0.1 * torch.ones(num_diag_tokens)
    beta_d = 0.1 * torch.ones(num_demo_tokens)

    # GROUP-LEVEL TOPIC DISTRIBUTIONS
    with pyro.plate("health_groups", num_groups):
        phi_x = pyro.sample("phi_x", dist.Dirichlet(beta_x))
        phi_d = pyro.sample("phi_d", dist.Dirichlet(beta_d))

    # PATIENT LOOP
    for p in pyro.plate("patients", num_patients):

        # Patient-specific group distribution
        theta_p = pyro.sample(f"theta_{p}", dist.Dirichlet(alpha))

        # Diagnosis tokens
        for n in pyro.plate(f"x_tokens_{p}", len(x_tokens[p])):
            z_x = pyro.sample(f"z_x_{p}_{n}", dist.Categorical(theta_p))
            pyro.sample(f"w_x_{p}_{n}", dist.Categorical(phi_x[z_x]), obs=x_tokens[p][n])

        # Demographic tokens
        for n in pyro.plate(f"d_tokens_{p}", len(d_tokens[p])):
            z_d = pyro.sample(f"z_d_{p}_{n}", dist.Categorical(theta_p))
            pyro.sample(f"w_d_{p}_{n}", dist.Categorical(phi_d[z_d]), obs=d_tokens[p][n])


In [None]:
# Example input shapes:
num_groups = 10
num_diag_tokens = 500    # Vocabulary size for diagnosis codes
num_demo_tokens = 50     # Vocabulary size for demographic tokens

# Simulated patient token data (indices)
x_tokens = [torch.tensor([12, 34, 101]), torch.tensor([2, 5])]  # 2 patients
d_tokens = [torch.tensor([4, 1]), torch.tensor([0, 3, 2])]


In [None]:
import pyro
import pyro.distributions as dist
import pyro.infer.autoguide as autoguide
import pyro.poutine as poutine
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam
import torch

# Define your model here (use the MCLDA model I gave you earlier)

# Example synthetic token data (replace with your real data)
x_tokens = [torch.tensor([1, 5, 2]), torch.tensor([3, 4]), torch.tensor([7, 8, 1, 2])]
d_tokens = [torch.tensor([0, 1]), torch.tensor([2]), torch.tensor([3, 4, 5])]

# Constants
num_groups = 5
num_diag_tokens = 20
num_demo_tokens = 10

# Ensure obs tokens are of type long
x_tokens = [tokens.long() for tokens in x_tokens]
d_tokens = [tokens.long() for tokens in d_tokens]

# Build a wrapped model for SVI (fixing args)
def model_wrapper():
    return model(x_tokens, d_tokens, num_groups, num_diag_tokens, num_demo_tokens)

# Guide setup
guide = autoguide.AutoGuideList(model_wrapper)
guide.add(autoguide.AutoDelta(poutine.block(model_wrapper, expose=['phi_x', 'phi_d'])))
for p in range(len(x_tokens)):
    guide.add(autoguide.AutoDiagonalNormal(poutine.block(model_wrapper, expose=[f"theta_{p}"])))

# Optimizer and ELBO
optimizer = ClippedAdam({"lr": 0.01})
elbo = Trace_ELBO()
svi = SVI(model_wrapper, guide, optimizer, loss=elbo)

# Training loop
losses = []
for step in range(500):
    loss = svi.step()
    losses.append(loss)
    if step % 50 == 0:
        print(f"Step {step} - ELBO: {loss:.2f}")
