In [2]:
import cvxpy as cp
import numpy as np
from scipy.special import rel_entr
import pickle

In [3]:
def solve_Q_new(P: np.ndarray):
  '''
  Compute optimal Q given 3d array P 
  with dimensions coressponding to x1, x2, and y respectively
  '''
  Py = P.sum(axis=0).sum(axis=0)
  Px1 = P.sum(axis=1).sum(axis=1)
  Px2 = P.sum(axis=0).sum(axis=1)
  Px2y = P.sum(axis=0)
  Px1y = P.sum(axis=1)
  Px1y_given_x2 = P/P.sum(axis=(0,2),keepdims=True)
 
  Q = [cp.Variable((P.shape[0], P.shape[1]), nonneg=True) for i in range(P.shape[2])]
  Q_x1x2 = [cp.Variable((P.shape[0], P.shape[1]), nonneg=True) for i in range(P.shape[2])]

  # Constraints that conditional distributions sum to 1
  sum_to_one_Q = cp.sum([cp.sum(q) for q in Q]) == 1

  # Brute force constraints # 
  # [A]: p(x1, y) == q(x1, y) 
  # [B]: p(x2, y) == q(x2, y)

  # Adding [A] constraints
  A_cstrs = []
  for x1 in range(P.shape[0]):
      for y in range(P.shape[2]):
        vars = []
        for x2 in range(P.shape[1]):
          vars.append(Q[y][x1, x2])
        A_cstrs.append(cp.sum(vars) == Px1y[x1,y])
  
  # Adding [B] constraints
  B_cstrs = []
  for x2 in range(P.shape[1]):
      for y in range(P.shape[2]):
        vars = []
        for x1 in range(P.shape[0]):
          vars.append(Q[y][x1, x2])
        B_cstrs.append(cp.sum(vars) == Px2y[x2,y])

  # KL divergence
  Q_pdt_dist_cstrs = [cp.sum(Q) / P.shape[2] == Q_x1x2[i] for i in range(P.shape[2])]


  # objective
  obj = cp.sum([cp.sum(cp.rel_entr(Q[i], Q_x1x2[i])) for i in range(P.shape[2])])
  # print(obj.shape)
  all_constrs = [sum_to_one_Q] + A_cstrs + B_cstrs + Q_pdt_dist_cstrs
  prob = cp.Problem(cp.Minimize(obj), all_constrs)
  prob.solve(verbose=False, max_iter=50000)

  # print(prob.status)
  # print(prob.value)
  # for j in range(P.shape[1]):
  #  print(Q[j].value)

  return np.stack([q.value for q in Q],axis=2)

In [5]:
def gen_binary_data(num_data):
  # 00  0
  # 01  0
  # 10  0
  # 11  1

  x1 = np.random.randint(0, 2, (num_data, 1))
  x2 = np.random.randint(0, 2, (num_data, 1))
  data = {
      'and': (x1, x2, 1 * np.logical_and(x1, x2)),
      'or': (x1, x2, 1 * np.logical_or(x1, x2)),
      'xor': (x1, x2, 1 * np.logical_xor(x1, x2)),
      'unique1': (x1, x2, x1),
      'redundant': (x1, x1, x1),
      'redundant_and_unique1': (np.concatenate([x1, x2], axis=1), x2, 1 * np.logical_and(x1, x2)),
      'redundant_or_unique1': (np.concatenate([x1, x2], axis=1), x2, 1 * np.logical_or(x1, x2)),
      'redundant_xor_unique1': (np.concatenate([x1, x2], axis=1), x2, 1 * np.logical_xor(x1, x2)),
  }
  return data

def convert_data_to_distribution(x1: np.ndarray, x2: np.ndarray, y: np.ndarray):
  assert x1.size == x2.size
  assert x1.size == y.size

  numel = x1.size
  
  x1_discrete, x1_raw_to_discrete = extract_categorical_from_data(x1.squeeze())
  x2_discrete, x2_raw_to_discrete = extract_categorical_from_data(x2.squeeze())
  y_discrete, y_raw_to_discrete = extract_categorical_from_data(y.squeeze())

  joint_distribution = np.zeros((len(x1_raw_to_discrete), len(x2_raw_to_discrete), len(y_raw_to_discrete)))
  for i in range(numel):
    joint_distribution[x1_discrete[i], x2_discrete[i], y_discrete[i]] += 1
  joint_distribution /= np.sum(joint_distribution)

  return joint_distribution, (x1_raw_to_discrete, x2_raw_to_discrete, y_raw_to_discrete)

