In [24]:
import gymnasium as gym
import cv2
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import kl, Normal, Independent, OneHotCategorical
from torch.nn.init import kaiming_uniform_, constant_

print("gymnasium:", gym.__version__)
print("cv2:", cv2.__version__)
print("numpy:", np.__version__)
print("torch:", torch.__version__)

gymnasium: 1.2.2
cv2: 4.12.0
numpy: 2.0.2
torch: 2.4.1+cu121


In [25]:
class ResizeObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env, width, height):
        super().__init__(env)
        self.env = env
        self.width = width
        self.height = height
        self.observation_space = gym.spaces.Box(
            low=0, high=255,
            shape=(self.height, self.width, 3),
            dtype=np.uint8
        )

    def observation(self, observation):
        observation = self.env.render()
        observation = cv2.resize(observation, (self.width, self.height), interpolation=cv2.INTER_AREA)
        return observation

class ChannelFirstEnv(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_space = self.observation_space
        obs_shape = obs_space.shape[-1:] + obs_space.shape[:2]
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=obs_shape, dtype=np.uint8
        )

    # permute [H, W, C] array to [C, H, W] tensor
    def observation(self, observation):
        observation = np.transpose(observation, (2, 0, 1))
        return observation

class FrameSkipWrapper(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        for i in range(self._skip):
            next_state, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            if terminated or truncated:
                break
        return next_state, total_reward, terminated, truncated, info

class PixelNormalizationWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        next_state, reward, terminated, truncated, info = self.env.step(action)
        return self._pixel_normalization(next_state), reward, terminated, truncated, info

    def reset(self):
        state, info = self.env.reset()
        return self._pixel_normalization(state), info

    def _pixel_normalization(self, state):
        return state / 255.0 - 0.5

In [26]:
class Dreamer:
    def __init__(self, model_path=""):
        # model
        self.encoder = Encoder().to(device)
        self.decoder = Decoder().to(device)
        self.rssm = RSSM().to(device)
        self.reward_predictor = RewardPredictor().to(device)
        self.actor = Actor().to(device)
        self.critic = Critic().to(device)

        self.memory = ReplayBuffer()

        # optimizer
        self.world_model_params = (
              list(self.encoder.parameters())
            + list(self.decoder.parameters())
            + list(self.rssm.parameters())
            + list(self.reward_predictor.parameters())
        )

        self.world_model_optimizer = optim.Adam(self.world_model_params, lr=world_model_lr)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

        self.reconstruction_loss = 0
        self.reward_loss = 0
        self.kl_loss = 0
        self.actor_loss = 0
        self.critic_loss = 0
        self.score = 0

        if model_path:
            print(f"Loading checkpoint from: {model_path}")
            self.load_checkpoint(model_path)

    def load_checkpoint(self, path):
        checkpoint = torch.load(path, map_location=device)

        self.encoder.load_state_dict(checkpoint['encoder'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.rssm.load_state_dict(checkpoint['rssm'])
        self.reward_predictor.load_state_dict(checkpoint['reward_predictor'])
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])

        self.world_model_optimizer.load_state_dict(checkpoint['world_model_optimizer'])
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
        self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer'])
        print("Checkpoint loaded successfully.")

    def train(self, env):
        print("Collecting experiences...")
        self.collect_episodes(env, 20) # 초기 데이터 수집
        print(f"Collected experiences : {len(self.memory)}")

        for episode in range(1, episodes + 1):
            for _ in range(world_model_update_steps):
                experiences = self.memory.sample(batch_size, batch_length)
                posteriors, deterministics = self.dynamic_learning(experiences)
                self.behavior_learning(posteriors, deterministics)

            self.collect_episodes(env, collect_episodes)
            self.print_losses(episode) # 학습 손실 출력

            if episode % 10 == 0:
                torch.save({
                    'encoder': self.encoder.state_dict(),
                    'decoder': self.decoder.state_dict(),
                    'rssm': self.rssm.state_dict(),
                    'reward_predictor': self.reward_predictor.state_dict(),
                    'actor': self.actor.state_dict(),
                    'critic': self.critic.state_dict(),
                    'world_model_optimizer': self.world_model_optimizer.state_dict(),
                    'actor_optimizer': self.actor_optimizer.state_dict(),
                    'critic_optimizer': self.critic_optimizer.state_dict(),
                }, f'model_ep{episode}.pth')

    def print_losses(self, episode):
        print(f"episode: {episode:4} | "
              f"score: {int(self.score):3} | "
              f"Recon: {self.reconstruction_loss.item():8.2f} | "
              f"Reward: {self.reward_loss.item():6.4f} | "
              f"KL: {self.kl_loss.item():6.4f} | "
              f"Actor: {self.actor_loss.item():6.4f} | "
              f"Critic: {self.critic_loss.item():6.4f}")

    @torch.no_grad()
    def collect_episodes(self, env, episodes):
        for episode in range(episodes):
            posterior, deterministic = self.rssm.recurrent_model_input_init(1) # (1, 30), (1, 200)
            action = torch.zeros(1, action_size).to(device)                    # (1, 2)

            state, _ = env.reset() # (3, 64, 64)
            embedded_state = self.encoder(torch.from_numpy(state).unsqueeze(0).float().to(device)) # (1, 1024)

            done = False
            episode_reward = 0

            while not done:
                deterministic = self.rssm.recurrent_model(posterior, action, deterministic)  # (1, 30), (1, 2), (1, 200) -> (1, 200)
                _, posterior = self.rssm.representation_model(embedded_state, deterministic) # (1, 1024), (1, 200) -> (1, 30)
                action = self.actor(posterior, deterministic).detach()                       # (1, 200), (1, 30) -> (1, 2)

                buffer_action = action.cpu().numpy()
                env_action = buffer_action.argmax()

                next_state, reward, terminated, truncated, info = env.step(env_action)
                done = terminated or truncated

                self.memory.add(state, buffer_action, reward, next_state, done)

                embedded_state = self.encoder(torch.from_numpy(next_state).unsqueeze(0).float().to(device))
                state = next_state
                episode_reward += reward

            self.score = episode_reward

    def dynamic_learning(self, experiences):
        """
        [experiences]
        states:      (50, 50, 3, 64, 64)
        next_states: (50, 50, 3, 64, 64)
        actions:     (50, 50, 2)
        rewards:     (50, 50, 1)
        dones:       (50, 50, 1)
        """
        prior_means = []
        prior_stds = []
        posteriors = []
        posterior_means = []
        posterior_stds = []
        deterministics = []

        states, _, actions, _, _ = experiences
        prior, deterministic = self.rssm.recurrent_model_input_init(batch_size) # (50, 30), (50, 200)
        embedded_states = self.encoder(states) # (50, 50, 3, 64, 64) -> (50, 50, 1024)

        for t in range(1, batch_length):
            deterministic = self.rssm.recurrent_model(prior, actions[:, t - 1], deterministic) # (50, 30), (50, 2), (50, 200) -> (50, 200)
            prior_dist, prior = self.rssm.transition_model(deterministic)
            posterior_dist, posterior = self.rssm.representation_model(embedded_states[:, t], deterministic)

            prior_means.append(prior_dist.mean)
            prior_stds.append(prior_dist.scale)
            posteriors.append(posterior)
            posterior_means.append(posterior_dist.mean)
            posterior_stds.append(posterior_dist.scale)
            deterministics.append(deterministic)

            prior = posterior

        # size: (50, 49, 30)
        prior_means = torch.stack(prior_means, dim=1)
        prior_stds = torch.stack(prior_stds, dim=1)
        posteriors = torch.stack(posteriors, dim=1)
        posterior_means = torch.stack(posterior_means, dim=1)
        posterior_stds = torch.stack(posterior_stds, dim=1)

        # size: (50, 49, 200)
        deterministics = torch.stack(deterministics, dim=1)

        self.world_model_update(
            experiences=experiences,
            prior_means=prior_means,
            prior_stds=prior_stds,
            posteriors=posteriors,
            posterior_means=posterior_means,
            posterior_stds=posterior_stds,
            deterministics=deterministics,
        )
        return posteriors.detach(), deterministics.detach()

    def world_model_update(
        self,
        experiences,
        prior_means,
        prior_stds,
        posteriors,
        posterior_means,
        posterior_stds,
        deterministics,
    ):
        states, _, _, rewards, _ = experiences

        # reconstruction loss
        reconstruction_dist = self.decoder(posteriors, deterministics)
        reconstruction_loss = reconstruction_dist.log_prob(states[:, 1:])

        # reward loss
        reward_dist = self.reward_predictor(posteriors, deterministics)
        reward_loss = reward_dist.log_prob(rewards[:, 1:])

        # kl divergence loss
        prior_dist = create_normal_dist(prior_means, prior_stds, event_shape=1,)
        posterior_dist = create_normal_dist(posterior_means, posterior_stds, event_shape=1,)

        kl_loss = torch.mean(kl.kl_divergence(posterior_dist, prior_dist))
        kl_loss = torch.max(torch.tensor(free_nats).to(device), kl_loss)

        world_model_loss = (
              kl_loss
            - reconstruction_loss.mean()
            - reward_loss.mean()
        )

        self.world_model_optimizer.zero_grad()
        world_model_loss.backward()
        nn.utils.clip_grad_norm_(self.world_model_params, clip_grad, grad_norm_type)
        self.world_model_optimizer.step()

        self.reconstruction_loss = -reconstruction_loss.mean()
        self.reward_loss = -reward_loss.mean()
        self.kl_loss = kl_loss

    def behavior_learning(self, posteriors, deterministics):
        priors = posteriors.reshape(-1, stochastic_size) # (50, 49, 30) -> (2450, 30)
        deterministics = deterministics.reshape(-1, deterministic_size) # (50, 49, 200) -> (2450, 200)

        imagined_states = []
        imagined_deterministics = []

        for t in range(horizon_length):
            action = self.actor(priors, deterministics) # (2450, 30), (2450, 200) -> (2450, 2)
            deterministics = self.rssm.recurrent_model(priors, action, deterministics) # (2450, 30), (2450, 2), (2450, 200) -> (2450, 200)
            _, priors = self.rssm.transition_model(deterministics) # (2450, 200) -> (2450, 30)

            imagined_states.append(priors)
            imagined_deterministics.append(deterministics)

        imagined_states = torch.stack(imagined_states, dim=1)                 # (2450, 15, 30)
        imagined_deterministics = torch.stack(imagined_deterministics, dim=1) # (2450, 15, 200)

        self.agent_update(imagined_states, imagined_deterministics)

    def agent_update(self, priors, deterministics):
        predicted_rewards = self.reward_predictor(priors, deterministics).mean # (2450, 15, 1)
        values = self.critic(priors, deterministics)                           # (2450, 15, 1)

        lambda_returns = compute_lambda_returns(predicted_rewards, values)

        # 1) Actor
        actor_loss = -torch.mean(lambda_returns)

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), clip_grad, grad_norm_type)
        self.actor_optimizer.step()

        # 2) Critic
        pred_values = self.critic(priors.detach()[:, :-1], deterministics.detach()[:, :-1])
        target_values = lambda_returns.detach()
        value_loss = F.mse_loss(pred_values, target_values)

        self.critic_optimizer.zero_grad()
        value_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), clip_grad, grad_norm_type)
        self.critic_optimizer.step()

        self.actor_loss = actor_loss
        self.critic_loss = value_loss

