In [72]:
import numpy as np
import random

In [102]:
W_z = 0.2
W_u = 0.8

In [None]:
def simulate_data():
    # Initial state
    states = [[10, 0]]  # Tumor size, side effects
    actions = []
    for t in range(20):
        current_state = states[-1]
        
        # if t < 5:  
        if t % 3 ==0:
        # if random.random() < 0.5:  
            action = 1
            next_state = [current_state[0] * W_u, current_state[1] + W_z]
        else: 
            action = 0
            next_state = current_state
            
            
        actions.append(action)
        states.append(next_state)
    return states[:-1], actions

# Feature map (counterfactuals)
def feature_map(state, action):
    tumor, side_effects = state
    if action == 1:  # Treatment
        return [tumor * W_u, side_effects + W_z]
    else:  # No treatment
        return [tumor, side_effects]

In [110]:
def batch_max_margin_cirl(states, actions, max_iters=10, epsilon=1e-3):
    """
    This function implements the CIRL algorithm for batched data.
    
    Args:
        states (list): List of states.
        actions (list): List of actions.
        max_iters (int, optional): Maximum number of iterations. Defaults to 10.
        epsilon (float, optional): Convergence threshold. Defaults to 1e-3.
    
    Returns:
        tuple: A tuple containing the final reward weights, policies, and feature expectations.
    """
    feature_expectations = []
    policies = []
    w = np.random.rand(2)  # Random initial reward weights

    # Compute expert features
    expert_features = np.mean([feature_map(s, a) for s, a in zip(states, actions)], axis=0)

    for _ in range(max_iters):
        # Compute feature expectations for current policy
        policy_features = np.array([
            np.dot(w, feature_map(state, np.argmax([np.dot(w, feature_map(state, a)) for a in [0, 1]])))
            for state in states
        ])
        policy_features = np.mean(policy_features, axis=0)

        # Append to policies and feature expectations
        policies.append(w.copy())
        feature_expectations.append(policy_features)

        # Orthogonal projection
        difference = expert_features - policy_features
        projection = np.dot(difference, difference) / np.dot(difference, difference) if np.linalg.norm(difference) else 0
        w += projection * difference

        # Check convergence
        if np.linalg.norm(expert_features - policy_features) < epsilon:
            break

    return w, policies, feature_expectations# CIRL algorithm
def batch_max_margin_cirl(states, actions, max_iters=10, epsilon=1e-3):
    feature_expectations = []
    policies = []
    w = np.random.rand(2)  # Random initial reward weights
    expert_features = np.mean([feature_map(s, a) for s, a in zip(states, actions)], axis=0)

    for _ in range(max_iters):
        # Compute feature expectations for current policy
        policy_features = []
        for state in states:
            action_values = [np.dot(w, feature_map(state, a)) for a in [0, 1]]
            best_action = np.argmax(action_values)
            policy_features.append(feature_map(state, best_action))
        policy_features = np.mean(policy_features, axis=0)

        # Append to policies and feature expectations
        policies.append(w.copy())
        for policy in policies:
            print(policy)
        feature_expectations.append(policy_features)

        # Orthogonal projection
        difference = expert_features - policy_features
        projection = np.dot(difference, difference) / np.dot(difference, difference)
        
        w += projection * difference

        # Check convergence
        if np.linalg.norm(expert_features - policy_features) < epsilon:
            break

    return w, policies, feature_expectations


In [111]:
# Run the toy example
states, actions = simulate_data()
print(states)
print(actions)

[[10, 0], [8.0, 0.2], [8.0, 0.2], [8.0, 0.2], [6.4, 0.4], [6.4, 0.4], [6.4, 0.4], [5.120000000000001, 0.6000000000000001], [5.120000000000001, 0.6000000000000001], [5.120000000000001, 0.6000000000000001], [4.096000000000001, 0.8], [4.096000000000001, 0.8], [4.096000000000001, 0.8], [3.276800000000001, 1.0], [3.276800000000001, 1.0], [3.276800000000001, 1.0], [2.621440000000001, 1.2], [2.621440000000001, 1.2], [2.621440000000001, 1.2], [2.097152000000001, 1.4]]
[1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0]


In [112]:
reward_weights, policies, feature_expectations = batch_max_margin_cirl(states, actions)



[0.38153405 0.83379892]
[0.38153405 0.83379892]
[0.00736317 0.89379892]
[0.38153405 0.83379892]
[0.00736317 0.89379892]
[0.61861949 0.76379892]
[0.38153405 0.83379892]
[0.00736317 0.89379892]
[0.61861949 0.76379892]
[0.22347709 0.83379892]
[0.38153405 0.83379892]
[0.00736317 0.89379892]
[0.61861949 0.76379892]
[0.22347709 0.83379892]
[0.02625341 0.83379892]
[0.38153405 0.83379892]
[0.00736317 0.89379892]
[0.61861949 0.76379892]
[0.22347709 0.83379892]
[0.02625341 0.83379892]
[0.63750973 0.70379892]
[0.38153405 0.83379892]
[0.00736317 0.89379892]
[0.61861949 0.76379892]
[0.22347709 0.83379892]
[0.02625341 0.83379892]
[0.63750973 0.70379892]
[0.24236733 0.77379892]
[0.38153405 0.83379892]
[0.00736317 0.89379892]
[0.61861949 0.76379892]
[0.22347709 0.83379892]
[0.02625341 0.83379892]
[0.63750973 0.70379892]
[0.24236733 0.77379892]
[-0.05316035  0.80379892]
[0.38153405 0.83379892]
[0.00736317 0.89379892]
[0.61861949 0.76379892]
[0.22347709 0.83379892]
[0.02625341 0.83379892]
[0.63750973 0.

In [109]:
print("Learned Reward Weights:", reward_weights)

Learned Reward Weights: [0.17007761 0.73421491]