def extract_categorical_from_data(x):
  supp = set(x)
  raw_to_discrete = dict()
  for i in supp:
    raw_to_discrete[i] = len(raw_to_discrete)
  discrete_data = [raw_to_discrete[x_] for x_ in x]

  return discrete_data, raw_to_discrete 

def MI(P: np.ndarray):
  ''' P has 2 dimensions '''
  margin_1 = P.sum(axis=1)
  margin_2 = P.sum(axis=0)
  outer = np.outer(margin_1, margin_2)

  return np.sum(rel_entr(P, outer))
  # return np.sum(P * np.log(P/outer))

def CoI(P:np.ndarray):
  ''' P has 3 dimensions, in order X1, X2, Y '''
  # MI(Y; X1)
  A = P.sum(axis=1)

  # MI(Y; X2)
  B = P.sum(axis=0)

  # MI(Y; (X1, X2))
  C = P.transpose([2, 0, 1]).reshape((P.shape[2], P.shape[0]*P.shape[1]))

  return MI(A) + MI(B) - MI(C)

def CI(P, Q):
  assert P.shape == Q.shape
  P_ = P.transpose([2, 0, 1]).reshape((P.shape[2], P.shape[0]*P.shape[1]))
  Q_ = Q.transpose([2, 0, 1]).reshape((Q.shape[2], Q.shape[0]*Q.shape[1]))
  return MI(P_) - MI(Q_)

def UI(P, cond_id=0):
  ''' P has 3 dimensions, in order X1, X2, Y 
  We condition on X1 if cond_id = 0, if 1, then X2.
  '''
  sum = 0.

  if cond_id == 0:
    J= P.sum(axis=(1,2)) # marginal of x1
    for i in range(P.shape[0]):
      sum += MI(P[i,:,:]/P[i,:,:].sum()) * J[i]
  elif cond_id == 1:
    J= P.sum(axis=(0,2)) # marginal of x1
    for i in range(P.shape[1]):
      sum += MI(P[:,i,:]/P[:,i,:].sum()) * J[i]
  else:
    assert False

  return sum

def test(P):
  Q = solve_Q_new(P)
  redundancy = CoI(Q)
  print('Redundancy', redundancy)
  unique_1 = UI(Q, cond_id=1)
  print('Unique', unique_1)
  unique_2 = UI(Q, cond_id=0)
  print('Unique', unique_2)
  synergy = CI(P, Q)
  print('Synergy', synergy)
  return {'redundancy':redundancy, 'unique1':unique_1, 'unique2':unique_2, 'synergy':synergy}

In [15]:
P = np.zeros((2,2,2))
P[:,:,0] = np.eye(2) * 0.25
P[:,:,1] = np.array([[0., 0.25], [0.25, 0.]])
test(P)

Redundancy 3.3705091871802423e-09
Unique 1.1187109610023148e-16
Unique 1.1187109610023148e-16
Synergy 0.6931471771894356


{'redundancy': np.float64(3.3705091871802423e-09),
 'unique1': np.float64(1.1187109610023148e-16),
 'unique2': np.float64(1.1187109610023148e-16),
 'synergy': np.float64(0.6931471771894356)}

In [18]:
P

array([[[0.25, 0.  ],
        [0.  , 0.25]],

       [[0.  , 0.25],
        [0.25, 0.  ]]])

In [8]:
data = gen_binary_data(100000)
P, maps = convert_data_to_distribution(*data['xor'])
test(P)

Redundancy 9.98812088620106e-06
Unique 2.2337941696321481e-10
Unique 1.0791333733425263e-05
Synergy 0.6931253926816073


{'redundancy': np.float64(9.98812088620106e-06),
 'unique1': np.float64(2.2337941696321481e-10),
 'unique2': np.float64(1.0791333733425263e-05),
 'synergy': np.float64(0.6931253926816073)}

In [10]:
P.shape

(2, 2, 2)

In [11]:
data = gen_binary_data(1000000)
P, maps = convert_data_to_distribution(*data['and'])
test(P)

Redundancy 0.2150743943150804
Unique 1.169119897466959e-08
Unique 0.0004256690572395108
Synergy 0.3458022230173096