In [27]:
class ReplayBuffer(object):
    def __init__(self):
        self.index = 0
        self.capacity = capacity

        self.states_buffer      = np.empty((self.capacity, *observation_shape), dtype=np.float32)
        self.next_states_buffer = np.empty((self.capacity, *observation_shape), dtype=np.float32)
        self.actions_buffer     = np.empty((self.capacity, action_size), dtype=np.float32)
        self.rewards_buffer     = np.empty((self.capacity, 1), dtype=np.float32)
        self.dones_buffer       = np.empty((self.capacity, 1), dtype=np.float32)

    def __len__(self):
        return self.index if self.index < self.capacity else self.capacity

    def add(self, observation, action, reward, next_observation, done):
        self.states_buffer[self.index]      = observation
        self.actions_buffer[self.index]     = action
        self.rewards_buffer[self.index]     = reward
        self.next_states_buffer[self.index] = next_observation
        self.dones_buffer[self.index]       = done

        self.index = (self.index + 1) % self.capacity

    def sample(self, batch_size, seq_len):
        current_size = len(self)

        if current_size < seq_len:
            print("too small dataset")
            return

        # 버퍼가 꽉 차면 아무데니 시작해도 유효함
        # 아직 덜 찼으면 시퀀스가 범위를 넘지 않도록 시작점 제한
        max_start_index = current_size if current_size == self.capacity else current_size - seq_len + 1
        start_indices = np.random.randint(0, max_start_index, batch_size).reshape(-1, 1) # (50, 1)
        offsets = np.arange(seq_len).reshape(1, -1)                                      # (1, 50)
        indices = (start_indices + offsets) % self.capacity                              # (50, 50)

        states      = torch.as_tensor(self.states_buffer[indices], device=device).float()
        next_states = torch.as_tensor(self.next_states_buffer[indices], device=device).float()
        actions     = torch.as_tensor(self.actions_buffer[indices], device=device)
        rewards     = torch.as_tensor(self.rewards_buffer[indices], device=device)
        dones       = torch.as_tensor(self.dones_buffer[indices], device=device)

        """
        states:      (50, 50, 3, 64, 64)
        next_states: (50, 50, 3, 64, 64)
        actions:     (50, 50, 2)
        rewards:     (50, 50, 1)
        dones:       (50, 50, 1)
        """
        return states, next_states, actions, rewards, dones

