In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Rectangle
import numpy as np

# Set up the problem parameters
population_size = 10000
disease_rate = 0.01
sensitivity = 0.99  # P(Positive | Disease)
false_positive_rate = 0.01  # P(Positive | No Disease)

# Calculate the numbers
people_with_disease = int(population_size * disease_rate)
people_without_disease = population_size - people_with_disease

true_positives = int(people_with_disease * sensitivity)
false_negatives = people_with_disease - true_positives

false_positives = int(people_without_disease * false_positive_rate)
true_negatives = people_without_disease - false_positives

total_positives = true_positives + false_positives

# Create the visualization
fig = plt.figure(figsize=(16, 10))
fig.suptitle('Understanding Bayes\' Theorem: Why Testing Positive Doesn\'t Mean 99% Chance of Disease', 
             fontsize=16, fontweight='bold', y=0.98)

# Panel 1: Population breakdown
ax1 = plt.subplot(2, 3, 1)
ax1.set_title('Step 1: The Population\n(10,000 people)', fontsize=12, fontweight='bold')
ax1.barh(['Has Disease', 'No Disease'], 
         [people_with_disease, people_without_disease],
         color=['#d62728', '#2ca02c'])
ax1.set_xlabel('Number of People')
for i, (label, value) in enumerate(zip(['Has Disease', 'No Disease'], 
                                       [people_with_disease, people_without_disease])):
    ax1.text(value/2, i, f'{value:,}\n({value/population_size*100:.0f}%)', 
             ha='center', va='center', fontsize=11, fontweight='bold', color='white')
ax1.set_xlim(0, population_size)

# Panel 2: Test results for people WITH disease
ax2 = plt.subplot(2, 3, 2)
ax2.set_title(f'Step 2a: Testing the {people_with_disease} People WITH Disease\n(99% accurate)', 
              fontsize=12, fontweight='bold')
ax2.barh(['Test Positive', 'Test Negative'], 
         [true_positives, false_negatives],
         color=['#ff7f0e', '#d62728'])
ax2.set_xlabel('Number of People')
for i, (label, value) in enumerate(zip(['Test Positive', 'Test Negative'], 
                                       [true_positives, false_negatives])):
    if value > 0:
        ax2.text(value/2, i, f'{value:,}', 
                ha='center', va='center', fontsize=11, fontweight='bold', color='white')
ax2.set_xlim(0, people_with_disease)

# Panel 3: Test results for people WITHOUT disease
ax3 = plt.subplot(2, 3, 3)
ax3.set_title(f'Step 2b: Testing the {people_without_disease:,} People WITHOUT Disease\n(99% accurate)', 
              fontsize=12, fontweight='bold')
ax3.barh(['Test Positive', 'Test Negative'], 
         [false_positives, true_negatives],
         color=['#ff7f0e', '#2ca02c'])
ax3.set_xlabel('Number of People')
for i, (label, value) in enumerate(zip(['Test Positive', 'Test Negative'], 
                                       [false_positives, true_negatives])):
    ax3.text(value/2, i, f'{value:,}', 
            ha='center', va='center', fontsize=11, fontweight='bold', color='white')
ax3.set_xlim(0, people_without_disease)

# Panel 4: Confusion matrix style
ax4 = plt.subplot(2, 3, 4)
ax4.set_title('Step 3: The Complete Picture', fontsize=12, fontweight='bold')
ax4.axis('off')

# Create a table-like visualization
table_data = [
    ['', 'Actually\nHAS Disease', 'Actually\nNO Disease', 'Total Testing\nPositive'],
    ['Test\nPositive', f'{true_positives}', f'{false_positives}', f'{total_positives}'],
    ['Test\nNegative', f'{false_negatives}', f'{true_negatives}', 'â€”']
]

colors = [
    ['white', 'white', 'white', 'white'],
    ['white', '#90EE90', '#FFB6C6', '#FFE6AA'],
    ['white', 'white', 'white', 'white']
]

table = ax4.table(cellText=table_data, cellColours=colors,
                  cellLoc='center', loc='center',
                  bbox=[0, 0, 1, 1])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 3)

for i in range(3):
    for j in range(4):
        cell = table[(i, j)]
        if i == 0 or j == 0:
            cell.set_text_props(weight='bold')
        if i == 1 and j == 1:
            cell.set_text_props(weight='bold', color='darkgreen')
        if i == 1 and j == 2:
            cell.set_text_props(weight='bold', color='darkred')

# Panel 5: The key insight
ax5 = plt.subplot(2, 3, 5)
ax5.set_title('Step 4: Among Those Who Test Positive...', fontsize=12, fontweight='bold')

# Create stacked bar
ax5.barh(['All Positive\nTests'], [true_positives], 
         color='#90EE90', label=f'True Positives (have disease): {true_positives}')
ax5.barh(['All Positive\nTests'], [false_positives], 
         left=[true_positives],
         color='#FFB6C6', label=f'False Positives (no disease): {false_positives}')

ax5.set_xlabel('Number of People')
ax5.set_xlim(0, total_positives)
ax5.legend(loc='upper right', fontsize=9)

# Add percentage annotations
ax5.text(true_positives/2, 0, f'{true_positives/total_positives*100:.0f}%', 
         ha='center', va='center', fontsize=14, fontweight='bold')
ax5.text(true_positives + false_positives/2, 0, f'{false_positives/total_positives*100:.0f}%', 
         ha='center', va='center', fontsize=14, fontweight='bold')

# Panel 6: The final answer
ax6 = plt.subplot(2, 3, 6)
ax6.set_title('THE ANSWER', fontsize=12, fontweight='bold')
ax6.axis('off')

probability = true_positives / total_positives * 100
answer_text = f"""
If you TEST POSITIVE:

Probability you have the disease:

{probability:.0f}%

NOT 99%!
"""

ax6.text(0.5, 0.5, answer_text, 
         ha='center', va='center', fontsize=18, fontweight='bold',
         bbox=dict(boxstyle='round', facecolor='#FFE6AA', edgecolor='black', linewidth=2))

# Add explanation at the bottom
explanation = f"""
WHY? Because the disease is rare (1%), there are many more healthy people than sick people.
Even though the test is 99% accurate, that 1% error rate on the 9,900 healthy people gives us {false_positives} false positives.
This equals the {true_positives} true positives from the 100 sick people!
"""

fig.text(0.5, 0.02, explanation, ha='center', fontsize=11, 
         style='italic', wrap=True,
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout(rect=[0, 0.05, 1, 0.96])
plt.savefig('/mnt/user-data/outputs/bayes_visualization.png', dpi=300, bbox_inches='tight')
print("Visualization saved!")
print(f"\nKey numbers:")
print(f"  Population: {population_size:,}")
print(f"  People with disease: {people_with_disease} (1%)")
print(f"  People without disease: {people_without_disease:,} (99%)")
print(f"\nTest results:")
print(f"  True positives: {true_positives}")
print(f"  False positives: {false_positives}")
print(f"  Total positives: {total_positives}")
print(f"\nProbability of having disease if you test positive: {probability:.1f}%")

plt.show()