In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
from tmrl import get_environment
import os
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


class SimpleDQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(SimpleDQN, self).__init__()
        
        # Simplified convolutional layers
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        
        # Calculate the size of flattened features
        conv_out_size = self._get_conv_out(h, w)
        
        # Simplified fully connected layers
        self.fc1 = nn.Linear(conv_out_size + 9, 512)
        self.fc2 = nn.Linear(512, outputs)

    def _get_conv_out(self, h, w):
        o = F.relu(self.conv1(torch.zeros(1, 4, h, w)))
        o = F.relu(self.conv2(o))
        o = F.relu(self.conv3(o))
        return int(np.prod(o.size()))

    def forward(self, x, additional_inputs):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.cat((x, additional_inputs), dim=1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class SimpleDQNAgent:
    def __init__(self, n_actions, memory_size=100000, batch_size=32, gamma=0.99, epsilon_start=1.0, epsilon_final=0.01, epsilon_decay=0.995, learning_rate=0.001):
        self.n_actions = n_actions
        self.memory = deque(maxlen=memory_size)
        self.batch_size = batch_size
        self.gamma = gamma
        self.epsilon = epsilon_start
        self.epsilon_final = epsilon_final
        self.epsilon_decay = epsilon_decay
        self.steps = 0

        self.policy_net = SimpleDQN(64, 64, n_actions).to(device)
        self.target_net = SimpleDQN(64, 64, n_actions).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)

    def select_action(self, state):
        if random.random() > self.epsilon:
            with torch.no_grad():
                image = torch.FloatTensor(state[3]).unsqueeze(0).to(device)
                additional = torch.FloatTensor(np.concatenate([state[0], state[1], state[2], state[4], state[5]])).unsqueeze(0).to(device)
                return self.policy_net(image, additional).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.n_actions)]], device=device, dtype=torch.long)

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*batch)

        state_image_batch = torch.FloatTensor(np.stack([s[3] for s in state_batch])).to(device)
        state_additional_batch = torch.FloatTensor(np.stack([np.concatenate([s[0], s[1], s[2], s[4], s[5]]) for s in state_batch])).to(device)

        next_state_image_batch = torch.FloatTensor(np.stack([s[3] for s in next_state_batch])).to(device)
        next_state_additional_batch = torch.FloatTensor(np.stack([np.concatenate([s[0], s[1], s[2], s[4], s[5]]) for s in next_state_batch])).to(device)

        action_batch = torch.LongTensor(action_batch).to(device)
        reward_batch = torch.FloatTensor(reward_batch).to(device)
        done_batch = torch.FloatTensor(done_batch).to(device)

        q_values = self.policy_net(state_image_batch, state_additional_batch).gather(1, action_batch.unsqueeze(1))
        next_q_values = self.target_net(next_state_image_batch, next_state_additional_batch).max(1)[0].detach()
        expected_q_values = reward_batch + (1 - done_batch) * self.gamma * next_q_values

        loss = F.smooth_l1_loss(q_values, expected_q_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        self.epsilon = max(self.epsilon_final, self.epsilon * self.epsilon_decay)

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def save_checkpoint(self, episode, directory="checkpoints_simple_left_turn"):
        if not os.path.exists(directory):
            os.makedirs(directory)
        checkpoint = {
            'episode': episode,
            'model_state_dict': self.policy_net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epsilon': self.epsilon,
            'steps': self.steps
        }
        filename = os.path.join(directory, f"checkpoint_episode_{episode}_step_{self.steps}.pth")
        torch.save(checkpoint, filename)
        print(f"Checkpoint saved: {filename}")

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)
        self.policy_net.load_state_dict(checkpoint['model_state_dict'])
        self.target_net.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epsilon = checkpoint['epsilon']
        self.steps = checkpoint['steps']
        return checkpoint['episode']
def preprocess_observation(obs):
    speed = np.array(obs[0]).flatten()
    steering = np.array(obs[1]).flatten()
    gyro = np.array(obs[2]).flatten()
    image = np.array(obs[3]).astype(np.float32) / 255.0
    prev_action = np.array(obs[4]).flatten()
    action = np.array(obs[5]).flatten()
    
    speed = np.clip(speed / 300.0, -1, 1)
    steering = np.clip(steering / np.pi, -1, 1)
    gyro = np.clip(gyro / np.pi, -1, 1)
    prev_action = np.clip(prev_action, -1, 1)
    action = np.clip(action, -1, 1)
    
    return (speed, steering, gyro, image, prev_action, action)

