In [1]:
%cd RL_study/dreamer/

/mnt/c/Users/mingu/OneDrive/바탕 화면/성균관대/리서치인턴/공부/RL_study/dreamer


In [2]:
import gymnasium as gym
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal, kl_divergence

from tqdm import tqdm
from models import *
from logger import Logger


env = gym.make('CarRacing-v2')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


2024-10-04 01:02:27.405868: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-04 01:02:27.413680: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-10-04 01:02:27.420612: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-04 01:02:27.432515: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-04 01:02:27.436037: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attemptin

In [3]:
action_dim = env.action_space.shape[0]
obs_shape = env.observation_space.shape
print("action space: ",action_dim,", obs shape: ", obs_shape,sep='')

action space: 3, obs shape: (96, 96, 3)


In [4]:
def collect_data(env,state_dim, transition_representation, agent,replay_buffer, num_episode, device, training=True):
    print("collecting data...")
    score=0
    for _ in tqdm(range(num_episode)):
        obs, info = env.reset()
        done = False
        experience = []
        prev_state = torch.zeros(1, state_dim).to(device)
        prev_deter = transition_representation.init_hidden(1).to(device)
        prev_action = torch.zeros(1, action_dim).to(device)
        with torch.no_grad():
            while not done:
                #obs(96x96x3) -> (3x96x96) -> (1x3x96x96)
                obs = torch.tensor(obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)/255
                # s_t-1, a_t-1, o_t-1 -> s_t
                posterior_mean, posterior_std, prev_deter = transition_representation.posterior(prev_state, prev_action, prev_deter,obs)
                cur_state = posterior_mean + posterior_std*torch.normal(0, 1, posterior_mean.size()).to(device)

                action_mu, action_std = agent(cur_state, prev_deter)
                eps = torch.normal(0, 1, (1,action_dim)).to(device)
                if training:
                    cur_action = torch.tanh(action_mu + action_std*eps)
                else:
                    cur_action = torch.tanh(action_mu)
                next_obs, reward, terminated, truncated, info  = env.step(cur_action[0].cpu().numpy())
                done = terminated or truncated
                
                experience.append((np.array(obs.squeeze(0).cpu()), np.array(cur_action.squeeze(0).detach().cpu()), reward, done))
                
                obs = next_obs
                prev_state = cur_state
                prev_action = cur_action
                score+=reward
        if training:
            for exp in experience:
                replay_buffer.push(exp)
    return score/num_episode

In [5]:
def lambda_return(rewards, values, gamma, lambda_):
    # rewards, values : (Horizon+1, seq*batch)
    # 어렵다
    V_lambda = torch.zeros_like(rewards, device=rewards.device)

    H = rewards.shape[0] - 1
    V_n = torch.zeros_like(rewards, device=rewards.device)
    V_n[H] = values[H]
    for n in range(1, H+1):
        # n-step 계산 하기 위함
        # 각 step의 value 목표
        V_n[:-n] = (gamma ** n) * values[n:]
        for k in range(1, n+1):
            # n step의 reward 합 진행
            if k == n:
                V_n[:-n] += (gamma ** (n-1)) * rewards[k:]
            else:
                V_n[:-n] += (gamma ** (k-1)) * rewards[k:-n+k]

        # add lambda_ weighted n-step target to compute lambda target
        if n == H:
            V_lambda += (lambda_ ** (H-1)) * V_n
        else:
            V_lambda += (1 - lambda_) * (lambda_ ** (n-1)) * V_n
            
    return V_lambda

