In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=self.capacity)

    def store(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return np.array(states), actions, rewards, np.array(next_states), dones

    def __len__(self):
        return len(self.memory)

class CategoricalDQN(nn.Module):
    def __init__(self, action_dim, atoms_num, v_min, v_max):
        super(CategoricalDQN, self).__init__()
        self.action_dim = action_dim
        self.atoms_num = atoms_num
        self.v_min = v_min
        self.v_max = v_max

        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, action_dim * atoms_num)

    def forward(self, state):
        x = F.relu(self.conv1(state))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.softmax(x.view(-1, self.atoms_num), dim=1).view(-1, self.action_dim, self.atoms_num)
        return x

    def act(self, state, epsilon):
        if random.random() > epsilon:
            state = torch.FloatTensor(state).unsqueeze(0).permute(0, 3, 1, 2)  # Convert to (batch_size, channels, height, width)
            dist = self.forward(state)
            dist = dist.mul(torch.linspace(self.v_min, self.v_max, self.atoms_num))
            action = dist.sum(2).max(1)[1].item()
        else:
            action = random.choice(range(self.action_dim))
        return action

def projection_distribution(target_model, next_states, rewards, dones, v_min, v_max, atoms_num, gamma):
    batch_size = next_states.size(0)
    delta_z = float(v_max - v_min) / (atoms_num - 1)
    support = torch.linspace(v_min, v_max, atoms_num)
    next_dist = target_model(next_states).detach()
    next_action = next_dist.mul(support).sum(2).max(1)[1]
    next_dist = next_dist[range(batch_size), next_action]

    Tz = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * support.unsqueeze(0)
    Tz = Tz.clamp(min=v_min, max=v_max)
    b = (Tz - v_min) / delta_z
    l = b.floor().long()
    u = b.ceil().long()

    offset = torch.linspace(0, (batch_size - 1) * atoms_num, batch_size).long().unsqueeze(1).expand(batch_size, atoms_num)

    proj_dist = torch.zeros(next_dist.size())
    proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1))
    proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1))
    return proj_dist

def train(eval_model, target_model, buffer, v_min, v_max, atoms_num, gamma, batch_size, optimizer, count, update_freq):
    if len(buffer) < batch_size:
        return 0.0

    states, actions, rewards, next_states, dones = buffer.sample(batch_size)

    states = torch.FloatTensor(states).permute(0, 3, 1, 2)  # Convert to (batch_size, channels, height, width)
    actions = torch.LongTensor(actions)
    rewards = torch.FloatTensor(rewards)
    next_states = torch.FloatTensor(next_states).permute(0, 3, 1, 2)  # Convert to (batch_size, channels, height, width)
    dones = torch.FloatTensor(dones)

    proj_dist = projection_distribution(target_model, next_states, rewards, dones, v_min, v_max, atoms_num, gamma)

    dist = eval_model(states)
    actions = actions.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, atoms_num)
    dist = dist.gather(1, actions).squeeze(1)
    dist = dist.clamp(1e-3, 1 - 1e-3)
    loss = - (proj_dist * dist.log()).sum(1).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if count % update_freq == 0:
        target_model.load_state_dict(eval_model.state_dict())

    return loss.item()

def main():
    env = gym.make("ALE/Pong-v5")
    action_dim = env.action_space.n

    v_min = -10
    v_max = 10
    atoms_num = 51
    gamma = 0.99
    learning_rate = 1e-3
    batch_size = 64
    capacity = 10000
    epsilon_init = 1.0
    epsilon_decay = 0.995
    epsilon_min = 0.01
    update_freq = 1000
    num_episodes = 1000

    eval_model = CategoricalDQN(action_dim, atoms_num, v_min, v_max)
    target_model = CategoricalDQN(action_dim, atoms_num, v_min, v_max)
    target_model.load_state_dict(eval_model.state_dict())
    optimizer = torch.optim.Adam(eval_model.parameters(), lr=learning_rate)
    buffer = ReplayBuffer(capacity)

    epsilon = epsilon_init
    episode_rewards = []

    for episode in range(num_episodes):
        state = env.reset()
        state = np.stack([state] * 4, axis=2)  # Stack 4 frames to create a state
        done = False
        episode_reward = 0
        count = 0

        while not done:
            action = eval_model.act(state, epsilon)
            next_state, reward, done, truncated, info = env.step(action)
            next_state = np.append(state[:, :, 1:], np.expand_dims(next_state, 2), axis=2)  # Append new frame to state
            buffer.store(state, action, reward, next_state, done)
            episode_reward += reward
            state = next_state

            if len(buffer) > batch_size:
                loss = train(eval_model, target_model, buffer, v_min, v_max, atoms_num, gamma, batch_size, optimizer, count, update_freq)
                count += 1

            if epsilon > epsilon_min:
                epsilon *= epsilon_decay

        episode_rewards.append(episode_reward)
        print(f"Episode {episode + 1}: Reward = {episode_reward}")

        if (episode + 1) % 10 == 0:
            avg_reward = np.mean(episode_rewards[-10:])
            print(f"Average reward over last 10 episodes: {avg_reward}")

    plt.plot(episode_rewards)
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('Episode Rewards')
    plt.show()

