In [40]:
import streamlit as st
import pandas as pd
import numpy as np
from jax import random, jit
import jax.numpy as jnp
from datetime import datetime

In [60]:
import json

with open("../data/question_answer_pairs.json", "r") as jsonFile:
    dataset = json.load(jsonFile)

In [61]:
len(dataset)

35

In [62]:
# Constants
n_arms = len(dataset) # Number of question types (arms)
context_dim = 3  # Number of context features
exploration_threshold = 5  # Number of exploratory questions before switching to exploitation


In [63]:
# Initialize parameters
weights = jnp.zeros((n_arms, context_dim))
learning_rate = 0.1
gamma = 0.95

In [64]:
# Initialize Q-values and counts
Q_values = jnp.zeros(n_arms)
counts = jnp.zeros(n_arms)

In [65]:
# Store last attempted timestamps for (type, level) combinations
last_attempted_timestamps = {(type_, level): None for type_ in range(6) for level in range(1, 6)}

# Initialize counters
total_questions = 0
correct_answers = 0
hints_used = 0

In [66]:
@jit
def contextual_bandit_selection(weights, context, exploration_count):
    if exploration_count < exploration_threshold:
        return np.random.randint(n_arms), jnp.zeros(n_arms)  # Random selection for exploration
    expected_rewards = jnp.dot(weights, context)
    return jnp.argmax(expected_rewards), expected_rewards

In [57]:
def calculate_time_since_last_attempt():
    # Calculate time since last attempt
    if total_questions == 0:
        return 0.0  # If never attempted, use 0
    else:
        # Assuming the last attempt timestamp is stored
        last_time = last_attempted_timestamps.get('last_attempt', datetime.now())
        time_since_last_attempt = (datetime.now() - last_time).total_seconds()  # in seconds
        return min(time_since_last_attempt / 3600.0, 24.0)  # Normalize to hours (capped at 24 hours)


In [50]:
@jit
def update_Q_values(arm_index, reward, Q_values, expected_rewards, gamma, learning_rate):
    best_next_reward = jnp.max(expected_rewards)
    Q_values = Q_values.at[arm_index].set(
        Q_values[arm_index] + learning_rate * (reward + gamma * best_next_reward - Q_values[arm_index])
    )
    return Q_values

In [None]:
"""
recent_performace: ration of correct answers
timestamp_since_last_question_practiced
number_of_hints_per_questions
"""

In [59]:
df.head(2)

