In [None]:
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
from datetime import datetime
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def initalize(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

def make_env(seed):
    env = gym.make("Taxi-v3")
    env.reset(seed=seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env

def preprocess_state_with_walls(state, env):
    row, col, pass_loc, dest_idx = env.unwrapped.decode(state)
    row_onehot = np.eye(5)[row]
    col_onehot = np.eye(5)[col]
    passloc_onehot = np.eye(5)[pass_loc]
    dest_onehot = np.eye(4)[dest_idx]
    state_vec = np.concatenate([row_onehot, col_onehot, passloc_onehot, dest_onehot])
    return state_vec

def save_quadratic_state(model, folder, tag):
    W = model.W.data.cpu().numpy()

    plt.figure(figsize=(6, 5))
    plt.imshow(W, cmap='coolwarm', aspect='auto')
    plt.colorbar()
    plt.title(f"Quadratic Weight Matrix W @ Episode {tag}")
    plt.xlabel("Feature Index")
    plt.ylabel("Feature Index")
    plt.tight_layout()
    plt.savefig(os.path.join(folder, f"quadratic_W_{tag}.png"))
    plt.close()

    np.savetxt(os.path.join(folder, f"quadratic_W_{tag}.csv"), W, delimiter=',')

    torch.save(model.state_dict(), os.path.join(folder, f"quadratic_model_{tag}.pth"))

    print(f"[Saved] Quadratic model + W @ episode {tag}")


class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super().__init__()
        self.structure = nn.Sequential(
            nn.Linear(state_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_size)
        )

    def forward(self, x):
        return self.structure(x)

class QuadraticValue(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.W = nn.Parameter(torch.zeros(state_dim, state_dim))  # initialize with 0

    def forward(self, x):
        # x: (batch_size, state_dim)
        # V(s) = x^T W x
        return torch.einsum('bi,ij,bj->b', x, self.W, x)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)


def train_q_with_fixed_quadrature(model, target_net, value_model, episodes, alpha, gamma, epsilon, epsilon_decay,
                                      min_epsilon, optimizer, memory, batch_size, env, rewards, print_every=10):
    current_epsilon = epsilon

    for episode in range(episodes):
        state, _ = env.reset()
        done = False
        total_reward = 0

        while not done:
            state_vec = preprocess_state_with_walls(state, env)
            state_tensor = torch.from_numpy(state_vec.astype(np.float32)).unsqueeze(0).to(device)

            if random.random() < current_epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    q_vals = model(state_tensor)
                action = q_vals.argmax().item()

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward

            memory.push(state, action, reward, next_state, done)
            state = next_state

            if len(memory) >= batch_size:
                s_batch, a_batch, r_batch, ns_batch, d_batch = memory.sample(batch_size)

                s_array = np.array([preprocess_state_with_walls(s, env) for s in s_batch], dtype=np.float32)
                s_tensor = torch.from_numpy(s_array).to(device)
                a_tensor = torch.tensor(a_batch, dtype=torch.int64).unsqueeze(1).to(device)
                r_tensor = torch.tensor(r_batch, dtype=torch.float32).unsqueeze(1).to(device)
                ns_array = np.array([preprocess_state_with_walls(ns, env) for ns in ns_batch], dtype=np.float32)
                ns_tensor = torch.from_numpy(ns_array).to(device)
                d_tensor = torch.tensor(d_batch, dtype=torch.float32).unsqueeze(1).to(device)

                q_values = model(s_tensor)
                q_selected = q_values.gather(1, a_tensor)

                with torch.no_grad():
                    q_next = target_net(ns_tensor)
                    max_q_next = q_next.max(1)[0].unsqueeze(1)

                    v_current_estimate = value_model(s_tensor)
                    v_next_estimate = value_model(ns_tensor)
                    shaping_bonus = (gamma * v_next_estimate - v_current_estimate)
                    shaped_reward = r_tensor + shaping_bonus

                target = shaped_reward + gamma * max_q_next * (1 - d_tensor)
                # print(q_selected.shape)
                # print(target.shape)

                loss = nn.MSELoss()(q_selected, target)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        rewards.append(total_reward)
        current_epsilon = max(min_epsilon, current_epsilon * epsilon_decay)

        if episode % print_every == 0 or episode == episodes - 1:
            target_net.load_state_dict(model.state_dict())
            print(f"  [Q] Episode {episode:4d} | Reward: {total_reward:4d} | Epsilon: {current_epsilon:.3f}")

    return current_epsilon

def train_quadratic_value_supervised(feature_value_data, value_model, epochs=100, lr=0.01):
    features = np.array([f for f, _ in feature_value_data], dtype=np.float32)
    values = np.array([v for _, v in feature_value_data], dtype=np.float32).reshape(-1, 1)

    # print(values.shape)
    # print(features.shape)

    if len(features) == 0:
        print("  [Quadratic] No data to train on. Skipping.")
        return

    X = torch.from_numpy(features).to(device)
    y = torch.from_numpy(values).to(device)

    optimizer = optim.Adam(value_model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    value_model.train()
    for epoch in range(epochs):
        pred = value_model(X)
        loss = loss_fn(pred, y.squeeze(1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    value_model.eval()
    print(f"  [Quadratic] Trained for {epochs} epochs. Final Loss: {loss.item():.4f}")

def train_bilevel_vrail_quadratic(cycles=20, q_train_episodes=100, self_attention_epochs=50,
                                      alpha=0.1, gamma=0.98, epsilon=1.0, epsilon_decay=0.995, min_epsilon=0.01,
                                      dqn_lr=0.001, self_attention_lr=0.01, batch_size=64, memory_size=50000,
                                      seed=42, data_collection_sample_size=2000):
    initalize(seed)
    env = make_env(seed)
    state_size = 19
    action_size = env.action_space.n

    model = DQN(state_size, action_size).to(device)
    target_net = DQN(state_size, action_size).to(device)
    target_net.load_state_dict(model.state_dict())
    target_net.eval()

    value_model = QuadraticValue(state_size).to(device)
    value_model.eval()

    dqn_optimizer = optim.Adam(model.parameters(), lr=dqn_lr)
    memory = ReplayBuffer(memory_size)
    rewards = []
    feature_value_data = []
    current_epsilon = epsilon

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    save_path = "results_quadratic"
    folder_name = f"{timestamp}_seed{seed}"
    save_dir = os.path.join(save_path, folder_name)
    os.makedirs(save_dir, exist_ok=True)

    for cycle in range(cycles):
        print(f"\n🔁 Bi-level cycle {cycle + 1}/{cycles}")

        model.train()
        current_epsilon = train_q_with_fixed_quadrature(
            model, target_net, value_model, q_train_episodes, alpha, gamma,
            current_epsilon, epsilon_decay, min_epsilon, dqn_optimizer, memory,
            batch_size, env, rewards)
        model.eval()
        target_net.load_state_dict(model.state_dict())

        collected_data = []
        sample_size = min(len(memory), data_collection_sample_size)
        if sample_size > 0:
            s_batch, _, _, _, _ = zip(*random.sample(memory.buffer, sample_size))
            s_array = np.array([preprocess_state_with_walls(s, env) for s in s_batch], dtype=np.float32)
            s_tensor = torch.from_numpy(s_array).to(device)

            with torch.no_grad():
                v_estimates = model(s_tensor).max(1)[0].cpu().numpy()

            for feature, value in zip(s_array, v_estimates):
                collected_data.append((feature, value))

        feature_value_data.extend(collected_data)
        if len(feature_value_data) > 10000:
            feature_value_data = feature_value_data[-10000:]

        print(f"  [Cycle {cycle+1}] Collected {len(collected_data)} data points. Total: {len(feature_value_data)}")
        train_quadratic_value_supervised(feature_value_data, value_model,
                                      epochs=self_attention_epochs, lr=self_attention_lr)

        if (cycle + 1) in [5, 10, 15, cycles]:
            episode_tag = (cycle + 1) * q_train_episodes  # ex) 500, 1000, 1500, cycles
            save_quadratic_state(value_model, save_dir, f"{episode_tag}")

    env.close()
    settings = dict(cycles=cycles, DQN_episodes_per_cycle=q_train_episodes,
                    self_attention_epochs=self_attention_epochs, alpha=alpha, gamma=gamma,
                    epsilon=epsilon, epsilon_decay=epsilon_decay, min_epsilon=min_epsilon,
                    dqn_lr=dqn_lr, self_attention_lr=self_attention_lr, batch_size=batch_size,
                    memory_size=memory_size, seed=seed, data_collection_sample_size=data_collection_sample_size)


    torch.save(model.state_dict(), os.path.join(save_dir, "dqn_model.pth"))
    torch.save(value_model.state_dict(), os.path.join(save_dir, "quadratic_value_model.pth"))

    np.savetxt(os.path.join(save_dir, "quadratic_W.csv"), value_model.W.data.cpu().numpy(), delimiter=",")
    pd.DataFrame({'episode': list(range(len(rewards))), 'reward': rewards}).to_csv(
        os.path.join(save_dir, "rewards.csv"), index=False)

    plt.figure(figsize=(10, 5))
    plt.plot(rewards, label="Reward")
    plt.plot(moving_average(rewards), label="Moving Avg (50)")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.title(f"Bi-level VRAIL (seed={seed})")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, "reward_plot.png"))
    plt.close()

    return rewards, model, value_model, settings


def moving_average(data, window=50):
    return np.convolve(data, np.ones(window) / window, mode='valid')


In [None]:
for s in range(10):
    trained_rewards, trained_dqn_model, trained_value_model, training_settings = \
        train_bilevel_vrail_quadratic(seed=s, cycles=20)


🔁 Bi-level cycle 1/20
  [Q] Episode    0 | Reward: -875 | Epsilon: 0.995
  [Q] Episode   10 | Reward: -650 | Epsilon: 0.946
  [Q] Episode   20 | Reward: -677 | Epsilon: 0.900
  [Q] Episode   30 | Reward: -605 | Epsilon: 0.856
  [Q] Episode   40 | Reward: -731 | Epsilon: 0.814
  [Q] Episode   50 | Reward: -704 | Epsilon: 0.774
  [Q] Episode   60 | Reward: -704 | Epsilon: 0.737
  [Q] Episode   70 | Reward: -506 | Epsilon: 0.701
  [Q] Episode   80 | Reward: -560 | Epsilon: 0.666
  [Q] Episode   90 | Reward: -524 | Epsilon: 0.634
  [Q] Episode   99 | Reward: -470 | Epsilon: 0.606
  [Cycle 1] Collected 2000 data points. Total: 2000
  [Quadratic] Trained for 50 epochs. Final Loss: 3.1450

🔁 Bi-level cycle 2/20
  [Q] Episode    0 | Reward: -218 | Epsilon: 0.603
  [Q] Episode   10 | Reward: -524 | Epsilon: 0.573
  [Q] Episode   20 | Reward: -515 | Epsilon: 0.545
  [Q] Episode   30 | Reward: -335 | Epsilon: 0.519
  [Q] Episode   40 | Reward: -443 | Epsilon: 0.493
  [Q] Episode   50 | Reward: -