In [28]:
"""
[Encoder]
input  : (B, 3, 64, 64) or (B, T, 3, 64, 64)
output : (B, 1024) or (B, T, 1024)
"""
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2),    # (B, 3, 64, 64)  -> (B, 32, 31, 31)
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2),   # (B, 32, 31, 31) -> (B, 64, 14, 14)
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2),  # (B, 64, 14, 14) -> (B, 128, 6, 6)
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2), # (B, 128, 6, 6)  -> (B, 256, 2, 2)
            nn.ReLU(),
            nn.Flatten(),              # (B, 256, 2, 2)  -> (B, 1024)
        )
        self.network.apply(initialize_weights)

    def forward(self, x):
        if x.ndim == 5: # (B, T, C, H, W)
            B, T, C, H, W = x.shape
            x = x.view(B*T, C, H, W)
            x = self.network(x)
            x = x.view(B, T, -1)
        else:           # (B, C, H, W)
            x = self.network(x)
        return x

In [29]:
"""
[Decoder]
input  : posterior (B*T, 30), deterministic (B*T, 200)
output : Normal dist (B, T, 3, 64, 64)
"""
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(stochastic_size + deterministic_size, 1024), # (B*T, 230) -> (B*T, 1024)
            nn.Unflatten(1, (1024, 1, 1)),       # (B*T, 1024)       -> (B*T, 1024, 1, 1)
            nn.ConvTranspose2d(1024, 128, 5, 2), # (B*T, 1024, 1, 1) -> (B*T, 128, 5, 5)
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 5, 2),   # (B*T, 128, 5, 5)  -> (B*T, 64, 13, 13)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 6, 2),    # (B*T, 64, 13, 13) -> (B*T, 32, 30, 30)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 6, 2),     # (B*T, 32, 30, 30) -> (B*T, 3, 64, 64)
        )
        self.network.apply(initialize_weights)

    def forward(self, posterior, deterministic):
        x = temporal_forward(self.network, posterior, deterministic, output_shape=observation_shape)
        dist = create_normal_dist(x, std=1, event_shape=len(observation_shape))
        return dist