Unnamed: 0,problem,level,type,solution,stage,source
0,Kevin Kangaroo begins hopping on a number line...,Level 5,Algebra,Kevin hops $1/3$ of the remaining distance wit...,train,MATH
1,The ratio of the areas of two squares is $\fra...,Level 4,Algebra,We start off by simplifying the ratio $\frac{1...,train,MATH


In [None]:
def run_contextual_bandit(recent_performance, time_of_day, session_count):
    global weights, Q_values, question_count
    
    # Prepare context features
    context = jnp.array([recent_performance, time_of_day, session_count, question_count])
    
    # Select question using contextual bandit
    arm_index, expected_rewards = contextual_bandit_selection(weights, context, question_count)
    
    selected_question = questions.iloc[arm_index]
    
    # Simulate user response (for demonstration purposes)
    correct = np.random.rand() < 0.5  # Randomly determining correctness for simulation
    reward = 1.0 if correct else 0.0  # Basic reward structure for demonstration

    # Update Q-values and weights
    Q_values = update_Q_values(arm_index, reward, Q_values, expected_rewards, gamma, learning_rate)
    weights = weights.at[arm_index].set(weights[arm_index] + learning_rate * reward * context)

    # Update the timestamp for the problem type and level
    question_count += 1

    return selected_question['question'], selected_question['solution'], correct


In [1]:
0/1

0.0

In [None]:
import jax.numpy as jnp
import jax.random as jrandom
import streamlit as st

# Constants
NUM_QUESTIONS = 35  # Total number of questions
NUM_ARS = 3  # Number of arms (question sets)
INITIAL_Q_VALUES = 0.5  # Initial Q-values
EXPLORATION_RATE = 0.2  # Probability of exploration

# Initialize Q-values and counts using JAX
Q_values = jnp.full((NUM_QUESTIONS, NUM_ARS), INITIAL_Q_VALUES)
counts = jnp.zeros((NUM_QUESTIONS, NUM_ARS))

# Function to select a question using epsilon-greedy strategy
def select_question(rng_key, context):
    if jrandom.uniform(rng_key) < EXPLORATION_RATE:
        # Explore: select a random arm
        arm_index = jrandom.randint(rng_key, (1,), 0, NUM_ARS)[0]
    else:
        # Exploit: select the best arm based on Q-values
        arm_index = jnp.argmax(Q_values[context])
    return arm_index

# Custom reward function based on your criteria
def custom_reward(is_correct, hints_used, timestamp_since_last, correct_ratio):
    if is_correct:
        base_reward = 1.0  # Base reward for a correct answer
        hint_penalty = 0.2 * hints_used  # Penalty per hint used (0-5)
        reward = base_reward - hint_penalty
        
        # Deduct penalty for attempting too soon
        if timestamp_since_last < 10:  # less than 10 minutes
            reward -= 0.5
        
        # Scale reward based on correct answer ratio
        reward += 0.5 * correct_ratio  # Scale with ratio (0-1)
        
    else:
        reward = 0.0  # No reward for incorrect answers
    
    return max(reward, 0.0)  # Ensure reward is not negative

# Function to update Q-values using the reward received
def update_q_values(Q_values, counts, context, arm_index, reward):
    counts = counts.at[context, arm_index].add(1)  # Increment counts
    # Update Q-values using the formula
    Q_values = Q_values.at[context, arm_index].add((reward - Q_values[context, arm_index]) / counts[context, arm_index])
    return Q_values, counts

# Main function to run the Streamlit app
def main():
    st.title("Math Mentor AI - Contextual Bandit")
    
    # User inputs
    recent_performance = st.slider("Recent Performance Ratio (0-1)", 0.0, 1.0, 0.5)
    hints_used = st.slider("Number of Hints Used (0-5)", 0, 5, 2)
    timestamp_since_last = st.slider("Time Since Last Attempt (minutes)", 0, 30, 5)
    correct_ratio = st.slider("Correct Answer Ratio (0-1)", 0.0, 1.0, 0.5)

    # Prepare context features
    context = int(recent_performance * (NUM_QUESTIONS - 1))  # Map performance to context
    
    # Generate a random key for JAX
    rng_key = jrandom.PRNGKey(0)  # You might want to change the seed for each run
    arm_index = select_question(rng_key, context)  # Select question

    st.write(f"Suggested Question Index: {arm_index}")

    # Simulate user's answer
    is_correct = st.radio("Did you answer correctly?", ("Yes", "No"))
    is_correct = 1 if is_correct == "Yes" else 0

    # Calculate the reward based on user's response
    reward = custom_reward(is_correct, hints_used, timestamp_since_last, correct_ratio)
    
    # Update Q-values with the received reward
    global Q_values, counts
    Q_values, counts = update_q_values(Q_values, counts, context, arm_index, reward)

    st.write(f"Reward Received: {reward}")

# Run the Streamlit app
if __name__ == "__main__":
    main()


In [None]:




# Initialize Q-values and counts
# Q_values = jnp.zeros((NUM_CONTEXTS, NUM_ARMS))  # All Q-values initialized to 0
# counts = jnp.zeros((NUM_CONTEXTS, NUM_ARMS))  # All counts initialized to 0







# Main training loop
def train_bandit():
    rng_key = jrandom.PRNGKey(0)  # Random key for reproducibility
    global Q_values, counts
    
    for episode in range(NUM_EPISODES):
        context_index = jrandom.randint(rng_key, (1,), 0, NUM_CONTEXTS)[0]  # Randomly select a context
        arm_index = select_question(rng_key, Q_values, context_index)  # Select a question/arm
        
        # Simulate user response (for demonstration purposes)
        # In practice, replace this with actual user interactions and feedback
        is_correct = jrandom.uniform(rng_key) < 0.5  # Randomly simulating correctness
        hints_used = jrandom.randint(rng_key, (1,), 0, 3)[0]  # Randomly simulate hints used
        timestamp_since_last = jrandom.randint(rng_key, (1,), 1, 20)[0]  # Randomly simulate time since last question
        correct_ratio = 0.5  # Simulated correct ratio, should be calculated based on history
        
        # Calculate the reward based on the simulated user response
        reward = custom_reward(is_correct, hints_used, timestamp_since_last, correct_ratio)
        
        # Update Q-values and counts
        Q_values, counts = update_q_values(Q_values, counts, context_index, arm_index, reward)

    return Q_values, counts

# Execute the training
final_Q_values, final_counts = train_bandit()
print("Final Q-values:\n", final_Q_values)
print("Final counts:\n", final_counts)


In [None]:
import jax
import jax.numpy as jnp
from jax import random as jrandom

# Constants
NUM_CONTEXTS = 3  # Example number of contexts
NUM_ARMS = 35      # Example number of arms (questions)
EXPLORATION_RATE = 0.1  # Epsilon for the epsilon-greedy strategy

# Initialize Q-values and counts
Q_values = jnp.zeros((NUM_CONTEXTS, NUM_ARMS))  # All Q-values initialized to 0
counts = jnp.zeros((NUM_CONTEXTS, NUM_ARMS))  # All counts initialized to 0

# Function to select a question using epsilon-greedy strategy
@jax.jit
def select_question(rng_key, Q_values, context_index):
    if jrandom.uniform(rng_key) < EXPLORATION_RATE:
        arm_index = jrandom.randint(rng_key, (1,), 0, NUM_ARMS)[0]
    else:
        arm_index = jnp.argmax(Q_values[context_index])
    return arm_index

# Custom reward function based on user response and context
@jax.jit
def custom_reward(is_correct, hints_used, timestamp_since_last, correct_ratio):
    if is_correct:
        base_reward = 1.0
        hint_penalty = 0.2 * hints_used
        reward = base_reward - hint_penalty
        
        if timestamp_since_last < 10:
            reward -= 0.5
        
        reward += 0.5 * correct_ratio
    else:
        reward = 0.0
    
    return max(reward, 0.0)

# Function to update Q-values based on received reward
@jax.jit
def update_q_values(Q_values, counts, context_index, arm_index, reward):
    counts = counts.at[context_index, arm_index].add(1)
    Q_values = Q_values.at[context_index, arm_index].add((reward - Q_values[context_index, arm_index]) / counts[context_index, arm_index])
    return Q_values, counts

# Main training loop
def train_bandit():
    rng_key = jrandom.PRNGKey(0)  # Random key for reproducibility
    global Q_values, counts
    
    for episode in range(NUM_EPISODES):
        context_index = jrandom.randint(rng_key, (1,), 0, NUM_CONTEXTS)[0]  # Randomly select a context
        arm_index = select_question(rng_key, Q_values, context_index)  # Select a question/arm
        
        # Simulate user response (for demonstration purposes)
        # In practice, replace this with actual user interactions and feedback
        is_correct = jrandom.uniform(rng_key) < 0.5  # Randomly simulating correctness
        hints_used = jrandom.randint(rng_key, (1,), 0, 3)[0]  # Randomly simulate hints used
        timestamp_since_last = jrandom.randint(rng_key, (1,), 1, 20)[0]  # Randomly simulate time since last question
        correct_ratio = 0.5  # Simulated correct ratio, should be calculated based on history
        
        # Calculate the reward based on the simulated user response
        reward = custom_reward(is_correct, hints_used, timestamp_since_last, correct_ratio)
        
        # Update Q-values and counts
        Q_values, counts = update_q_values(Q_values, counts, context_index, arm_index, reward)

    return Q_values, counts

# Execute the training
final_Q_values, final_counts = train_bandit()
print("Final Q-values:\n", final_Q_values)
print("Final counts:\n", final_counts)


In [1]:
import pandas as pd

# Create a sample DataFrame
data = {
    'user_id': ['user1', 'user2', 'user3'],
    'recent_performance': [0.5, 0.7, 0.8],
    'hints_used': [2, 1, 3],
    'timestamp_since_last': [5, 3, 2],
    'correct_ratio': [0.5, 0.8, 0.9]
}

df = pd.DataFrame(data)

# Save the DataFrame to a CSV file
USER_INITIAL_DATA_FILE = 'user_initial_data.csv'  # Specify your file path
df.to_csv(USER_INITIAL_DATA_FILE, index=False)

print(f"CSV file '{USER_INITIAL_DATA_FILE}' created successfully.")


CSV file 'user_initial_data.csv' created successfully.
