# 🛡️ Interactive Squad Simulation
Follow the directions at the top of each cell. After running the last cell, use the sliders and dropdowns below to run and visualize the simulation.

Colab Environment

In [None]:
# If you are working in colab, use this block to clone the GitHub repo, install dependencies, and pull down the latest changes.
%cd /content/
!rm -rf SE3250-Spring2025-SquadSimulation
!git clone https://github.com/SuprMunchkin/SE3250-Spring2025-SquadSimulation.git
%cd SE3250-Spring2025-SquadSimulation
%pip install -r requirements.txt

!git fetch origin
!git checkout "main" #Change this line to use your branch
%ls

VS Code Environment

In [None]:
%pip install -r ./requirements.txt

In [None]:
# Import and setup
import sys
import yaml
import matplotlib.pyplot as plt
from IPython.display import display
from ipywidgets import interact
import pprint
pp = pprint.PrettyPrinter(indent=4)

# Custom imports
import os
sys.path.append("../models")
from models.squad_simulation import run_simulation
yaml_path = "config/simulation.yaml"
with open(yaml_path, "r") as f:
    config = yaml.safe_load(f)

# Define interactive runner
def run_interactive_sim(blue_stock, red_stock, direction_deviation, map_size, armor_type, environment):
    """
    Run the interactive squad simulation and plot the results.

    Parameters:
        blue_stock (int): Number of blue units.
        red_stock (int): Number of red units.
        direction_deviation (int): Direction deviation in degrees.
        map_size(int): Size of the map in Meters.
        armor_type (str): Type of armor for blue units.
        environment (str): Simulation environment.
    """
    params = {
        "blue_stock": blue_stock,
        "red_stock": red_stock,
        "direction_deviation": direction_deviation,
        "map_size": map_size,
        "armor_type": armor_type,
        "environment": environment
    }
    result = run_simulation(params, full_log=True)

    blue_positions = result['blue']['position_history']
    red_position = result['red']['current_position']

    plt.figure(figsize=(8, 8))

    # Plot Blue Patrol Path
    if blue_positions:  # Check if there are any blue positions
        x_vals, y_vals = zip(*blue_positions)
        plt.plot(x_vals, y_vals, label='Blue Patrol Path', color='blue')
        plt.scatter(x_vals[0], y_vals[0], c='green', label='Start', zorder=5)
        plt.scatter(x_vals[-1], y_vals[-1], c='purple', label='End', zorder=5)

    # Plot red Position(s)
    if red_position: # Check if red_positions exist
        # Check if it's a list of positions (moving) or a single position (stationary)
        if isinstance(red_position[0], (list, tuple)):
            # It's a list of positions, plot the path
            hx_vals, hy_vals = zip(*red_position)
            plt.plot(hx_vals, hy_vals, label='red Path', linestyle='--', color='red')
        else:
            # It's a single position, plot a scatter point
            hx, hy = red_position # Unpack the single coordinate pair
            plt.scatter(hx, hy, c='red', label='red Position', zorder=5)
    else:
        print("Warning: No red position data available.")

    plt.xlim(0, map_size)
    plt.ylim(0, map_size)
    plt.title("Squad Movement Simulation")
    plt.xlabel("X Position")
    plt.ylabel("Y Position")
    plt.legend()
    plt.grid(True)
    plt.axis('equal')
    # Force Map to display first.
    # display(plt.gfc())
    plt.show()

    print(f"👥 Blue Remaining: {result['blue']['stock']} / {params['blue_stock']}")
    print(f"🔴 red Remaining: {result['red']['stock']} / {params['red_stock']}")

    exhaustion_data = result['blue']['exhaustion_data']
    soldier_data = list(zip(*exhaustion_data))
    # Debugging
    # pp.pprint(exhaustion_data)

    plt.figure(figsize=(6,4))
    for idx, soldier in enumerate(soldier_data):
        if idx == 0:
            plt.plot(range(len(soldier)), soldier, label='Threshold')
        else:
            plt.plot(range(len(soldier)), soldier, label=f'Soldier{idx}')
    plt.title("Squad Exaustion vs Time (per Soldier)")
    plt.ylim(0,1000)
    plt.xlabel("Simulation Time (min)")
    plt.ylabel("Squad Exhaustion ()")
    plt.show()

