<a href="https://colab.research.google.com/github/adimunot21/FPD-OT/blob/main/FPD_OT_Codebase.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Scenario 1: Univariate Supply and Univariate Demand (1D1D)
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import poisson

# ---------------------------
# 1. Define Distributions
# ---------------------------
n = 50  # Number of discrete points (bins) for supply/demand quantities
x_quantities = np.arange(n)  # Represents possible quantities k or j

mu_supply = 15  # Mean for the supply distribution
mu2_supply = 30  # Mean for the demand distribution
p_supply = 0.4

mu_demand = 20  # Mean for the supply distribution
mu2_demand = 40  # Mean for the demand distribution
p_demand = 0.7

# Compute the Poisson PMF for supply and demand
p_supply = p_supply*poisson.pmf(x_quantities, mu_supply) + (1-p_supply)*poisson.pmf(x_quantities, mu2_supply)   # P(Supply = k)

p_demand = p_demand*poisson.pmf(x_quantities, mu_demand) + (1-p_demand)*poisson.pmf(x_quantities, mu2_demand)   # P(Demand = j)

# Renormalize in case the truncation on [0, n-1] omits some mass
p_supply /= p_supply.sum()
p_demand /= p_demand.sum()

# ---------------------------
# 2. Build the Cost Matrix and Regularization Kernel
# ---------------------------
# Define cost parameters
cost_oversupply_co = 1
cost_undersupply_cu = 5
supply_threshold = 25
cost_above_threshold = 5

# Create a cost matrix M[k, j] where I holds supply index k and J holds demand index j
I, J = np.meshgrid(x_quantities, x_quantities, indexing='ij')  # 'ij' indexing for consistency

# Calculate mismatch cost component
mismatch_cost = cost_oversupply_co * np.maximum(0, I - J) + cost_undersupply_cu * np.maximum(0, J - I)

# Calculate supply threshold cost component (depends only on supply I=k)
threshold_cost = cost_above_threshold * np.maximum(0, I - supply_threshold)

# Final cost matrix
M = mismatch_cost + threshold_cost

# Scale cost matrix to [0, 1] for visualization/kernel stability
M_scaled = M / M.max()

# Regularization parameter for the Sinkhorn algorithm
epsilon = 0.1

# Initialize the regularization kernel; in standard notation, this is K
phi = np.ones((n, n)) / (n * n)
pi_I = phi * np.exp(-M_scaled / epsilon)
pi_I /= np.sum(pi_I)  # Normalizing the kernel

# ---------------------------
# 3. Sinkhorn Iterations
# ---------------------------
u = np.ones(n)  # Scaling factors for supply margin
v = np.ones(n)  # Scaling factors for demand margin

numIter = 1000
tol = 1e-9

for _ in range(numIter):
    u_prev = u.copy()
    # Update v to match the demand margin
    v = p_demand / (pi_I.T @ u)
    # Update u to match the supply margin
    u = p_supply / (pi_I @ v)

    if np.linalg.norm(u - u_prev) < tol:
        break

# Compute the OT coupling (joint probability plan P(k, j))
P_ot = np.diag(u) @ pi_I @ np.diag(v)

# ---------------------------
# 4. Compute and Print the Final Transport Cost
# ---------------------------
# The final cost is the expected cost under the coupling, computed with the original (unscaled) cost matrix M.
final_cost = np.sum(P_ot * M)
print(f"Final Expected Transport Cost (E[C(k,j)]): {final_cost:.4f}")

# ---------------------------
# 5. Plot Heatmaps for the Cost Metric and OT Coupling
# ---------------------------
fig, axs = plt.subplots(1, 2, figsize=(14, 5))

# Cost Matrix (Scaled Cost Metric)
im0 = axs[0].imshow(M_scaled.T, cmap='viridis', origin='lower', aspect='auto')
axs[0].set_title("Cost Metric")
axs[0].set_ylabel("Demand Index (y)")
axs[0].set_xlabel("Supply Index (x)")
fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)

# OT Coupling Heatmap
im1 = axs[1].imshow(P_ot.T, cmap='plasma', origin='lower', aspect='auto')
axs[1].set_title("OT Coupling")
axs[1].set_ylabel("Demand Index (y)")
axs[1].set_xlabel("Supply Index (x)")
fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)

plt.suptitle("Cost Metric & OT Coupling")
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

# ---------------------------
# 6. Plot Histograms for the Distributions
# ---------------------------
fig2, axs2 = plt.subplots(1, 2, figsize=(12, 5))

axs2[0].bar(x_quantities, p_supply, color='blue', alpha=0.7)
axs2[0].set_title("Supply Distribution")
axs2[0].set_xlabel("Supply Quantity")
axs2[0].set_ylabel("Probability")

axs2[1].bar(x_quantities, p_demand, color='red', alpha=0.7)
axs2[1].set_title("Demand Distribution")
axs2[1].set_xlabel("Demand Quantity")
axs2[1].set_ylabel("Probability")

plt.tight_layout()
plt.show()

# ---------------------------
# 7. Slice the OT Plan for a Specific Demand Index and Plot the Histogram
# ---------------------------
# Choose a specific demand index 'j' to analyze the optimal supply distribution P(k | j)
demand_index_j = 25  # Adjust as desired (must be between 0 and n-1)

if 0 <= demand_index_j < n:
    # P_ot[k, j] gives the joint probability P(Supply=k, Demand=j)
    # Compute the conditional distribution: P(k | j) = P(k, j) / sum_k P(k, j)
    conditional_supply_dist = P_ot[:, demand_index_j] / p_demand[demand_index_j]
    # Re-normalize numerically for safety
    conditional_supply_dist /= conditional_supply_dist.sum()

    # Compute Expected Value, Mode, and Median for the conditional distribution
    expected_value = np.sum(x_quantities * conditional_supply_dist)
    mode = np.argmax(conditional_supply_dist)
    cumulative = np.cumsum(conditional_supply_dist)
    median_idx = np.searchsorted(cumulative, 0.5)

    # Print the statistics
    print(f"Expected Supply (E[k | j={demand_index_j}]): {expected_value:.2f}")
    print(f"Mode of Supply for j={demand_index_j}: {mode}")
    print(f"Median of Supply for j={demand_index_j}: {median_idx}")

    # Plot the histogram for the conditional supply distribution
    plt.figure(figsize=(7, 5))
    plt.bar(x_quantities, conditional_supply_dist, color='green', alpha=0.7)
    # plt.title(f"Optimal Supply Distribution given Demand y={demand_index_j}")
    # plt.xlabel("Supply Quantity")
    # plt.ylabel("Conditional Probability P(x | y)")
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()
else:
    print(f"Error: demand_index_j={demand_index_j} is out of bounds (0 to {n-1}).")

# ---------------------------
# NEW: 7. Verify Marginals
# ---------------------------
induced_p_supply = P_ot.sum(axis=1)
induced_p_demand = P_ot.sum(axis=0)

fig3, axs3 = plt.subplots(1, 2, figsize=(14, 5))

axs3[0].bar(x_quantities, p_supply, color='blue', alpha=0.6, label='Original Supply')
axs3[0].plot(x_quantities, induced_p_supply, 'ro-', label='Induced Supply', markersize=4, linewidth=1)
axs3[0].set_title("Original vs. Induced Supply Marginals")
axs3[0].set_xlabel("Supply Quantity")
axs3[0].set_ylabel("Probability")
axs3[0].legend()
axs3[0].grid(True, linestyle='--', alpha=0.6)

axs3[1].bar(x_quantities, p_demand, color='red', alpha=0.6, label='Original Demand')
axs3[1].plot(x_quantities, induced_p_demand, 'go-', label='Induced Demand', markersize=4, linewidth=1)
axs3[1].set_title("Original vs. Induced Demand Marginals")
axs3[1].set_xlabel("Demand Quantity")
axs3[1].set_ylabel("Probability")
axs3[1].legend()
axs3[1].grid(True, linestyle='--', alpha=0.6)

plt.suptitle("Verification of Marginals")
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

print(f"Norm of difference (Supply): {np.linalg.norm(p_supply - induced_p_supply):.2e}")
print(f"Norm of difference (Demand): {np.linalg.norm(p_demand - induced_p_demand):.2e}")

# ---------------------------
# NEW: 8. P(undersupply) and P(oversupply)
# ---------------------------
p_undersupply = np.sum(P_ot[I < J])
p_oversupply = np.sum(P_ot[I > J])
p_match = np.sum(P_ot[I == J])

print(f"\nProbability of Undersupply (Supply < Demand): {p_undersupply:.4f}")
print(f"Probability of Oversupply (Supply > Demand): {p_oversupply:.4f}")
print(f"Probability of Match (Supply = Demand): {p_match:.4f}")
print(f"Total Probability Check: {p_undersupply + p_oversupply + p_match:.4f}") # Should be 1

plt.figure(figsize=(6, 4))
plt.bar(['Undersupply', 'Match', 'Oversupply'],
        [p_undersupply, p_match, p_oversupply],
        color=['orange', 'green', 'purple'])
plt.title("Probabilities of Supply vs. Demand Scenarios")
plt.ylabel("Probability")
plt.ylim(0, 1)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

# ---------------------------
# 9. Conditional Supply Distributions P(k | j) - HISTOGRAMS
# ---------------------------
demand_indices_j = np.linspace(5, n - 5, 5, dtype=int)
fig_supply, axs_supply = plt.subplots(2, 3, figsize=(15, 8))
axs_supply = axs_supply.flatten() # Flatten to 1D array for easy iteration

plot_count = 0
for j_val in demand_indices_j:
    if p_demand[j_val] > 1e-9:
        conditional_supply = P_ot[:, j_val] / p_demand[j_val]
        conditional_supply /= conditional_supply.sum() # Ensure normalization

        ax = axs_supply[plot_count]
        ax.bar(x_quantities, conditional_supply, color='green', alpha=0.7)
        ax.set_title(' ')
        ax.set_xlabel(" ")
        ax.set_ylabel(" ")
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.axvline(j_val, color='red', linestyle=':', label=f'Demand = {j_val}') # Add a line for demand
        # ax.legend()
        plot_count += 1

# Remove any unused subplots
for i in range(plot_count, len(axs_supply)):
    fig_supply.delaxes(axs_supply[i])

fig_supply.suptitle("Optimal Supply Distributions P(k | j) for Various Demands", fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()


# ---------------------------
# 10. Conditional Demand Distributions P(j | k) - HISTOGRAMS
# ---------------------------
supply_indices_k = np.linspace(5, n - 5, 5, dtype=int)
fig_demand, axs_demand = plt.subplots(2, 3, figsize=(15, 8))
axs_demand = axs_demand.flatten() # Flatten to 1D array for easy iteration

plot_count = 0
for k_val in supply_indices_k:
    if p_supply[k_val] > 1e-9:
        conditional_demand = P_ot[k_val, :] / p_supply[k_val]
        conditional_demand /= conditional_demand.sum() # Ensure normalization

        ax = axs_demand[plot_count]
        ax.bar(x_quantities, conditional_demand, color='purple', alpha=0.7)
        ax.set_title(' ')
        ax.set_xlabel(" ")
        ax.set_ylabel(" ")
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.axvline(k_val, color='blue', linestyle=':', label=f'Supply = {k_val}') # Add a line for supply
        # ax.legend()
        plot_count += 1

# Remove any unused subplots
for i in range(plot_count, len(axs_demand)):
    fig_demand.delaxes(axs_demand[i])

fig_demand.suptitle("Optimal Demand Distributions P(y | x) for Various Supplies", fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()


# ---------------------------
# NEW: 11. Ideal Plan Plot (Diagonal)
# ---------------------------
fig, axs = plt.subplots(1, 2, figsize=(14, 5))

# Cost Matrix (Scaled Cost Metric)
im0 = axs[0].imshow(M_scaled.T, cmap='viridis', origin='lower', aspect='equal')
axs[0].set_title("Cost Metric")
axs[0].set_xlabel("Supply (y)")
axs[0].set_ylabel("Demand (x)")
fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)

# OT Coupling Heatmap
im1 = axs[1].imshow(pi_I.T, cmap='plasma', origin='lower', aspect='equal')
axs[1].set_title("Ideal Plan")
axs[1].set_ylabel("Demand Index (y)")
axs[1].set_xlabel("Supply Index (x)")
fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)

plt.suptitle("Cost Metric & OT Coupling")
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

fig = plt.figure(figsize=(9, 9))
from matplotlib.gridspec import GridSpec
# Define the grid layout
# We want the heatmap to be larger, and the histograms smaller.
# width_ratios controls the width of the left (hist) and right (heatmap) columns.
# height_ratios controls the height of the top (heatmap) and bottom (hist) rows.
gs = GridSpec(2, 2, width_ratios=(1, 4), height_ratios=(4, 1), wspace=0.05, hspace=0.05)

# Add subplots to the grid
ax_heatmap = fig.add_subplot(gs[0, 1])
ax_hist_supply = fig.add_subplot(gs[1, 1], sharex=ax_heatmap)
ax_hist_demand = fig.add_subplot(gs[0, 0], sharey=ax_heatmap)

# Plot the OT Coupling Heatmap (Top-Right)
im = ax_heatmap.imshow(P_ot.T, cmap='plasma', origin='lower', aspect='auto')
ax_heatmap.set_title("Optimal Joint Probability Distribution")
# Hide x and y tick labels on the heatmap itself
plt.setp(ax_heatmap.get_xticklabels(), visible=False)
plt.setp(ax_heatmap.get_yticklabels(), visible=False)

# Add a color bar for the heatmap
cbar = fig.colorbar(im, ax=ax_heatmap, fraction=0.046, pad=0.04)
cbar.set_label('Joint Probability P(x, y)')

# Plot the Supply Histogram (Bottom-Right)
ax_hist_supply.bar(x_quantities, p_supply, color='blue', alpha=0.7)
ax_hist_supply.invert_yaxis() # Make bars point downwards
ax_hist_supply.set_xlabel("Supply Quantity (x)")
ax_hist_supply.set_ylabel("P(x)", labelpad=10) # Add a label
# Hide y-axis ticks and labels
plt.setp(ax_hist_supply.get_yticklabels(), visible=False)
ax_hist_supply.spines['top'].set_visible(False)
ax_hist_supply.spines['right'].set_visible(False)
ax_hist_supply.spines['left'].set_visible(False)

# Plot the Demand Histogram (Top-Left)
ax_hist_demand.barh(x_quantities, p_demand, color='red', alpha=0.7)
ax_hist_demand.invert_xaxis() # Make bars point leftwards
ax_hist_demand.set_ylabel("Demand Quantity (y)")
ax_hist_demand.set_xlabel("P(y)", labelpad=10) # Add a label
# Hide x-axis ticks and labels
plt.setp(ax_hist_demand.get_xticklabels(), visible=False)
ax_hist_demand.spines['top'].set_visible(False)
ax_hist_demand.spines['right'].set_visible(False)
ax_hist_demand.spines['bottom'].set_visible(False)


# Ensure plot limits match
ax_heatmap.set_xlim(-0.5, n - 0.5)
ax_heatmap.set_ylim(-0.5, n - 0.5)

plt.show()