In [6]:
def train(batch,state_dim,deterministic_dim, device, transition_representation, reward_model, observation, actor, value, model_optimizer, actor_optimizer, critic_optimizer):
    obs_seq = []
    action_seq = []
    reward_seq = []
    #batch = batch, seq, (obs, action, reward, done)
    for seq in batch:
        obs_temp=[]
        action_temp=[]
        reward_temp=[]
        for (obs, action, reward, done) in seq:
            obs_temp.append(obs)
            action_temp.append(action)
            reward_temp.append(reward)
        obs_seq.append(obs_temp)
        action_seq.append(action_temp)
        reward_seq.append(reward_temp)
    obs_seq = torch.tensor(obs_seq, dtype=torch.float32).to(device)
    action_seq = torch.tensor(action_seq, dtype=torch.float32).to(device)
    reward_seq = torch.tensor(reward_seq, dtype=torch.float32).to(device)
    batch_size, seq_len, _, _, _ = obs_seq.size()
    
    prev_deter = transition_representation.init_hidden(batch_size).to(device)
    prev_state = torch.zeros(batch_size, state_dim).to(device)
    
    states = torch.zeros(seq_len,batch_size, state_dim).to(device)
    deters = torch.zeros(seq_len,batch_size, deterministic_dim).to(device)
    
    beta=0.1 #kl조절
    imagine_horizon=15
    gamma=0.99
    lambda_=0.95
    kl_loss = 0
    reconstruction_loss = 0
    reward_loss = 0
    
    total_kl_loss = 0
    total_reconstruction_loss = 0
    total_reward_loss = 0
    
    action_prev = action_seq[:, 0].to(device)
    total_loss=torch.zeros(1).to(device)
    for t in range(1,seq_len):
        obs = obs_seq[:, t].to(device)
        action = action_seq[:, t].to(device)
        reward = reward_seq[:, t].to(device)
        prior_mean, prior_std, _ = transition_representation(prev_state, action_prev, prev_deter)
        posterior_mean, posterior_std, cur_deter = transition_representation.posterior(prev_state, action_prev, prev_deter,obs)
        
        state = posterior_mean + posterior_std*torch.normal(0, 1, posterior_mean.size()).to(device)
        obs_pred = observation(state, cur_deter)
        reconstruction_loss = nn.functional.mse_loss(obs_pred, obs)
        
        
        reward_pred = reward_model(state, cur_deter)
        reward_loss = nn.functional.mse_loss(reward_pred, reward)
        
        prior = Normal(prior_mean, prior_std)
        posterior = Normal(posterior_mean, posterior_std)
        kl_loss = kl_divergence(posterior, prior).mean()

        
        total_loss += reconstruction_loss + reward_loss + beta*kl_loss

        action_prev = action
        prev_state = state
        prev_deter = cur_deter
        
        states[t] = state
        deters[t] = cur_deter
        
        total_kl_loss += kl_loss.item()
        total_reconstruction_loss += reconstruction_loss.item()
        total_reward_loss += reward_loss.item()
    model_optimizer.zero_grad()
    total_loss.backward()
    model_optimizer.step()

    
    ##actor, critic 학습
    
    #states (seq, batch, state_dim) -> (seq*batch, state_dim)
    #deters (seq, batch, deterministic_dim) -> (seq*batch, deterministic_dim)
    states = states.view(-1, state_dim).detach()
    deters = deters.view(-1, deterministic_dim).detach()
    
    imagined_states = [states]
    imagined_deters = [deters]
    
    rewards = []
    values = []
    
    
    rewards.append(reward_model(states, deters).squeeze())
    values.append(value(states, deters).squeeze())
    
    for t in range(1,imagine_horizon+1):
        action_mu, action_std = actor(imagined_states[t-1], imagined_deters[t-1])
        eps = torch.normal(0, 1, (action_mu.size())).to(device)
        action = torch.tanh(action_mu + action_std*eps)
        
        prior_mean, prior_std, deter = transition_representation(imagined_states[t-1], action, imagined_deters[t-1])
        state = prior_mean + prior_std*torch.normal(0, 1, prior_mean.size()).to(device)
        
        imagined_states.append(state)
        imagined_deters.append(deter)
        
        rewards.append(reward_model(imagined_states[t], imagined_deters[t]).squeeze())
        values.append( value(imagined_states[t], imagined_deters[t]).squeeze())
    
    imagined_states = torch.stack(imagined_states, dim=0)
    imagined_deters = torch.stack(imagined_deters, dim=0)
    values = torch.stack(values, dim=0)
    rewards = torch.stack(rewards, dim=0)
    
    returns = lambda_return(rewards, values,0.99, 0.95)
    
    critic_loss = nn.functional.mse_loss(values[1:],returns[1:].detach())
    critic_optimizer.zero_grad()
    critic_loss.backward(retain_graph=True)
    torch.nn.utils.clip_grad_norm_(value.parameters(), max_norm=100)
    critic_optimizer.step()
    
    actor_loss = -returns.mean()
    actor_optimizer.zero_grad()
    actor_loss.backward()
    torch.nn.utils.clip_grad_norm_(actor.parameters(), max_norm=100)
    actor_optimizer.step()
    
    print("actor loss: ",actor_loss.item(),", critic loss: ",critic_loss.item(),sep='')
    
    return total_kl_loss/(seq_len-1), total_reconstruction_loss/(seq_len-1), total_reward_loss/(seq_len-1), actor_loss.item(), critic_loss.item()
    