{'redundancy': np.float64(0.2150743943150804),
 'unique1': np.float64(1.169119897466959e-08),
 'unique2': np.float64(0.0004256690572395108),
 'synergy': np.float64(0.3458022230173096)}

In [13]:
P.shape

(2, 2, 2)

In [19]:
P = np.random.uniform(size=(5,4,3))
P = P / np.sum(P)
test(P)

Redundancy 0.005723754879699304
Unique 0.001233640614499781
Unique 0.028903272546529815
Synergy 0.07845121016645373


{'redundancy': np.float64(0.005723754879699304),
 'unique1': np.float64(0.001233640614499781),
 'unique2': np.float64(0.028903272546529815),
 'synergy': np.float64(0.07845121016645373)}

In [21]:
P.shape

(5, 4, 3)

In [25]:
HR_flag = np.array([1, 0, 1, 0])    # 1=HR high, 0=normal
BP_flag = np.array([1, 1, 0, 0])    # 1=BP low, 0=normal
Event   = np.array([1, 0, 0, 0])    # 1=shock event, 0=no event
data = (BP_flag, HR_flag, Event)
P, maps = convert_data_to_distribution(*data)
test(P)

Redundancy 0.2157615011398003
Unique 2.9101687188058173e-08
Unique 2.9101687186610088e-08
Synergy 0.34657358527563387


{'redundancy': np.float64(0.2157615011398003),
 'unique1': np.float64(2.9101687188058173e-08),
 'unique2': np.float64(2.9101687186610088e-08),
 'synergy': np.float64(0.34657358527563387)}

In [26]:
HR_flag = np.array([1, 0, 1, 0])    # 1=HR high, 0=normal
BP_flag = np.array([1, 1, 0, 0])    # 1=BP low, 0=normal
Event   = np.array([1, 0, 0, 0])    # 1=shock event, 0=no event
data = (HR_flag, BP_flag, Event)
P, maps = convert_data_to_distribution(*data)
test(P)

Redundancy 0.2157615011398003
Unique 2.9101687188058173e-08
Unique 2.9101687186610088e-08
Synergy 0.34657358527563387


{'redundancy': np.float64(0.2157615011398003),
 'unique1': np.float64(2.9101687188058173e-08),
 'unique2': np.float64(2.9101687186610088e-08),
 'synergy': np.float64(0.34657358527563387)}

In [1]:
import math, random
from collections import Counter

# Functions to sample synthetic data for each structure
def sample_chain(n=10000, p_noise12=0.1, p_noise2Y=0.1):
    """Chain: X1 -> X2 -> Y (with optional noise in X1->X2 and X2->Y links)."""
    samples = []
    for _ in range(n):
        x1 = random.randint(0, 1)               # X1 (root cause)
        # X2 is X1 (possibly flipped with noise probability p_noise12)
        x2 = x1 if random.random() > p_noise12 else 1 - x1
        # Y is X2 (possibly flipped with noise probability p_noise2Y)
        y  = x2 if random.random() > p_noise2Y else 1 - x2
        samples.append((x1, x2, y))
    return samples

def sample_fork(n=10000, p_noise2X1=0.1, p_noise2Y=0.1):
    """Fork: X2 -> X1 and X2 -> Y (X2 is common cause of X1 and Y)."""
    samples = []
    for _ in range(n):
        x2 = random.randint(0, 1)               # X2 (common cause)
        # X1 and Y are copies of X2 (with noise on each link)
        x1 = x2 if random.random() > p_noise2X1 else 1 - x2
        y  = x2 if random.random() > p_noise2Y else 1 - x2
        samples.append((x1, x2, y))
    return samples

def sample_vstructure(n=10000, mode="xor", p_noise=0.0):
    """
    V-structure: X1 -> Y <- X2 with X1, X2 independent.
    mode "xor": Y = X1 XOR X2 (pure synergy).
    mode "or":  Y = X1 OR X2 (partial synergy).
    """
    samples = []
    for _ in range(n):
        x1 = random.randint(0, 1)
        x2 = random.randint(0, 1)
        # Compute Y based on independent contributions of X1 and X2
        if mode == "xor":
            y_true = x1 ^ x2           # XOR yields Y=1 only if an odd number of inputs are 1
        elif mode == "or":
            y_true = 1 if (x1 == 1 or x2 == 1) else 0  # OR yields Y=1 if any input is 1
        else:
            y_true = x1 ^ x2           # default to XOR
        # Optionally flip Y with noise
        y = y_true if random.random() > p_noise else 1 - y_true
        samples.append((x1, x2, y))
    return samples

