In [1]:
import gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from tqdm import tqdm
import imageio

In [33]:
env = gym.make("ALE/BattleZone-v5", render_mode="rgb_array")

print("Action Space:", env.action_space)
print("Observation Space:", env.observation_space)

class DQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(self._calculate_conv_output(input_shape), 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )
    
    def _calculate_conv_output(self, shape):
        with torch.no_grad():
            o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
    
    def forward(self, x):
        x = self.conv(x)
        x = x.reshape(x.size(0), -1)
        return self.fc(x)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), action, reward, np.array(next_state), done
    
    def __len__(self):
        return len(self.buffer)

class DQNAgent:
    def __init__(self, input_shape, num_actions):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_actions = num_actions
        self.epsilon = 1.0
        self.epsilon_min = 0.1
        self.epsilon_decay = 0.99995
        self.gamma = 0.95
        self.learning_rate = 0.0001
        self.batch_size = 32
        self.memory_capacity = 30000
        
        self.policy_net = DQN(input_shape, num_actions).to(self.device)
        self.target_net = DQN(input_shape, num_actions).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        self.memory = ReplayBuffer(self.memory_capacity)
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
        self.loss_fn = nn.MSELoss()
        
    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.num_actions - 1)
        else:
            state = np.transpose(state, (2, 0, 1))  # Convert (H, W, C) -> (C, H, W)
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
            q_values = self.policy_net(state)
            return q_values.max(1)[1].item()

    def update(self):
        if len(self.memory) < self.batch_size:
            return

        states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
        states = np.transpose(np.array(states), (0, 3, 1, 2))
        next_states = np.transpose(np.array(next_states), (0, 3, 1, 2))

        states = torch.tensor(states, dtype=torch.float32).to(self.device)
        next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
        actions = torch.tensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.tensor(rewards).to(self.device)
        dones = torch.tensor(dones, dtype=torch.float32).to(self.device)

        q_values = self.policy_net(states).gather(1, actions)
        next_q_values = self.target_net(next_states).max(1)[0].detach()
        target_q_values = rewards + (self.gamma * next_q_values * (1 - dones))

        loss = self.loss_fn(q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        rewards_history.append(total_reward)

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    
    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

num_episodes = 47
max_steps_per_episode = 10000
input_shape = (3, 210, 160)
num_actions = env.action_space.n
rewards_history = []

agent = DQNAgent(input_shape, num_actions)

for episode in tqdm(range(num_episodes), desc="Training Progress"):
    state, _ = env.reset()
    total_reward = 0
    
    for step in range(max_steps_per_episode):
        action = agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        
        agent.memory.push(state, action, reward, next_state, done)
        agent.update()
        
        state = next_state
        total_reward += reward
        
        if done:
            break
    
    if episode % 5 == 0:
        agent.update_target_network()
    
    print(f"Episode {episode+1}/{num_episodes}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")

env.close()

torch.save(agent.policy_net.state_dict(), "dqn_policy_weights.pth")
torch.save(agent.target_net.state_dict(), "dqn_target_weights.pth")

with open("rewards_history.json", "w") as f:
    json.dump(rewards_history, f)

def record_video(env, agent, out_path, fps=30):
    frames = []
    state, _ = env.reset()
    done = False
    while not done:
        frame = env.render()
        frames.append(frame)
        
        action = agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        state = next_state
    
    imageio.mimsave(out_path, frames, fps=fps)

record_video(env, agent, "/Users/romain/Rein/battlezone_play.mp4")


Action Space: Discrete(18)
Observation Space: Box(0, 255, (210, 160, 3), uint8)


Training Progress:   2%|▌                        | 1/47 [01:08<52:08, 68.01s/it]

Episode 1/47, Total Reward: 5000.0, Epsilon: 0.9326


Training Progress:   4%|█                        | 2/47 [01:55<41:53, 55.85s/it]

Episode 2/47, Total Reward: 2000.0, Epsilon: 0.8886


Training Progress:   6%|█▌                       | 3/47 [02:40<37:19, 50.91s/it]

Episode 3/47, Total Reward: 1000.0, Epsilon: 0.8484


Training Progress:   9%|██▏                      | 4/47 [03:54<43:02, 60.06s/it]

Episode 4/47, Total Reward: 0.0, Epsilon: 0.7870


Training Progress:  11%|██▋                      | 5/47 [05:02<44:07, 63.03s/it]

Episode 5/47, Total Reward: 3000.0, Epsilon: 0.7354


Training Progress:  13%|███▏                     | 6/47 [05:57<41:10, 60.25s/it]

Episode 6/47, Total Reward: 1000.0, Epsilon: 0.6968


Training Progress:  15%|███▋                     | 7/47 [06:57<40:09, 60.24s/it]

Episode 7/47, Total Reward: 2000.0, Epsilon: 0.6581


Training Progress:  17%|████▎                    | 8/47 [08:21<44:06, 67.86s/it]

Episode 8/47, Total Reward: 0.0, Epsilon: 0.6084


Training Progress:  19%|████▊                    | 9/47 [09:10<39:08, 61.80s/it]

Episode 9/47, Total Reward: 0.0, Epsilon: 0.5812


Training Progress:  21%|█████                   | 10/47 [09:54<34:45, 56.37s/it]

Episode 10/47, Total Reward: 1000.0, Epsilon: 0.5576


Training Progress:  23%|█████▌                  | 11/47 [10:45<32:52, 54.78s/it]

Episode 11/47, Total Reward: 1000.0, Epsilon: 0.5313


Training Progress:  26%|██████▏                 | 12/47 [11:45<32:47, 56.21s/it]

Episode 12/47, Total Reward: 1000.0, Epsilon: 0.5026


Training Progress:  28%|██████▋                 | 13/47 [13:16<37:46, 66.66s/it]

Episode 13/47, Total Reward: 2000.0, Epsilon: 0.4628


Training Progress:  30%|███████▏                | 14/47 [14:14<35:21, 64.30s/it]

Episode 14/47, Total Reward: 3000.0, Epsilon: 0.4387


Training Progress:  32%|███████▋                | 15/47 [14:55<30:32, 57.27s/it]

Episode 15/47, Total Reward: 2000.0, Epsilon: 0.4224


Training Progress:  34%|████████▏               | 16/47 [16:17<33:19, 64.49s/it]

Episode 16/47, Total Reward: 2000.0, Epsilon: 0.3926


Training Progress:  36%|████████▋               | 17/47 [17:43<35:27, 70.93s/it]

Episode 17/47, Total Reward: 4000.0, Epsilon: 0.3630


Training Progress:  38%|████████▊              | 18/47 [20:45<50:32, 104.56s/it]

Episode 18/47, Total Reward: 11000.0, Epsilon: 0.3080


Training Progress:  40%|█████████▋              | 19/47 [21:24<39:33, 84.78s/it]

Episode 19/47, Total Reward: 1000.0, Epsilon: 0.2975


Training Progress:  43%|██████████▏             | 20/47 [22:47<37:56, 84.30s/it]

Episode 20/47, Total Reward: 1000.0, Epsilon: 0.2762


Training Progress:  45%|██████████▋             | 21/47 [23:38<32:13, 74.36s/it]

Episode 21/47, Total Reward: 1000.0, Epsilon: 0.2640


Training Progress:  47%|███████████▏            | 22/47 [24:50<30:36, 73.47s/it]

Episode 22/47, Total Reward: 8000.0, Epsilon: 0.2474


Training Progress:  49%|███████████▋            | 23/47 [25:51<27:53, 69.74s/it]

Episode 23/47, Total Reward: 1000.0, Epsilon: 0.2342


Training Progress:  51%|████████████▎           | 24/47 [27:11<27:52, 72.72s/it]

Episode 24/47, Total Reward: 5000.0, Epsilon: 0.2180


Training Progress:  53%|████████████▊           | 25/47 [29:30<33:58, 92.67s/it]

Episode 25/47, Total Reward: 2000.0, Epsilon: 0.1906


Training Progress:  55%|█████████████▎          | 26/47 [31:03<32:26, 92.71s/it]

Episode 26/47, Total Reward: 3000.0, Epsilon: 0.1751


Training Progress:  57%|█████████████▊          | 27/47 [32:13<28:41, 86.06s/it]

Episode 27/47, Total Reward: 1000.0, Epsilon: 0.1642


Training Progress:  60%|██████████████▎         | 28/47 [33:25<25:53, 81.75s/it]

Episode 28/47, Total Reward: 14000.0, Epsilon: 0.1538


Training Progress:  62%|██████████████▊         | 29/47 [35:38<29:09, 97.17s/it]

Episode 29/47, Total Reward: 6000.0, Epsilon: 0.1364


Training Progress:  64%|███████████████▎        | 30/47 [36:41<24:39, 87.03s/it]

Episode 30/47, Total Reward: 1000.0, Epsilon: 0.1287


Training Progress:  66%|███████████████▊        | 31/47 [37:30<20:06, 75.40s/it]

Episode 31/47, Total Reward: 2000.0, Epsilon: 0.1231


Training Progress:  68%|████████████████▎       | 32/47 [38:28<17:35, 70.39s/it]

Episode 32/47, Total Reward: 1000.0, Epsilon: 0.1166


Training Progress:  70%|████████████████▊       | 33/47 [39:36<16:13, 69.56s/it]

Episode 33/47, Total Reward: 5000.0, Epsilon: 0.1096


Training Progress:  72%|█████████████████▎      | 34/47 [41:31<18:02, 83.30s/it]

Episode 34/47, Total Reward: 0.0, Epsilon: 0.1000


Training Progress:  74%|█████████████████▊      | 35/47 [42:24<14:48, 74.01s/it]

Episode 35/47, Total Reward: 2000.0, Epsilon: 0.1000


Training Progress:  77%|██████████████████▍     | 36/47 [43:26<12:54, 70.38s/it]

Episode 36/47, Total Reward: 2000.0, Epsilon: 0.1000


Training Progress:  79%|██████████████████▉     | 37/47 [44:31<11:29, 68.92s/it]

Episode 37/47, Total Reward: 1000.0, Epsilon: 0.1000


Training Progress:  81%|███████████████████▍    | 38/47 [45:17<09:18, 62.04s/it]

Episode 38/47, Total Reward: 1000.0, Epsilon: 0.1000


Training Progress:  83%|███████████████████▉    | 39/47 [46:32<08:47, 65.91s/it]

Episode 39/47, Total Reward: 0.0, Epsilon: 0.1000


Training Progress:  85%|████████████████████▍   | 40/47 [47:38<07:42, 66.04s/it]

Episode 40/47, Total Reward: 1000.0, Epsilon: 0.1000


Training Progress:  87%|████████████████████▉   | 41/47 [49:08<07:19, 73.21s/it]

Episode 41/47, Total Reward: 2000.0, Epsilon: 0.1000


Training Progress:  89%|█████████████████████▍  | 42/47 [50:15<05:56, 71.40s/it]

Episode 42/47, Total Reward: 0.0, Epsilon: 0.1000


Training Progress:  91%|█████████████████████▉  | 43/47 [50:59<04:12, 63.17s/it]

Episode 43/47, Total Reward: 2000.0, Epsilon: 0.1000


Training Progress:  94%|██████████████████████▍ | 44/47 [51:57<03:04, 61.45s/it]

Episode 44/47, Total Reward: 2000.0, Epsilon: 0.1000


Training Progress:  96%|██████████████████████▉ | 45/47 [53:39<02:27, 73.82s/it]

Episode 45/47, Total Reward: 2000.0, Epsilon: 0.1000


Training Progress:  98%|███████████████████████▍| 46/47 [55:41<01:28, 88.17s/it]

Episode 46/47, Total Reward: 5000.0, Epsilon: 0.1000


Training Progress: 100%|████████████████████████| 47/47 [56:25<00:00, 72.03s/it]

Episode 47/47, Total Reward: 2000.0, Epsilon: 0.1000





In [35]:
with open("rewards_history.json", "w") as f:
    json.dump(rewards_history, f)
torch.save(agent.policy_net.state_dict(), "dqn_policy_weights.pth")
torch.save(agent.target_net.state_dict(), "dqn_target_weights.pth")

In [41]:
policy_weights = torch.load("dqn_policy_weights.pth")

for name, weight in policy_weights.items():
    print(f"{name}: {weight.shape}")
    print(weight)
    break

conv.0.weight: torch.Size([32, 3, 8, 8])
tensor([[[[ 0.0166, -0.0155,  0.0329,  ..., -0.0142, -0.0528,  0.0652],
          [-0.0304, -0.0211,  0.0529,  ..., -0.0389,  0.0263,  0.0414],
          [-0.0318, -0.0312, -0.0301,  ..., -0.1082, -0.1115, -0.0698],
          ...,
          [-0.0299,  0.0591,  0.0680,  ..., -0.0145, -0.0033,  0.0792],
          [-0.0555, -0.0620,  0.0859,  ..., -0.0184, -0.1053,  0.0798],
          [-0.0714, -0.1015,  0.0053,  ..., -0.0668, -0.0589, -0.0394]],

         [[-0.0657,  0.0215,  0.0217,  ...,  0.0607,  0.0148, -0.0263],
          [ 0.0151,  0.0118, -0.0085,  ..., -0.0392, -0.0620,  0.0486],
          [ 0.0149,  0.0315,  0.1018,  ...,  0.0433, -0.0978, -0.0556],
          ...,
          [ 0.0769,  0.0185,  0.0964,  ...,  0.0683, -0.0710,  0.0325],
          [ 0.0281,  0.0636,  0.0039,  ...,  0.0331, -0.0523,  0.0727],
          [-0.0264, -0.0233,  0.0522,  ...,  0.0415,  0.0164, -0.0671]],

         [[-0.0321,  0.0774,  0.0380,  ..., -0.0131,  0.0013,

In [49]:
import torch

# Задаём параметры модели
input_shape = (3, 210, 160)  # Размер входных данных (например, для RGB-изображения)
num_actions = 18  # Количество возможных действий (зависит от вашей среды)

# Загружаем модель (предположим, что DQN определена в вашем коде)
model = DQN(input_shape, num_actions)

# Загружаем веса в модель
model.load_state_dict(torch.load('dqn_policy_weights.pth'))
model.eval()  # Устанавливаем модель в режим оценки (выключает dropout и batch normalization)
with torch.no_grad():  # Отключаем градиенты, так как они нам не нужны для предсказаний
    q_values = model(random_input_tensor)  # Получаем Q-значения для всех возможных действий

# Выводим результат
print("Q-values for the input:", q_values)

# Получаем индекс действия с максимальным Q-значением
best_action = q_values.argmax().item()
print(f"Best action based on the Q-values: {best_action}")


Q-values for the input: tensor([[0.2117, 0.2958, 0.1526, 0.1713, 0.3313, 0.2374, 0.2093, 0.2602, 0.1805,
         0.2498, 0.1894, 0.2541, 0.2192, 0.3707, 0.4587, 0.3392, 0.2776, 0.4882]])
Best action based on the Q-values: 17
