In [None]:
import numpy as np
from matplotlib import pyplot as plt
import torch
from tqdm import tqdm
import seaborn as sns
import pickle
import os
from IGT.environments import IGTEnv, NonStochasticBanditEnv, NonStationaryEnv
from IGT.train import train

## Hyperparameters for Parameter Sweep

In [None]:
# Base hyperparameters
trails = 200
epochs = 20  # Reduced for faster execution while looping
bins = 10
lr = 0.1
num_arms = 2
scaling_factor = 100
time_stamp_change = 100
stationary = False
window_size = 20  # Bin size for analyzing arm probabilities

# Define reward values for each arm
reward_values = np.array([40, 40])/scaling_factor
reward_std = np.array([1, 1])/scaling_factor

# Define initial probabilities for each arm
initial_probabilities = np.array([0.8, 0.2])

# Define changed reward values (after time_stamp_change)
reward_values_changed = np.array([40, 40])/scaling_factor
reward_std_changed = np.array([1, 1])/scaling_factor

# Define changed probabilities (after time_stamp_change)
changed_probabilities = np.array([0.2, 0.8])

# Set which parameters to change: 'reward', 'probability', or 'both'
change_type = 'probability'

# Parameters to vary
del_lim_values = np.linspace(0, 1, 20)  # 20 points from 0 to 1
var_values = np.linspace(0, 1, 20)      # 20 points from 0 to 1

# Results dictionary
results = {
    'del_lim': {},
    'var': {}
}

## Helper Functions for Data Analysis

In [None]:
def get_arm_probabilities_over_time(arm_chosen_monitor, num_arms, trails, epochs, window_size):
    """
    Calculate the probability of choosing each arm in time windows.
    
    Parameters:
    -----------
    arm_chosen_monitor: Tensor
        The monitor of chosen arms for each epoch and trial
    num_arms: int
        Number of arms in the bandit
    trails: int
        Number of trials
    epochs: int
        Number of epochs
    window_size: int
        Size of the time window for analysis
        
    Returns:
    --------
    dict: Dictionary with arm probabilities over time windows
    """
    num_windows = trails // window_size
    
    # Initialize array to store probabilities for each window
    arm_probs_over_time = {arm: np.zeros(num_windows) for arm in range(num_arms)}
    window_centers = [(i + 0.5) * window_size for i in range(num_windows)]
    
    # Calculate probability of choosing each arm in each time window
    for epoch in range(epochs):
        for window_idx in range(num_windows):
            start_idx = window_idx * window_size
            end_idx = start_idx + window_size
            
            # Count occurrences of each arm in this window
            window_counts = {arm: 0 for arm in range(num_arms)}
            valid_choices = 0
            
            for trial in range(start_idx, end_idx):
                if trial < trails:
                    arm = arm_chosen_monitor[epoch][trial].item()
                    if arm < num_arms:  # Only count valid arms
                        window_counts[arm] += 1
                        valid_choices += 1
            
            # Calculate probabilities for this window
            if valid_choices > 0:
                for arm in range(num_arms):
                    arm_probs_over_time[arm][window_idx] += window_counts[arm] / valid_choices
    
    # Average across epochs
    for arm in range(num_arms):
        arm_probs_over_time[arm] /= epochs
    
    # Find indices for pre-change and post-change
    change_window_idx = next((i for i, center in enumerate(window_centers) if center > time_stamp_change), None)
    
    # Calculate average arm probabilities before and after change
    arm_probs_pre_change = {arm: np.mean(arm_probs_over_time[arm][:change_window_idx]) 
                           for arm in range(num_arms)} if change_window_idx else {}
    arm_probs_post_change = {arm: np.mean(arm_probs_over_time[arm][change_window_idx:]) 
                            for arm in range(num_arms)} if change_window_idx else {}
    
    return {
        'probs': arm_probs_over_time,
        'window_centers': window_centers,
        'pre_change': arm_probs_pre_change,
        'post_change': arm_probs_post_change,
        'change_window_idx': change_window_idx
    }

def calculate_adaptation_time(arm_probs_over_time, window_centers, time_stamp_change):
    """
    Calculate the time it takes for the agent to adapt after environment change.
    Adaptation is defined as when arm 1's selection probability exceeds arm 0's.
    
    Parameters:
    -----------
    arm_probs_over_time: dict
        Dictionary with arm probabilities over time
    window_centers: list
        Centers of time windows
    time_stamp_change: int
        Trial number when the environment changes
        
    Returns:
    --------
    dict: Dictionary with adaptation metrics
    """
    adaptation_window = None
    for i in range(len(window_centers)):
        window_center = window_centers[i]
        if window_center > time_stamp_change and arm_probs_over_time[1][i] > arm_probs_over_time[0][i]:
            adaptation_window = i
            break
    
    if adaptation_window is not None:
        adaptation_trial = window_centers[adaptation_window]
        adaptation_time = adaptation_trial - time_stamp_change
        return {
            'adapted': True,
            'adaptation_trial': adaptation_trial,
            'adaptation_time': adaptation_time
        }
    else:
        return {
            'adapted': False,
            'adaptation_trial': None,
            'adaptation_time': None
        }