# Create interactive widget interface
interact(
    run_interactive_sim,
    blue_stock=(1, 20),
    red_stock=(1, 40),
    direction_deviation=(0, 45, 5),
    map_size=(1000, 4000, 500),
    armor_type=list(config['armor_profiles'].keys()),
    environment=list(config['threat_probs'].keys())
)

In [None]:
from ipywidgets import Button, VBox, Output
import itertools
import pandas as pd
import sys
import yaml
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display
from tqdm.notebook import tqdm
import pprint
import os
sys.path.append("../models")
from models.squad_simulation import run_simulation

yaml_path = "config/simulation.yaml"
with open(yaml_path, "r") as f:
    config = yaml.safe_load(f)

pp = pprint.PrettyPrinter(indent=4)

number_of_runs = 50
armor_types = list(config['armor_profiles'].keys())
environments = list(config['threat_probs'].keys())
combinations = list(itertools.product(armor_types, environments))
output = Output()
total_hostiles_killed_dict = {}
total_warfighters_killed_dict = {}
total_patrol_distance_dict = {}

def plot_histograms(data_dict, title_prefix, xlabel, ylabel, step=1):
    fig, axes = plt.subplots(3, 3, figsize=(9, 9))
    axes = axes.flatten()
    fig.subplots_adjust(hspace=0.5, wspace=0.3)
    for i, (key, values) in enumerate(data_dict.items()):
        ax = axes[i]
        if values:
            bins = range(int(min(values)), int(max(values)) + 2)
            ax.hist(values, bins=bins, align='left', rwidth=0.4)
            ax.set_title(f'{title_prefix}\nfor {key}', fontsize=8)
            ax.set_xlabel(xlabel, fontsize=6)
            ax.set_ylabel(ylabel, fontsize=6)
            ax.grid(axis='y', alpha=0.75)
            min_val = int(min(values))
            max_val = int(max(values))
            ax.set_xticks(range(min_val, max_val + 1, step))
            mean = np.mean(values)
            std = np.std(values)
            ax.axvline(mean, color='red', linestyle='--', label=f'Mean: {mean:.2f}')
            if len(values) > 1:
                ax.axvline(mean + std, color='green', linestyle=':', label=f'Std: {std:.2f}')
                ax.axvline(mean - std, color='green', linestyle=':')
            ax.grid(True, axis='y', linestyle='--', alpha=0.3)
            ax.legend(fontsize=6)
        else:
            ax.set_title(f'{title_prefix}\nfor {key}', fontsize=8)
            ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', transform=ax.transAxes)
    plt.show()

def plot_line_from_dict(data_dict, armor_types, environments, title, ylabel):
    fig, ax = plt.subplots(figsize=(10, 6))
    x_indices = np.arange(len(armor_types))
    for env in environments:
        ys = []
        for armor in armor_types:
            key = f"{armor}\n+{env}"
            values = data_dict.get(key, [])
            ys.append(np.mean(values) if values else 0)
        ax.plot(x_indices, ys, marker='o', label=f'Environment: {env}')
    ax.set_title(title)
    ax.set_xlabel("Armor Type")
    ax.set_ylabel(ylabel)
    ax.set_xticks(x_indices)
    ax.set_xticklabels(armor_types)
    ax.legend()
    plt.show()