# Function to compute mutual informations and PID R, U1, U2, S
def compute_pid(measurements):
    """
    Compute mutual information quantities and PID (R, U1, U2, S) 
    using the minimum mutual information heuristic for redundancy.
    Returns a dict of computed values.
    """
    N = len(measurements)
    # Frequency counts
    count_xyz = Counter(measurements)
    count_x1y = Counter((x1, y) for x1, x2, y in measurements)
    count_x2y = Counter((x2, y) for x1, x2, y in measurements)
    count_x1  = Counter(x1 for x1, x2, y in measurements)
    count_x2  = Counter(x2 for x1, x2, y in measurements)
    count_y   = Counter(y  for x1, x2, y in measurements)
    count_x1x2 = Counter((x1, x2) for x1, x2, y in measurements)

    # Helper to compute entropy from a Counter
    def entropy(counts):
        H = 0.0
        for _, c in counts.items():
            p = c / N
            if p > 0:
                H -= p * math.log(p, 2)
        return H

    # Entropies needed for triple mutual information
    H_x1   = entropy(count_x1)
    H_x2   = entropy(count_x2)
    H_y    = entropy(count_y)
    H_x1x2 = entropy(count_x1x2)
    H_x1y  = entropy(count_x1y)
    H_x2y  = entropy(count_x2y)
    H_x1x2y = entropy(count_xyz)

    # Mutual informations
    I_x1_y   = H_x1 + H_y - H_x1y                 # I(X1;Y)
    I_x2_y   = H_x2 + H_y - H_x2y                 # I(X2;Y)
    I_x1x2_y = H_x1x2 + H_y - H_x1x2y             # I(X1,X2;Y) 
    I_triple = H_x1 + H_x2 + H_y - H_x1x2 - H_x1y - H_x2y + H_x1x2y  # I(X1;X2;Y)

    # Standard PID decomposition (Williams & Beer minimal overlap heuristic)
    R  = min(I_x1_y, I_x2_y)            # assume redundancy = smaller individual MI
    U1 = I_x1_y - R                    # unique info in X1
    U2 = I_x2_y - R                    # unique info in X2
    S  = I_x1x2_y - (R + U1 + U2)      # synergy = total info minus accounted parts

    return {
        'I(X1;Y)': I_x1_y, 
        'I(X2;Y)': I_x2_y, 
        'I(X1,X2;Y)': I_x1x2_y, 
        'I(X1;X2;Y)': I_triple, 
        'R': R, 'U1': U1, 'U2': U2, 'S': S
    }

# Generate example data for each causal structure
random.seed(42)  # for reproducibility
data_chain = sample_chain(n=10000, p_noise12=0.1, p_noise2Y=0.1)
data_fork  = sample_fork(n=10000, p_noise2X1=0.1, p_noise2Y=0.1)
data_v_xor = sample_vstructure(n=10000, mode="xor", p_noise=0.0)
data_v_or  = sample_vstructure(n=10000, mode="or", p_noise=0.0)

# Compute PID values for each
pid_chain = compute_pid(data_chain)
pid_fork  = compute_pid(data_fork)
pid_v_xor = compute_pid(data_v_xor)
pid_v_or  = compute_pid(data_v_or)

# Print results for comparison
print("Chain structure PID:", pid_chain)
print("Fork structure PID: ", pid_fork)
print("V-structure (XOR) PID:", pid_v_xor)
print("V-structure (OR) PID: ", pid_v_or)