In [7]:
state_dim=64
deterministic_dim=256
model_lr=1e-4
actor_critc_lr=1e-4
transition_representation=TransitionRepresentationModel(state_dim, action_dim).to(device)
observation=ObservationModel(state_dim,deterministic_dim, obs_shape[2]).to(device)
reward=RewardModel(state_dim,deterministic_dim).to(device)

agent=Agent(state_dim,deterministic_dim, action_dim).to(device)
value=ValueModel(state_dim,deterministic_dim).to(device)

model_params = list(transition_representation.parameters()) + list(observation.parameters()) + list(reward.parameters())
model_optimizer = optim.Adam(model_params, lr=model_lr)
actor_optimizer = optim.Adam(agent.parameters(), lr=actor_critc_lr)
critic_optimizer = optim.Adam(value.parameters(), lr=actor_critc_lr)

#state, action, reward, next_state, done 저장하고 sampling 가능
replay_buffer = ReplayBufferSeq(100000)
logger = Logger('./logs')

In [8]:
num_epochs = 10000
batch_size = 64
seq_len = 50

world_episodes = 1
update_step = 20

seed_episodes = 5
test_interval = 3
save_interval = 20
print("collecting seed data...")
collect_data(env,state_dim, transition_representation, agent, replay_buffer, seed_episodes, device)

for epoch in range(num_epochs):
    train_score=collect_data(env,state_dim, transition_representation, agent, replay_buffer, world_episodes, device)
    logger.log(epoch*update_step,train_score=train_score)

    if len(replay_buffer) < batch_size*seq_len:
        continue
    
    #train world model and actor, critic
    for _ in range(update_step):
        batch = replay_buffer.sample_seq(batch_size, seq_len)
        kl_loss,reconst_loss, reward_loss, actor_loss, critic_loss=train(batch,state_dim,deterministic_dim, device, transition_representation, reward, observation, agent, value, model_optimizer, actor_optimizer, critic_optimizer)
        logger.log(epoch*update_step+_,epoch=epoch, kl_loss=kl_loss, reconst_loss=reconst_loss, reward_loss=reward_loss, actor_loss=actor_loss, critic_loss=critic_loss)

    if epoch % test_interval == 0:
        test_score=collect_data(env,state_dim, transition_representation, agent, replay_buffer, world_episodes, device,training=False)
        logger.log(epoch*update_step,test_score=test_score)
    if epoch % save_interval == 0:
        torch.save(transition_representation.state_dict(), 'transition_representation.pth')
        torch.save(observation.state_dict(), 'observation.pth')
        torch.save(reward.state_dict(), 'reward.pth')
        torch.save(agent.state_dict(), 'agent.pth')
        torch.save(value.state_dict(), 'value.pth')