In [30]:
class RSSM(nn.Module):
    def __init__(self):
        super().__init__()
        self.recurrent_model = RecurrentModel()
        self.transition_model = TransitionModel()
        self.representation_model = RepresentationModel()

    def recurrent_model_input_init(self, batch_size):
        return self.transition_model.input_init(batch_size), self.recurrent_model.input_init(batch_size)

"""
[RecurrentModel]
input  : stochastic (B, 30), action (B, 2), deterministic (B, 200)
output : deterministic (B, 200)
"""
class RecurrentModel(nn.Module):
    def __init__(self, hidden_size=200):
        super().__init__()
        self.hidden_size = hidden_size

        self.network = nn.Sequential(
            nn.Linear(stochastic_size + action_size, self.hidden_size),
            nn.ELU(),
        )
        self.recurrent = nn.GRUCell(self.hidden_size, deterministic_size)

    def forward(self, stochastic, action, deterministic):
        x = torch.cat((stochastic, action), 1)
        x = self.network(x)
        x = self.recurrent(x, deterministic)
        return x

    def input_init(self, batch_size):
        return torch.zeros(batch_size, deterministic_size).to(device)

"""
[TransitionModel]
input  : deterministic (B, 200)
output : Normal dist, prior (B, 30)
"""
class TransitionModel(nn.Module):
    def __init__(self, min_std=0.1, hidden_size=200):
        super().__init__()
        self.min_std = min_std
        self.hidden_size = hidden_size

        self.network = nn.Sequential(
            nn.Linear(deterministic_size, self.hidden_size),
            nn.ELU(),
            nn.Linear(self.hidden_size, stochastic_size * 2),
        )
        self.network.apply(initialize_weights)

    def forward(self, deterministic):
        x = self.network(deterministic)
        prior_dist = create_normal_dist(x, min_std=self.min_std)
        prior = prior_dist.rsample()
        return prior_dist, prior

    def input_init(self, batch_size):
        return torch.zeros(batch_size, stochastic_size).to(device)

