In [None]:
import numpy as np
import numba
from numba import jit, prange
import matplotlib.pyplot as plt

In [None]:
@jit(nopython=True)
def initialize_agents(num_agents):
    agents = np.zeros(num_agents, dtype=np.float32)
    wealth_values = np.concatenate([
        np.random.uniform(0.1, 3, int(0.5 * num_agents)),
        np.random.uniform(3, 12, int(0.4 * num_agents)),
        np.random.uniform(12, 50, int(0.1 * num_agents))
    ])
    np.random.shuffle(wealth_values)
    agents[:] = wealth_values
    return agents

@jit(nopython=True)
def calculate_tax(wealth):
    slabs = np.array([3, 6, 9, 12, 15], dtype=np.float32)
    rates = np.array([0, 0.05, 0.1, 0.15, 0.2, 0.3], dtype=np.float32)
    
    tax = 0.0
    for i in range(len(slabs)):
        if wealth <= slabs[i]:
            tax += rates[i] * (wealth - (slabs[i-1] if i > 0 else 0))
            break
        else:
            tax += rates[i] * (slabs[i] - (slabs[i-1] if i > 0 else 0))
    
    if wealth > slabs[-1]:
        tax += rates[-1] * (wealth - slabs[-1])
    
    return tax

In [None]:
# Cell 3: Trade Probability and Agent Update Functions
@jit(nopython=True)
def calculate_trade_prob(agent1_wealth, agent2_wealth):
    return np.exp(-np.abs(agent1_wealth - agent2_wealth))

@jit(nopython=True)
def initialize_trade_probabilities(agents):
    num_agents = len(agents)
    trade_probs = np.zeros((num_agents, num_agents), dtype=np.float32)
    for i in range(num_agents):
        for j in range(i + 1, num_agents):
            prob = calculate_trade_prob(agents[i], agents[j])
            trade_probs[i, j] = prob
            trade_probs[j, i] = prob
    return trade_probs

@jit(nopython=True)
def roulette_wheel_selection(probs):
    r = np.random.random()
    cum_prob = 0.0
    for i, prob in enumerate(probs):
        cum_prob += prob
        if r <= cum_prob:
            return i
    return len(probs) - 1

@jit(nopython=True)
def update_agents(agents, trade_probs):
    num_agents = len(agents)
    agent1_index = np.random.randint(0, num_agents)
    agent2_index = roulette_wheel_selection(trade_probs[agent1_index])
    
    transaction_amount = 0.075 * (agents[agent1_index] * agents[agent2_index]) / (agents[agent1_index] + agents[agent2_index])
    agents[agent1_index] -= transaction_amount
    agents[agent2_index] += transaction_amount
    
    for i in range(num_agents):
        if i != agent1_index:
            prob = calculate_trade_prob(agents[agent1_index], agents[i])
            trade_probs[agent1_index, i] = prob
            trade_probs[i, agent1_index] = prob
        if i != agent2_index:
            prob = calculate_trade_prob(agents[agent2_index], agents[i])
            trade_probs[agent2_index, i] = prob
            trade_probs[i, agent2_index] = prob

In [None]:
# Cell 4: Simulation Function
@jit(nopython=True)
def run_simulation(num_agents, num_time_steps):
    agents = initialize_agents(num_agents)
    trade_probs = initialize_trade_probabilities(agents)
    
    gini_indices = np.zeros(num_time_steps // 1000 + 1, dtype=np.float32)
    gini_indices[0] = gini_coefficient(agents)
    
    for step in range(num_time_steps):
        update_agents(agents, trade_probs)
        
        # Deduct tax
        total_tax = 0.0
        for i in range(num_agents):
            tax = calculate_tax(agents[i])
            agents[i] -= tax
            total_tax += tax
        
        # Redistribute tax
        redistribution = total_tax / num_agents
        agents += redistribution
        
        if (step + 1) % 1000 == 0:
            gini_indices[(step + 1) // 1000] = gini_coefficient(agents)
    
    return agents, gini_indices

@jit(nopython=True)
def gini_coefficient(wealths):
    sorted_wealths = np.sort(wealths)
    n = len(sorted_wealths)
    index = np.arange(1, n + 1)
    return (np.sum((2 * index - n - 1) * sorted_wealths)) / (n * np.sum(sorted_wealths))

In [None]:
# Cell 5: Run Simulation
# Simulation parameters
num_agents = 10000
num_time_steps = 100000

# Run the simulation
final_wealths, gini_indices = run_simulation(num_agents, num_time_steps)

print(f"Final Gini coefficient: {gini_indices[-1]:.4f}")

In [None]:
# Cell 6: Visualize Gini Coefficient Over Time
plt.figure(figsize=(10, 6))
plt.plot(np.arange(0, num_time_steps + 1, 1000), gini_indices)
plt.title('Gini Coefficient Over Time')
plt.xlabel('Time Steps')
plt.ylabel('Gini Coefficient')
plt.grid(True)
plt.show()

In [None]:
# Cell 7: Visualize Final Wealth Distribution
plt.figure(figsize=(10, 6))
plt.hist(final_wealths, bins=50, edgecolor='black')
plt.title('Final Wealth Distribution')
plt.xlabel('Wealth')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()