Chain structure PID: {'I(X1;Y)': 0.3093620837512343, 'I(X2;Y)': 0.5275337297323008, 'I(X1,X2;Y)': 0.5275630150785158, 'I(X1;X2;Y)': 0.30933279840501937, 'R': 0.3093620837512343, 'U1': 0.0, 'U2': 0.21817164598106653, 'S': 2.9285346214935615e-05}
Fork structure PID:  {'I(X1;Y)': 0.323052832152251, 'I(X2;Y)': 0.5338501504233331, 'I(X1,X2;Y)': 0.5342164420634503, 'I(X1;X2;Y)': 0.32268654051213375, 'R': 0.323052832152251, 'U1': 0.0, 'U2': 0.21079731827108206, 'S': 0.00036629164011725557}
V-structure (XOR) PID: {'I(X1;Y)': 1.3968909156059084e-05, 'I(X2;Y)': 3.5760985284127855e-09, 'I(X1,X2;Y)': 0.9998153271549208, 'I(X1;X2;Y)': -0.9998013546696662, 'R': 3.5760985284127855e-09, 'U1': 1.3965333057530671e-05, 'U2': 0.0, 'S': 0.9998013582457648}
V-structure (OR) PID:  {'I(X1;Y)': 0.3037191732646174, 'I(X2;Y)': 0.31425939867150343, 'I(X1,X2;Y)': 0.8125436340078087, 'I(X1;X2;Y)': -0.19456506207168767, 'R': 0.3037191732646174, 'U1': 0.0, 'U2': 0.010540225406886039, 'S': 0.4982842353363053}


In [2]:
def adjust_pid_for_causality(pid_values):
    """Adjust R, S based on the sign of triple interaction information."""
    I3 = pid_values['I(X1;X2;Y)']
    R, S = pid_values['R'], pid_values['S']
    # If positive co-info (redundancy case), increase R (at least I3) and reduce S
    if I3 > 1e-6:
        R_adj = max(R, I3)      # ensure redundancy covers the overlap indicated by I3
        S_adj = 0.0            # minimize synergy
    # If negative co-info (synergy case), increase S (at least |I3|) and reduce R
    elif I3 < -1e-6:
        R_adj = 0.0
        S_adj = max(S, -I3)    # ensure synergy at least the magnitude of negative I3
    else:
        # If I3 ~ 0, structure is ambiguous or purely additive – keep as is
        R_adj, S_adj = R, S
    # Recompute unique infos with adjusted R (keeping total info same)
    I_x1y, I_x2y, I_tot = pid_values['I(X1;Y)'], pid_values['I(X2;Y)'], pid_values['I(X1,X2;Y)']
    U1_adj = max(0.0, I_x1y - R_adj)  # cannot be negative
    U2_adj = max(0.0, I_x2y - R_adj)
    # (If R was lowered significantly, some info might become unique.)
    return {'R': R_adj, 'U1': U1_adj, 'U2': U2_adj, 'S': S_adj}

# Adjust PID for each scenario
print("Adjusted Chain PID:", adjust_pid_for_causality(pid_chain))
print("Adjusted Fork PID: ", adjust_pid_for_causality(pid_fork))
print("Adjusted V-XOR PID:", adjust_pid_for_causality(pid_v_xor))
print("Adjusted V-OR PID: ", adjust_pid_for_causality(pid_v_or))

Adjusted Chain PID: {'R': 0.3093620837512343, 'U1': 0.0, 'U2': 0.21817164598106653, 'S': 0.0}
Adjusted Fork PID:  {'R': 0.323052832152251, 'U1': 0.0, 'U2': 0.21079731827108206, 'S': 0.0}
Adjusted V-XOR PID: {'R': 0.0, 'U1': 1.3968909156059084e-05, 'U2': 3.5760985284127855e-09, 'S': 0.9998013582457648}
Adjusted V-OR PID:  {'R': 0.0, 'U1': 0.3037191732646174, 'U2': 0.31425939867150343, 'S': 0.4982842353363053}


In [1]:
import numpy as np
from sklearn.cluster import KMeans
# Assume BATCH estimator is available from the document's codebase
# Replace with actual implementation if accessible
class BATCH:
    def estimate(self, X1, X2, Y):
        # Placeholder: returns standard PID values (R, U1, U2, S)
        # In practice, use the neural network-based estimator from the document
        I_X1_Y = self.mutual_info(X1, Y)
        I_X2_Y = self.mutual_info(X2, Y)
        I_X1_X2_Y = self.mutual_info(np.column_stack((X1, X2)), Y)
        R = min(I_X1_Y, I_X2_Y)  # Simplified approximation
        U1 = max(0, I_X1_Y - R)
        U2 = max(0, I_X2_Y - R)
        S = I_X1_X2_Y - I_X1_Y - I_X2_Y + R
        return {'R': R, 'U1': U1, 'U2': U2, 'S': S}
    
    def mutual_info(self, X, Y):
        # Placeholder for mutual information estimation
        # Use histogram-based or neural estimation in practice
        return np.random.uniform(0.1, 1.0)  # Dummy value