"""
[RepresentationModel]
input  : embedded_observation (B, 1024), deterministic (B, 200)
output : Normal dist, posterior (B, 30)
"""
class RepresentationModel(nn.Module):
    def __init__(self, min_std=0.1, hidden_size=200):
        super().__init__()
        self.min_std = min_std
        self.hidden_size = hidden_size

        self.network = nn.Sequential(
            nn.Linear(embedded_state_size + deterministic_size, self.hidden_size),
            nn.ELU(),
            nn.Linear(self.hidden_size, stochastic_size * 2),
        )
        self.network.apply(initialize_weights)

    def forward(self, embedded_observation, deterministic):
        x = self.network(torch.cat((embedded_observation, deterministic), 1))
        posterior_dist = create_normal_dist(x, min_std=self.min_std)
        posterior = posterior_dist.rsample()
        return posterior_dist, posterior

"""
[RewardPredictor]
input  : posterior (B, T, 30), deterministic (B, T, 200)
output : Normal dist (B, T, 1)
"""
class RewardPredictor(nn.Module):
    def __init__(self, hidden_size=400):
        super().__init__()
        self.hidden_size = hidden_size

        self.network = nn.Sequential(
            nn.Linear(stochastic_size + deterministic_size, self.hidden_size),
            nn.ELU(),
            nn.Linear(self.hidden_size, 1)
        )
        self.network.apply(initialize_weights)

    def forward(self, posterior, deterministic):
        x = temporal_forward(self.network, posterior, deterministic, output_shape=(1,))
        dist = create_normal_dist(x, std=1, event_shape=1)
        return dist

In [31]:
"""
[Actor]
input  : posterior (B, 30), deterministic (B, 200)
output : action (B, 2)
"""
class Actor(nn.Module):
    def __init__(self, hidden_size=400):
        super().__init__()
        self.hidden_size = hidden_size

        self.network = nn.Sequential(
            nn.Linear(stochastic_size + deterministic_size, self.hidden_size),
            nn.ELU(),
            nn.Linear(self.hidden_size, action_size)
        )
        self.network.apply(initialize_weights)

    def forward(self, posterior, deterministic):
        x = torch.cat((posterior, deterministic), -1)
        x = self.network(x)
        dist = OneHotCategorical(logits=x)
        action = dist.sample() + dist.probs - dist.probs.detach()
        return action