def env_action_to_agent_action(env_action, n_actions):
    if isinstance(env_action, (int, np.integer)):
        return env_action
    elif isinstance(env_action, (float, np.float32, np.float64)):
        return int((env_action + 1) * (n_actions - 1) / 2)
    elif isinstance(env_action, np.ndarray):
        return np.argmax(env_action)
    else:
        raise ValueError(f"Unexpected action type: {type(env_action)}")

def agent_action_to_env_action(agent_action, n_actions):
    steering = (agent_action * 2 / (n_actions - 1)) - 1
    throttle = 1.0
    brake = 0.0
    return [throttle, brake, steering]

def detect_crash(obs, prev_obs, speed_threshold=1.0):
    if prev_obs is None:
        return False
    
    current_speed = np.linalg.norm(obs[0])
    prev_speed = np.linalg.norm(prev_obs[0])
    
    if prev_speed - current_speed > speed_threshold:
        return True
    
    return False

def train(env, agent, num_episodes, max_steps_per_episode):
    for episode in range(num_episodes):
        obs, _ = env.reset()
        state = preprocess_observation(obs)
        total_reward = 0
        episode_start_time = time.time()
        prev_obs = None
        crashes = 0

        for step in range(max_steps_per_episode):
            agent_action = agent.select_action(state)
            env_action = agent_action_to_env_action(agent_action.item(), agent.n_actions)
            next_obs, env_reward, terminated, truncated, _ = env.step(env_action)
            next_state = preprocess_observation(next_obs)
            
            if detect_crash(next_obs, prev_obs):
                crashes += 1
                env_reward-= 10
            done = terminated or truncated
            
            elapsed_time = time.time() - episode_start_time
            time_factor = 1 / (elapsed_time + 1)
            reward = (env_reward)
            
            agent.remember(state, agent_action.item(), reward, next_state, done)
            agent.replay()
            
            state = next_state
            prev_obs = next_obs
            total_reward += reward

            agent.steps += 1

            if agent.steps % 1000 == 0:
                agent.save_checkpoint(episode)

            if done:
                break
        # total_reward-= 20*crashes
        if episode % 10 == 0:
            agent.update_target_network()

        print(f"Episode {episode}, Total Reward: {total_reward:.2f}, Crashes: {crashes}, Epsilon: {agent.epsilon:.4f}, Time: {elapsed_time:.2f}s")

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
from tmrl import get_environment
import os
import time
import wandb
from collections import deque, defaultdict
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# SimpleDQN and SimpleDQNAgent classes remain the same as in the provided code