In [None]:
#@title Scenario 2: Univariate Supply and Bivariate Demand (1D2D)


import numpy as np
import matplotlib.pyplot as plt
from scipy.special import factorial
from scipy.stats import poisson

supply_n = 51
demand_max = supply_n // 2
demand_n = demand_max + 1
D=25
np.random.seed(42)

S = np.arange(supply_n)
Y = np.arange(demand_n)


mu1, mu2 = 15, 30
mix = 0.3
p1 = poisson.pmf(S, mu1)
p2 = poisson.pmf(S, mu2)
p_supply = mix * p1 + (1-mix) * p2
p_supply /= p_supply.sum()


def bivariate_poisson(l1, l2, l12, max_val):
    n = max_val + 1
    p = np.zeros((n, n))
    coef = np.exp(-(l1 + l2 + l12))
    for i in range(n):
        for j in range(n):
            k_max = min(i, j)
            p[i, j] = coef * sum(
                (l1**(i - k) / factorial(i - k)) *
                (l2**(j - k) / factorial(j - k)) *
                (l12**k / factorial(k))
                for k in range(k_max + 1)
            )
    return p / p.sum()


lambda1a, lambda2a, lambda12a =
lambda1b, lambda2b, lambda12b = 10, 20, 3
mix_d = 0.5
p_Y_a = bivariate_poisson(lambda1a, lambda2a, lambda12a, D)
p_Y_b = bivariate_poisson(lambda1b, lambda2b, lambda12b, D)
p_Y   = mix_d * p_Y_a + (1-mix_d) * p_Y_b
p_Y  /= p_Y.sum()
plt.figure(figsize=(6,5))
plt.imshow(p_Y, origin='lower', aspect='auto')
plt.colorbar(label='P(y1,y2)')
plt.title("2D Demand Joint PMF")
plt.xlabel("Demand Group 2 (y2)")
plt.ylabel("Demand Group 1 (y1)")
plt.show()

plt.figure(figsize=(6,4))
plt.bar(S, p_supply, color='C1', alpha=0.8)
plt.title("Supply")
plt.xlabel("Supply x")
plt.ylabel("P(x)")
plt.show()


co, cu = 2.0, 10.0
threshold, ca = 20, 5.0
Y1, Y2 = np.meshgrid(Y, Y, indexing='ij')
J = Y1 + Y2  # total demand
C = np.zeros((supply_n, demand_n, demand_n))
for s in range(supply_n):
    mismatch = co*np.maximum(0, s - J) + cu*np.maximum(0, J - s)
    overflow = ca*np.maximum(0, s - threshold)
    C[s] = mismatch + overflow
C_scaled = C / C.max()

epsilon = 0.1
C_flat = C_scaled.reshape(supply_n, -1)
phi = np.ones_like(C_flat) / (supply_n * demand_n**2)
pi_I = phi * np.exp(-C_flat/epsilon)
pi_I /= pi_I.sum()

p_Y_flat = p_Y.flatten()
u = np.ones(supply_n)
v = np.ones(demand_n**2)
for _ in range(2000):
    v = p_Y_flat / (pi_I.T @ u)
    u_prev = u.copy()
    u = p_supply / (pi_I @ v)
    if np.linalg.norm(u - u_prev) < 1e-9:
        break

P_flat = np.diag(u) @ pi_I @ np.diag(v)
P = P_flat.reshape(supply_n, demand_n, demand_n)

s_levels = [0, 10, 15, 20, 25, 30, 35, 40, 45, 49]

fig, axs = plt.subplots(2,5,figsize=(16,6), constrained_layout=True)
for ax, s in zip(axs.flatten(), s_levels):
    im = ax.imshow(C_scaled[s], origin='lower', aspect='auto', cmap='magma')
    ax.set_title(f"Cost | x={s}")
    ax.set_xlabel("y2")
    ax.set_ylabel("y1")
cbar = fig.colorbar(im, ax=axs.ravel().tolist(), fraction=0.02, pad=0.01)
cbar.set_label("Scaled Cost")
plt.show()

fig, axs = plt.subplots(2,5,figsize=(16,6), constrained_layout=True)
for ax, s in zip(axs.flatten(), s_levels):
    im = ax.imshow(P[s] , origin='lower', aspect='auto', cmap='viridis')
    ax.set_title(f"P(y1,y2|x={s})")
    ax.set_xlabel("Demand Group 2 (y2)")
    ax.set_ylabel("Demand Group 1 (y1)")
cbar2 = fig.colorbar(im, ax=axs.ravel().tolist(), fraction=0.02, pad=0.01)
cbar2.set_label("Conditional P")
plt.show()