torch.save(transition_representation.state_dict(), 'transition_representation.pth')
torch.save(observation.state_dict(), 'observation.pth')
torch.save(reward.state_dict(), 'reward.pth')
torch.save(agent.state_dict(), 'agent.pth')
torch.save(value.state_dict(), 'value.pth')

collecting seed data...
collecting data...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:57<00:00, 11.44s/it]


collecting data...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:11<00:00, 11.14s/it]

2024-10-04 01:03:38,565 global_step: 0,train_score: -51.127819548872935, 



  obs_seq = torch.tensor(obs_seq, dtype=torch.float32).to(device)
  reward_loss = nn.functional.mse_loss(reward_pred, reward)


actor loss: -0.16721706092357635, critic loss: 0.02095862850546837
2024-10-04 01:03:54,752 global_step: 0,epoch: 0, kl_loss: 0.04724697250758811, reconst_loss: 0.05853355789975244, reward_loss: 0.21097925012665136, actor_loss: -0.16721706092357635, critic_loss: 0.02095862850546837, 
actor loss: -0.08586236834526062, critic loss: 0.0065489462576806545
2024-10-04 01:04:08,117 global_step: 1,epoch: 0, kl_loss: 0.04654910202062099, reconst_loss: 0.0583842969974693, reward_loss: 0.18402682248578997, actor_loss: -0.08586236834526062, critic_loss: 0.0065489462576806545, 
actor loss: -0.007309821899980307, critic loss: 0.013196372427046299
2024-10-04 01:04:19,666 global_step: 2,epoch: 0, kl_loss: 0.0455610510532041, reconst_loss: 0.0580580295348654, reward_loss: 0.15516277976638201, actor_loss: -0.007309821899980307, critic_loss: 0.013196372427046299, 
actor loss: 0.08251304179430008, critic loss: 0.039441704750061035
2024-10-04 01:04:30,616 global_step: 3,epoch: 0, kl_loss: 0.0448546020114528

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.86s/it]

2024-10-04 01:07:37,593 global_step: 0,test_score: -92.53731343283484, 





collecting data...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.64s/it]

2024-10-04 01:07:50,240 global_step: 20,train_score: -63.503649635037185, 





actor loss: 0.40864279866218567, critic loss: 0.06950075179338455
2024-10-04 01:08:01,156 global_step: 20,epoch: 1, kl_loss: 0.03281098168001187, reconst_loss: 0.0550857380184592, reward_loss: 0.1697316064252233, actor_loss: 0.40864279866218567, critic_loss: 0.06950075179338455, 
actor loss: 0.4085176885128021, critic loss: 0.060808081179857254
2024-10-04 01:08:12,098 global_step: 21,epoch: 1, kl_loss: 0.03223165118952795, reconst_loss: 0.05479680815217446, reward_loss: 0.18455084345341488, actor_loss: 0.4085176885128021, critic_loss: 0.060808081179857254, 
actor loss: 0.4122876524925232, critic loss: 0.054538071155548096
2024-10-04 01:08:23,029 global_step: 22,epoch: 1, kl_loss: 0.03185971272748192, reconst_loss: 0.054573401048475384, reward_loss: 0.1842106553982487, actor_loss: 0.4122876524925232, critic_loss: 0.054538071155548096, 
actor loss: 0.42517340183258057, critic loss: 0.052988890558481216
2024-10-04 01:08:34,166 global_step: 23,epoch: 1, kl_loss: 0.03114822572477314, recons

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.54s/it]

2024-10-04 01:11:40,378 global_step: 40,train_score: -73.50993377483466, 