if __name__ == "__main__":
    main()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
import random
from collections import deque

class replay_buffer(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=self.capacity)

    def store(self, observation, action, reward, next_observation, done):
        observation = np.expand_dims(observation, 0)
        next_observation = np.expand_dims(next_observation, 0)

        self.memory.append([observation, action, reward, next_observation, done])

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        observation, action, reward, next_observation, done = zip(* batch)

        return np.concatenate(observation, 0), action, reward, np.concatenate(next_observation, 0), done

    def __len__(self):
        return len(self.memory)


class NoisyLinear(nn.Module):
    def __init__(self, input_dim, output_dim, std_init=0.4):
        super(NoisyLinear, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.std_init = std_init

        self.weight_mu = nn.Parameter(torch.FloatTensor(self.output_dim, self.input_dim))
        self.weight_sigma = nn.Parameter(torch.FloatTensor(self.output_dim, self.input_dim))

        self.bias_mu = nn.Parameter(torch.FloatTensor(self.output_dim))
        self.bias_sigma = nn.Parameter(torch.FloatTensor(self.output_dim))

        self.register_buffer('weight_epsilon', torch.FloatTensor(self.output_dim, self.input_dim))
        self.register_buffer('bias_epsilon', torch.FloatTensor(self.output_dim))

        self.reset_parameter()
        self.reset_noise()

    def forward(self, input):
        if self.training:
            weight = self.weight_mu + self.weight_sigma.mul(self.weight_epsilon)
            bias = self.bias_mu + self.bias_sigma.mul(self.bias_epsilon)
        else:
            weight = self.weight_mu
            bias = self.bias_mu
        return F.linear(input, weight, bias)

    def _scale_noise(self, size):
        noise = torch.randn(size)
        noise = noise.sign().mul(noise.abs().sqrt())
        return noise

    def reset_parameter(self):
        mu_range = 1.0 / np.sqrt(self.input_dim)

        self.weight_mu.detach().uniform_(-mu_range, mu_range)
        self.bias_mu.detach().uniform_(-mu_range, mu_range)

        self.weight_sigma.detach().fill_(self.std_init / np.sqrt(self.input_dim))
        self.bias_sigma.detach().fill_(self.std_init / np.sqrt(self.output_dim))

    def reset_noise(self):
        epsilon_in = self._scale_noise(self.input_dim)
        epsilon_out = self._scale_noise(self.output_dim)

        self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
        self.bias_epsilon.copy_(self._scale_noise(self.output_dim))


class categorical_dqn(nn.Module):
    def __init__(self, observation_dim, action_dim, atoms_num, v_min, v_max):
        super(categorical_dqn, self).__init__()
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.atoms_num = atoms_num
        self.v_min = v_min
        self.v_max = v_max

        self.fc1 = nn.Linear(self.observation_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.noisy1 = NoisyLinear(128, 512)
        self.noisy2 = NoisyLinear(512, self.action_dim * self.atoms_num)

    def forward(self, observation):
        x = self.fc1(observation)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.noisy1(x)
        x = F.relu(x)
        x = self.noisy2(x)
        x = F.softmax(x.view(-1, self.atoms_num), 1).view(-1, self.action_dim, self.atoms_num)
        return x

    def reset_noise(self):
        self.noisy1.reset_noise()
        self.noisy2.reset_noise()

    def act(self, observation, epsilon):
        if random.random() > epsilon:
            dist = self.forward(observation)
            dist = dist.detach()
            dist = dist.mul(torch.linspace(v_min, v_max, self.atoms_num))
            action = dist.sum(2).max(1)[1].detach()[0].item()
        else:
            action = random.choice(list(range(self.action_dim)))
        return action


def projection_distribution(target_model, next_observation, reward, done, v_min, v_max, atoms_num, gamma):
    batch_size = next_observation.size(0)
    # next_observation: [batch_size, ...]
    delta_z = float(v_max - v_min) / (atoms_num - 1)
    support = torch.linspace(v_min, v_max, atoms_num)
    # support: [atoms_num]

    next_dist = target_model.forward(next_observation).detach().mul(support)
    # next_dist: [batch_size, action_dim, atoms_num]
    next_action = next_dist.sum(2).max(1)[1]
    # next_action: [batch_size]
    next_action = next_action.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, atoms_num)
    # next_action: [batch_size, 1, atoms_num]
    next_dist = next_dist.gather(1, next_action).squeeze(1)
    # next_dist: [batch_size, atoms_num]

    reward = reward.unsqueeze(1).expand_as(next_dist)
    done = done.unsqueeze(1).expand_as(next_dist)
    support = support.unsqueeze(0).expand_as(next_dist)

    Tz = reward + (1 - done) * support * gamma
    Tz = Tz.clamp(min=v_min, max=v_max)
    b = (Tz - v_min) / delta_z
    l = b.floor().long()
    u = b.ceil().long()

    offset = torch.linspace(0, (batch_size - 1) * atoms_num, batch_size).long().unsqueeze(1).expand_as(next_dist)

    proj_dist = torch.zeros_like(next_dist, dtype=torch.float32)
    proj_dist.view(-1).index_add_(0, (offset + l).view(-1), (next_dist * (u.float() - b)).view(-1))
    proj_dist.view(-1).index_add_(0, (offset + u).view(-1), (next_dist * (b - l.float())).view(-1))
    return proj_dist


def train(eval_model, target_model, buffer, v_min, v_max, atoms_num, gamma, batch_size, optimizer, count, update_freq):
    observation, action, reward, next_observation, done = buffer.sample(batch_size)

    observation = torch.FloatTensor(observation)
    action = torch.LongTensor(action)
    reward = torch.FloatTensor(reward)
    next_observation = torch.FloatTensor(next_observation)
    done = torch.FloatTensor(done)

    proj_dist = projection_distribution(target_model, next_observation, reward, done, v_min, v_max, atoms_num, gamma)

    dist = eval_model.forward(observation)
    action = action.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, atoms_num)
    dist = dist.gather(1, action).squeeze(1)
    dist.detach().clamp_(0.01, 0.99)
    # * make KL divergence as the loss
    loss = - (proj_dist * dist.log()).sum(1).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    eval_model.reset_noise()
    target_model.reset_noise()

    if count % update_freq == 0:
        target_model.load_state_dict(eval_model.state_dict())