def run_all_combinations(_):
    output.clear_output()
    blue_stock = 10
    red_stock = 40
    direction_deviation = 10
    results = []
    with output:
        for armor, env in tqdm(combinations, desc="Processing Combinations"):
            hostiles_killed_list = []
            warfighters_killed_list = []
            patrol_distance_list = []
            total_hostiles_killed_dict[f"{armor}\n+{env}"] = hostiles_killed_list
            total_warfighters_killed_dict[f"{armor}\n+{env}"] = warfighters_killed_list
            total_patrol_distance_dict[f"{armor}\n+{env}"] = patrol_distance_list
            total_blue_remaining = 0
            total_red_remaining = 0
            total_red_spawned = 0
            total_effective_movement = 0
            total_hostiles_killed = 0
            total_warfighters_killed = 0
            for _ in tqdm(range(number_of_runs), desc=f"Running Simulations for {armor}-{env}", leave=False):
                # Reset accumulators for each simulation run
                params = {
                    "blue_stock": blue_stock,
                    "red_stock": red_stock,
                    "direction_deviation": direction_deviation,
                    "armor_type": armor,
                    "environment": env
                }
                result = run_simulation(params, full_log=False)
                # Accumulate metrics for this run
                blue_remaining = result['blue']['stock']
                effective_movement = result['blue']['patrol_distance'] * blue_stock
                hostiles_killed = result['blue']['hostiles_killed']
                red_remaining = 0
                red_spawned = 0
                warfighters_killed = 0
                for red in result['red_patrols']:
                    red_remaining += red['stock']
                    red_spawned += red_stock
                    warfighters_killed += red['warfighters_killed']
                # Validate data
                if hostiles_killed != (red_spawned - red_remaining):
                    print(f"data error: Hostiles Killed {hostiles_killed} != Red Spawned {red_spawned} Red Remaining {red_remaining}")
                if warfighters_killed != (blue_stock - blue_remaining):
                    print(f"data error: Warfighters Killed {warfighters_killed} != Blue Casualties {blue_stock - blue_remaining}")
                # Store metrics for this run
                hostiles_killed_list.append(hostiles_killed)
                warfighters_killed_list.append(warfighters_killed)
                patrol_distance_list.append(int(result['blue']['patrol_distance'] / 1000))
                # Accumulate for averages
                total_blue_remaining += blue_remaining
                total_red_remaining += red_remaining
                total_red_spawned += red_spawned
                total_effective_movement += effective_movement
                total_hostiles_killed += hostiles_killed
                total_warfighters_killed += warfighters_killed
            # Calculate averages
            average_blue_remaining = total_blue_remaining / number_of_runs
            average_red_remaining = total_red_remaining / number_of_runs
            average_effective_movement = total_effective_movement / number_of_runs
            average_blue_lethality = total_hostiles_killed / blue_stock
            average_red_lethality = total_warfighters_killed / total_red_spawned
            results.append({
                "Armor": armor,
                "Environment": env,
                "Average_Blue_Remaining": average_blue_remaining,
                "Average_red_Remaining": average_red_remaining,
                "Average_Effective_Movement": average_effective_movement,
                "Average_Blue_Lethality": average_blue_lethality,
                "Average_red_Lethality": average_red_lethality
            })
        plot_histograms(total_hostiles_killed_dict, "Hostiles Killed", "Number of Hostiles Killed", "Frequency", step=40)
        plot_histograms(total_warfighters_killed_dict, "Warfighters Killed", "Number of Warfighters Killed by Hostiles", "Frequency", step=1)
        plot_histograms(total_patrol_distance_dict, "Patrol Distance", "Patrol Distance (km)", "Frequency", step=1)
        plot_line_from_dict(total_hostiles_killed_dict, armor_types, environments, f"Mean Hostiles Killed vs Armor ({number_of_runs} simulations per combination)", "Mean Hostiles Killed")
        plot_line_from_dict(total_warfighters_killed_dict, armor_types, environments, f"Mean Warfighters Killed vs Armor ({number_of_runs} simulations per combination)", "Mean Warfighters Killed")
        plot_line_from_dict(total_patrol_distance_dict, armor_types, environments, f"Mean Patrol Distance vs Armor ({number_of_runs} simulations per combination)", "Mean Patrol Distance per run")
        df_results = pd.DataFrame(results)
        display(df_results)
        df_results.to_csv("metrics_data.csv", index=False)
        print("Accumulated metrics saved to 'metrics_data.csv'")
        histogram_data = []
        for combo in total_hostiles_killed_dict.keys():
            for tbk_val, thk_val, tpd_val in zip(total_hostiles_killed_dict[combo], total_warfighters_killed_dict[combo], total_patrol_distance_dict[combo]):
                histogram_data.append({
                    "Combination": combo,
                    "Blue_Kills": tbk_val,
                    "Hostile_Kills": thk_val,
                    "Patrol_Distance_km": tpd_val
                })
        df_histogram = pd.DataFrame(histogram_data)
        df_histogram.to_csv("histogram_data.csv", index=False)
        print("Histogram data saved to 'histogram_data.csv'")

run_button = Button(description=f"Run All Armor/Threat Combinations ({number_of_runs} times each)")
run_button.on_click(run_all_combinations)
VBox([run_button, output])