"""
[Critic]
input  : posterior (B, 30), deterministic (B, 200)
output : value (B, 1)
"""
class Critic(nn.Module):
    def __init__(self, hidden_size=400):
        super().__init__()
        self.hidden_size = hidden_size

        self.network = nn.Sequential(
            nn.Linear(stochastic_size + deterministic_size, self.hidden_size),
            nn.ELU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ELU(),
            nn.Linear(self.hidden_size, 1)
        )
        self.network.apply(initialize_weights)

    def forward(self, posterior, deterministic):
        x = temporal_forward(self.network, posterior, deterministic, output_shape=(1,))
        return x

In [32]:
def initialize_weights(module):
    if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
        kaiming_uniform_(module.weight, nonlinearity="relu")
        constant_(module.bias, 0)
    elif isinstance(module, nn.Linear):
        kaiming_uniform_(module.weight)
        constant_(module.bias, 0)

def temporal_forward(network, x, y, output_shape):
    x = torch.cat((x, y), -1) # (B, C1), (B, C2) -> (B, C1 + C2)

    if x.ndim == 2:
        B, C = x.shape
        x = network(x)
        x = x.reshape(B, *output_shape)
    elif x.ndim == 3:
        B, T, C = x.shape
        x = x.reshape(B*T, C)
        x = network(x)
        x = x.reshape(B, T, *output_shape)
    else:
        raise ValueError(f"input dimension must be 2 or 3 | current x.shape: {x.shape}")
    return x

def create_normal_dist(x, std=None, min_std=0.1, event_shape=None):
    if std is None:
        mean, std = torch.chunk(x, 2, -1)
        std = F.softplus(std) + min_std
    else:
        mean = x
    dist = Normal(mean, std)
    if event_shape:
        dist = Independent(dist, event_shape)
    return dist

def compute_lambda_returns(rewards, values):
    current_rewards = rewards[:, :-1] # r_{t}  : t (0 ~ h-1)
    next_values     = values[:, :-1]

    lambda_return = next_values[:, -1] # v_{h}  : 마지막 시점의 가치
    td_targets = current_rewards + gamma * (1 - lambda_) * next_values # 재귀적이지 않은 부분 미리 구하기

    lambda_returns = []
    for index in reversed(range(horizon_length - 1)):
        last = td_targets[:, index] + gamma * lambda_ * lambda_return
        lambda_returns.append(last)
    returns = torch.stack(list(reversed(lambda_returns)), dim=1).to(device)
    return returns

In [33]:
height = 64
width = 64
frame_skip = 2

world_model_lr = 0.001
actor_lr = 0.0001
critic_lr = 0.0001

episodes = 1000
world_model_update_steps = 100
collect_episodes = 1

batch_size = 50
batch_length = 50

capacity = 100000
deterministic_size = 200
stochastic_size = 30
embedded_state_size = 1024

free_nats = 3
horizon_length = 15

clip_grad = 100
grad_norm_type = 2
gamma = 0.99
lambda_ = 0.95

In [34]:
env = gym.make("CartPole-v1", render_mode="rgb_array") # state : (400, 600, 3)
env = ResizeObservationWrapper(env, height, width)     # state : (64, 64, 3)
env = ChannelFirstEnv(env)                             # state : (3, 64, 64)
env = FrameSkipWrapper(env, skip=frame_skip)
env = PixelNormalizationWrapper(env)                   # [0 ~ 255] -> [-0.5 ~ 0.5]

In [35]:
observation_shape = env.observation_space.shape # (3, 64, 64)
action_size = env.action_space.n                # 2
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = ''
# model_path = 'model_ep200.pth' # if not -> ""

agent = Dreamer(model_path)

In [36]:
agent.train(env)

Collecting experiences...


  from pkg_resources import resource_stream, resource_exists


Collected experiences : 170


KeyboardInterrupt: 