if __name__ == '__main__':
    episode = 100000
    epsilon_init = 0.95
    epsilon_decay = 0.95
    epsilon_min = 0.01
    update_freq = 200
    gamma = 0.99
    learning_rate = 1e-3
    atoms_num = 51
    v_min = -10
    v_max = 10
    batch_size = 64
    capacity = 10000
    exploration = 100
    render = False

    env = gym.make('CartPole-v1')
    env = env.unwrapped
    action_dim = env.action_space.n
    observation_dim = env.observation_space.shape[0]
    count = 0
    target_net = categorical_dqn(observation_dim, action_dim, atoms_num, v_min, v_max)
    eval_net = categorical_dqn(observation_dim, action_dim, atoms_num, v_min, v_max)
    target_net.load_state_dict(eval_net.state_dict())
    optimizer = torch.optim.Adam(eval_net.parameters(), lr=learning_rate)
    buffer = replay_buffer(capacity)
    weight_reward = None
    epsilon = epsilon_init

    for i in range(episode):
        obs = env.reset()
        reward_total = 0
        if render:
            env.render()
        while True:
            action = eval_net.act(torch.FloatTensor(np.expand_dims(obs, 0)), epsilon)
            next_obs, reward, done, info = env.step(action)
            count += 1
            if render:
                env.render()
            buffer.store(obs, action, reward, next_obs, done)
            reward_total += reward
            obs = next_obs
            if i > exploration:
                train(eval_net, target_net, buffer, v_min, v_max, atoms_num, gamma, batch_size, optimizer, count, update_freq)
            if done:
                if epsilon > epsilon_min:
                    epsilon = epsilon * epsilon_decay
                if not weight_reward:
                    weight_reward = reward_total
                else:
                    weight_reward = 0.99 * weight_reward + 0.01 * reward_total
                print('episode: {}  reward: {}  weight_reward: {:.3f}  epsilon: {:.2f}'.format(i+1, reward_total, weight_reward, epsilon))
                break

  logger.deprecation(


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.