## Parameter Sweep: 2D grid of del_lim and var values (400 points total)

In [None]:
# Initialize a nested dictionary for 2D results
results_2d = {}

# Pre-change and post-change probability matrices for heatmaps
arm0_pre_matrix = np.zeros((len(del_lim_values), len(var_values)))
arm0_post_matrix = np.zeros((len(del_lim_values), len(var_values)))
arm1_pre_matrix = np.zeros((len(del_lim_values), len(var_values)))
arm1_post_matrix = np.zeros((len(del_lim_values), len(var_values)))

# Create a progress bar for the full 400 combinations
total_combinations = len(del_lim_values) * len(var_values)
pbar = tqdm(total=total_combinations, desc="Parameter combinations")

# Double loop for all 400 combinations
for i, del_lim in enumerate(del_lim_values):
    results_2d[del_lim] = {}
    
    for j, var in enumerate(var_values):
        # Create environment
        env = NonStationaryEnv(
            num_arms=num_arms,
            mean_reward=reward_values,
            std=reward_std,
            probabilities=initial_probabilities,
            mean_rew_change=reward_values_changed,
            std_change=reward_std_changed,
            probabilities_change=changed_probabilities,
            stationary=stationary,
            time_stamp_change=time_stamp_change,
            change_type=change_type
        )
        
        # Train model with current parameter combination
        reward_monitor, arm_chosen_monitor, avg_counts, ip_monitor, dp_monitor, _ = train(
            env,
            trails=trails,
            epochs=epochs,
            lr=lr,
            bins=bins,
            STN_spike_output=None,
            d1_amp=0.5,
            d2_amp=0.3,
            gpi_threshold=0.2,
            max_gpi_iters=50,
            STN_neurons=256,
            stn_mean=0,
            stn_std=0,
            del_lim=del_lim,  # Varied parameter 1
            train_IP=False,
            del_med=None,
            printing=False,  # Turn off printing for loop
            gpi_mean=1,
            gpi_var=var  # Varied parameter 2
        )
        
        # Calculate arm probabilities over time
        arm_probs_data = get_arm_probabilities_over_time(
            arm_chosen_monitor, num_arms, trails, epochs, window_size
        )
        
        # Calculate adaptation time
        adaptation_data = calculate_adaptation_time(
            arm_probs_data['probs'], 
            arm_probs_data['window_centers'], 
            time_stamp_change
        )
        
        # Store results
        results_2d[del_lim][var] = {
            'arm_probs': arm_probs_data['probs'],
            'window_centers': arm_probs_data['window_centers'],
            'adaptation': adaptation_data,
            'pre_change': arm_probs_data['pre_change'],
            'post_change': arm_probs_data['post_change']
        }
        
        # Fill matrices for heatmaps
        if 0 in arm_probs_data['pre_change']:
            arm0_pre_matrix[i, j] = arm_probs_data['pre_change'][0]
            arm0_post_matrix[i, j] = arm_probs_data['post_change'][0]
        
        if 1 in arm_probs_data['pre_change']:
            arm1_pre_matrix[i, j] = arm_probs_data['pre_change'][1]
            arm1_post_matrix[i, j] = arm_probs_data['post_change'][1]
        
        # Update progress bar
        pbar.update(1)

# Close progress bar
pbar.close()

# Save the full 2D results
output_file_2d = "param_sweep_2d_results.pkl"
with open(output_file_2d, 'wb') as f:
    pickle.dump(results_2d, f)

# Also save the heatmap matrices for easier access later
heatmap_data = {
    'arm0_pre': arm0_pre_matrix,
    'arm0_post': arm0_post_matrix,
    'arm1_pre': arm1_pre_matrix,
    'arm1_post': arm1_post_matrix,
    'del_lim_values': del_lim_values,
    'var_values': var_values
}

with open("arm_probability_heatmaps.pkl", 'wb') as f:
    pickle.dump(heatmap_data, f)

print(f"Full 2D parameter sweep results saved to {output_file_2d}")
print(f"Heatmap data saved to arm_probability_heatmaps.pkl")

## Extract Slices for Backward Compatibility

In [None]:
# Extract data from the 2D results to be compatible with the existing visualization code
# This will create the original 'results' dictionary format

# Extract results for fixed var = 0.1 (for del_lim sweep)
fixed_var = 0.1
closest_var_idx = np.abs(var_values - fixed_var).argmin()
closest_var = var_values[closest_var_idx]

