In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
import tensorflow as tf

# Assuming Connect4Env, CNN_QNetwork, and PERMemory are already defined

def compute_reward(board, player):
    opponent = 2 if player == 1 else 1
    reward = 0
    for r in range(6):
        for c in range(4):
            line = board[r, c:c+4]
            if np.sum(line == player) == 3 and np.sum(line == 0) == 1:
                reward += 0.5
            if np.sum(line == opponent) == 3 and np.sum(line == 0) == 1:
                reward += 0.2
    return reward

def train_dqn(episodes=1000):
    env = Connect4Env()
    action_size = 7
    model = CNN_QNetwork(action_size)
    target_model = CNN_QNetwork(action_size)
    optimizer = tf.keras.optimizers.Adam(1e-3)
    model.compile(optimizer=optimizer, loss='mse')

    memory = PERMemory(10000)
    gamma = 0.99
    epsilon = 1.0
    epsilon_min = 0.1
    epsilon_decay = 0.995
    batch_size = 64
    sync_every = 10

    rewards_list = []
    epsilons = []
    moving_avg_rewards = []
    q_values_per_episode = []

    for ep in range(episodes):
        state = env.reset()
        total_reward = 0
        done = False

        while not done:
            if np.random.rand() < epsilon:
                action = random.choice(env.get_valid_actions())
            else:
                q_values = model(np.expand_dims(state, 0))[0].numpy()
                valid_actions = env.get_valid_actions()
                q_values_filtered = [q_values[a] if a in valid_actions else -np.inf for a in range(action_size)]
                action = np.argmax(q_values_filtered)

            next_state, reward, done = env.step(1, action)
            shaped_reward = reward + compute_reward(next_state, 1)
            total_reward += shaped_reward

            q_next = target_model(np.expand_dims(next_state, 0))[0].numpy()
            target = shaped_reward + gamma * np.max(q_next) * (1 - int(done))
            pred = model(np.expand_dims(state, 0))[0][action].numpy()
            error = target - pred

            memory.add((state, action, shaped_reward, next_state, done), error)
            state = next_state

            if len(memory.buffer) >= batch_size:
                batch, indices = memory.sample(batch_size)
                states, actions, rewards, next_states, dones = zip(*batch)
                states = np.array(states)
                next_states = np.array(next_states)

                q_next = target_model(next_states).numpy()
                q_target = model(states).numpy()

                td_errors = []
                for i in range(batch_size):
                    target_q = rewards[i] + gamma * np.max(q_next[i]) * (1 - int(dones[i]))
                    td_errors.append(target_q - q_target[i, actions[i]])
                    q_target[i, actions[i]] = target_q

                memory.update(indices, td_errors)
                model.train_on_batch(states, q_target)

        if ep % sync_every == 0:
            target_model.set_weights(model.get_weights())

        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        # Track stats 📊
        rewards_list.append(total_reward)
        epsilons.append(epsilon)
        if len(rewards_list) >= 10:
            moving_avg_rewards.append(np.mean(rewards_list[-10:]))
        else:
            moving_avg_rewards.append(total_reward)

        # Q-values of valid actions at start of episode
        with tf.device('/cpu:0'):  # make sure this doesn't throw CUDA memory issues
            q_vals = model(np.expand_dims(state, 0))[0].numpy()
            avg_q = np.mean([q_vals[a] for a in env.get_valid_actions()])
            q_values_per_episode.append(avg_q)

        print(f"Episode {ep}, Total Reward: {total_reward:.2f}, Epsilon: {epsilon:.2f}")

    # Save the model after training
    model.save('trained_connect4_model.h5')

    # 📈 Plotting
    def plot_training_curves():
        plt.figure(figsize=(16, 10))

        plt.subplot(2, 2, 1)
        plt.plot(rewards_list)
        plt.title('Total Reward per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.grid(True)

        plt.subplot(2, 2, 2)
        plt.plot(epsilons)
        plt.title('Epsilon Decay')
        plt.xlabel('Episode')
        plt.ylabel('Epsilon')
        plt.grid(True)

        plt.subplot(2, 2, 3)
        plt.plot(moving_avg_rewards)
        plt.title('Moving Average Reward (window=10)')
        plt.xlabel('Episode')
        plt.ylabel('Avg Reward')
        plt.grid(True)

        plt.subplot(2, 2, 4)
        plt.hist(rewards_list, bins=30)
        plt.title('Reward Distribution')
        plt.xlabel('Total Episode Reward')
        plt.ylabel('Frequency')
        plt.grid(True)

        plt.tight_layout()
        plt.show()

        # Additional Q-value tracking
        plt.figure()
        plt.plot(q_values_per_episode)
        plt.title("Average Q-Value per Episode")
        plt.xlabel("Episode")
        plt.ylabel("Q-Value")
        plt.grid(True)
        plt.show()

    plot_training_curves()


def play_with_trained_model():
    # Load the trained model
    model = tf.keras.models.load_model('trained_connect4_model.h5')

    env = Connect4Env()
    done = False
    state = env.reset()

    print("Let's play Connect 4! You are Player 1 (X), and the AI is Player 2 (O).")

    while not done:
        # Player 1 (User)
        action = int(input("Enter your move (0-6): "))
        state, reward, done = env.step(1, action)
        env.render()

        if done:
            print("Game Over!")
            if reward == 1:
                print("You win!")
            else:
                print("It's a tie!")
            break

        # AI's turn (Player 2)
        print("AI's turn...")
        q_values = model(np.expand_dims(state, 0))[0].numpy()
        valid_actions = env.get_valid_actions()
        q_values_filtered = [q_values[a] if a in valid_actions else -np.inf for a in range(7)]
        ai_action = np.argmax(q_values_filtered)

        state, reward, done = env.step(2, ai_action)
        env.render()

        if done:
            print("Game Over!")
            if reward == 1:
                print("AI wins!")
            else:
                print("It's a tie!")
            break

# Start training 🧠
train_dqn(1000)

# After training is complete, let the user play
play_with_trained_model()