# Data generation functions
def generate_chain(n_samples=1000):
    X1 = np.random.normal(0, 1, n_samples)
    X2 = X1 + np.random.normal(0, 0.1, n_samples)  # X2 depends on X1
    Y = X2 + np.random.normal(0, 0.1, n_samples)   # Y depends on X2
    return X1, X2, Y

def generate_fork(n_samples=1000):
    Z = np.random.normal(0, 1, n_samples)
    X1 = Z + np.random.normal(0, 0.1, n_samples)
    X2 = Z + np.random.normal(0, 0.1, n_samples)
    Y = Z + np.random.normal(0, 0.1, n_samples)
    return X1, X2, Y

def generate_v_structure(n_samples=1000):
    X1 = np.random.normal(0, 1, n_samples)
    X2 = np.random.normal(0, 1, n_samples)
    Y = X1 * X2 + np.random.normal(0, 0.1, n_samples)  # Synergistic relationship
    return X1, X2, Y

# Function to estimate mutual information (simplified)
def estimate_mi(X, Y, bins=10):
    # Discretize for simplicity; use kernel density or neural methods in practice
    X_discrete = np.digitize(X, np.linspace(X.min(), X.max(), bins))
    Y_discrete = np.digitize(Y, np.linspace(Y.min(), Y.max(), bins))
    joint = np.histogram2d(X_discrete, Y_discrete, bins=bins)[0] + 1e-10
    px = joint.sum(axis=1) / joint.sum()
    py = joint.sum(axis=0) / joint.sum()
    joint /= joint.sum()
    mi = np.sum(joint * np.log2(joint / (px[:, None] * py[None, :])))
    return max(0, mi)

# Compute new PID quantities
def compute_new_pid(X1, X2, Y, pid):
    epsilon = 0.01
    # Estimate conditional mutual informations
    I_X1_Y_given_X2 = estimate_mi(X1, Y) - estimate_mi(X2, Y)  # Approximation
    I_X2_Y_given_X1 = estimate_mi(X2, Y) - estimate_mi(X1, Y)  # Approximation
    I_X1_X2 = estimate_mi(X1, X2)
    
    # Adjust for negative values
    I_X1_Y_given_X2 = max(0, I_X1_Y_given_X2)
    I_X2_Y_given_X1 = max(0, I_X2_Y_given_X1)
    
    R_prime = pid['R']
    U1_prime = pid['U1'] * (I_X1_Y_given_X2 / max(I_X1_Y_given_X2, epsilon))
    U2_prime = pid['U2'] * (I_X2_Y_given_X1 / max(I_X2_Y_given_X1, epsilon))
    S_prime = pid['S'] * (1 - I_X1_X2 / max(I_X1_X2, epsilon))
    
    return {'R': R_prime, 'U1': U1_prime, 'U2': U2_prime, 'S': S_prime}

# Main execution
structures = {
    'Chain': generate_chain(),
    'Fork': generate_fork(),
    'V-Structure': generate_v_structure()
}

estimator = BATCH()

for name, (X1, X2, Y) in structures.items():
    # Discretize Y for BATCH (assumes discrete labels)
    Y_discrete = KMeans(n_clusters=10).fit(Y.reshape(-1, 1)).labels_
    
    # Compute standard PID
    pid_standard = estimator.estimate(X1, X2, Y_discrete)
    print(f"\nStandard PID for {name}:")
    print(pid_standard)
    
    # Compute new PID
    pid_new = compute_new_pid(X1, X2, Y, pid_standard)
    print(f"New PID for {name}:")
    print(pid_new)



Standard PID for Chain:
{'R': 0.2604879092178548, 'U1': 0.4740577543299681, 'U2': 0, 'S': -0.2537549731827138}
New PID for Chain:
{'R': 0.2604879092178548, 'U1': 0.0, 'U2': np.float64(0.0), 'S': np.float64(-0.0)}

Standard PID for Fork:
{'R': 0.15018877034547537, 'U1': 0.8274647506866232, 'U2': 0, 'S': -0.8716422848647978}
New PID for Fork:
{'R': 0.15018877034547537, 'U1': np.float64(0.6668949196706878), 'U2': 0.0, 'S': np.float64(-0.0)}