for del_lim in del_lim_values:
    results['del_lim'][del_lim] = results_2d[del_lim][closest_var]

# Extract results for fixed del_lim = 1.0 (for var sweep)
fixed_del_lim = 1.0
closest_del_lim_idx = np.abs(del_lim_values - fixed_del_lim).argmin()
closest_del_lim = del_lim_values[closest_del_lim_idx]

for var in var_values:
    results['var'][var] = results_2d[closest_del_lim][var]

# Save the backward-compatible results
output_file = "param_sweep_results.pkl"
with open(output_file, 'wb') as f:
    pickle.dump(results, f)

print(f"Backward-compatible results saved to {output_file}")

## Results Analysis: Adaptation Times

In [None]:
# Extract adaptation times for del_lim sweep
del_lim_adapt_times = []
del_lim_values_valid = []

for del_lim in del_lim_values:
    adapt_data = results['del_lim'][del_lim]['adaptation']
    if adapt_data['adapted']:
        del_lim_values_valid.append(del_lim)
        del_lim_adapt_times.append(adapt_data['adaptation_time'])

# Extract adaptation times for var sweep
var_adapt_times = []
var_values_valid = []

for var in var_values:
    adapt_data = results['var'][var]['adaptation']
    if adapt_data['adapted']:
        var_values_valid.append(var)
        var_adapt_times.append(adapt_data['adaptation_time'])

# Plot adaptation times vs parameters
plt.figure(figsize=(15, 6))

plt.subplot(1, 2, 1)
plt.plot(del_lim_values_valid, del_lim_adapt_times, 'o-', linewidth=2)
plt.xlabel('del_lim value')
plt.ylabel('Adaptation Time (trials)')
plt.title('Effect of del_lim on Adaptation Time')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(var_values_valid, var_adapt_times, 'o-', linewidth=2, color='orange')
plt.xlabel('var value')
plt.ylabel('Adaptation Time (trials)')
plt.title('Effect of var on Adaptation Time')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Results Analysis: Arm Selection Probabilities

In [None]:
# Plot arm probabilities for selected del_lim values
plt.figure(figsize=(15, 10))

# Select a few del_lim values to visualize
del_lim_to_show = [del_lim_values[0], del_lim_values[5], del_lim_values[10], del_lim_values[15], del_lim_values[-1]]

for i, del_lim in enumerate(del_lim_to_show):
    plt.subplot(2, 3, i+1)
    data = results['del_lim'][del_lim]
    window_centers = data['window_centers']
    
    for arm in range(num_arms):
        plt.plot(window_centers, data['arm_probs'][arm], 
                 marker='o', linestyle='-', label=f'Arm {arm}')
    
    plt.axvline(x=time_stamp_change, color='black', linestyle='--', linewidth=2)
    plt.title(f'del_lim = {del_lim:.2f}')
    plt.ylim(0, 1)
    plt.grid(True, alpha=0.3)
    if i == 0:
        plt.legend()

plt.tight_layout()
plt.show()

# Plot arm probabilities for selected var values
plt.figure(figsize=(15, 10))

# Select a few var values to visualize
var_to_show = [var_values[0], var_values[5], var_values[10], var_values[15], var_values[-1]]

for i, var in enumerate(var_to_show):
    plt.subplot(2, 3, i+1)
    data = results['var'][var]
    window_centers = data['window_centers']
    
    for arm in range(num_arms):
        plt.plot(window_centers, data['arm_probs'][arm], 
                 marker='o', linestyle='-', label=f'Arm {arm}')
    
    plt.axvline(x=time_stamp_change, color='black', linestyle='--', linewidth=2)
    plt.title(f'var = {var:.2f}')
    plt.ylim(0, 1)
    plt.grid(True, alpha=0.3)
    if i == 0:
        plt.legend()

plt.tight_layout()
plt.show()

## Heat Map Visualization

In [None]:
# Create a matrix to store adaptation times for different parameter combinations
adaptation_matrix = np.zeros((len(del_lim_values), len(var_values)))
adaptation_matrix.fill(np.nan)  # Fill with NaN for unadapted cases

# Populate the matrix from our 2D results
for i, del_lim in enumerate(del_lim_values):
    for j, var in enumerate(var_values):
        adapt_data = results_2d[del_lim][var]['adaptation']
        if adapt_data['adapted']:
            adaptation_matrix[i, j] = adapt_data['adaptation_time']

# Create heatmap for adaptation time
plt.figure(figsize=(10, 8))
ax = sns.heatmap(adaptation_matrix, 
                xticklabels=np.round(var_values, 2)[::2],  # Show fewer labels for clarity
                yticklabels=np.round(del_lim_values, 2)[::2],
                cmap="viridis", 
                cbar_kws={'label': 'Adaptation Time (trials)'},
                mask=np.isnan(adaptation_matrix))