actor loss: 0.8868411183357239, critic loss: 0.14424102008342743
2024-10-04 01:11:51,530 global_step: 40,epoch: 2, kl_loss: 0.02296482807985146, reconst_loss: 0.04810088279904151, reward_loss: 0.18756635884079625, actor_loss: 0.8868411183357239, critic_loss: 0.14424102008342743, 
actor loss: 0.9032642841339111, critic loss: 0.1380825787782669
2024-10-04 01:12:02,518 global_step: 41,epoch: 2, kl_loss: 0.022433251570149953, reconst_loss: 0.04802800274016906, reward_loss: 0.20131853398420296, actor_loss: 0.9032642841339111, critic_loss: 0.1380825787782669, 
actor loss: 0.9197560548782349, critic loss: 0.13106907904148102
2024-10-04 01:12:13,482 global_step: 42,epoch: 2, kl_loss: 0.02230836658225376, reconst_loss: 0.04806892429383434, reward_loss: 0.2067033797645067, actor_loss: 0.9197560548782349, critic_loss: 0.13106907904148102, 
actor loss: 0.9356011748313904, critic loss: 0.12287493795156479
2024-10-04 01:12:24,587 global_step: 43,epoch: 2, kl_loss: 0.02175178408280623, reconst_loss: 

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.45s/it]

2024-10-04 01:15:29,178 global_step: 60,train_score: -79.79797979797951, 





actor loss: 1.4963157176971436, critic loss: 0.11335998773574829
2024-10-04 01:15:40,244 global_step: 60,epoch: 3, kl_loss: 0.01630589144532474, reconst_loss: 0.04471880334372423, reward_loss: 0.15153176794113704, actor_loss: 1.4963157176971436, critic_loss: 0.11335998773574829, 
actor loss: 1.5492439270019531, critic loss: 0.11587461084127426
2024-10-04 01:15:51,301 global_step: 61,epoch: 3, kl_loss: 0.016071407982072204, reconst_loss: 0.04408758597410455, reward_loss: 0.1501746677546477, actor_loss: 1.5492439270019531, critic_loss: 0.11587461084127426, 
actor loss: 1.6075255870819092, critic loss: 0.11851570755243301
2024-10-04 01:16:02,335 global_step: 62,epoch: 3, kl_loss: 0.015919057113042444, reconst_loss: 0.04427242453913299, reward_loss: 0.1934288083171776, actor_loss: 1.6075255870819092, critic_loss: 0.11851570755243301, 
actor loss: 1.666734218597412, critic loss: 0.12089784443378448
2024-10-04 01:16:13,439 global_step: 63,epoch: 3, kl_loss: 0.01570609815380707, reconst_loss:

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.06s/it]

2024-10-04 01:19:19,604 global_step: 60,test_score: -40.35087719298306, 





collecting data...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:11<00:00, 11.24s/it]

2024-10-04 01:19:30,843 global_step: 80,train_score: -65.98639455782379, 





actor loss: 5.037748336791992, critic loss: 0.31155452132225037
2024-10-04 01:19:41,695 global_step: 80,epoch: 4, kl_loss: 0.011375568731098759, reconst_loss: 0.03985837025910008, reward_loss: 0.16434116211092595, actor_loss: 5.037748336791992, critic_loss: 0.31155452132225037, 
actor loss: 5.4593658447265625, critic loss: 0.3270821273326874
2024-10-04 01:19:52,777 global_step: 81,epoch: 4, kl_loss: 0.010876764590870969, reconst_loss: 0.03993202084485365, reward_loss: 0.20035658496413, actor_loss: 5.4593658447265625, critic_loss: 0.3270821273326874, 
actor loss: 5.959342002868652, critic loss: 0.3752315938472748
2024-10-04 01:20:03,683 global_step: 82,epoch: 4, kl_loss: 0.010370833324078394, reconst_loss: 0.03913581819862735, reward_loss: 0.16958629906804737, actor_loss: 5.959342002868652, critic_loss: 0.3752315938472748, 
actor loss: 6.516692161560059, critic loss: 0.42621833086013794
2024-10-04 01:20:14,957 global_step: 83,epoch: 4, kl_loss: 0.009901970064229503, reconst_loss: 0.0390

  0%|                                                                                                                                 | 0/1 [00:00<?, ?it/s]