flat_inds = np.random.choice(demand_n**2, size=5, p=p_Y_flat)
pairs = [(idx//demand_n, idx%demand_n) for idx in flat_inds]

fig, axs = plt.subplots(1,5,figsize=(18,4), constrained_layout=True)
for ax, (i,j) in zip(axs, pairs):
    cond_sup = P[:,i,j] / p_Y[i,j]
    ax.bar(S, cond_sup, color='C3', alpha=0.7)
    ax.set_title(f"x | y=({i},{j})")
    ax.set_xlabel("x")
    ax.set_ylabel("P")
plt.show()


In [None]:
#@title Scenario 3: Bivariate Supply and Bivariate Demand (2D2D)

# Parameters (all inputs are defined here)
n = 26  # grid size (0..n-1 in each dimension)
# Cost parameters
co, cu = 1,10            # under- and over-capacity unit costs
threshold, ca = 20, 5     # threshold and capacity adjustment cost

# Sinkhorn parameter
epsilon = 0.1

# Sampling parameters for histograms / couplings
hist_samples = 10         # number of random demand pairs for histograms (part 6)
coupling_samples = 10     # number of random demand pairs for supply coupling plots (part 7)

# ---------------------------
# 1. Domains & 2D Poisson PMFs
# ---------------------------
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import factorial

# Define coordinate arrays
X1 = X2 = Y1 = Y2 = np.arange(n)

lambda1_s1, lambda2_s1, lambda12_s1 = 5,  10,  1
# Mode 2 for supply: high–mean
lambda1_s2, lambda2_s2, lambda12_s2 = 12, 10, 3
# Mixing weight (how much Mode 1 vs. Mode 2)
w_s = 0.4

# ----- B) Compute each mode’s PMF on the same grid -----
def bivar_poisson_pmf(l1, l2, l12, n):
    """Return an n×n array of a correlated bivariate Poisson PMF."""
    coef = np.exp(-(l1 + l2 + l12))
    P = np.zeros((n,n))
    for i in range(n):
        for j in range(n):
            k_max = min(i,j)
            P[i,j] = coef * sum(
                (l1**(i-k)/factorial(i-k)) *
                (l2**(j-k)/factorial(j-k)) *
                (l12**k      /factorial(k))
                for k in range(k_max+1)
            )
    return P / P.sum()

p_X_mode1 = bivar_poisson_pmf(lambda1_s1, lambda2_s1, lambda12_s1, n)
p_X_mode2 = bivar_poisson_pmf(lambda1_s2, lambda2_s2, lambda12_s2, n)

# ----- C) Mix them to get a bimodal supply law -----
p_X = w_s * p_X_mode1 + (1 - w_s) * p_X_mode2
p_X /= p_X.sum()   # just to guard against numerical drift

# Repeat the same for demand:
w_d = 0.3
# two demand modes
lambda1_d1, lambda2_d1, lambda12_d1 = 8,  8,  2
lambda1_d2, lambda2_d2, lambda12_d2 = 12, 20, 5

p_Y_mode1 = bivar_poisson_pmf(lambda1_d1, lambda2_d1, lambda12_d1, n)
p_Y_mode2 = bivar_poisson_pmf(lambda1_d2, lambda2_d2, lambda12_d2, n)
p_Y = w_d * p_Y_mode1 + (1 - w_d) * p_Y_mode2
p_Y /= p_Y.sum()

# Plot 2D PMFs
plt.figure(figsize=(6,5))
plt.imshow(p_X, origin='lower', aspect='auto')
plt.colorbar(label='P(X1,X2)')
plt.title('Supply Joint PMF (2D Poisson)')
plt.xlabel('X2'); plt.ylabel('X1')
plt.show()

plt.figure(figsize=(6,5))
plt.imshow(p_Y, origin='lower', aspect='auto')
plt.colorbar(label='P(Y1,Y2)')
plt.title('Demand Joint PMF (2D Poisson)')
plt.xlabel('Y2'); plt.ylabel('Y1')
plt.show()

# ---------------------------
# 2. Cost Matrix and OT Setup
# ---------------------------
X1g, S2g = np.meshgrid(X1, X2, indexing='ij')
J1g, J2g = np.meshgrid(Y1, Y2, indexing='ij')
Sup_tot = X1g + S2g
Dem_tot = J1g + J2g

# cost shape (n,n,n,n)
C = np.maximum(0, Sup_tot[...,None,None] - Dem_tot[None,None]) * co + \
    np.maximum(0, Dem_tot[None,None] - Sup_tot[...,None,None]) * cu + \
    ca * np.maximum(0, Sup_tot[...,None,None] - threshold)
C_scaled = C / C.max()

# Flatten for OT
m = n*n
C_flat = C_scaled.reshape(m, m)
phi = np.ones_like(C_flat) / m**2

# Ideal plan
pi_I = phi * np.exp(-C_flat/epsilon)
pi_I /= pi_I.sum()

# Marginals
p_X_flat = p_X.flatten()
p_Y_flat = p_Y.flatten()

# Sinkhorn
u = np.ones(m)
v = np.ones(m)
for _ in range(1000):
    v = p_Y_flat / (pi_I.T @ u)
    u_prev = u.copy()
    u = p_X_flat / (pi_I @ v)
    if np.linalg.norm(u - u_prev) < 1e-9:
        break

P_flat = np.diag(u) @ pi_I @ np.diag(v)
P = P_flat.reshape((n,n,n,n))

# ---------------------------
# 3. Plot Cost slices at 10 supply pairs
# ---------------------------
sup_pairs = [(5,5), (7,10), (15,6), (13,13), (23,23)]
rnd_idx = np.random.choice(m, size=hist_samples, p=p_Y_flat)
d_pairs = [(idx//n, idx%n) for idx in rnd_idx]

fig, axes = plt.subplots(1,5,figsize=(18,4), constrained_layout=True)
for ax, (x1,x2) in zip(axes.flatten(), sup_pairs):
    im = ax.imshow(C_scaled[x1,x2], origin='lower', aspect='auto', cmap='magma')
    ax.set_title(f"Cost | X=({x1},{x2})")
    ax.set_xlabel("Y2"); ax.set_ylabel("Y1")
fig.colorbar(im, ax=axes.ravel().tolist(), fraction=0.02, pad=0.01)
plt.show()

# ---------------------------
# 4. Plot OT coupling slices
# ---------------------------
fig, axes = plt.subplots(1,5,figsize=(18,4), constrained_layout=True)
for ax, (x1,x2) in zip(axes.flatten(), sup_pairs):
    cond = P[x1,x2] / (p_X[x1,x2] + 1e-12)
    im2 = ax.imshow(cond, origin='lower', aspect='auto', cmap='viridis')
    ax.set_title(f"P(Y|X=({x1},{x2}))")
    ax.set_xlabel("Y2"); ax.set_ylabel("Y1")
fig.colorbar(im2, ax=axes.ravel().tolist(), fraction=0.02, pad=0.01)
plt.show()

# ---------------------------
# 5. Histograms: Supply_total & Demand_total
# ---------------------------
sup_tot = Sup_tot.flatten()
dem_tot = Dem_tot.flatten()

p_sup_tot = np.zeros((2*n))
p_dem_tot = np.zeros((2*n))
for idx,(i,j) in enumerate([(i,j) for i in range(n) for j in range(n)]):
    p_sup_tot[i+j] += p_X[i,j]
    p_dem_tot[i+j] += p_Y[i,j]

fig, ax = plt.subplots(1,2,figsize=(12,4), constrained_layout=True)
ax[0].bar(np.arange(2*n), p_sup_tot, color='C3')
ax[0].set_title('Supply Total PMF')
ax[0].set_xlabel('X1+X2'); ax[0].set_ylabel('P')
ax[1].bar(np.arange(2*n), p_dem_tot, color='C4')
ax[1].set_title('Demand Total PMF')
ax[1].set_xlabel('Y1+Y2'); ax[1].set_ylabel('P')
plt.show()

# ---------------------------
# 6. Histograms: Supply_total | random demand pairs
# ---------------------------
# sample random demand indices


fig, axes = plt.subplots(1, hist_samples, figsize=(4*hist_samples,4), constrained_layout=True)
for ax, (i,j) in zip(axes, sup_pairs):
    p_cond_sup = np.zeros((2*n))
    for a in range(n):
        for b in range(n):
            p_cond_sup[a+b] += P[a,b,i,j]
    p_cond_sup /= p_Y[i,j]  # normalize
    ax.bar(np.arange(2*n), p_cond_sup, alpha=0.7)
    ax.set_title(f"X_tot | Y=({i},{j})")
    ax.set_xlabel('X1+X2'); ax.set_ylabel('P')
plt.show()

# ---------------------------
# 7. Plot 2D supply PMF conditional on random demand couplings
# ---------------------------
rnd_idx = np.random.choice(m, size=coupling_samples, p=p_Y_flat)
d_pairs = [(idx // n, idx % n) for idx in rnd_idx]

fig, axes = plt.subplots(1, 5, figsize=(18, 4), constrained_layout=True)
for ax, (y1, y2) in zip(axes.flatten(), sup_pairs):
    p_cond_S = P[:, :, y1, y2] / (p_Y[y1, y2] + 1e-12)
    im = ax.imshow(p_cond_S, origin='lower', aspect='auto', cmap='plasma')
    ax.set_title(f"P(X│Y=({y1},{y2}))")
    ax.set_xlabel("X2"); ax.set_ylabel("X1")

fig.colorbar(im, ax=axes.ravel().tolist(), fraction=0.02, pad=0.01)
plt.show()

# Previous code... (Sections 1 and 2, ending with P = P_flat.reshape((n,n,n,n)))

# ---------------------------
# 8. Plot Derived 2D Marginals from the 4D Plan
# ---------------------------
# These marginals show the effective joint probabilities between specific
# supply-demand pairs after accounting for all interactions in the optimal plan.

# Calculate the four 2D marginals by summing over the appropriate axes
# P has shape (x1, x2, y1, y2) corresponding to axes (0, 1, 2, 3)

print("\nCalculating and plotting derived 2D marginals from the 4D plan P...")

# Marginal P(X1, Y1) = sum over X2 (axis 1) and Y2 (axis 3)
p_S1D1 = P.sum(axis=(1, 3))

# Marginal P(X1, Y2) = sum over X2 (axis 1) and Y1 (axis 2)
p_S1D2 = P.sum(axis=(1, 2))

# Marginal P(X2, Y1) = sum over X1 (axis 0) and Y2 (axis 3)
p_S2D1 = P.sum(axis=(0, 3))

# Marginal P(X2, Y2) = sum over X1 (axis 0) and Y1 (axis 2)
p_S2D2 = P.sum(axis=(0, 2))

# Plotting the four 2D marginals
fig, axes = plt.subplots(2, 2, figsize=(11, 10), constrained_layout=True)
fig.suptitle("Derived 2D Marginals from Optimal 4D Plan P(X1,X2,Y1,Y2)", fontsize=14)

# S1D1 plot: P(X1, Y1)
im1 = axes[0, 0].imshow(p_S1D1, origin='lower', aspect='auto', cmap='viridis')
axes[0, 0].set_title("Marginal P(X1, Y1)")
axes[0, 0].set_xlabel("Y1 (Demand 1)")
axes[0, 0].set_ylabel("X1 (Supply 1)")
fig.colorbar(im1, ax=axes[0, 0], label="Probability")

# S1D2 plot: P(X1, Y2)
im2 = axes[0, 1].imshow(p_S1D2, origin='lower', aspect='auto', cmap='viridis')
axes[0, 1].set_title("Marginal P(X1, Y2)")
axes[0, 1].set_xlabel("Y2 (Demand 2)")
axes[0, 1].set_ylabel("X1 (Supply 1)")
fig.colorbar(im2, ax=axes[0, 1], label="Probability")

# S2D1 plot: P(X2, Y1)
im3 = axes[1, 0].imshow(p_S2D1, origin='lower', aspect='auto', cmap='viridis')
axes[1, 0].set_title("Marginal P(X2, Y1)")
axes[1, 0].set_xlabel("Y1 (Demand 1)")
axes[1, 0].set_ylabel("X2 (Supply 2)")
fig.colorbar(im3, ax=axes[1, 0], label="Probability")

# S2D2 plot: P(X2, Y2)
im4 = axes[1, 1].imshow(p_S2D2, origin='lower', aspect='auto', cmap='viridis')
axes[1, 1].set_title("Marginal P(X2, Y2)")
axes[1, 1].set_xlabel("Y2 (Demand 2)")
axes[1, 1].set_ylabel("X2 (Supply 2)")
fig.colorbar(im4, ax=axes[1, 1], label="Probability")
# ---------------------------
# 9. Induced supply & demand marginals from P
# ---------------------------

# Induced joint of (S1,S2): sum out y1,y2 → shape (n,n)
p_S12_plan = P.sum(axis=(2,3))

# Induced joint of (D1,D2): sum out x1,x2 → shape (n,n)
p_D12_plan = P.sum(axis=(0,1))

# Plot originals vs. induced
fig, axes = plt.subplots(2, 2, figsize=(12,10), constrained_layout=True)
fig.suptitle("Original vs. Induced 2D Marginals", fontsize=16)

# Supply: original p_X
im0 = axes[0,0].imshow(p_X,   origin='lower', aspect='auto', cmap='viridis')
axes[0,0].set_title("Original Supply PMF\np_X(x1,x2)")
axes[0,0].set_xlabel("x2"); axes[0,0].set_ylabel("x1")
fig.colorbar(im0, ax=axes[0,0])

# Supply: induced from plan
im1 = axes[0,1].imshow(p_S12_plan, origin='lower', aspect='auto', cmap='viridis')
axes[0,1].set_title("Induced Supply Marginal\n∑_{y1,y2}P(x1,x2,y1,y2)")
axes[0,1].set_xlabel("x2"); axes[0,1].set_ylabel("x1")
fig.colorbar(im1, ax=axes[0,1])

# Demand: original p_Y
im2 = axes[1,0].imshow(p_Y,   origin='lower', aspect='auto', cmap='magma')
axes[1,0].set_title("Original Demand PMF\np_Y(y1,y2)")
axes[1,0].set_xlabel("y2"); axes[1,0].set_ylabel("y1")
fig.colorbar(im2, ax=axes[1,0])

# Demand: induced from plan
im3 = axes[1,1].imshow(p_D12_plan, origin='lower', aspect='auto', cmap='magma')
axes[1,1].set_title("Induced Demand Marginal\n∑_{x1,x2}P(x1,x2,y1,y2)")
axes[1,1].set_xlabel("y2"); axes[1,1].set_ylabel("y1")
fig.colorbar(im3, ax=axes[1,1])

plt.show()


plt.show()
# flatten arrays for convenience
Sup_tot_flat = Sup_tot.flatten()
Dem_tot_flat = Dem_tot.flatten()
C_flat_unscaled = C_flat * C.max()  # recover original scale if you like

sup = Sup_tot_flat[:, None]      # shape (m,1)
dem = Dem_tot_flat[None, :]      # shape (1,m)
delta = dem - sup                # shape (m,m)

exp_under = np.sum(P_flat * cu * np.maximum(delta, 0))
exp_over  = np.sum(P_flat * co *  np.maximum(-delta,0))


# expected total cost
exp_cost = np.sum(P_flat * C_flat_unscaled)


# expected threshold-adjustment cost
exp_adj = np.sum(
    P_flat * ca * np.maximum(Sup_tot_flat - threshold, 0)
)

print(f"Total expected cost:        {exp_cost:.3f}")
print(f" Under-capacity (lost sales): {exp_under:.3f}")
print(f" Over-capacity (holding):      {exp_over:.3f}")
print(f" Threshold-adjustment:        {exp_adj:.3f}")

# --- print summary of the full 4D plan ---
print("Joint plan P shape (flattened):", P_flat.shape)
print("Total mass of P (should be 1):", P_flat.sum())

# --- compute the two event‐masks ---
Sup_tot_flat = Sup_tot.flatten()            # shape (m,)
Dem_tot_flat = Dem_tot.flatten()            # shape (m,)
# broadcast to (m,m)
mask_less    = Sup_tot_flat[:,None] <  Dem_tot_flat[None,:]
mask_greater = Sup_tot_flat[:,None] >  Dem_tot_flat[None,:]

# --- event probabilities ---
p_less    = P_flat[mask_less].sum()      # Pr(X1+X2 < Y1+Y2)
p_greater = P_flat[mask_greater].sum()   # Pr(X1+X2 > Y1+Y2)
p_equal   = 1.0 - p_less - p_greater     # tie‐mass

print(f"P(X1+X2 < Y1+Y2) = {p_less:.6f}")
print(f"P(Y1+Y2 < X1+X2) = {p_greater:.6f}")
print(f"P(X1+X2 = Y1+Y2) = {p_equal:.6f}")


In [None]:
#@title Quality of Service: Feasibility Test LP


import numpy as np
from scipy.optimize import linprog
from scipy.special import factorial
import pandas as pd

# 1. PARAMETERS (same as before)
n = 20
lambda1_s1, lambda2_s1, lambda12_s1 = 5, 10, 1
lambda1_s2, lambda2_s2, lambda12_s2 = 12, 15, 3
w_s = 0.4
lambda1_d1, lambda2_d1, lambda12_d1 = 8, 8, 2
lambda1_d2, lambda2_d2, lambda12_d2 = 12, 20, 5
w_d = 0.3

# 2. BIVARIATE POISSON PMF FUNCTION
def bivar_poisson_pmf(l1, l2, l12, n):
    coef = np.exp(-(l1 + l2 + l12))
    P = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            P[i, j] = coef * sum(
                (l1**(i-k) / factorial(i-k)) *
                (l2**(j-k) / factorial(j-k)) *
                (l12**k       / factorial(k))
                for k in range(min(i, j) + 1)
            )
    return P / P.sum()

# Build supply and demand joint PMFs
p_X1 = bivar_poisson_pmf(lambda1_s1, lambda2_s1, lambda12_s1, n)
p_X2 = bivar_poisson_pmf(lambda1_s2, lambda2_s2, lambda12_s2, n)
p_X  = w_s * p_X1 + (1 - w_s) * p_X2
p_X /= p_X.sum()
p_Y1 = bivar_poisson_pmf(lambda1_d1, lambda2_d1, lambda12_d1, n)
p_Y2 = bivar_poisson_pmf(lambda1_d2, lambda2_d2, lambda12_d2, n)
p_Y  = w_d * p_Y1 + (1 - w_d) * p_Y2
p_Y /= p_Y.sum()

# Aggregate to totals
T = 2 * n
p_sup_tot = np.zeros(T)
p_dem_tot = np.zeros(T)
for i in range(n):
    for j in range(n):
        p_sup_tot[i + j] += p_X[i, j]
        p_dem_tot[i + j] += p_Y[i, j]
p_sup_tot /= p_sup_tot.sum()
p_dem_tot /= p_dem_tot.sum()

# LP setup
T = len(p_sup_tot)
m = T * T
A_eq = []
b_eq = []
for u in range(T):
    row = np.zeros(m); row[u*T:(u+1)*T] = 1
    A_eq.append(row); b_eq.append(p_sup_tot[u])
for v in range(T):
    row = np.zeros(m); row[v::T] = 1
    A_eq.append(row); b_eq.append(p_dem_tot[v])
A_eq = np.vstack(A_eq); b_eq = np.array(b_eq)

# Objectives for X<Y and Y<X
c_xy = np.zeros(m)
c_yx = np.zeros(m)
for u in range(T):
    for v in range(T):
        if u < v:
            c_xy[u*T + v] = 1
        if u > v:
            c_yx[u*T + v] = 1

# Solve LPs
res_min_xy = linprog(c_xy,  A_eq=A_eq, b_eq=b_eq, bounds=(0, None), method='highs')
res_max_xy = linprog(-c_xy, A_eq=A_eq, b_eq=b_eq, bounds=(0, None), method='highs')
res_min_yx = linprog(c_yx,  A_eq=A_eq, b_eq=b_eq, bounds=(0, None), method='highs')
res_max_yx = linprog(-c_yx, A_eq=A_eq, b_eq=b_eq, bounds=(0, None), method='highs')

S_min_xy = res_min_xy.fun
S_max_xy = -res_max_xy.fun
S_min_yx = res_min_yx.fun
S_max_yx = -res_max_yx.fun

# Feasibility grid
alphas = np.linspace(0, 1, 101)
feas_xy = alphas >= (S_min_xy - 1e-12)
feas_yx = alphas >= (S_min_yx - 1e-12)

# Display DataFrame
df = pd.DataFrame({
    'alpha': alphas,
    'feasible(X<Y)': feas_xy,
    'feasible(Y<X)': feas_yx
})

# Print summary
print(f"Min P(X<Y) = {S_min_xy:.6f}, Max P(X<Y) = {S_max_xy:.6f}")
print(f"Min P(Y<X) = {S_min_yx:.6f}, Max P(Y<X) = {S_max_yx:.6f}")


In [None]:
#@title Quality of Service: Feasibility Test + Modified Sinkhorn + Comparison of Outputs
n = 20

# Cost parameters
co, cu = 1, 10       # over- and under-capacity unit costs (co=overage/holding, cu=underage/stockout)
threshold, ca = 20, 5  # threshold and capacity adjustment cost

# Sinkhorn parameter
epsilon = 0.1

# Sampling parameters for histograms / couplings
hist_samples = 10     # number of random demand pairs for histograms
coupling_samples = 10 # number of random demand pairs for supply coupling plots

# QoS Parameters
alpha_service = 0.6     # target “service level” P[X_tot >= Y_tot]
beta_target = 1.0 - alpha_service  # desired P[X_tot < Y_tot] (stock-out probability)

# Imports
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import factorial
from scipy.optimize import linprog
from scipy.sparse import lil_matrix, csr_matrix
import time
np.random.seed(42)



# 1. Domains & Bimodal Bivariate Poisson 2D PMFs
print("--- Section 1: Defining Domains & 2D Bimodal Bivariate Poisson PMFs ---")
# Define coordinate arrays
X1_coords = X2_coords = Y1_coords = Y2_coords = np.arange(n)

# --- Bimodal Bivariate Poisson Model Parameters

# For Supply:
# Component 1
lambda1_s1, lambda2_s1, lambda12_s1 = 5, 10, 1
# Component 2
lambda1_s2, lambda2_s2, lambda12_s2 = 12, 15, 3
w_s = 0.4  # Weight for the first supply component

# For Demand:
# Component 1
lambda1_d1, lambda2_d1, lambda12_d1 = 8, 8, 2
# Component 2
lambda1_d2, lambda2_d2, lambda12_d2 = 12, 20, 5
w_d = 0.3  # Weight for the first demand component


def bivar_poisson_pmf(l1, l2, l12, size_n):
    """Return an n×n array of a correlated bivariate Poisson PMF."""
    coef = np.exp(-(l1 + l2 + l12))
    P_matrix = np.zeros((size_n, size_n))
    for r_idx in range(size_n):
        for c_idx in range(size_n):
            k_max = min(r_idx, c_idx)
            if l1 < 0 or l2 < 0 or l12 < 0: # Guard against issues with very small lambdas if they become negative
                P_matrix[r_idx,c_idx] = 0
                continue
            term_sum = 0
            for k in range(k_max + 1):
                val1 = l1**(r_idx-k) / factorial(r_idx-k) if (r_idx-k) >=0 and l1 >=0 else 0
                val2 = l2**(c_idx-k) / factorial(c_idx-k) if (c_idx-k) >=0 and l2 >=0 else 0
                val12 = l12**k / factorial(k) if k>=0 and l12 >=0 else 0
                term_sum += val1 * val2 * val12
            P_matrix[r_idx, c_idx] = coef * term_sum

    current_sum = P_matrix.sum()
    return P_matrix / current_sum if current_sum > 1e-9 else np.zeros_like(P_matrix)

# --- Generate Supply PMF (p_X) ---
p_X1_component = bivar_poisson_pmf(lambda1_s1, lambda2_s1, lambda12_s1, n)
p_X2_component = bivar_poisson_pmf(lambda1_s2, lambda2_s2, lambda12_s2, n)
p_X = w_s * p_X1_component + (1 - w_s) * p_X2_component
p_X /= p_X.sum() # Ensure normalization

# --- Generate Demand PMF (p_Y) ---
p_Y1_component = bivar_poisson_pmf(lambda1_d1, lambda2_d1, lambda12_d1, n)
p_Y2_component = bivar_poisson_pmf(lambda1_d2, lambda2_d2, lambda12_d2, n)
p_Y = w_d * p_Y1_component + (1 - w_d) * p_Y2_component
p_Y /= p_Y.sum() # Ensure normalization

# Plot 2D PMFs
fig_pmf, axes_pmf = plt.subplots(1, 2, figsize=(12, 5))
im_s = axes_pmf[0].imshow(p_X, origin='lower', aspect='auto', cmap='viridis')
axes_pmf[0].set_title('$\\mu(x_1,x_2)$')
axes_pmf[0].set_xlabel('$x_2$'); axes_pmf[0].set_ylabel('$x_1$')
fig_pmf.colorbar(im_s, ax=axes_pmf[0], label='$P(x_1,x_2)$')

im_d = axes_pmf[1].imshow(p_Y, origin='lower', aspect='auto', cmap='viridis')
axes_pmf[1].set_title('$\\nu(y_1,y_2)$')
axes_pmf[1].set_xlabel('$y_2$'); axes_pmf[1].set_ylabel('$y_1$')
fig_pmf.colorbar(im_d, ax=axes_pmf[1], label='$P(y_1,y_2)$')
plt.tight_layout()
plt.show()

# ---------------------------
# 2. Cost Matrix and OT Setup
# ---------------------------
print("\n--- Section 2: Cost Matrix and Initial OT Setup ---")
X1g_mesh, X2g_mesh = np.meshgrid(X1_coords, X2_coords, indexing='ij') # Supply component grids
Y1g_mesh, Y2g_mesh = np.meshgrid(Y1_coords, Y2_coords, indexing='ij') # Demand component grids

Sup_tot_grid = X1g_mesh + X2g_mesh # n x n grid: Sup_tot_grid[x1,x2] = x1 + x2
Dem_tot_grid = Y1g_mesh + Y2g_mesh # n x n grid: Dem_tot_grid[y1,y2] = y1 + y2

# Cost matrix C[x1,x2,y1,y2]
C = (np.maximum(0, Sup_tot_grid[..., None, None] - Dem_tot_grid[None, None, ...]) * co +
    np.maximum(0, Dem_tot_grid[None, None, ...] - Sup_tot_grid[..., None, None]) * cu +
    ca * np.maximum(0, Sup_tot_grid[..., None, None] - threshold))
C_flat_unscaled = C.reshape(n*n, n*n) # Keep unscaled for final cost calculation

C_scaled = C / C.max() if C.max() > 0 else C

# Flatten for OT
m_pairs = n * n  # Number of (x1,x2) pairs or (y1,y2) pairs
C_flat = C_scaled.reshape(m_pairs, m_pairs) # (n*n) x (n*n)

# Ideal plan based on cost (Gibbs kernel)
pi_I = np.exp(-C_flat / epsilon) # Kernel K in Sinkhorn
pi_I_sum = pi_I.sum()
pi_I /= pi_I_sum if pi_I_sum > 1e-9 else 1 # Normalize to be a starting point for P


# Flatten marginals
p_X_flat = p_X.flatten() # (n*n,)
p_Y_flat = p_Y.flatten() # (n*n,)

# Flatten total supply and demand grids for use in masks and LPs
Sup_tot_flat = Sup_tot_grid.flatten() # (n*n,)
Dem_tot_flat = Dem_tot_grid.flatten() # (n*n,)

# Mask for stock-out condition: Sup_tot < Dem_tot
mask_under = (Sup_tot_flat[:, None] < Dem_tot_flat[None, :]) # (m_pairs, m_pairs)

# ------------------------------------
# Helper Sinkhorn Functions
# ------------------------------------
def sinkhorn_two_way(K_kernel, mu_target, nu_target, it_max=5000, tol=1e-9, check_interval=50):
    """Standard 2-way Sinkhorn for P = diag(u) @ K @ diag(v)."""
    u = np.ones(mu_target.shape[0]) / mu_target.shape[0] # Initialize u
    v = np.ones(nu_target.shape[0]) / nu_target.shape[0] # Initialize v

    for it in range(it_max):
        u_prev = u.copy()
        v_scaling_factor = K_kernel.T @ u
        v = nu_target / (v_scaling_factor + 1e-12)
        u_scaling_factor = K_kernel @ v
        u = mu_target / (u_scaling_factor + 1e-12)

        if it % check_interval == 0:
            if np.linalg.norm(u - u_prev) < tol :
                # print(f"Sinkhorn (2-way) converged at iteration {it}.")
                break
    P_final = u[:, None] * K_kernel * v[None, :]
    P_final_sum = P_final.sum()
    return P_final / P_final_sum if P_final_sum > 1e-9 else np.zeros_like(P_final)


def sinkhorn_three_way(K_kernel, mu_target, nu_target, qos_mask, qos_beta, it_max=8000, tol=1e-9, check_interval=50):
    """3-way Sinkhorn: projects onto marginals and QoS constraint."""
    P = K_kernel.copy()
    P_sum = P.sum()
    P /= P_sum if P_sum > 1e-9 else 1.0 # Start with normalized K

    for it in range(it_max):
        P_row_sums = P.sum(axis=1)
        P *= (mu_target / (P_row_sums + 1e-12))[:, None]
        P_col_sums = P.sum(axis=0)
        P *= (nu_target / (P_col_sums + 1e-12))[None, :]

        current_qos_mass = P[qos_mask].sum()
        if current_qos_mass > qos_beta + 1e-12 : # Add small tolerance for strict inequality
            P[qos_mask] *= (qos_beta / (current_qos_mass + 1e-12)) # Scale down to meet target
            # Renormalize the whole plan P after QoS projection to maintain sum of P approx 1
            # This is important as scaling only a part of P changes the total sum.
            P_sum_after_qos = P.sum()
            P /= P_sum_after_qos if P_sum_after_qos > 1e-9 else 1.0


        if it % check_interval == 0:
            err_mu = np.abs(P.sum(axis=1) - mu_target).max()
            err_nu = np.abs(P.sum(axis=0) - nu_target).max()
            current_qos_val = P[qos_mask].sum()
            if err_mu < tol and err_nu < tol and current_qos_val <= qos_beta + tol: # check if QoS is met or undershot
                # print(f"Sinkhorn (3-way) converged at iteration {it}.")
                break
    final_P_sum = P.sum()
    return P / final_P_sum if final_P_sum > 1e-9 else np.zeros_like(P)

# ---------------------------
# 3. Plot Cost slices (Plan Independent)
# ---------------------------
print("\n--- Section 3: Plotting Cost Slices (Plan Independent) ---")
num_slice_plots = min(10, n)
sup_pairs_indices = []
if n > 0 : sup_pairs_indices = [(i,i) for i in np.linspace(0, n-1, num_slice_plots, dtype=int)]

fig_cost, axes_cost = plt.subplots(2, (num_slice_plots+1)//2 if num_slice_plots > 0 else 1, figsize=(18, 7), constrained_layout=True)
if num_slice_plots == 0: axes_cost = np.array([[axes_cost]]) # Ensure it's a 2D array for flatten
axes_cost_flat = axes_cost.flatten()


for k_ax, (x1_idx, x2_idx) in enumerate(sup_pairs_indices):
    ax = axes_cost_flat[k_ax]
    im = ax.imshow(C_scaled[x1_idx, x2_idx, :, :], origin='lower', aspect='auto', cmap='magma')
    ax.set_title(f"Cost | X=({X1_coords[x1_idx]},{X2_coords[x2_idx]})")
    ax.set_xlabel("Y2 Index"); ax.set_ylabel("Y1 Index")

if num_slice_plots > 0:
    fig_cost.colorbar(im, ax=axes_cost.ravel().tolist(), fraction=0.02, pad=0.01, label="Scaled Cost")
else:
    axes_cost_flat[0].text(0.5, 0.5, "No slices to plot (n=0)", ha='center', va='center')
    axes_cost_flat[0].axis('off')


plt.suptitle("Cost Function Slices C(X1,X2,Y1,Y2) for Fixed (X1,X2)")
plt.show()

# ---------------------------
# 5. Histograms: Supply_total & Demand_total (Original Marginals - Plan Independent)
# ---------------------------
print("\n--- Section 5: Plotting Histograms of Total Supply and Demand (Original Marginals - Plan Independent) ---")
max_total_val = 2 * (n - 1) if n > 0 else 0
p_sup_tot_orig = np.zeros(max_total_val + 1)
p_dem_tot_orig = np.zeros(max_total_val + 1)
total_coords = np.arange(max_total_val + 1)


if n > 0:
    for r_idx in range(n):
        for c_idx in range(n):
            p_sup_tot_orig[X1_coords[r_idx] + X2_coords[c_idx]] += p_X[r_idx, c_idx]
            p_dem_tot_orig[Y1_coords[r_idx] + Y2_coords[c_idx]] += p_Y[r_idx, c_idx]

fig_hist_totals, ax_hist_totals = plt.subplots(1, 2, figsize=(12, 4), constrained_layout=True)
ax_hist_totals[0].bar(total_coords, p_sup_tot_orig, color='C3', width=0.8)
ax_hist_totals[0].set_title('Original Supply Total PMF P(X1+X2)')
ax_hist_totals[0].set_xlabel('X1+X2'); ax_hist_totals[0].set_ylabel('P')
ax_hist_totals[1].bar(total_coords, p_dem_tot_orig, color='C4', width=0.8)
ax_hist_totals[1].set_title('Original Demand Total PMF P(Y1+Y2)')
ax_hist_totals[1].set_xlabel('Y1+Y2'); ax_hist_totals[1].set_ylabel('P')
plt.show()

# --------------------------------------------------
# Function to Perform Analysis and Plotting for a given Plan
# --------------------------------------------------
# --------------------------------------------------
# Generate Fixed Random Demand Pairs for Consistent Analysis
# --------------------------------------------------
print("\n--- Generating Fixed Random Demand Pairs for Consistent Analysis ---")
d_pairs_for_hist = []
d_pairs_for_coupling = []

# Ensure necessary variables (p_Y_flat, m_pairs, n, hist_samples, coupling_samples) are defined
if 'p_Y_flat' in locals() and p_Y_flat.sum() > 1e-9 and 'm_pairs' in locals() and m_pairs > 0 and 'n' in locals() and n > 0:
    if 'hist_samples' in locals() and hist_samples > 0:
        try:
            rnd_demand_indices_hist = np.random.choice(m_pairs, size=hist_samples, p=p_Y_flat)
            d_pairs_for_hist = [(idx // n, idx % n) for idx in rnd_demand_indices_hist]
            print(f"Generated {len(d_pairs_for_hist)} pairs for histograms.")
        except Exception as e:
            print(f"Error generating histogram pairs: {e}")
            d_pairs_for_hist = [] # Ensure it's empty on error
    else:
        print("hist_samples = 0, no pairs generated for histograms.")

    if 'coupling_samples' in locals() and coupling_samples > 0:
        try:
            rnd_demand_indices_coupling = np.random.choice(m_pairs, size=coupling_samples, p=p_Y_flat)
            d_pairs_for_coupling = [(idx // n, idx % n) for idx in rnd_demand_indices_coupling]
            print(f"Generated {len(d_pairs_for_coupling)} pairs for coupling plots.")
        except Exception as e:
            print(f"Error generating coupling pairs: {e}")
            d_pairs_for_coupling = [] # Ensure it's empty on error
    else:
        print("coupling_samples = 0, no pairs generated for coupling plots.")
else:
    print("Cannot generate demand pairs: Prerequisite variables missing or invalid (p_Y_flat, m_pairs, n).")

# --------------------------------------------------
# Function to Perform Analysis and Plotting for a given Plan
# --------------------------------------------------
def perform_plan_analysis_and_plotting(P_flat_current_plan, plan_label_suffix,
                                       d_pairs_hist_input, d_pairs_coupling_input): # Added input arguments
    print(f"\n\n=== Detailed Analysis & Plots for: {plan_label_suffix} ===")
    if P_flat_current_plan is None or P_flat_current_plan.sum() < 1e-9:
        print(f"Skipping analysis for {plan_label_suffix} as the plan is invalid or empty.")
        return

    P_final_4D_current = P_flat_current_plan.reshape((n, n, n, n))

    # ---------------------------
    # 4. Plot OT coupling slices P(Y1,Y2 | X1,X2)
    # ---------------------------
    print(f"\n--- Section 4 (for {plan_label_suffix}): Plotting OT Coupling Slices ---")
    fig_ot, axes_ot = plt.subplots(2, (num_slice_plots+1)//2 if num_slice_plots > 0 else 1, figsize=(18, 7), constrained_layout=True)
    if num_slice_plots == 0: axes_ot = np.array([[axes_ot]])
    axes_ot_flat = axes_ot.flatten()

    for k_ax, (x1_idx, x2_idx) in enumerate(sup_pairs_indices):
        ax = axes_ot_flat[k_ax]
        conditional_P_Y_given_X = P_final_4D_current[x1_idx, x2_idx, :, :] / (p_X[x1_idx, x2_idx] + 1e-12)
        im2 = ax.imshow(conditional_P_Y_given_X, origin='lower', aspect='auto', cmap='viridis')
        ax.set_title(f"P(Y|X=({X1_coords[x1_idx]},{X2_coords[x2_idx]}))")
        ax.set_xlabel("Y2 Index"); ax.set_ylabel("Y1 Index")

    if num_slice_plots > 0:
        fig_ot.colorbar(im2, ax=axes_ot.ravel().tolist(), fraction=0.02, pad=0.01, label="Conditional Probability")
    else:
        axes_ot_flat[0].text(0.5, 0.5, "No slices to plot (n=0)", ha='center', va='center')
        axes_ot_flat[0].axis('off')
    plt.suptitle(f"OT Coupling Slices P(Y1,Y2 | X1,X2) - {plan_label_suffix}")
    plt.show()

    # ---------------------------
    # 6. Histograms: P(Supply_total | random Y=(y1,y2) from p_Y)
    # ---------------------------
    print(f"\n--- Section 6 (for {plan_label_suffix}): Histograms P(X_tot | Y) for Random Demand Pairs ---")
    # Check if input list is valid and non-empty
    if d_pairs_hist_input and len(d_pairs_hist_input) > 0 and n > 0 :

        num_hist_plots = len(d_pairs_hist_input)
        fig_hist_cond_S, axes_hist_cond_S = plt.subplots(1, num_hist_plots,
                                                         figsize=(max(4 * num_hist_plots, 5), 4),
                                                         constrained_layout=True, squeeze=False)
        axes_hist_cond_S_flat = axes_hist_cond_S.flatten()

        for k_ax, (y1_idx, y2_idx) in enumerate(d_pairs_hist_input): # USE input list
            ax = axes_hist_cond_S_flat[k_ax]
            p_X_given_Y = P_final_4D_current[:, :, y1_idx, y2_idx] / (p_Y[y1_idx, y2_idx] + 1e-12)
            p_cond_sup_tot = np.zeros(max_total_val + 1)
            for x1_s_idx in range(n):
                for x2_s_idx in range(n):
                    total_supply_val = X1_coords[x1_s_idx] + X2_coords[x2_s_idx]
                    p_cond_sup_tot[total_supply_val] += p_X_given_Y[x1_s_idx, x2_s_idx]
            ax.bar(total_coords, p_cond_sup_tot, alpha=0.7, width=0.8)
            ax.set_title(f"P(X_tot|Y=({Y1_coords[y1_idx]},{Y2_coords[y2_idx]}))")
            ax.set_xlabel('X1+X2');
            if k_ax == 0: ax.set_ylabel('P')

        plt.suptitle(f"Conditional Supply Total PMFs P(X1+X2 | Y1,Y2) - {plan_label_suffix}")
        plt.show()
    else:
        print(f"Skipping Section 6 for {plan_label_suffix} as no input demand pairs were provided or n=0.")

    # ---------------------------
    # 7. Plot P(X1,X2 | Y1,Y2) for random demand pairs
    # ---------------------------
    # print(f"\n--- Section 7 (for {plan_label_suffix}): Plotting UNDERSUPPLY P(X | Y) ---")

    # if not (d_pairs_coupling_input and len(d_pairs_coupling_input) > 0 and n > 0):
    #     print(f"Skipping Section 7 for {plan_label_suffix} as no demand pairs were provided or n=0.")
    #     return

    # num_coupling_samples_actual = len(d_pairs_coupling_input)
    # num_coupling_cols = min(num_coupling_samples_actual, 5)
    # num_coupling_rows = (num_coupling_samples_actual + num_coupling_cols - 1) // num_coupling_cols

    # fig_coupling, axes_coupling = plt.subplots(num_coupling_rows, num_coupling_cols,
    #                                            figsize=(18, 7), constrained_layout=True, squeeze=False)
    # axes_coupling_flat = axes_coupling.flatten()

    # for k_ax, (y1_idx, y2_idx) in enumerate(d_pairs_coupling_input):
    #     ax = axes_coupling_flat[k_ax]

    #     # Full conditional PMF P(X|Y)
    #     p_X_given_Y_2D = P_final_4D_current[:, :, y1_idx, y2_idx] / (p_Y[y1_idx, y2_idx] + 1e-12)

    #     # --------------------------------------------------------------
    #     # *** NEW LOGIC ***   Keep only UNDERSUPPLY cells (X₁+X₂ < Y₁+Y₂)
    #     # --------------------------------------------------------------
    #     demand_total = Y1_coords[y1_idx] + Y2_coords[y2_idx]
    #     undersupply_mask = Sup_tot_grid < demand_total  # n × n boolean
    #     p_X_under_2D = p_X_given_Y_2D * undersupply_mask

    #     remaining_mass = p_X_under_2D.sum()
    #     if remaining_mass > 0:
    #         p_X_under_2D /= remaining_mass  # renormalise
    #     else:
    #         # If no undersupply states exist for this (y₁,y₂) pair, keep zeros;
    #         # the plot will appear empty (white) which is informative in itself.
    #         pass

    #     im_c = ax.imshow(p_X_under_2D, origin='lower', aspect='auto', cmap='plasma')
    #     ax.set_title(f"P(X | Y=({Y1_coords[y1_idx]},{Y2_coords[y2_idx]}))\n(undersupply only)")
    #     ax.set_xlabel("X2 Index"); ax.set_ylabel("X1 Index")
    #     fig_coupling.colorbar(im_c, ax=ax, label="Prob (renormalised)")

    # # Hide any unused sub‑plots
    # for k_ax_extra in range(num_coupling_samples_actual, len(axes_coupling_flat)):
    #     axes_coupling_flat[k_ax_extra].axis('off')

    # plt.suptitle(f"Conditional Supply PMFs (undersupply) P(X₁,X₂ | Y₁,Y₂) - {plan_label_suffix}")
    # plt.show()
    print(f"\n--- Section 7 (for {plan_label_suffix}): Plotting P(X | Y) for Random Demand Pairs ---")
    # Check if input list is valid and non-empty
    if d_pairs_coupling_input and len(d_pairs_coupling_input) > 0 and n > 0:

        num_coupling_samples_actual = len(d_pairs_coupling_input)
        num_coupling_cols = min(num_coupling_samples_actual, 5)
        num_coupling_rows = (num_coupling_samples_actual + num_coupling_cols - 1) // num_coupling_cols

        fig_coupling, axes_coupling = plt.subplots(num_coupling_rows, num_coupling_cols,
                                                  figsize=(18,7),
                                                  constrained_layout=True, squeeze=False)
        axes_coupling_flat = axes_coupling.flatten()

        for k_ax, (y1_idx, y2_idx) in enumerate(d_pairs_coupling_input): # USE input list
            ax = axes_coupling_flat[k_ax]
            p_X_given_Y_2D = P_final_4D_current[:, :, y1_idx, y2_idx] / (p_Y[y1_idx, y2_idx] + 1e-12)
            im_c = ax.imshow(p_X_given_Y_2D, origin='lower', aspect='auto', cmap='plasma')
            ax.set_title(f"P(X|Y=({Y1_coords[y1_idx]},{Y2_coords[y2_idx]}))")
            ax.set_xlabel("X2 Index"); ax.set_ylabel("X1 Index")
            fig_coupling.colorbar(im_c, ax=ax, label="Prob")

        for k_ax_extra in range(num_coupling_samples_actual, len(axes_coupling_flat)):
            axes_coupling_flat[k_ax_extra].axis('off')
        plt.suptitle(f"Conditional Supply PMFs P(X1,X2 | Y1,Y2) - {plan_label_suffix}")
        plt.show()
    else:
        print(f"Skipping Section 7 for {plan_label_suffix} as no input demand pairs were provided or n=0.")

    # --- Sections 8, 9, 10 remain unchanged ---

    # ---------------------------
    # 8. Plot Derived 2D Marginals from the 4D Plan P(X_i, Y_j)
    # ---------------------------
    print(f"\n--- Section 8 (for {plan_label_suffix}): Plotting Derived 2D Cross-Marginals P(Xi, Yj) ---")
    if n > 0:
        p_S1D1 = P_final_4D_current.sum(axis=(1, 3))
        p_S1D2 = P_final_4D_current.sum(axis=(1, 2))
        p_S2D1 = P_final_4D_current.sum(axis=(0, 3))
        p_S2D2 = P_final_4D_current.sum(axis=(0, 2))

        fig_derived_marg, axes_derived_marg = plt.subplots(2, 2, figsize=(11, 10), constrained_layout=True)
        fig_derived_marg.suptitle(f"Derived 2D Cross-Marginals P(Xi,Yj) - {plan_label_suffix}", fontsize=14)

        im_s1d1 = axes_derived_marg[0, 0].imshow(p_S1D1, origin='lower', aspect='auto', cmap='viridis')
        axes_derived_marg[0, 0].set_title("P(X1, Y1)"); axes_derived_marg[0, 0].set_xlabel("Y1 Index"); axes_derived_marg[0, 0].set_ylabel("X1 Index")
        fig_derived_marg.colorbar(im_s1d1, ax=axes_derived_marg[0, 0])

        im_s1d2 = axes_derived_marg[0, 1].imshow(p_S1D2, origin='lower', aspect='auto', cmap='viridis')
        axes_derived_marg[0, 1].set_title("P(X1, Y2)"); axes_derived_marg[0, 1].set_xlabel("Y2 Index"); axes_derived_marg[0, 1].set_ylabel("X1 Index")
        fig_derived_marg.colorbar(im_s1d2, ax=axes_derived_marg[0, 1])

        im_s2d1 = axes_derived_marg[1, 0].imshow(p_S2D1, origin='lower', aspect='auto', cmap='viridis')
        axes_derived_marg[1, 0].set_title("P(X2, Y1)"); axes_derived_marg[1, 0].set_xlabel("Y1 Index"); axes_derived_marg[1, 0].set_ylabel("X2 Index")
        fig_derived_marg.colorbar(im_s2d1, ax=axes_derived_marg[1, 0])

        im_s2d2 = axes_derived_marg[1, 1].imshow(p_S2D2, origin='lower', aspect='auto', cmap='viridis')
        axes_derived_marg[1, 1].set_title("P(X2, Y2)"); axes_derived_marg[1, 1].set_xlabel("Y2 Index"); axes_derived_marg[1, 1].set_ylabel("X2 Index")
        fig_derived_marg.colorbar(im_s2d2, ax=axes_derived_marg[1, 1])
        plt.show()
    else:
        print(f"Skipping Section 8 for {plan_label_suffix} (n=0).")

    # ---------------------------
    # 9. Induced supply & demand marginals from P vs Original
    # ---------------------------
    print(f"\n--- Section 9 (for {plan_label_suffix}): Comparing Original vs. Induced Supply/Demand Marginals ---")
    if n > 0:
        p_S12_from_plan = P_final_4D_current.sum(axis=(2, 3))
        p_D12_from_plan = P_final_4D_current.sum(axis=(0, 1))

        fig_induced_marg, axes_induced_marg = plt.subplots(2, 2, figsize=(12, 10), constrained_layout=True)
        fig_induced_marg.suptitle(f"Original vs. Induced 2D Marginals - {plan_label_suffix}", fontsize=16)

        im_s_orig = axes_induced_marg[0,0].imshow(p_X, origin='lower', aspect='auto', cmap='viridis') # p_X is original
        axes_induced_marg[0,0].set_title("Original Supply PMF p_X(x1,x2)")
        axes_induced_marg[0,0].set_xlabel("X2 Index"); axes_induced_marg[0,0].set_ylabel("X1 Index")
        fig_induced_marg.colorbar(im_s_orig, ax=axes_induced_marg[0,0])

        im_s_plan = axes_induced_marg[0,1].imshow(p_S12_from_plan, origin='lower', aspect='auto', cmap='viridis')
        axes_induced_marg[0,1].set_title(f"Induced Supply P(x1,x2) ({plan_label_suffix})")
        axes_induced_marg[0,1].set_xlabel("X2 Index"); axes_induced_marg[0,1].set_ylabel("X1 Index")
        fig_induced_marg.colorbar(im_s_plan, ax=axes_induced_marg[0,1])

        im_d_orig = axes_induced_marg[1,0].imshow(p_Y, origin='lower', aspect='auto', cmap='magma') # p_Y is original
        axes_induced_marg[1,0].set_title("Original Demand PMF p_Y(y1,y2)")
        axes_induced_marg[1,0].set_xlabel("Y2 Index"); axes_induced_marg[1,0].set_ylabel("Y1 Index")
        fig_induced_marg.colorbar(im_d_orig, ax=axes_induced_marg[1,0])

        im_d_plan = axes_induced_marg[1,1].imshow(p_D12_from_plan, origin='lower', aspect='auto', cmap='magma')
        axes_induced_marg[1,1].set_title(f"Induced Demand P(y1,y2) ({plan_label_suffix})")
        axes_induced_marg[1,1].set_xlabel("Y2 Index"); axes_induced_marg[1,1].set_ylabel("Y1 Index")
        fig_induced_marg.colorbar(im_d_plan, ax=axes_induced_marg[1,1])
        plt.show()

        print(f"Max diff original p_X_flat vs induced ({plan_label_suffix}): {np.abs(p_X_flat - p_S12_from_plan.flatten()).max():.2e}")
        print(f"Max diff original p_Y_flat vs induced ({plan_label_suffix}): {np.abs(p_Y_flat - p_D12_from_plan.flatten()).max():.2e}")
    else:
        print(f"Skipping Section 9 for {plan_label_suffix} (n=0).")

    # ---------------------------
    # 10. Final Cost Calculations and Event Probabilities
    # ---------------------------
    print(f"\n--- Section 10 (for {plan_label_suffix}): Final Cost Calculations and Event Probabilities ---")
    if n > 0 :
        # C_flat_unscaled is passed or accessed globally
        exp_total_cost = np.sum(P_flat_current_plan * C_flat_unscaled)
        delta_sup_dem = Dem_tot_flat[None, :] - Sup_tot_flat[:, None] # Dem_tot_flat, Sup_tot_flat are global
        exp_under_capacity_cost = np.sum(P_flat_current_plan * cu * np.maximum(delta_sup_dem, 0)) # cu is global
        exp_over_capacity_cost = np.sum(P_flat_current_plan * co * np.maximum(-delta_sup_dem, 0)) # co is global
        cost_component_adj = ca * np.maximum(0, Sup_tot_flat[:,None] - threshold) # ca, threshold global
        exp_adj_cost = np.sum(P_flat_current_plan * cost_component_adj)

        print(f"\nPlan Type: {plan_label_suffix}")
        print(f"Total expected cost:                {exp_total_cost:.3f}")
        print(f"  Under-capacity (lost sales):    {exp_under_capacity_cost:.3f}")
        print(f"  Over-capacity (holding):        {exp_over_capacity_cost:.3f}")
        print(f"  Threshold-adjustment:           {exp_adj_cost:.3f}")

        mask_less_final = Sup_tot_flat[:, None] < Dem_tot_flat[None, :]
        mask_greater_final = Sup_tot_flat[:, None] > Dem_tot_flat[None, :]

        p_less_final = P_flat_current_plan[mask_less_final].sum()
        p_greater_final = P_flat_current_plan[mask_greater_final].sum()
        p_equal_final = 1.0 - p_less_final - p_greater_final # Approx, due to P sum might not be exactly 1

        print(f"\nEvent Probabilities from Plan ({plan_label_suffix}):")
        print(f"  P(X_tot < Y_tot) (Stock-out):   {p_less_final:.6f}")
        print(f"  P(X_tot > Y_tot) (Over-supply): {p_greater_final:.6f}")
        print(f"  P(X_tot = Y_tot) (Balanced):    {p_equal_final:.6f}")
        print(f"  Sum of event probabilities:     {p_less_final + p_greater_final + p_equal_final:.6f}")
    else:
        print(f"Skipping Section 10 for {plan_label_suffix} (n=0).")
    print(f"=== End of Detailed Analysis & Plots for: {plan_label_suffix} ===")

# --------------------------------------------------
# 2-bis. Part 1: Unconstrained (No QoS) Plan
# --------------------------------------------------
print("\n--- Section 2-bis Part 1: Calculating Unconstrained OT Plan (2-way Sinkhorn) ---")
P_flat_noqos = sinkhorn_two_way(pi_I, p_X_flat, p_Y_flat)
p_less_noqos = P_flat_noqos[mask_under].sum() if P_flat_noqos is not None else np.nan
print(f"Stock-out P(X_tot < Y_tot) for NO QOS plan: {p_less_noqos:.6f}")

# Perform analysis and plotting for the No QoS plan
perform_plan_analysis_and_plotting(P_flat_noqos, "Unconstrained (No QoS) Plan",
                                   d_pairs_for_hist, d_pairs_for_coupling)


# ---------------------------------------------------------------------------------
# 5-c. QoS Feasibility Bounds (Full 2D Marginals) - MORE ACCURATE CHECK
# ---------------------------------------------------------------------------------
print("\n--- Section 5-c: QoS Feasibility Bounds (Full 2D Marginals) ---")
print(f"Grid size n = {n}. LP variables: n^4 = {n**4}. LP constraints: 2*n^2 = {2*n**2}.")
if n > 10: # Adjusted from 20 for practical runtimes with LP
    print("WARNING: This LP can be computationally intensive for n > 10.")

time_start_lp = time.time()

N_supply_pairs = m_pairs
N_demand_pairs = m_pairs
N_lp_vars = N_supply_pairs * N_demand_pairs
N_lp_cons = N_supply_pairs + N_demand_pairs

b_eq_full = np.concatenate([p_X_flat, p_Y_flat])
A_eq_full_sparse = lil_matrix((N_lp_cons, N_lp_vars), dtype=float)

for s_idx in range(N_supply_pairs):
    A_eq_full_sparse[s_idx, s_idx * N_demand_pairs : (s_idx + 1) * N_demand_pairs] = 1.0
for d_idx in range(N_demand_pairs):
    con_idx = N_supply_pairs + d_idx
    A_eq_full_sparse[con_idx, d_idx::N_demand_pairs] = 1.0

A_eq_full_csr = A_eq_full_sparse.tocsr()
print("A_eq matrix for LP constructed.")

c_objective_stockout = np.zeros(N_lp_vars)
for s_idx in range(N_supply_pairs):
    for d_idx in range(N_demand_pairs):
        var_idx = s_idx * N_demand_pairs + d_idx
        if Sup_tot_flat[s_idx] < Dem_tot_flat[d_idx]:
            c_objective_stockout[var_idx] = 1.0
print("LP objective vector (c_objective_stockout) constructed.")

S_min_under_full, S_max_under_full = -np.inf, np.inf # Initialize
res_min_under_full_success, res_max_under_full_success = False, False

if n <= 30: # Limit LP solving for larger n to avoid excessive computation time.
    print("Solving for S_min_under_full (minimum P(X_tot < Y_tot))...")
    res_min_under_full = linprog(c_objective_stockout, A_eq=A_eq_full_csr, b_eq=b_eq_full,
                                bounds=(0, None), method='highs')
    if res_min_under_full.success:
        S_min_under_full = res_min_under_full.fun
        res_min_under_full_success = True
    else:
        print(f"⚠️ WARNING: LP for S_min_under_full FAILED: {res_min_under_full.message}")

    print("Solving for S_max_under_full (maximum P(X_tot < Y_tot))...")
    res_max_under_full = linprog(-c_objective_stockout, A_eq=A_eq_full_csr, b_eq=b_eq_full,
                                bounds=(0, None), method='highs')
    if res_max_under_full.success:
        S_max_under_full = -res_max_under_full.fun
        res_max_under_full_success = True
    else:
        print(f"⚠️ WARNING: LP for S_max_under_full FAILED: {res_max_under_full.message}")
else:
    print(f"Skipping LP solving for S_min/S_max_under_full due to n={n} > 15.")


time_end_lp = time.time()
print(f"LP solving (if attempted) took {time_end_lp - time_start_lp:.2f} seconds.")

P_flat_qos_is_feasible_based_on_2D_marginals = False
if res_min_under_full_success and res_max_under_full_success:
    if np.isfinite(S_min_under_full) and np.isfinite(S_max_under_full) and S_min_under_full <= S_max_under_full + 1e-9: # Allow small tolerance
        print(f"\nAchievable P(X_tot < Y_tot) based on full 2D marginals ∈ [{S_min_under_full:.6f}, {S_max_under_full:.6f}]")
        if S_min_under_full - 1e-9 <= beta_target <= S_max_under_full + 1e-9: # Allow small tolerance
            print(f"✅ Desired β*={beta_target:.4f} IS WITHIN this feasible range.")
            P_flat_qos_is_feasible_based_on_2D_marginals = True
        else:
            print(f"⚠️ Desired β*={beta_target:.4f} IS OUTSIDE this feasible range.")
    else:
        print("⚠️ Could not reliably determine feasibility bounds from full 2D marginals due to LP solver issues or invalid bounds.")
elif n <= 15 : # If LP was attempted but failed for one or both
    print("⚠️ Could not reliably determine feasibility bounds from full 2D marginals due to LP solver issues for S_min and/or S_max.")
else: # If LP was skipped
    print("⚠️ Feasibility bounds from full 2D marginals not determined as LP solving was skipped.")
    # Heuristic: check if unconstrained plan already meets QoS or is close, to decide if to try 3-way.
    if p_less_noqos is not None and p_less_noqos <= beta_target + 0.05 : # If unconstrained is already good or close
        print("Heuristically assuming QoS might be feasible as unconstrained plan is close/better than target.")
        P_flat_qos_is_feasible_based_on_2D_marginals = True # Tentative
    else:
        print("Unconstrained plan does not meet QoS target; proceeding with caution for 3-way Sinkhorn if LP was skipped.")
        # Potentially still try 3-way Sinkhorn but with understanding that true feasibility is unknown.
        # For this script, let's be conservative if LP bounds are unknown for large n.
        # P_flat_qos_is_feasible_based_on_2D_marginals = False # Or True, if we always want to try
        if p_less_noqos > beta_target: # Only attempt 3-way if no_qos plan is worse than target
            P_flat_qos_is_feasible_based_on_2D_marginals = True # Allow attempt
        else:
            P_flat_qos_is_feasible_based_on_2D_marginals = False


# ------------------------------------------------------------------
# 2-bis. Part 2: QoS Constrained Plan (Conditional) & Final Plan Selection
# ------------------------------------------------------------------
print("\n--- Section 2-bis Part 2: Calculating QoS Constrained Plan (Conditional) ---")
P_flat_qos = None
p_less_qos = np.nan
qos_plan_valid_and_computed = False

if P_flat_qos_is_feasible_based_on_2D_marginals:
    print(f"Attempting to compute QoS-compliant plan (target P[X_tot<Y_tot] <= {beta_target:.6f})...")
    P_flat_qos = sinkhorn_three_way(pi_I, p_X_flat, p_Y_flat,
                                    mask_under, beta_target)
    if P_flat_qos is not None and P_flat_qos.sum() > 1e-9 :
        p_less_qos = P_flat_qos[mask_under].sum()
        print(f"Stock-out P(X_tot < Y_tot) for QOS plan (3-way): {p_less_qos:.6f}")
        if not (p_less_qos <= beta_target + 1e-5): # Allow small tolerance
            print(f"⚠️ QoS plan from 3-way Sinkhorn ({p_less_qos:.6f}) did not meet target ({beta_target:.6f}).")
            # P_flat_qos = None # Optionally invalidate if strict adherence is required. Here we keep it for analysis.
            qos_plan_valid_and_computed = True # It was computed, even if slightly off target
        else:
            qos_plan_valid_and_computed = True
            print(f"✅ QoS plan from 3-way Sinkhorn ({p_less_qos:.6f}) meets target ({beta_target:.6f}).")
    else:
        print("⚠️ QoS plan (3-way Sinkhorn) computation resulted in an invalid plan (e.g. all zeros).")
        P_flat_qos = None
else:
    print("Skipping QoS-compliant plan computation due to prior infeasibility assessment, LP failure, or n too large for LP and unconstrained plan already met target.")

# Perform analysis and plotting for the QoS plan if computed
if P_flat_qos is not None and qos_plan_valid_and_computed:
    perform_plan_analysis_and_plotting(P_flat_qos, "QoS Constrained Plan",
                                       d_pairs_for_hist, d_pairs_for_coupling)
elif P_flat_qos_is_feasible_based_on_2D_marginals: # It was attempted but failed
    print(f"QoS plan was attempted but resulted in an invalid plan. No plots/analysis for QoS plan.")
else: # It was not attempted
    print(f"QoS plan was not attempted. No plots/analysis for QoS plan.")


# Determine final plan for summary (though detailed analysis for both is done)
P_flat_final = None
final_plan_type = "None"

if P_flat_qos is not None and qos_plan_valid_and_computed and (p_less_qos <= beta_target + 1e-5) :
    P_flat_final = P_flat_qos
    final_plan_type = "QoS (3-way Sinkhorn)"
    print("\n>>> Final selected plan for overall summary: QoS (3-way Sinkhorn).")
elif P_flat_noqos is not None:
    P_flat_final = P_flat_noqos
    final_plan_type = "No QoS (2-way Sinkhorn)"
    print("\n>>> Final selected plan for overall summary: No QoS (2-way Sinkhorn) (QoS plan not available, not valid, or did not meet target).")
else:
    print("\n>>> No valid plan could be computed.")


print("\n============== FINAL STOCK-OUT PROBABILITY COMPARISON ==============")
print(f"Target β* (desired P[X_tot<Y_tot]) : {beta_target:.6f}")
print(f"Without QoS constraint (2-way)        : {p_less_noqos:.6f}")
if qos_plan_valid_and_computed : # P_flat_qos is not None and np.isfinite(p_less_qos):
    print(f"With    QoS constraint (3-way)        : {p_less_qos:.6f} (if computed and valid)")
else:
    print(f"With    QoS constraint (3-way)        : Not available or not validly computed.")

if P_flat_final is not None:
    print(f"Final selected plan type              : {final_plan_type}")
    print(f"Final selected plan P[X_tot<Y_tot]    : {P_flat_final[mask_under].sum():.6f}")
else:
    print("No final plan selected.")
print("================================================================\n")


# ---------------------------------------------------------------
# (Optional) Evaluate a grid of α values based on Full 2D Marginal feasibility
# ---------------------------------------------------------------
if res_min_under_full_success : # Check if S_min_under_full was successfully computed
    print("\n--- (Optional) Alpha Grid Evaluation based on Full 2D Marginal Feasibility ---")
    alpha_grid = np.linspace(0.0, 1.0, 11)
    # Feasibility for alpha means 1 - beta_target >= S_min_under_full, so beta_target <= 1 - S_min_under_full
    # Or, P(X_tot >= Y_tot) = alpha_target. P(X_tot < Y_tot) = 1 - alpha_target = beta_target.
    # We need S_min_under_full <= beta_target.
    # So, S_min_under_full <= 1 - alpha_val
    # alpha_val <= 1 - S_min_under_full

    feasible_alpha_full = alpha_grid <= (1.0 - S_min_under_full + 1e-9) # add tolerance

    if res_max_under_full_success:
        print(f"Feasibility for P(X_tot < Y_tot) is in [{S_min_under_full:.4f}, {S_max_under_full:.4f}]")
    else:
        print(f"Min P(X_tot < Y_tot) is {S_min_under_full:.4f} (S_max not reliably computed).")
    print(f"This means target service level alpha = P(X_tot >= Y_tot) must be <= {1.0 - S_min_under_full:.4f}")

    for alpha_val, is_feasible in zip(alpha_grid, feasible_alpha_full):
        print(f" α = {alpha_val:.2f} (target stock-out {1-alpha_val:.2f}): Feasible based on LP? {'YES' if is_feasible else 'NO'}")

    alpha_example = 0.90 # Example
    print(f"\nExample: α = {alpha_example:.2f} (target P[X_tot<Y_tot] <= {1-alpha_example:.2f})")
    print("   • Feasible based on 2D marginals (LP)?",
          "YES" if alpha_example <= 1.0 - S_min_under_full + 1e-9 else "NO",
          f"(needs α ≤ {1.0 - S_min_under_full:.4f})")
else:
    print("\nSkipping Alpha Grid Evaluation as S_min_under_full was not successfully computed via LP.")


print("\n--- Script Finished ---")

In [None]:
#@title Monte Carlo Simulations Against Benchmarks
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import factorial
import time
import matplotlib # Added for colormaps

# --- PARAMETERS ---
# Parameters from "Script B" which used 4D cost evaluation primarily
N_GRID_POINTS = 26
N_SIMULATION_DAYS = 100000
N_SAA_SAMPLES = 1000 # Number of demand samples for SAA policy (can be slow with 4D cost)

# Supply Poisson parameters (from Script B)
# These tuples will now be interpreted directly by the modified bipoisson
# as (mean_independent_part1, mean_independent_part2, mean_shared_part) for each mode
L_S_PARAM = [(11, 6, 1), (20, 11, 2)]
W_S_PARAM = 0.4
L_S_PARAM = [(11,6,1), (20, 12, 1)]
W_S_PARAM = 0.5

# Demand Poisson parameters (config - can be overridden by search)
L_D_PARAM = [(5, 2,1), (14, 20, 2)]
W_D_PARAM= 0.4


# # Demand Poisson parameters (from Script B)
# L_D_PARAM = [(20, 10, 2), (4, 15, 5)]
# W_D_PARAM = 0.5

# Cost parameters (from Script B - co, cu for mismatch; ca, threshold for capacity)
# These define the 4D cost matrix C
CO_MISMATCH, CU_MISMATCH = 1, 10
THRESHOLD_CAP, CA_CAP = 20, 5

# Parameters for the 1D aggregate cost function (used internally by FPD-OT TopK, and potentially STN, SAA if simplified)
# For consistency, let's use the same underlying cost rates but apply them to totals
C_O_AGG, C_U_AGG = CO_MISMATCH, CU_MISMATCH # Use same unit costs for aggregate mismatch
X_THRESH_AGG, C_T_AGG = THRESHOLD_CAP, CA_CAP # Use same for aggregate threshold logic


# FPD-OT parameters
EPSILON_OT = 0.1 # From Script B
K_TOP_FPD = 5   # Number of top candidates for FPD-OT TopK policy

# Simulation RNG
RNG = np.random.default_rng(42)

# --- HELPER FUNCTIONS (FROM USER'S COMPACT SCRIPT) ---
def bipoisson(n_points, l1, l2, l12): # MODIFIED to match Script 2's bivar_poisson_pmf logic
    """
    Generates a bivariate Poisson PMF.
    l1, l2, l12 are treated as the direct means of three underlying
    independent Poisson variables.
    P(X=i, Y=j) = sum_{k=0}^{min(i,j)} [P(U=i-k;l1) * P(V=j-k;l2) * P(W=k;l12)]
    where U, V, W are independent.
    """
    pmf = np.zeros((n_points, n_points))
    # l1p, l2p = max(0, l1 - l12), max(0, l2 - l12) # REMOVED - l1, l2 are now direct

    # max_fact_arg buffer should use l1, l2, l12 directly as they define the effective range
    max_fact_arg = n_points + int(l1 + l2 + l12) + 20 # Adjusted buffer slightly, uses direct lambdas
    fact = factorial(np.arange(max_fact_arg), exact=False) # Use float factorials for potentially larger intermediate values
    fact[0] = 1.0 # Ensure factorial(0)=1

    coef = np.exp(-(l1 + l2 + l12)) # MODIFIED: uses direct l1, l2, l12
    for i_val_idx in range(n_points):
        for j_val_idx in range(n_points):
            current_sum_val = 0.0
            for k_val in range(min(i_val_idx, j_val_idx) + 1):
                idx1, idx2, idx_k = i_val_idx - k_val, j_val_idx - k_val, k_val
                if not (0 <= idx1 < max_fact_arg and 0 <= idx2 < max_fact_arg and 0 <= idx_k < max_fact_arg):
                    continue

                term1_num = l1**idx1 # MODIFIED: uses l1 directly (was l1p)
                term1_den = fact[idx1]
                term1 = term1_num / term1_den if term1_den > 1e-100 else (1.0 if abs(term1_num) < 1e-9 else float('inf'))


                term2_num = l2**idx2 # MODIFIED: uses l2 directly (was l2p)
                term2_den = fact[idx2]
                term2 = term2_num / term2_den if term2_den > 1e-100 else (1.0 if abs(term2_num) < 1e-9 else float('inf'))

                term3_num = l12**k_val # l12 is used as is
                term3_den = fact[idx_k]
                term3 = term3_num / term3_den if term3_den > 1e-100 else (1.0 if abs(term3_num) < 1e-9 else float('inf'))

                if term1 == float('inf') or term2 == float('inf') or term3 == float('inf'):
                    # This might indicate that l1,l2,l12 are too large for n_points, leading to overflow with powers
                    # or underflow with factorials if not handled by logs (which adds complexity)
                    # print(f"Warning: Inf term encountered for i={i_val_idx}, j={j_val_idx}, k={k_val} with lambdas {l1,l2,l12}")
                    continue
                current_sum_val += term1 * term2 * term3
            pmf[i_val_idx, j_val_idx] = coef * current_sum_val

    current_pmf_sum = pmf.sum()
    if current_pmf_sum > 1e-9:
        return pmf / current_pmf_sum
    else:
        # This case is more likely if lambdas are very small or n_points is small
        # print(f"Warning: bipoisson sum for l1,l2,l12=({l1},{l2},{l12}) is very small: {current_pmf_sum}. Returning uniform.")
        return np.ones((n_points,n_points))/(n_points*n_points if n_points > 0 else 1.0)


def bimodal(n_points, L_params, w_mix): # This function remains structurally the same
    # It now calls the modified bipoisson function
    pmf1 = bipoisson(n_points, *L_params[0])
    pmf2 = bipoisson(n_points, *L_params[1])
    bm = w_mix * pmf1 + (1 - w_mix) * pmf2
    current_bm_sum = bm.sum()
    return bm / current_bm_sum if current_bm_sum > 1e-9 else np.ones((n_points,n_points))/(n_points*n_points if n_points > 0 else 1)

def plot_heatmap(M, title, xlabel="x-index", ylabel="y-index",
                 cmap="viridis", figsize=(6, 5), cbar=True):
    """Small wrapper around matplotlib.im show for consistent heat-maps."""
    plt.figure(figsize=figsize)
    im = plt.imshow(M, origin="lower", cmap=cmap, aspect="auto")
    if cbar:
        plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.tight_layout()
    plt.show()

# --- 1. GENERATE MARGINAL DISTRIBUTIONS (p_X, p_Y) ---
print("Generating marginal distributions...")
n = N_GRID_POINTS
coords = np.arange(n)

p_X = bimodal(n, L_S_PARAM, W_S_PARAM)  # 2D Supply PMF P(X1, X2)
p_Y = bimodal(n, L_D_PARAM, W_D_PARAM)  # 2D Demand PMF P(Y1, Y2)
# ---------- VISUAL: SUPPLY & DEMAND MARGINALS ----------
plot_heatmap(p_X, "Supply joint PMF  p_X(x1, x2)", "x2", "x1")
plot_heatmap(p_Y, "Demand joint PMF  p_Y(y1, y2)", "y2", "y1")

# Flattened PMFs for Sinkhorn & sampling
p_X_flat_sinkhorn = p_X.flatten()
p_Y_flat_sinkhorn = p_Y.flatten()

p_X_flat_sampling = np.maximum(0, p_X_flat_sinkhorn.copy())
sum_px_flat = p_X_flat_sampling.sum()
if sum_px_flat > 1e-9: p_X_flat_sampling /= sum_px_flat
else: p_X_flat_sampling = np.ones_like(p_X_flat_sampling) / p_X_flat_sampling.size

p_Y_flat_sampling = np.maximum(0, p_Y_flat_sinkhorn.copy())
sum_py_flat = p_Y_flat_sampling.sum()
if sum_py_flat > 1e-9: p_Y_flat_sampling /= sum_py_flat
else: p_Y_flat_sampling = np.ones_like(p_Y_flat_sampling) / p_Y_flat_sampling.size
print("Marginal distributions generated.")

# --- 2. COMPUTE 4D COST MATRIX C AND FPD-OT PLAN P ---
print("Computing 4D Cost Matrix C and FPD-OT plan P...")
start_fpd_time = time.time()

X1_coords_mesh_fpd, X2_coords_mesh_fpd = np.meshgrid(coords, coords, indexing='ij')
Y1_coords_mesh_fpd, Y2_coords_mesh_fpd = np.meshgrid(coords, coords, indexing='ij')

Sup_tot_grid_fpd = X1_coords_mesh_fpd + X2_coords_mesh_fpd
Dem_tot_grid_fpd = Y1_coords_mesh_fpd + Y2_coords_mesh_fpd

# Full 4D Cost Matrix C[x1,x2,y1,y2]
C_4D = (np.maximum(0, Sup_tot_grid_fpd[..., None, None] - Dem_tot_grid_fpd[None, None, ...]) * CO_MISMATCH +
        np.maximum(0, Dem_tot_grid_fpd[None, None, ...] - Sup_tot_grid_fpd[..., None, None]) * CU_MISMATCH +
        CA_CAP * np.maximum(0, Sup_tot_grid_fpd[..., None, None] - THRESHOLD_CAP))

C_4D_scaled = C_4D / (C_4D.max() if C_4D.max() > 0 else 1.0)

m_flat = n * n
C_4D_flat_scaled = C_4D_scaled.reshape(m_flat, m_flat)
phi_base_measure = np.ones_like(C_4D_flat_scaled) / (m_flat * m_flat if m_flat > 0 else 1.0)

# Ideal plan for Sinkhorn (using scaled cost, as in Script B)
pi_I_sinkhorn = phi_base_measure * np.exp(-C_4D_flat_scaled / EPSILON_OT)
if pi_I_sinkhorn.sum() > 1e-9: pi_I_sinkhorn /= pi_I_sinkhorn.sum()
else: pi_I_sinkhorn = np.ones_like(pi_I_sinkhorn) / (m_flat*m_flat if m_flat > 0 else 1)


u_sink = np.ones(m_flat)
v_sink = np.ones(m_flat)
for iter_idx in range(2000):
    pi_I_T_u = pi_I_sinkhorn.T @ u_sink
    v_sink = p_Y_flat_sinkhorn / np.where(pi_I_T_u == 0, 1e-12, pi_I_T_u)
    u_prev_sink = u_sink.copy()
    pi_I_v = pi_I_sinkhorn @ v_sink
    u_sink = p_X_flat_sinkhorn / np.where(pi_I_v == 0, 1e-12, pi_I_v)
    if iter_idx > 50 and np.linalg.norm(u_sink - u_prev_sink) < 1e-9:
        print(f"  Sinkhorn converged at iteration {iter_idx}.")
        break
if iter_idx == 1999: print("  Sinkhorn reached max iterations.")

P_flat_final = np.diag(u_sink) @ pi_I_sinkhorn @ np.diag(v_sink)
if P_flat_final.sum() > 1e-9: P_flat_final /= P_flat_final.sum()
P_4D_optimal_plan = P_flat_final.reshape(n, n, n, n)
end_fpd_time = time.time()
print(f"4D Cost Matrix C and FPD-OT Plan P computed in {end_fpd_time - start_fpd_time:.2f} seconds.")

# ---------- VISUAL: OPTIMAL TRANSPORT PLAN & COST SLICES ----------
# Pick two representative demand pairs to slice on:
mid = n // 2
demo_slices = [(mid, mid),           # roughly median demand
               (0, 0),               # very low demand
               (n-1, n-1)]           # very high demand

for (y1_idx, y2_idx) in demo_slices:
    title_suffix = f"(y1={coords[y1_idx]}, y2={coords[y2_idx]})"
    # OPTIMAL PLAN P(x1,x2 | y1,y2)
    P_slice = P_4D_optimal_plan[:, :, y1_idx, y2_idx]
    plot_heatmap(P_slice,
                 f"Optimal plan conditional on demand {title_suffix}",
                 xlabel="x2", ylabel="x1")
    # COST C(x1,x2 | same y1,y2)
    C_slice = C_4D[:, :, y1_idx, y2_idx]
    plot_heatmap(C_slice,
                 f"Cost slice {title_suffix}",
                 xlabel="x2", ylabel="x1", cmap="magma")
# ---------- VISUAL: CROSS-MARGINALS ----------
# Dimensions are (x1, x2, y1, y2)
S1D1 = P_4D_optimal_plan.sum(axis=(1, 3))      # x2, y2 summed out → (x1, y1)
S2D2 = P_4D_optimal_plan.sum(axis=(0, 2))      # x1, y1 summed out → (x2, y2)
S1D2 = P_4D_optimal_plan.sum(axis=(1, 2))      # x2, y1 summed out → (x1, y2)
S2D1 = P_4D_optimal_plan.sum(axis=(0, 3))      # x1, y2 summed out → (x2, y1)

plot_heatmap(S1D1, "Cross-marginal  S1–D1  (x1 vs y1)", "y1", "x1")
plot_heatmap(S2D2, "Cross-marginal  S2–D2  (x2 vs y2)", "y2", "x2")
plot_heatmap(S1D2, "Cross-marginal  S1–D2  (x1 vs y2)", "y2", "x1")
plot_heatmap(S2D1, "Cross-marginal  S2–D1  (x2 vs y1)", "y1", "x2")



# --- 3. PRE-CALCULATE OTHER BENCHMARK POLICIES ---
print("Pre-calculating benchmark policy parameters...")
# For policies needing expectations from p_X and p_Y
E_X1 = np.sum(coords[:, None] * p_X)
E_X2 = np.sum(coords[None, :] * p_X)
E_Y1 = np.sum(coords[:, None] * p_Y)
E_Y2 = np.sum(coords[None, :] * p_Y)
E_X_tot = E_X1 + E_X2
E_Y_tot = E_Y1 + E_Y2
prop_X1_hist = E_X1 / E_X_tot if E_X_tot > 1e-9 else 0.5

# Policy 1: SMS (Static Mean Supply)
sms_policy = (min(n-1,max(0,int(round(E_X1)))), min(n-1,max(0,int(round(E_X2)))))
print(f"  SMS Policy: Supply {sms_policy}")

# Policy 2: Total Mean Match
tmm_x1 = int(round(E_Y_tot * prop_X1_hist))
tmm_x2 = int(round(E_Y_tot * (1 - prop_X1_hist)))
tmm_policy = (min(n-1,max(0,tmm_x1)), min(n-1,max(0,tmm_x2)))
print(f"  Total Mean Match Policy: Supply {tmm_policy} (Targeting Y_tot_mean={E_Y_tot:.2f})")

# Policy 3: Newsvendor (Total)
pmf_Y_tot_nv = np.zeros(2 * n - 1)
for i in range(n):
    for j in range(n): pmf_Y_tot_nv[coords[i] + coords[j]] += p_Y[i, j]
if pmf_Y_tot_nv.sum() > 1e-9: pmf_Y_tot_nv /= pmf_Y_tot_nv.sum()
cdf_Y_tot_nv = pmf_Y_tot_nv.cumsum()
cr_nv = CU_MISMATCH / (CO_MISMATCH + CU_MISMATCH) if (CO_MISMATCH + CU_MISMATCH) > 1e-9 else 0.5
target_total_supply_nv = np.searchsorted(cdf_Y_tot_nv, cr_nv)
nv_x1 = int(round(target_total_supply_nv * prop_X1_hist))
nv_x2 = int(round(target_total_supply_nv * (1 - prop_X1_hist)))
nv_policy = (min(n-1,max(0,nv_x1)), min(n-1,max(0,nv_x2)))
print(f"  Newsvendor (Total) Policy: Supply {nv_policy} (Targeting X_tot*={target_total_supply_nv})")

# Policy 4: SAA (Optimized for 4D Cost Matrix C_4D)
print(f"  Calculating SAA Policy (N_SAA={N_SAA_SAMPLES}, 4D Cost)... This will take time.")
start_saa_time = time.time()
saa_avg_costs_4d = np.full((n, n), float('inf'))
# Sample demand pairs (y1_idx, y2_idx)
saa_demand_samples_idx = [divmod(RNG.choice(m_flat, p=p_Y_flat_sampling), n) for _ in range(N_SAA_SAMPLES)]

for x1_idx_saa in range(n):
    for x2_idx_saa in range(n):
        current_saa_total_cost_4d = 0
        for y1_s_idx, y2_s_idx in saa_demand_samples_idx:
            current_saa_total_cost_4d += C_4D[x1_idx_saa, x2_idx_saa, y1_s_idx, y2_s_idx]
        saa_avg_costs_4d[x1_idx_saa, x2_idx_saa] = current_saa_total_cost_4d / N_SAA_SAMPLES
saa_x1_idx, saa_x2_idx = np.unravel_index(np.argmin(saa_avg_costs_4d), (n,n))
saa_policy = (saa_x1_idx, saa_x2_idx)
end_saa_time = time.time()
print(f"  SAA Policy (4D Cost): Supply {saa_policy} (Min Avg Cost: {np.min(saa_avg_costs_4d):.2f}, Took {end_saa_time-start_saa_time:.2f}s)")

# Pre-calculate FPD-OT Mean and Mode lookups
fpd_mean_x1_lookup = np.zeros((n,n))
fpd_mean_x2_lookup = np.zeros((n,n))
fpd_mode_x1_lookup = np.zeros((n,n), dtype=int)
fpd_mode_x2_lookup = np.zeros((n,n), dtype=int)
X1_mgrid, X2_mgrid = np.meshgrid(coords, coords, indexing='ij')
print("  Pre-calculating FPD-OT Mean/Mode lookups...")
for y1_idx_pre in range(n):
    for y2_idx_pre in range(n):
        p_cond_pre_slice = P_4D_optimal_plan[:, :, y1_idx_pre, y2_idx_pre]
        sum_p_cond_pre_slice = p_cond_pre_slice.sum()
        if sum_p_cond_pre_slice > 1e-12:
            p_cond_pre_norm = p_cond_pre_slice / sum_p_cond_pre_slice
        else:
            p_cond_pre_norm = np.ones((n, n)) / (n*n if n > 0 else 1.0)

        fpd_mean_x1_lookup[y1_idx_pre, y2_idx_pre] = np.sum(X1_mgrid * p_cond_pre_norm)
        fpd_mean_x2_lookup[y1_idx_pre, y2_idx_pre] = np.sum(X2_mgrid * p_cond_pre_norm)
        mode_idx_flat = np.argmax(p_cond_pre_norm)
        m_x1, m_x2 = divmod(mode_idx_flat, n)
        fpd_mode_x1_lookup[y1_idx_pre, y2_idx_pre] = m_x1
        fpd_mode_x2_lookup[y1_idx_pre, y2_idx_pre] = m_x2
print("Benchmark policy parameters pre-calculated.")

# --- 4. MONTE CARLO SIMULATION LOOP ---
policy_names_list = [
    "FPD-OT (Cond. Mean)", "FPD-OT (Cond. Mode)", "FPD-OT (Sampling)", "FPD-OT TopK",
    "Total Mean Match", "Newsvendor (Total)", "SAA", "SMS"
]
daily_costs_all = {name: np.zeros(N_SIMULATION_DAYS) for name in policy_names_list}
# NEW: Store daily supply decisions (x1_idx, x2_idx)
daily_supply_decisions_all = {name: np.zeros((N_SIMULATION_DAYS, 2), dtype=int) for name in policy_names_list}
supply_indices_flat = np.arange(n * n) # For FPD-OT sampling

print(f"\nRunning Monte Carlo simulation for {N_SIMULATION_DAYS} days...")
start_sim_loop_time = time.time()
REPORT_DAY = N_SIMULATION_DAYS // 100 # Report 1% of the time, or set to 0 to report first day, or N_SIMULATION_DAYS to never report in loop.

for day_iter in range(N_SIMULATION_DAYS):
    if (N_SIMULATION_DAYS // 10 > 0) and (day_iter + 1) % (N_SIMULATION_DAYS // 10) == 0 : # Avoid modulo by zero if N_SIMULATION_DAYS < 10
        print(f"  Simulating day {day_iter + 1}/{N_SIMULATION_DAYS}...")
    elif N_SIMULATION_DAYS < 10 and (day_iter == 0 or day_iter == N_SIMULATION_DAYS -1): # For very short sims, print start and end
         print(f"  Simulating day {day_iter + 1}/{N_SIMULATION_DAYS}...")


    demand_flat_idx = RNG.choice(m_flat, p=p_Y_flat_sampling)
    y1_idx_today, y2_idx_today = divmod(demand_flat_idx, n)

    p_cond_today_slice = P_4D_optimal_plan[:, :, y1_idx_today, y2_idx_today]
    p_cond_today_sum_slice = p_cond_today_slice.sum()

    if p_cond_today_sum_slice > 1e-12:
        p_cond_norm_today = p_cond_today_slice / p_cond_today_sum_slice
    else:
        p_cond_norm_today = np.ones((n,n)) / (n*n if n > 0 else 1.0)

    p_cond_flat_norm_today = np.maximum(0, p_cond_norm_today.flatten())
    current_sum_p_cond_flat_norm = p_cond_flat_norm_today.sum() # Store current sum
    if current_sum_p_cond_flat_norm > 1e-9: # Use stored sum
        p_cond_flat_norm_today /= current_sum_p_cond_flat_norm # Use stored sum
    else:
        p_cond_flat_norm_today = np.ones(m_flat)/(m_flat if m_flat > 0 else 1.0)

    mean_x1_today = int(round(fpd_mean_x1_lookup[y1_idx_today, y2_idx_today])) # Renamed for clarity
    mean_x2_today = int(round(fpd_mean_x2_lookup[y1_idx_today, y2_idx_today])) # Renamed for clarity
    fpd_mean_decision = (min(n-1,max(0,mean_x1_today)), min(n-1,max(0,mean_x2_today)))

    mode_x1_today = fpd_mode_x1_lookup[y1_idx_today, y2_idx_today] # Renamed for clarity
    mode_x2_today = fpd_mode_x2_lookup[y1_idx_today, y2_idx_today] # Renamed for clarity
    fpd_mode_decision = (mode_x1_today, mode_x2_today)


    samp_flat_idx = RNG.choice(supply_indices_flat, p=p_cond_flat_norm_today)
    s_x1, s_x2 = divmod(samp_flat_idx, n)
    fpd_sampling_decision = (s_x1, s_x2)

    yt_today_val = coords[y1_idx_today] + coords[y2_idx_today]
    actual_K_TOP_FPD = min(K_TOP_FPD, m_flat)
    fpd_topk_decision_current_day = fpd_mode_decision # Initialize to mode
    if actual_K_TOP_FPD > 0 :
        # Ensure p_cond_flat_norm_today is not all zeros before argsort
        if np.any(p_cond_flat_norm_today > 0):
             sorted_indices_topk = np.argsort(p_cond_flat_norm_today)[-actual_K_TOP_FPD:]
        else: # Handle case where all probabilities are zero (should be rare with proper normalization)
            sorted_indices_topk = np.array([np.argmax(p_cond_flat_norm_today)]) # Fallback to mode index if possible

    else:
        sorted_indices_topk = []


    best_cost_topk_internal = float('inf')

    if len(sorted_indices_topk) > 0:
        current_best_topk_decision_candidate = fpd_mode_decision # Initialize with mode
        for flat_idx_topk in sorted_indices_topk:
            xi_idx, xj_idx = divmod(flat_idx_topk, n)
            total_supply_val = coords[xi_idx] + coords[xj_idx]
            cost_internal = (C_O_AGG * max(0, total_supply_val - yt_today_val) +
                             C_U_AGG * max(0, yt_today_val - total_supply_val) +
                             C_T_AGG * max(0, total_supply_val - X_THRESH_AGG))
            if cost_internal < best_cost_topk_internal:
                best_cost_topk_internal = cost_internal
                current_best_topk_decision_candidate = (xi_idx, xj_idx)
        fpd_topk_decision_current_day = current_best_topk_decision_candidate # Assign the determined decision
    else: # If no topK candidates (e.g. K_TOP_FPD = 0 or p_cond_flat all zero)
        fpd_topk_decision_current_day = fpd_mode_decision # Default to mode


    all_policy_decisions = {
        "FPD-OT (Cond. Mean)": fpd_mean_decision,
        "FPD-OT (Cond. Mode)": fpd_mode_decision,
        "FPD-OT (Sampling)": fpd_sampling_decision,
        "FPD-OT TopK": fpd_topk_decision_current_day, # Use the decision determined for the current day
        "Total Mean Match": tmm_policy,
        "Newsvendor (Total)": nv_policy,
        "SAA": saa_policy,
        "SMS": sms_policy
    }
    if N_SIMULATION_DAYS >= 100: # Ensure N_SIMULATION_DAYS // 10 is at least 10
        print_interval = N_SIMULATION_DAYS // 10
    elif N_SIMULATION_DAYS >= 10: # For smaller N_SIMULATION_DAYS, print every day or every other day
        print_interval = 1 # Print every day if N_SIM_DAYS is 10-99
    # else: for N_SIMULATION_DAYS < 10, rely on existing REPORT_DAY or print every day

    if print_interval > 0 and (day_iter + 1) % print_interval == 0:
        print(f"\n---------------- DETAILED REPORT: Day {day_iter + 1}/{N_SIMULATION_DAYS} ----------------")
        y1_val_today = coords[y1_idx_today]
        y2_val_today = coords[y2_idx_today]
        print(f"Demand Realization (y1, y2): ({y1_val_today}, {y2_val_today})")
        print("Supply Decisions by Policy (x1_value, x2_value):")
        for policy_name_print, (x1_idx_print, x2_idx_print) in all_policy_decisions.items():
            # Ensure indices are valid before accessing coords
            safe_x1_idx_print = min(n-1, max(0, int(x1_idx_print)))
            safe_x2_idx_print = min(n-1, max(0, int(x2_idx_print)))
            x1_val_print = coords[safe_x1_idx_print]
            x2_val_print = coords[safe_x2_idx_print]
            print(f"  {policy_name_print:<25}: ({x1_val_print}, {x2_val_print}) -- Indices: ({safe_x1_idx_print}, {safe_x2_idx_print})")
        print("-----------------------------------------------------------------------\n")
    elif N_SIMULATION_DAYS < 10: # For very few days, print every day
        print(f"\n---------------- DETAILED REPORT: Day {day_iter + 1}/{N_SIMULATION_DAYS} ----------------")
        y1_val_today = coords[y1_idx_today]
        y2_val_today = coords[y2_idx_today]
        print(f"Demand Realization (y1, y2): ({y1_val_today}, {y2_val_today})")
        print("Supply Decisions by Policy (x1_value, x2_value):")
        for policy_name_print, (x1_idx_print, x2_idx_print) in all_policy_decisions.items():
            safe_x1_idx_print = min(n-1, max(0, int(x1_idx_print)))
            safe_x2_idx_print = min(n-1, max(0, int(x2_idx_print)))
            x1_val_print = coords[safe_x1_idx_print]
            x2_val_print = coords[safe_x2_idx_print]
            print(f"  {policy_name_print:<25}: ({x1_val_print}, {x2_val_print}) -- Indices: ({safe_x1_idx_print}, {safe_x2_idx_print})")
        print("-----------------------------------------------------------------------\n")

    if day_iter == REPORT_DAY: # Check if it's the day to report
        y1_val = int(coords[y1_idx_today])
        y2_val = int(coords[y2_idx_today])
        yt_val = y1_val + y2_val

        print(f"\n--- Monte-Carlo day {day_iter+1}: demand realisation ---")
        print(f"y1 = {y1_val},  y2 = {y2_val},  total = {yt_val}")
        print("----------------------------------------------------------")
        print("Supply decisions:")
        for pname, (sx1, sx2) in all_policy_decisions.items():
            print(f"{pname:<25s}: supply = ({int(sx1)}, {int(sx2)})")
        print("----------------------------------------------------------------\n")


    for policy_name, (x1_pol_idx, x2_pol_idx) in all_policy_decisions.items():
        # Ensure indices are integers and within bounds before using for C_4D access or storage
        x1_pol_idx_int = min(n-1, max(0, int(x1_pol_idx)))
        x2_pol_idx_int = min(n-1, max(0, int(x2_pol_idx)))

        cost = C_4D[x1_pol_idx_int, x2_pol_idx_int, y1_idx_today, y2_idx_today]
        daily_costs_all[policy_name][day_iter] = cost
        # NEW: Store the decision (indices)
        daily_supply_decisions_all[policy_name][day_iter, 0] = x1_pol_idx_int
        daily_supply_decisions_all[policy_name][day_iter, 1] = x2_pol_idx_int

end_sim_loop_time = time.time()
print(f"Simulation loop completed in {end_sim_loop_time - start_sim_loop_time:.2f} seconds.")
# --- 5. RESULTS ---
# print("\n--- Simulation Results (Evaluated with Full 4D Cost Matrix) ---")
# print(f"{'Policy':<25s} {'Avg Cost':>12s} {'Std Dev':>12s}")
# avg_costs_final = {}
# for name_res in policy_names_list: # Uses updated policy_names_list
#     mean_cost = daily_costs_all[name_res].mean()
#     std_dev_cost = daily_costs_all[name_res].std()
#     avg_costs_final[name_res] = mean_cost
#     print(f"{name_res:<25s} {mean_cost:12.4f} {std_dev_cost:12.4f}")

# plt.figure(figsize=(15, 8))
# sorted_policies_plot = sorted(avg_costs_final, key=avg_costs_final.get)
# means_plot = [avg_costs_final[p_name] for p_name in sorted_policies_plot]
# sem_plot = [daily_costs_all[p_name].std() / np.sqrt(N_SIMULATION_DAYS) for p_name in sorted_policies_plot]

# cmap = matplotlib.colormaps['viridis']
# bar_colors = cmap(np.linspace(0, 1, len(sorted_policies_plot)))

# bars = plt.bar(sorted_policies_plot, means_plot, yerr=sem_plot, capsize=5, color=bar_colors)
# plt.ylabel("Average Daily Cost (from 4D Cost Matrix C)")
# plt.xlabel("Allocation Strategy")
# plt.title(rf"Benchmark: Avg Daily Costs (N_days={N_SIMULATION_DAYS})\nFPD-OT $\epsilon$={EPSILON_OT}, Cost (4D): Co_m={CO_MISMATCH}, Cu_m={CU_MISMATCH}, Ca_c={CA_CAP}, Thresh_c={THRESHOLD_CAP}")
# plt.xticks(rotation=45, ha="right")
# plt.grid(axis='y', linestyle='--', alpha=0.7)
# plt.tight_layout()
# for i, bar in enumerate(bars):
#     yval = bar.get_height()
#     plt.text(bar.get_x() + bar.get_width()/2.0, yval + sem_plot[i] + 0.01*max(means_plot), f'{yval:.2f}', ha='center', va='bottom', fontsize=9)
# plt.show()

# print("\nAnalysis complete.")
# --- 5. RESULTS ---
print("\n--- Simulation Results (Evaluated with Full 4D Cost Matrix) ---")
print(f"{'Policy':<25s} {'Avg Cost':>12s} {'Std Dev':>12s}")
avg_costs_final = {}
for name_res in policy_names_list:
    mean_cost = daily_costs_all[name_res].mean()
    std_dev_cost = daily_costs_all[name_res].std()
    avg_costs_final[name_res] = mean_cost
    print(f"{name_res:<25s} {mean_cost:12.4f} {std_dev_cost:12.4f}")

# NEW: --- Plotting Distribution of Chosen Supply Pairs for Each Policy ---
print("\n--- Plotting Distribution of Chosen Supply Pairs for Each Policy ---")
for policy_name_viz in policy_names_list: # Use a different variable name for clarity
    supply_choices_policy = daily_supply_decisions_all[policy_name_viz]
    supply_heatmap = np.zeros((n, n))

    for day_idx_viz in range(N_SIMULATION_DAYS): # Use a different variable name
        x1_choice_idx, x2_choice_idx = supply_choices_policy[day_idx_viz, 0], supply_choices_policy[day_idx_viz, 1]
        # Ensure indices are within bounds for the heatmap array
        if 0 <= x1_choice_idx < n and 0 <= x2_choice_idx < n:
            supply_heatmap[x1_choice_idx, x2_choice_idx] += 1
        else:
            print(f"Warning: Out-of-bounds supply choice index for policy {policy_name_viz} on day {day_idx_viz}: ({x1_choice_idx}, {x2_choice_idx})")


    if N_SIMULATION_DAYS > 0:
        supply_heatmap /= N_SIMULATION_DAYS # Normalize to get probability/frequency
    else:
        supply_heatmap = np.zeros((n,n)) # Avoid division by zero if no simulation days

    plot_heatmap(supply_heatmap,
                 f"Distribution of Chosen Supply (x1,x2)\nPolicy: {policy_name_viz}",
                 xlabel=f"x2 Supply Index (0 to {n-1})", # More descriptive labels
                 ylabel=f"x1 Supply Index (0 to {n-1})", # More descriptive labels
                 cbar=True) # Keep cbar for probability scale
# END OF NEW SECTION FOR HEATMAPS

# --- Original plotting of average costs bar chart ---
plt.figure(figsize=(15, 8)) # Adjusted for better readability
sorted_policies_plot = sorted(avg_costs_final, key=avg_costs_final.get)
means_plot = [avg_costs_final[p_name] for p_name in sorted_policies_plot]
# Calculate Standard Error of the Mean (SEM) for error bars
sem_plot = [daily_costs_all[p_name].std(ddof=1) / np.sqrt(N_SIMULATION_DAYS if N_SIMULATION_DAYS > 0 else 1) for p_name in sorted_policies_plot]


cmap_name = 'viridis' # Example: 'viridis', 'plasma', 'coolwarm'
try:
    cmap = matplotlib.colormaps[cmap_name]
except AttributeError: # For older matplotlib versions
    cmap = plt.get_cmap(cmap_name)
bar_colors = cmap(np.linspace(0.2, 0.8, len(sorted_policies_plot))) # Use a good portion of the colormap


bars = plt.bar(sorted_policies_plot, means_plot, yerr=sem_plot, capsize=5, color=bar_colors, width=0.7) # Adjusted bar width
plt.ylabel("Average Daily Cost (from 4D Cost Matrix C)", fontsize=12)
plt.xlabel("Allocation Strategy", fontsize=12)
plt.title(rf"Benchmark: Avg Daily Costs (N_days={N_SIMULATION_DAYS})", fontsize=14) # Simplified title for clarity in plot
# Sub-notes can be added in the figure caption in the thesis
# plt.title(rf"Benchmark: Avg Daily Costs (N_days={N_SIMULATION_DAYS})\nFPD-OT $\epsilon$={EPSILON_OT}, Cost (4D): Co_m={CO_MISMATCH}, Cu_m={CU_MISMATCH}, Ca_c={CA_CAP}, Thresh_c={THRESHOLD_CAP}")
plt.xticks(rotation=45, ha="right", fontsize=10)
plt.yticks(fontsize=10)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.tight_layout(pad=1.5) # Add some padding

# Add text labels on bars
for i, bar in enumerate(bars):
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2.0, yval + (sem_plot[i] if sem_plot[i] > 0 else 0) + 0.01*max(means_plot), f'{yval:.2f}', ha='center', va='bottom', fontsize=8, color='black')

plt.show()

print("\nAnalysis complete.")