plt.xlabel('var')
plt.ylabel('del_lim')
plt.title('Adaptation Time for Parameter Combinations')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Create a binary matrix showing which parameter combinations resulted in adaptation
adapted_matrix = np.zeros((len(del_lim_values), len(var_values)))

# Populate the matrix
for i, del_lim in enumerate(del_lim_values):
    for j, var in enumerate(var_values):
        adapted_matrix[i, j] = results_2d[del_lim][var]['adaptation']['adapted']

# Create heatmap showing adapted vs. not adapted
plt.figure(figsize=(10, 8))
ax = sns.heatmap(adapted_matrix, 
                xticklabels=np.round(var_values, 2)[::2],
                yticklabels=np.round(del_lim_values, 2)[::2],
                cmap=["white", "green"],
                cbar_kws={'label': 'Adapted (1) / Not adapted (0)'},
                vmin=0, vmax=1)
plt.xlabel('var')
plt.ylabel('del_lim')
plt.title('Parameter Combinations Where Agent Successfully Adapted')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Arm probability heatmaps (4 plots - 2 arms, before and after change)
plt.figure(figsize=(16, 14))

# Arm 0 pre-change
plt.subplot(2, 2, 1)
ax0_pre = sns.heatmap(arm0_pre_matrix, 
                    xticklabels=np.round(var_values, 2)[::2],
                    yticklabels=np.round(del_lim_values, 2)[::2],
                    cmap="coolwarm", 
                    cbar_kws={'label': 'Selection Probability'},
                    vmin=0, vmax=1)
plt.xlabel('var')
plt.ylabel('del_lim')
plt.title(f'Arm 0 Selection Probability BEFORE Change (Initial prob: {initial_probabilities[0]})')

# Arm 0 post-change
plt.subplot(2, 2, 2)
ax0_post = sns.heatmap(arm0_post_matrix, 
                     xticklabels=np.round(var_values, 2)[::2],
                     yticklabels=np.round(del_lim_values, 2)[::2],
                     cmap="coolwarm", 
                     cbar_kws={'label': 'Selection Probability'},
                     vmin=0, vmax=1)
plt.xlabel('var')
plt.ylabel('del_lim')
plt.title(f'Arm 0 Selection Probability AFTER Change (Changed prob: {changed_probabilities[0]})')

# Arm 1 pre-change
plt.subplot(2, 2, 3)
ax1_pre = sns.heatmap(arm1_pre_matrix, 
                    xticklabels=np.round(var_values, 2)[::2],
                    yticklabels=np.round(del_lim_values, 2)[::2],
                    cmap="coolwarm", 
                    cbar_kws={'label': 'Selection Probability'},
                    vmin=0, vmax=1)
plt.xlabel('var')
plt.ylabel('del_lim')
plt.title(f'Arm 1 Selection Probability BEFORE Change (Initial prob: {initial_probabilities[1]})')

# Arm 1 post-change
plt.subplot(2, 2, 4)
ax1_post = sns.heatmap(arm1_post_matrix, 
                     xticklabels=np.round(var_values, 2)[::2],
                     yticklabels=np.round(del_lim_values, 2)[::2],
                     cmap="coolwarm", 
                     cbar_kws={'label': 'Selection Probability'},
                     vmin=0, vmax=1)
plt.xlabel('var')
plt.ylabel('del_lim')
plt.title(f'Arm 1 Selection Probability AFTER Change (Changed prob: {changed_probabilities[1]})')

plt.tight_layout()
plt.show()

# Calculate difference between pre and post change
arm0_diff_matrix = arm0_post_matrix - arm0_pre_matrix
arm1_diff_matrix = arm1_post_matrix - arm1_pre_matrix

# Plot the differences
plt.figure(figsize=(15, 6))

# Arm 0 difference
plt.subplot(1, 2, 1)
ax0_diff = sns.heatmap(arm0_diff_matrix, 
                     xticklabels=np.round(var_values, 2)[::2],
                     yticklabels=np.round(del_lim_values, 2)[::2],
                     cmap="RdBu_r", 
                     cbar_kws={'label': 'Probability Change'},
                     vmin=-1, vmax=1, center=0)
plt.xlabel('var')
plt.ylabel('del_lim')
plt.title(f'Arm 0 Selection Probability CHANGE')

# Arm 1 difference
plt.subplot(1, 2, 2)
ax1_diff = sns.heatmap(arm1_diff_matrix, 
                     xticklabels=np.round(var_values, 2)[::2],
                     yticklabels=np.round(del_lim_values, 2)[::2],
                     cmap="RdBu_r", 
                     cbar_kws={'label': 'Probability Change'},
                     vmin=-1, vmax=1, center=0)
plt.xlabel('var')
plt.ylabel('del_lim')
plt.title(f'Arm 1 Selection Probability CHANGE')

plt.tight_layout()
plt.show()