### M-LDA

## import data

In [1]:
import pandas as pd
import numpy as np 

In [5]:
df = pd.read_csv("data/data_summed.csv")
diagnose_df = pd.read_csv("data/diagnosis.csv", sep=";")




In [20]:
unique_diag_lists = df.groupby("Patient ID")["Aktionsdiagnosekode"].unique()

code = "Aktionsdiagnosekode"
text = "Aktionsdiagnosetekst"
token_dict = {}

for idx, row in diagnose_df.iterrows():
    code_val = row["Aktionsdiagnosekode"]
    text_val = row["Aktionsdiagnosetekst"]

    text1 = str(code_val) if pd.notnull(code_val) else ""
    text2 = str(text_val) if pd.notnull(text_val) else ""
    combined_text = text1 + " " + text2
    token_dict[code_val] = combined_text


# Convert diagnosis codes to full text tokens
diagnosis_per_patient = {
    pid: [token_dict.get(str(code)) for code in codes if str(code) in token_dict]
    for pid, codes in unique_diag_lists.items()
}

age_bins = [0, 5, 18, 25, 35, 45, 55, 65, 75, 80, 85, 90, 99,  float("inf")]
age_labels = [f"Age_Group_{str(i).zfill(2)}" for i in range(13)]
df["Age_Group"] = pd.cut(
    df["alder"],
    bins=age_bins,
    labels=age_labels,
    right=True,
    include_lowest=True
)

# Clean and format demographic values
df["gender"] = df["gender"].astype(str).str.strip()
df["civilStand"] = df["civilStand"].astype(str).str.strip()
#df["Patient kommune"] = df["Patient kommune"].astype(str).str.strip()
df["Age_Group"] = df["Age_Group"].astype(str).str.strip()

# Construct context token list
def demographic_tokens(row):
    tokens = []
    if row["gender"]:
        tokens.append(f"Sex_{row['gender']}")
    if row["civilStand"]:
        tokens.append(row["civilStand"])
    #if row["Patient kommune"]:
        #tokens.append(f"Kommune_{row['Patient kommune']}")
    if row["Age_Group"]:
        tokens.append(row["Age_Group"])
    return tokens

df["context"] = df.apply(demographic_tokens, axis=1)

# Drop rows with no context tokens
df = df[df["context"].map(lambda x: len(x) > 0)]

# Group by Patient ID and take the first available context per patient
context_dict = (
    df.groupby("Patient ID")["context"]
    .first()
    .to_dict()
)

# --- Step 4: Merge diagnoses and context into a single DataFrame ---

# Create patient_df from diagnoses dictionary
patient_df = pd.DataFrame([
    {"patient_id": pid, "diagnoses": diag}
    for pid, diag in diagnosis_per_patient.items()
])

# Create context_df from context_dict
context_df = pd.DataFrame.from_dict(
    {pid: {"context": tokens} for pid, tokens in context_dict.items()},
    orient="index"
)

# Make sure indices match in type and formatting
patient_df["patient_id"] = patient_df["patient_id"].astype(str).str.strip()
context_df.index = context_df.index.astype(str).str.strip()

# Merge into one final dataframe
patient_df = patient_df.set_index("patient_id").join(context_df, how="left").reset_index()

# Optional: drop rows where diagnoses or context are missing
patient_df = patient_df.dropna(subset=["diagnoses", "context"])

patient_df = patient_df.sample(10000, random_state=42)


### Create Vocabularies for the model

In [21]:
from itertools import chain

# Diagnosis vocabulary
diag_vocab = {token: idx for idx, token in enumerate(set(chain.from_iterable(patient_df["diagnoses"])))}

# Context vocabulary
context_vocab = {token: idx for idx, token in enumerate(set(chain.from_iterable(patient_df["context"])))}

import torch

def tokens_to_indices(tokens, vocab):
    return torch.tensor([vocab[token] for token in tokens if token in vocab], dtype=torch.long)


patient_records_indexed = [tokens_to_indices(tokens, diag_vocab) for tokens in patient_df["diagnoses"]]
context_dict_indexed    = [tokens_to_indices(tokens, context_vocab) for tokens in patient_df["context"]]