class Router(nn.Module):
    def __init__(self, num_experts):
        super(Router, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        
        # Calculate the size of flattened features
        self.conv_out_size = self._get_conv_out(64, 64)
        
        self.fc1 = nn.Linear(self.conv_out_size + 9, 512)  # +9 for additional inputs
        self.fc2 = nn.Linear(512, num_experts)

    def _get_conv_out(self, h, w):
        o = F.relu(self.conv1(torch.zeros(1, 4, h, w)))
        o = F.relu(self.conv2(o))
        o = F.relu(self.conv3(o))
        return int(np.prod(o.size()))

    def forward(self, x, additional_inputs):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.cat((x, additional_inputs), dim=1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class MixtureOfExpertsDQN:
    def __init__(self, expert_agents, learning_rate=0.001, gamma=0.99, epsilon_start=1.0, epsilon_final=0.01, epsilon_decay=0.995):
        self.expert_agents = expert_agents
        self.num_experts = len(expert_agents)
        self.router = Router(self.num_experts).to(device)
        self.target_router = Router(self.num_experts).to(device)
        self.target_router.load_state_dict(self.router.state_dict())
        self.optimizer = optim.Adam(self.router.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.epsilon = epsilon_start
        self.epsilon_final = epsilon_final
        self.epsilon_decay = epsilon_decay
        self.memory = deque(maxlen=100000)

    def select_action(self, state):
        if random.random() > self.epsilon:
            with torch.no_grad():
                image = torch.FloatTensor(state[3]).unsqueeze(0).to(device)
                additional_inputs = torch.FloatTensor(np.concatenate([state[0], state[1], state[2], state[4], state[5]])).unsqueeze(0).to(device)
                expert_probs = self.router(image, additional_inputs)
                expert_idx = expert_probs.argmax().item()
        else:
            expert_idx = random.randint(0, self.num_experts - 1)
        
        selected_expert = self.expert_agents[expert_idx]
        return selected_expert.select_action(state), expert_idx

    def remember(self, state, expert_idx, reward, next_state, done):
        self.memory.append((state, expert_idx, reward, next_state, done))

    def train_step(self, batch_size):
        if len(self.memory) < batch_size:
            return None

        batch = random.sample(self.memory, batch_size)
        states, expert_indices, rewards, next_states, dones = zip(*batch)

        states_image = torch.FloatTensor(np.array([state[3] for state in states])).to(device)
        states_additional = torch.FloatTensor(np.array([np.concatenate([state[0], state[1], state[2], state[4], state[5]]) for state in states])).to(device)
        next_states_image = torch.FloatTensor(np.array([state[3] for state in next_states])).to(device)
        next_states_additional = torch.FloatTensor(np.array([np.concatenate([state[0], state[1], state[2], state[4], state[5]]) for state in next_states])).to(device)
        expert_indices = torch.LongTensor(expert_indices).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        dones = torch.FloatTensor(dones).to(device)

        current_q_values = self.router(states_image, states_additional).gather(1, expert_indices.unsqueeze(1))
        next_q_values = self.target_router(next_states_image, next_states_additional).max(1)[0].detach()
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        loss = F.smooth_l1_loss(current_q_values, target_q_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.epsilon = max(self.epsilon_final, self.epsilon * self.epsilon_decay)

        return loss.item()

    def update_target_network(self):
        self.target_router.load_state_dict(self.router.state_dict())

    def save_model(self, path):
        if not os.path.exists(path):
            os.makedirs(path)
        torch.save({
            'router_state_dict': self.router.state_dict(),
            'target_router_state_dict': self.target_router.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epsilon': self.epsilon,
        }, os.path.join(path, 'moe_model.pth'))

    def load_model(self, path):
        checkpoint = torch.load(os.path.join(path, 'moe_model.pth'))
        self.router.load_state_dict(checkpoint['router_state_dict'])
        self.target_router.load_state_dict(checkpoint['target_router_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epsilon = checkpoint['epsilon']
def train_mixture_of_experts(env, moe_agent, num_episodes, max_steps_per_episode, batch_size):
    total_crashes = 0
    total_time = 0
    total_loss = 0
    expert_usage = [0] * moe_agent.num_experts

    for episode in range(num_episodes):
        obs, _ = env.reset()
        state = preprocess_observation(obs)
        total_reward = 0
        episode_crashes = 0
        episode_loss = 0
        episode_start_time = time.time()
        prev_obs = None
        episode_expert_usage = [0] * moe_agent.num_experts
        crash_reward = 0

        for step in range(max_steps_per_episode):
            agent_action, expert_idx = moe_agent.select_action(state)
            expert_usage[expert_idx] += 1
            episode_expert_usage[expert_idx] += 1
            env_action = agent_action_to_env_action(agent_action.item(), moe_agent.expert_agents[0].n_actions)
            next_obs, env_reward, terminated, truncated, _ = env.step(env_action)
            next_state = preprocess_observation(next_obs)
            done = terminated or truncated

            if detect_crash(next_obs, prev_obs):
                episode_crashes += 1
                crash_penalty = -0.2  # Penalty for crashing
                env_reward += crash_penalty
                crash_reward += crash_penalty

            moe_agent.remember(state, expert_idx, env_reward, next_state, done)

            loss = moe_agent.train_step(batch_size)
            if loss is not None:
                episode_loss += loss

            state = next_state
            prev_obs = next_obs
            total_reward += env_reward

            if done:
                break

        episode_time = time.time() - episode_start_time
        total_time += episode_time
        total_crashes += episode_crashes
        total_loss += episode_loss

        if episode % 10 == 0:
            moe_agent.update_target_network()

        # Calculate expert usage ratio
        total_expert_usage = sum(episode_expert_usage)
        expert_usage_ratio = [usage / total_expert_usage for usage in episode_expert_usage]
        if (episode + 1) % 100 == 0:
            save_path = f"router_model_episode_{episode + 1}"
            moe_agent.save_model(save_path)
            print(f"Router model saved at episode {episode + 1}")
        # Print episode information including loss and crash reward
        print(f"Episode {episode}, Total Reward: {total_reward:.2f}, Loss: {episode_loss:.4f}, Crashes: {episode_crashes}, Crash Reward: {crash_reward:.2f}, Time: {episode_time:.2f}s")

        # Log metrics to wandb
        wandb.log({
            "episode": episode,
            "total_reward": total_reward,
            "loss": episode_loss,
            "crashes": episode_crashes,
            "crash_reward": crash_reward,
            "time": episode_time,
            "average_reward": total_reward / (step + 1),
            "total_crashes": total_crashes,
            "total_time": total_time,
            "episode_duration": episode_time,
            "experts_used": sum(1 for usage in episode_expert_usage if usage > 0),
        })

        for i, usage in enumerate(episode_expert_usage):
            wandb.log({
                f"expert_{i}_usage": usage,
                f"expert_{i}_usage_ratio": expert_usage_ratio[i]
            })

    # Print final training statistics
    print(f"Training completed. Total crashes: {total_crashes}, Total time: {total_time:.2f}s, Average loss: {total_loss/num_episodes:.4f}")

def test_mixture_of_experts(env, moe_agent, num_episodes, max_steps_per_episode):
    total_rewards = []
    total_crashes = 0
    total_time = 0
    expert_usage = [0] * moe_agent.num_experts
    
    for episode in range(num_episodes):
        obs, _ = env.reset()
        state = preprocess_observation(obs)
        total_reward = 0
        episode_crashes = 0
        episode_start_time = time.time()
        prev_obs = None
        episode_expert_usage = [0] * moe_agent.num_experts
        
        for step in range(max_steps_per_episode):
            agent_action, expert_idx = moe_agent.select_action(state)
            expert_usage[expert_idx] += 1
            episode_expert_usage[expert_idx] += 1
            env_action = agent_action_to_env_action(agent_action.item(), moe_agent.expert_agents[0].n_actions)
            next_obs, reward, terminated, truncated, _ = env.step(env_action)
            next_state = preprocess_observation(next_obs)
            done = terminated or truncated

            if detect_crash(next_obs, prev_obs):
                episode_crashes += 1

            state = next_state
            prev_obs = next_obs
            total_reward += reward

            if done:
                break

        episode_time = time.time() - episode_start_time
        total_time += episode_time
        total_crashes += episode_crashes
        total_rewards.append(total_reward)

        # Calculate expert usage ratio
        total_expert_usage = sum(episode_expert_usage)
        expert_usage_ratio = [usage / total_expert_usage for usage in episode_expert_usage]

        # Log test metrics to wandb
        wandb.log({
            "test_episode": episode,
            "test_total_reward": total_reward,
            "test_crashes": episode_crashes,
            "test_time": episode_time,
            "test_average_reward": total_reward / (step + 1),
            "test_total_crashes": total_crashes,
            "test_total_time": total_time,
            "test_episode_duration": episode_time,
            "test_experts_used": sum(1 for usage in episode_expert_usage if usage > 0),
        })

        for i, usage in enumerate(episode_expert_usage):
            wandb.log({
                f"test_expert_{i}_usage": usage,
                f"test_expert_{i}_usage_ratio": expert_usage_ratio[i]
            })

        print(f"Test Episode {episode}, Total Reward: {total_reward:.2f}, Crashes: {episode_crashes}, Time: {episode_time:.2f}s")

    avg_reward = sum(total_rewards) / len(total_rewards)
    print(f"Average Reward over {num_episodes} episodes: {avg_reward:.2f}")
    wandb.log({"test_average_reward_overall": avg_reward})

# The test_mixture_of_experts function remains the same

import warnings
warnings.filterwarnings("ignore")

if __name__ == "__main__":
    env = get_environment()
    n_actions = 3

    # Initialize wandb
    wandb.init(project="moe-router-scratch")

    # Load your three pre-trained expert agents
    expert1 = SimpleDQNAgent(n_actions)
    expert1.load_checkpoint("checkpoints_simple/checkpoint_episode_177_step_30000.pth")
    expert2 = SimpleDQNAgent(n_actions)
    expert2.load_checkpoint("checkpoints_simple_left_turn/checkpoint_episode_179_step_78000.pth")
    expert3 = SimpleDQNAgent(n_actions)
    expert3.load_checkpoint("checkpoints_simple_right_turn/checkpoint_episode_179_step_34000.pth")

    moe_agent = MixtureOfExpertsDQN([expert1, expert2, expert3])

    # Train the Mixture of Experts with alternating updates
    train_mixture_of_experts(env, moe_agent, num_episodes=4000, max_steps_per_episode=1000000,batch_size=32)

    # Test the Mixture of Experts
    test_mixture_of_experts(env, moe_agent, num_episodes=100, max_steps_per_episode=5000)

    # Save the final model

    # Close wandb run
    wandb.finish()


cuda
cuda
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33madamyvashisth[0m ([33madamyvashisth-indian-institute-of-technology-roorkee[0m). Use [1m`wandb login --relogin`[0m to force relogin


Episode 0, Total Reward: -0.38, Loss: 0.0371, Crashes: 4, Crash Reward: -0.80, Time: 5.04s
Episode 1, Total Reward: 77.16, Loss: 0.1430, Crashes: 25, Crash Reward: -5.00, Time: 51.23s
Episode 2, Total Reward: 2.58, Loss: 0.0099, Crashes: 6, Crash Reward: -1.20, Time: 4.84s
Episode 3, Total Reward: 19.92, Loss: 0.0530, Crashes: 10, Crash Reward: -2.00, Time: 20.97s
Episode 4, Total Reward: 578.00, Loss: 0.1097, Crashes: 33, Crash Reward: -6.60, Time: 44.33s
Episode 5, Total Reward: 8.14, Loss: 46.8231, Crashes: 12, Crash Reward: -2.40, Time: 6.54s
Episode 6, Total Reward: 0.62, Loss: 0.0438, Crashes: 5, Crash Reward: -1.00, Time: 4.57s
Episode 7, Total Reward: -0.80, Loss: 0.0296, Crashes: 4, Crash Reward: -0.80, Time: 4.43s
Episode 8, Total Reward: -1.00, Loss: 0.0205, Crashes: 5, Crash Reward: -1.00, Time: 4.32s
Episode 9, Total Reward: 77.60, Loss: 109.3447, Crashes: 35, Crash Reward: -7.00, Time: 44.51s
Episode 10, Total Reward: 1.57, Loss: 15.6411, Crashes: 7, Crash Reward: -1.40, 

Exception in thread Thread-462053 (__send_act_get_obs_and_wait):
Traceback (most recent call last):
  File "d:\mini_conda\envs\collabkart\Lib\threading.py", line 1038, in _bootstrap_inner
    self.run()
  File "d:\mini_conda\envs\collabkart\Lib\site-packages\ipykernel\ipkernel.py", line 761, in run_closure
    _threading_Thread_run(self)
  File "d:\mini_conda\envs\collabkart\Lib\threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "d:\mini_conda\envs\collabkart\Lib\site-packages\rtgym\envs\real_time_env.py", line 438, in __send_act_get_obs_and_wait
    self.__update_obs_rew_terminated_truncated()  # capture observation
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "d:\mini_conda\envs\collabkart\Lib\site-packages\rtgym\envs\real_time_env.py", line 452, in __update_obs_rew_terminated_truncated
    o, r, d, i = self.interface.get_obs_rew_terminated_info()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "d:\mini_conda\envs\coll

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
from tmrl import get_environment
import os
import time
import wandb
from collections import deque, defaultdict
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# SimpleDQN and SimpleDQNAgent classes remain the same as in the provided code

class Router(nn.Module):
    def __init__(self, num_experts):
        super(Router, self).__init__()
        self.conv1 = nn.Conv2d(4, 16, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=2)
        self.fc = nn.Linear(32 * 7 * 7, num_experts)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        return F.softmax(self.fc(x), dim=1)

class MixtureOfExpertsDQN:
    def __init__(self, expert_agents, learning_rate=0.001, gamma=0.99):
        self.expert_agents = expert_agents
        self.num_experts = len(expert_agents)
        self.router = Router(self.num_experts).to(device)
        self.optimizer = optim.Adam(self.router.parameters(), lr=learning_rate)
        self.gamma = gamma

    def select_action(self, state):
        with torch.no_grad():
            image = torch.FloatTensor(state[3]).unsqueeze(0).to(device)
            additional_inputs = torch.FloatTensor(np.concatenate([state[0], state[1], state[2], state[4], state[5]])).unsqueeze(0).to(device)
            expert_probs = self.router(image)
            
            expert_idx = expert_probs.argmax().item()
            selected_expert = self.expert_agents[expert_idx]
            
            return selected_expert.select_action(state), expert_idx

    def train_step(self, experiences, update_router=True):
        if update_router:
            self.train_router(experiences)
        else:
            self.train_experts(experiences)

    def train_router(self, experiences):
        states, actions, rewards, next_states, dones = zip(*experiences)
        
        states_image = torch.FloatTensor(np.array([state[3] for state in states])).to(device)
        states_additional = torch.FloatTensor(np.array([np.concatenate([state[0], state[1], state[2], state[4], state[5]]) for state in states])).to(device)
        next_states_image = torch.FloatTensor(np.array([state[3] for state in next_states])).to(device)
        next_states_additional = torch.FloatTensor(np.array([np.concatenate([state[0], state[1], state[2], state[4], state[5]]) for state in next_states])).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        dones = torch.FloatTensor(dones).to(device)

        expert_probs = self.router(states_image)
        
        current_q_values = torch.zeros(len(experiences), self.expert_agents[0].n_actions).to(device)
        next_q_values = torch.zeros(len(experiences), self.expert_agents[0].n_actions).to(device)

        for i, expert in enumerate(self.expert_agents):
            current_q_values += expert_probs[:, i].unsqueeze(1) * expert.policy_net(states_image, states_additional)
            next_q_values += expert_probs[:, i].unsqueeze(1) * expert.target_net(next_states_image, next_states_additional)

        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values.max(1)[0]
        loss = F.smooth_l1_loss(current_q_values.gather(1, actions.unsqueeze(1)), target_q_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train_experts(self, experiences_dict):
        losses = []
        for expert_idx, experiences in experiences_dict.items():
            if experiences:
                expert = self.expert_agents[expert_idx]
                for experience in experiences:
                    expert.remember(*experience)
                loss = expert.replay()
                if loss is not None:
                    losses.append(loss)
        return np.mean(losses) if losses else None

    def update_target_networks(self):
        for expert in self.expert_agents:
            expert.update_target_network()

    def save_model(self, path):
        if not os.path.exists(path):
            os.makedirs(path)
        torch.save({
            'router_state_dict': self.router.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, os.path.join(path, 'moe_model.pth'))
        for i, expert in enumerate(self.expert_agents):
            expert.save_checkpoint(episode=0, directory=os.path.join(path, f'expert_{i}'))

    def load_model(self, path):
        checkpoint = torch.load(os.path.join(path, 'moe_model.pth'))
        self.router.load_state_dict(checkpoint['router_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

def train_mixture_of_experts(env, moe_agent, num_episodes, max_steps_per_episode):
    total_crashes = 0
    total_time = 0
    total_loss = 0
    expert_usage = [0] * moe_agent.num_experts

    for episode in range(num_episodes):
        obs, _ = env.reset()
        state = preprocess_observation(obs)
        total_reward = 0
        episode_crashes = 0
        episode_loss = 0
        episode_start_time = time.time()
        prev_obs = None
        episode_expert_usage = [0] * moe_agent.num_experts
        
        update_router = episode % 2 == 0  # Alternate between updating router and experts
        experiences = []
        experiences_dict = defaultdict(list)

        for step in range(max_steps_per_episode):
            agent_action, expert_idx = moe_agent.select_action(state)
            expert_usage[expert_idx] += 1
            episode_expert_usage[expert_idx] += 1
            env_action = agent_action_to_env_action(agent_action.item(), moe_agent.expert_agents[0].n_actions)
            next_obs, env_reward, terminated, truncated, _ = env.step(env_action)
            next_state = preprocess_observation(next_obs)
            done = terminated or truncated

            if detect_crash(next_obs, prev_obs):
                episode_crashes += 1
                env_reward -= 2  # Penalty for crashing

            experience = (state, agent_action.item(), env_reward, next_state, done)
            experiences.append(experience)
            experiences_dict[expert_idx].append(experience)

            state = next_state
            prev_obs = next_obs
            total_reward += env_reward

            if done:
                break

        if update_router:
            loss = moe_agent.train_step(experiences, update_router=True)
        else:
            loss = moe_agent.train_step(experiences_dict, update_router=False)

        if loss is not None:
            episode_loss = loss
            total_loss += episode_loss

        episode_time = time.time() - episode_start_time
        total_time += episode_time
        total_crashes += episode_crashes

        if episode % 10 == 0:
            moe_agent.update_target_networks()

        # Calculate expert usage ratio
        total_expert_usage = sum(episode_expert_usage)
        expert_usage_ratio = [usage / total_expert_usage for usage in episode_expert_usage]

        # Log metrics to wandb
        wandb.log({
            "episode": episode,
            "total_reward": total_reward,
            "crashes": episode_crashes,
            "time": episode_time,
            "loss": episode_loss,
            "epsilon": moe_agent.expert_agents[0].epsilon,  # Assuming all experts have the same epsilon
            "average_reward": total_reward / (step + 1),
            "total_crashes": total_crashes,
            "total_time": total_time,
            "average_loss": total_loss / (episode + 1),
            "episode_duration": episode_time,
            "experts_used": sum(1 for usage in episode_expert_usage if usage > 0),
            "update_router": update_router,
        })

        for i, usage in enumerate(episode_expert_usage):
            wandb.log({
                f"expert_{i}_usage": usage,
                f"expert_{i}_usage_ratio": expert_usage_ratio[i]
            })

        print(f"Episode {episode}, Total Reward: {total_reward:.2f}, Crashes: {episode_crashes}, Time: {episode_time:.2f}s, Loss: {episode_loss:.4f}, Update Router: {update_router}")

        # Save model every 100 episodes
        if episode % 100 == 0:
            moe_agent.save_model(f"checkpoints_moe_dynamic/episode_{episode}")

def test_mixture_of_experts(env, moe_agent, num_episodes, max_steps_per_episode):
    total_rewards = []
    total_crashes = 0
    total_time = 0
    expert_usage = [0] * moe_agent.num_experts
    
    for episode in range(num_episodes):
        obs, _ = env.reset()
        state = preprocess_observation(obs)
        total_reward = 0
        episode_crashes = 0
        episode_start_time = time.time()
        prev_obs = None
        episode_expert_usage = [0] * moe_agent.num_experts
        
        for step in range(max_steps_per_episode):
            agent_action, expert_idx = moe_agent.select_action(state)
            expert_usage[expert_idx] += 1
            episode_expert_usage[expert_idx] += 1
            env_action = agent_action_to_env_action(agent_action.item(), moe_agent.expert_agents[0].n_actions)
            next_obs, reward, terminated, truncated, _ = env.step(env_action)
            next_state = preprocess_observation(next_obs)
            done = terminated or truncated

            if detect_crash(next_obs, prev_obs):
                episode_crashes += 1

            state = next_state
            prev_obs = next_obs
            total_reward += reward

            if done:
                break

        episode_time = time.time() - episode_start_time
        total_time += episode_time
        total_crashes += episode_crashes
        total_rewards.append(total_reward)

        # Calculate expert usage ratio
        total_expert_usage = sum(episode_expert_usage)
        expert_usage_ratio = [usage / total_expert_usage for usage in episode_expert_usage]

        # Log test metrics to wandb
        wandb.log({
            "test_episode": episode,
            "test_total_reward": total_reward,
            "test_crashes": episode_crashes,
            "test_time": episode_time,
            "test_average_reward": total_reward / (step + 1),
            "test_total_crashes": total_crashes,
            "test_total_time": total_time,
            "test_episode_duration": episode_time,
            "test_experts_used": sum(1 for usage in episode_expert_usage if usage > 0),
        })

        for i, usage in enumerate(episode_expert_usage):
            wandb.log({
                f"test_expert_{i}_usage": usage,
                f"test_expert_{i}_usage_ratio": expert_usage_ratio[i]
            })

        print(f"Test Episode {episode}, Total Reward: {total_reward:.2f}, Crashes: {episode_crashes}, Time: {episode_time:.2f}s")

    avg_reward = sum(total_rewards) / len(total_rewards)
    print(f"Average Reward over {num_episodes} episodes: {avg_reward:.2f}")
    wandb.log({"test_average_reward_overall": avg_reward})

# The test_mixture_of_experts function remains the same

import warnings
warnings.filterwarnings("ignore")

if __name__ == "__main__":
    env = get_environment()
    n_actions = 3

    # Initialize wandb
    wandb.init(project="mixture-of-experts-dqn-alternating")

    # Load your three pre-trained expert agents
    expert1 = SimpleDQNAgent(n_actions)
    expert1.load_checkpoint("checkpoints_moe/episode_200/expert_0/checkpoint_episode_0_step_40000.pth")

    expert2 = SimpleDQNAgent(n_actions)
    expert2.load_checkpoint("checkpoints_moe/episode_200/expert_1/checkpoint_episode_0_step_46000.pth")

    expert3 = SimpleDQNAgent(n_actions)
    expert3.load_checkpoint("checkpoints_moe/episode_200/expert_2/checkpoint_episode_0_step_91000.pth")

    moe_agent = MixtureOfExpertsDQN([expert1, expert2, expert3])

    # Train the Mixture of Experts with alternating updates
    train_mixture_of_experts(env, moe_agent, num_episodes=4000, max_steps_per_episode=1000000)

    # Test the Mixture of Experts
    test_mixture_of_experts(env, moe_agent, num_episodes=100, max_steps_per_episode=5000)

    # Save the final model
    moe_agent.save_model("checkpoints_moe_alternating/final_model")

    # Close wandb run
    wandb.finish()

cuda
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33madamyvashisth[0m ([33madamyvashisth-indian-institute-of-technology-roorkee[0m). Use [1m`wandb login --relogin`[0m to force relogin


NameError: name 'SimpleDQNAgent' is not defined

In [3]:
import warnings
warnings.filterwarnings("ignore")
if __name__ == "__main__":
    env = get_environment()
    n_actions = 3

    # Initialize wandb
    wandb.init(project="mixture-of-experts-dqn-trial-2")

    # Load your three pre-trained expert agents
    expert1 = SimpleDQNAgent(n_actions)
    expert1.load_checkpoint("checkpoints_moe/episode_200/expert_0/checkpoint_episode_0_step_40000.pth")

    expert2 = SimpleDQNAgent(n_actions)
    expert2.load_checkpoint("checkpoints_moe/episode_200/expert_1/checkpoint_episode_0_step_46000.pth")

    expert3 = SimpleDQNAgent(n_actions)
    expert3.load_checkpoint("checkpoints_moe/episode_200/expert_2/checkpoint_episode_0_step_91000.pth")


    moe_agent=MixtureOfExpertsDQN([expert1,expert2,expert3])
    moe_agent.load_model(path="checkpoints_moe/final_model")

    # Train the Mixture of Experts
    train_mixture_of_experts(env, moe_agent, num_episodes=4000, max_steps_per_episode=1000000)

    # Test the Mixture of Experts
    test_mixture_of_experts(env, moe_agent, num_episodes=100, max_steps_per_episode=5000)

    # Save the final model
    moe_agent.save_model("checkpoints_moe/final_model")

    # Close wandb run
    wandb.finish()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33madamyvashisth[0m ([33madamyvashisth-indian-institute-of-technology-roorkee[0m). Use [1m`wandb login --relogin`[0m to force relogin


Episode 0, Total Reward: -2.69, Crashes: 13, Time: 14.04s, Loss: 151.1328
Checkpoint saved: checkpoints_moe_1/episode_0\expert_0\checkpoint_episode_0_step_40000.pth
Checkpoint saved: checkpoints_moe_1/episode_0\expert_1\checkpoint_episode_0_step_46000.pth
Checkpoint saved: checkpoints_moe_1/episode_0\expert_2\checkpoint_episode_0_step_91000.pth
Episode 1, Total Reward: -3.71, Crashes: 7, Time: 10.07s, Loss: 70.5242
Episode 2, Total Reward: -1.58, Crashes: 4, Time: 10.05s, Loss: 92.2249
Episode 3, Total Reward: -0.53, Crashes: 2, Time: 9.95s, Loss: 53.6625
Episode 4, Total Reward: 2.38, Crashes: 3, Time: 9.77s, Loss: 60.3140
Episode 5, Total Reward: 2.54, Crashes: 2, Time: 10.17s, Loss: 53.7422
Episode 6, Total Reward: -4.49, Crashes: 6, Time: 10.01s, Loss: 64.1602


KeyboardInterrupt: 

In [None]:
wandb.init(project="mixture-of-experts-dqn-trial-1")
env = get_environment()
n_actions=3
expert1 = SimpleDQNAgent(n_actions)
expert1.load_checkpoint("checkpoints_moe/episode_200/expert_0/checkpoint_episode_0_step_40000.pth")

expert2 = SimpleDQNAgent(n_actions)
expert2.load_checkpoint("checkpoints_moe/episode_200/expert_1/checkpoint_episode_0_step_46000.pth")

expert3 = SimpleDQNAgent(n_actions)
expert3.load_checkpoint("checkpoints_moe/episode_200/expert_2/checkpoint_episode_0_step_91000.pth")


moe_agent=MixtureOfExpertsDQN([expert1,expert2,expert3])
moe_agent.load_model(path="checkpoints_moe/final_model")
test_mixture_of_experts(env, moe_agent, num_episodes=1000, max_steps_per_episode=1000000)