Standard PID for V-Structure:
{'R': 0.3394560533387343, 'U1': 0, 'U2': 0.24958022751542108, 'S': 0.2623282679730138}
New PID for V-Structure:
{'R': 0.3394560533387343, 'U1': 0.0, 'U2': np.float64(0.24958022751542108), 'S': np.float64(0.0)}


In [3]:
from graphviz import Digraph

dot = Digraph(comment='Temporal RUS-Guided MoE Architecture for Multimodal EHR', format='png')
dot.attr(rankdir='TB', fontsize="10")

# Input modalities
dot.node('A', 'Multimodal EHR Data', shape='box', style='filled', color='lightblue')
dot.node('B1', 'Vital Signs', shape='box')
dot.node('B2', 'Clinical Notes', shape='box')
dot.node('B3', 'Medical Imaging', shape='box')
dot.node('B4', 'ECG', shape='box')

dot.edge('A', 'B1')
dot.edge('A', 'B2')
dot.edge('A', 'B3')
dot.edge('A', 'B4')

# Modality-specific encoders
dot.node('C1', 'Encoder: CNN/RNN\n(Time-Series)', shape='box')
dot.node('C2', 'Encoder: Transformer/LSTM\n(Text)', shape='box')
dot.node('C3', 'Encoder: CNN/ViT\n(Imaging)', shape='box')
dot.node('C4', 'Encoder: 1D-CNN/Recurrent\n(ECG)', shape='box')

dot.edge('B1', 'C1')
dot.edge('B2', 'C2')
dot.edge('B3', 'C3')
dot.edge('B4', 'C4')

# Fusing modality representations
dot.node('D', 'Fused Representations\nper Time Step', shape='box', style='filled', color='lightyellow')
dot.edge('C1', 'D')
dot.edge('C2', 'D')
dot.edge('C3', 'D')
dot.edge('C4', 'D')

# RUS Estimation module and its outputs
dot.node('E', 'RUS Estimation\nModule', shape='box', style='filled', color='lightgrey')
dot.node('F', '(R, U, S) Scores\n(Redundancy,\nUniqueness, Synergy)', shape='box', style='filled', color='lightgrey')
dot.edge('D', 'E')
dot.edge('E', 'F')

# RUS-Guided Gating module
dot.node('G', 'RUS-Guided Gating\nModule', shape='box', style='filled', color='orange')
dot.edge('D', 'G')
dot.edge('F', 'G')

# Top-K Expert Selection
dot.node('H', 'Top-K Expert\nSelection (Sparse Routing)', shape='box', style='filled', color='palegreen')
dot.edge('G', 'H')

# Expert Modules
dot.node('I1', 'Redundancy Experts', shape='box')
dot.node('I2', 'Uniqueness Experts', shape='box')
dot.node('I3', 'Synergy Experts', shape='box')
dot.node('I4', 'Universal Experts', shape='box')

dot.edge('H', 'I1')
dot.edge('H', 'I2')
dot.edge('H', 'I3')
dot.edge('H', 'I4')

# Aggregation of Expert Outputs
dot.node('J', 'Expert Output\nAggregation\n(Weighted Fusion)', shape='box', style='filled', color='lightpink')
dot.edge('I1', 'J')
dot.edge('I2', 'J')
dot.edge('I3', 'J')
dot.edge('I4', 'J')

# Temporal Integration Module
dot.node('K', 'Temporal Integration Module\n(Cross-Time Attention,\nRecurrent Fusion)', shape='box', style='filled', color='wheat')
dot.edge('J', 'K')

# Prediction Head and Output
dot.node('L', 'Prediction Head', shape='box', style='filled', color='lightblue')
dot.edge('K', 'L')
dot.node('M', 'Output Prediction\n(Diagnosis/Outcome)', shape='box', style='filled', color='lightblue')
dot.edge('L', 'M')

# Loss Functions (auxiliary connections)
dot.node('N', 'Main Task Loss\n(Cross-Entropy/MSE)', shape='note', color='grey')
dot.edge('M', 'N', style='dashed')

dot.node('O', 'Auxiliary Losses\n(Uniqueness, Redundancy, Synergy)', shape='note', color='grey')
dot.edge('G', 'O', style='dashed')

dot.node('P', 'Temporal Consistency Loss', shape='note', color='grey')
dot.edge('K', 'P', style='dashed')