### Model

In [24]:
import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam

def model(x_tokens, d_tokens, num_groups, num_diag_tokens, num_demo_tokens):
    """
    Multi-Channel LDA model:
    - x_tokens: list of diagnosis token indices for each patient [list of [N_x_p]]
    - d_tokens: list of demographic token indices for each patient [list of [N_d_p]]
    """
    num_patients = len(x_tokens)

    # Hyperparameters
    alpha = torch.ones(num_groups)
    beta_x = 0.1 * torch.ones(num_diag_tokens)
    beta_d = 0.1 * torch.ones(num_demo_tokens)

    # GROUP-LEVEL TOPIC DISTRIBUTIONS
    with pyro.plate("health_groups", num_groups):
        phi_x = pyro.sample("phi_x", dist.Dirichlet(beta_x))
        phi_d = pyro.sample("phi_d", dist.Dirichlet(beta_d))

    # PATIENT LOOP
    for p in pyro.plate("patients", num_patients):

        # Patient-specific group distribution
        theta_p = pyro.sample(f"theta_{p}", dist.Dirichlet(alpha))

        # Diagnosis tokens
        for n in pyro.plate(f"x_tokens_{p}", len(x_tokens[p])):
            
            z_x = pyro.sample(f"z_x_{p}_{n}", dist.Categorical(theta_p), infer={"enumerate": "parallel", "is_auxiliary": True})

            pyro.sample(f"w_x_{p}_{n}", dist.Categorical(phi_x[z_x]), obs=x_tokens[p][n])

        # Demographic tokens
        for n in pyro.plate(f"d_tokens_{p}", len(d_tokens[p])):
            z_d = pyro.sample(f"z_d_{p}_{n}", dist.Categorical(theta_p), infer={"enumerate": "parallel", "is_auxiliary": True})

            pyro.sample(f"w_d_{p}_{n}", dist.Categorical(phi_d[z_d]), obs=d_tokens[p][n])

In [25]:
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam
import pyro.poutine as poutine
from pyro.distributions import constraints
from pyro.infer.autoguide import AutoDelta

pyro.clear_param_store()

def guide(x_tokens, d_tokens, num_groups, num_diag_tokens, num_demo_tokens):
    num_patients = len(x_tokens)

    # Global topic-word posteriors
    phi_x_posterior = pyro.param("phi_x_posterior", torch.ones(num_groups, num_diag_tokens),
                                 constraint=constraints.positive)
    phi_d_posterior = pyro.param("phi_d_posterior", torch.ones(num_groups, num_demo_tokens),
                                 constraint=constraints.positive)

    with pyro.plate("health_groups", num_groups):
        pyro.sample("phi_x", dist.Dirichlet(phi_x_posterior))
        pyro.sample("phi_d", dist.Dirichlet(phi_d_posterior))

    # Patient-level topic mixtures
    for p in pyro.plate("patients", num_patients):
        theta_posterior = pyro.param(f"theta_posterior_{p}", torch.ones(num_groups),
                                     constraint=constraints.simplex)
        pyro.sample(f"theta_{p}", dist.Delta(theta_posterior, event_dim=1))


optimizer = ClippedAdam({"lr": 0.01, "clip_norm": 10.0})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# Model input parameters
num_groups = 20
num_diag_tokens = len(diag_vocab)
num_demo_tokens = len(context_vocab)

# Training loop
for step in range(10):
    loss = svi.step(patient_records_indexed, context_dict_indexed, num_groups, num_diag_tokens, num_demo_tokens)
    if step % 1 == 0:
        print(f"[Step {step}] ELBO Loss: {loss:.2f}")




[Step 0] ELBO Loss: 17883.94
[Step 1] ELBO Loss: 19891.57
[Step 2] ELBO Loss: 18960.64
[Step 3] ELBO Loss: 17315.06
[Step 4] ELBO Loss: 11546.84
[Step 5] ELBO Loss: 18144.46
[Step 6] ELBO Loss: 15563.53
[Step 7] ELBO Loss: 16912.64
[Step 8] ELBO Loss: 17770.27
[Step 9] ELBO Loss: 18229.55
