In [7]:
import torch
import sys
sys.path.append('../')
sys.path.append('../../')
from lstm_agent_cql_bc import DecisionLSTM
import yaml
import argparse
from tqdm import tqdm
from TMaze_new.TMaze_new_src.utils.tmaze import TMazeClassicPassive
from TMaze_new.TMaze_new_src.utils import seeds_list
import numpy as np
from tqdm import tqdm
import os

In [8]:
def load_model(seed, exp_name, loss_mode, stacked_input, context_length, segments):
    agent = DecisionLSTM(4, 1, 32, num_layers=2, mode='tmaze')
    
    run_name = f'{exp_name}_{loss_mode}_{seed}_stacked_{stacked_input}_context_{context_length}_segments_{segments}'
    print(run_name)
    model_path = f'../ckpt/tmaze_ckpt_v2/{loss_mode}/{seed}/{run_name}.ckpt'
    
    agent.load_state_dict(torch.load(model_path))
    
    agent.eval()
    agent.to(agent.device)
    
    return agent

In [9]:
exp_name = 'tmaze_v2'
stacked_input = False
loss_mode = 'bc'
segments = 3
context_length = 3

In [10]:
agent = load_model(
    seed=1,
    exp_name=exp_name,
    stacked_input=stacked_input,
    loss_mode=loss_mode, 
    context_length=context_length,
    segments=segments
)

_ = agent.eval()
_ = agent.to(agent.device)

tmaze_v2_bc_1_stacked_False_context_3_segments_3


In [11]:
def run_episode(seed, episode_timeout, corridor_length, stacked_input, loss_mode):
    channels = 5
    create_video = False

    env = TMazeClassicPassive(
        episode_length=episode_timeout, 
        corridor_length=corridor_length, 
        penalty=0, 
        seed=seed, 
        goal_reward=1.0)

    state = env.reset() # {x, y, hint}
    np.random.seed(seed)
    where_i = state[0]
    mem_state = state[2]
    mem_state2 = state

    state = np.concatenate((state, np.array([0]))) # {x, y, hint, flag}
    state = np.concatenate((state, np.array([np.random.randint(low=-1, high=1+1)]))) # {x, y, hint, flag, noise}

    if create_video == True:
        print("down, required act: 3" if mem_state == -1.0 else "up,  required act: 1")

    state = torch.tensor(state).reshape(1, 1, channels)
    out_states = []
    out_states.append(state.cpu().numpy())
    done = True
    Flag = 0
    rtg = 1.0
    agent.init_hidden(1)
    
    episode_return, episode_length = 0, 0

    for t in range(episode_timeout):
        with torch.no_grad():
            q_values = []
            for possible_action in [0, 1, 2, 3]:  # 4 –≤–æ–∑–º–æ–∂–Ω—ã—Ö –¥–µ–π—Å—Ç–≤–∏—è
                action_tensor = torch.tensor([[[possible_action]]], 
                                        dtype=torch.float32, 
                                        device=agent.device).long()
                rtg_tensor = torch.tensor([[[rtg]]], 
                                        dtype=torch.float32, 
                                        device=agent.device)#.long()
                if loss_mode == 'cql':
                    update_lstm_hidden = possible_action==3
                else:
                    update_lstm_hidden = True
                    
                action_preds, q1, q2, _ = agent.forward(
                    states = state[:, :, 1:].cuda().float(),
                    actions = action_tensor.cuda(),
                    returns_to_go = rtg_tensor.cuda(),
                    update_hidden = update_lstm_hidden,
                    stacked_input = stacked_input,
                )

                q_value = torch.minimum(q1, q2)
                q_values.append(q_value)

                if loss_mode == 'bc':
                    break

            # Select action with max Q-value
            if loss_mode == 'cql':
                q_values = torch.cat(q_values, dim=-1)
                action = torch.argmax(q_values).item() #+ 3
            else:
                action = torch.argmax(torch.softmax(action_preds, dim=-1).squeeze()).item()

        # print(t, action, torch.softmax(action_preds, dim=-1).squeeze().detach().cpu().numpy())


        state, reward, done, info = env.step(action)

        rtg -= reward
        
        if t < 0:
            state[2] = mem_state2[2]
        
            # {x, y, hint} -> {x, y, hint, flag}
        if state[0] != env.corridor_length:
            state = np.concatenate((state, np.array([0])))
        else:
            if Flag != 1:
                state = np.concatenate((state, np.array([1])))
                Flag = 1
            else:
                state = np.concatenate((state, np.array([0])))

        state = np.concatenate((state, np.array([np.random.randint(low=-1, high=1+1)])))
        state = state.reshape(1, 1, channels)
        state = torch.from_numpy(state).float().cuda()
                
            
        if done:
            if create_video == True:
                if np.round(where_i, 4) == np.round(corridor_length, 4):
                    print("Junction achieved üòÄ ‚úÖ‚úÖ‚úÖ")
                    print("Chosen act:", "up" if action == 1 else "down" if action == 3 else "wrong")
                    if mem_state == -1 and action == 3:
                        print("Correct choice üòÄ ‚úÖ‚úÖ‚úÖ")
                    elif mem_state == 1 and action == 1:
                        print("Correct choice üòÄ ‚úÖ‚úÖ‚úÖ")
                    else:
                        print("Wrong choice üò≠ ‚õîÔ∏è‚õîÔ∏è‚õîÔ∏è")
                else:
                    print("Junction is not achieved üò≠ ‚õîÔ∏è‚õîÔ∏è‚õîÔ∏è")
            break 

    return reward

In [12]:
episode_timeout = 9 # segments * context_length
corridor_length = episode_timeout - 2

rewards = []
for seed in tqdm([0,1,2,5]): # seeds_list[::25]
    reward = run_episode(seed=seed, episode_timeout=episode_timeout, corridor_length=corridor_length, stacked_input=stacked_input, loss_mode=loss_mode)
    rewards.append(reward)
    print(f"seed: {seed}, reward: {reward}")

print(np.mean(rewards), np.std(rewards))

print('-'*100)

episode_timeout = 90 # segments * context_length
corridor_length = episode_timeout - 2

rewards = []
for seed in tqdm([0,1,2,5]): # seeds_list[::25]
    reward = run_episode(seed=seed, episode_timeout=episode_timeout, corridor_length=corridor_length, stacked_input=stacked_input, loss_mode=loss_mode)
    rewards.append(reward)
    print(f"seed: {seed}, reward: {reward}")

print(np.mean(rewards), np.std(rewards))


  0%|          | 0/4 [00:00<?, ?it/s]

seed: 0, reward: 1.0


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  9.55it/s]


seed: 1, reward: 1.0
seed: 2, reward: 1.0
seed: 5, reward: 1.0
1.0 0.0
----------------------------------------------------------------------------------------------------


 25%|‚ñà‚ñà‚ñå       | 1/4 [00:00<00:01,  1.76it/s]

seed: 0, reward: 1.0


 50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 2/4 [00:01<00:01,  1.67it/s]

seed: 1, reward: 1.0


 75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 3/4 [00:01<00:00,  2.01it/s]

seed: 2, reward: 1.0


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:02<00:00,  1.83it/s]

seed: 5, reward: 1.0
1.0 0.0