# Optional: Batch Priority Routing
dot.node('Q', 'Batch Priority Routing\n(Optional)', shape='box', style='dotted', color='darkgreen')
dot.edge('G', 'Q', style='dotted')
dot.edge('Q', 'H', style='dotted')

# Render and save diagram to a file
dot.render('Temporal_RUS_Guided_MoE_Architecture', view=True)

'Temporal_RUS_Guided_MoE_Architecture.png'

Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)


In [2]:
# Import necessary libraries from the 'diagrams' package
# Note: You might need to install it: pip install diagrams
# Also requires Graphviz to be installed: https://graphviz.org/download/
from diagrams import Diagram, Cluster, Node, Edge

# Define the diagram context
# filename specifies the output file name (default is png)
# show=False prevents the diagram from opening automatically
with Diagram("TRUSMoEModel_LargeScale Architecture", show=False, filename="trus_moe_large_scale_arch", direction="TB"):

    # --- Input Nodes ---
    token_emb = Node("Input Token Embeddings\n(B, M, T, E_in)")
    rus_values = Node("RUS Values\n(U, R, S)")

    # --- Processing Blocks ---
    with Cluster("1. Input Processing"):
        input_proj = Node("Input Projection\n(Linear + Scale)")
        flatten = Node("Flatten\n(B, M*T, d_model)")
        pos_enc = Node("Positional Encoding")
        input_processing_group = [input_proj, flatten, pos_enc] # Group for layout

    with Cluster("2. Encoder Stack (N Layers)"):
        # Represent the first block (can be either type)
        with Cluster("Layer 1 (e.g., Transformer Block)"):
            tf_mhsa1 = Node("MHSA")
            tf_addnorm1_1 = Node("Add & Norm")
            tf_ffn1 = Node("FFN")
            tf_addnorm1_2 = Node("Add & Norm")
            transformer_block_1 = [tf_mhsa1, tf_addnorm1_1, tf_ffn1, tf_addnorm1_2]

        # Represent the second block (can be the other type)
        with Cluster("Layer 2 (e.g., TRUS-MoE Block)"):
            moe_mhsa2 = Node("MHSA")
            moe_addnorm2_1 = Node("Add & Norm")
            moe_layer2 = Node("TemporalRUSMoELayer\n(RUS-Aware Router + Experts)")
            moe_addnorm2_2 = Node("Add & Norm")
            moe_block_2 = [moe_mhsa2, moe_addnorm2_1, moe_layer2, moe_addnorm2_2]

        # Indicate repetition
        stack_ellipsis = Node("...", shape="plaintext")


    with Cluster("3. Output Processing"):
        final_norm = Node("Final LayerNorm")
        aggregate = Node("Aggregation\n(e.g., Mean Pooling over Seq)")
        output_proj = Node("Output Projection\n(Linear)")
        output_processing_group = [final_norm, aggregate, output_proj]

    # --- Output Node ---
    final_logits = Node("Final Logits\n(B, num_classes)")
    aux_outputs = Node("MoE Aux Outputs\n(List from MoE Layers)") # Represent aux outputs

    # --- Define Data Flow ---
    # Input Processing
    token_emb >> input_proj >> flatten >> pos_enc

    # Into Encoder Stack
    pos_enc >> tf_mhsa1 # Connect to the first block's input

    # Flow through Transformer Block 1
    tf_mhsa1 >> tf_addnorm1_1 >> tf_ffn1 >> tf_addnorm1_2

    # Flow from Block 1 to Block 2
    tf_addnorm1_2 >> moe_mhsa2

    # Flow through TRUS-MoE Block 2
    moe_mhsa2 >> moe_addnorm2_1 >> moe_layer2 >> moe_addnorm2_2
    # Show RUS values feeding into the MoE layer specifically
    rus_values >> Edge(color="darkgreen", style="dashed") >> moe_layer2
    # Show Aux outputs coming from MoE layer
    moe_layer2 >> Edge(color="blue", style="dashed") >> aux_outputs

    # Ellipsis indicating more layers
    moe_addnorm2_2 >> stack_ellipsis

    # Out of Stack to Output Processing
    stack_ellipsis >> final_norm # Connect from ellipsis to final processing

    # Output Processing Flow
    final_norm >> aggregate >> output_proj >> final_logits

