# To analyse the policy of a multilayer perceptron we load in a model, create all possible game states which is feasible for leduc and get the action greedily

In [1]:
import sys 
sys.path.append('..')


import os 
import torch 
import rlcard 
from rlcard.envs.leducholdem import LeducholdemEnv
from rlcard.agents import RandomAgent, DQNAgent, NFSPAgent
from rlcard.utils import (
    set_seed,
    reorganize,
)

In [2]:
def load_agent(dir: str, agent: str = 'dqn'):
    if agent == 'dqn':
        from rlcard.agents.dqn_agent import DQNAgent
        agent_class = DQNAgent
    elif agent == 'nfsp':
        from rlcard.agents.nfsp_agent import NFSPAgent
        agent_class = NFSPAgent
    return agent_class.from_checkpoint(torch.load(dir))


In [15]:
def create_env(seed: int = 0) -> LeducholdemEnv:
    """
    Creates the Leduc Hold'em environment, setting the selected random seed.
    """
    set_seed(seed)
    return rlcard.make('leduc-holdem', config={'seed': seed})


def fill_env_with_agents(env: LeducholdemEnv, agents: list[DQNAgent]) -> None:
    """
    Fills the environment with the given agents and possibly random agents.
    """
    for _ in range(len(agents), env.num_players):
        agents.append(RandomAgent(num_actions=env.num_actions))
    env.set_agents(agents)


def collect_trajectories(env: LeducholdemEnv, agent: NFSPAgent | DQNAgent, max_memory_size: int | None = None):
    fill_env_with_agents(env, [agent])
    memory = agent.memory if isinstance(agent, DQNAgent) else agent._rl_agent.memory
    max_memory_size = max_memory_size or memory.max_memory_size
    
    # Start training
    while memory.memory_size < max_memory_size:
        print(f"Current memory size: {memory.memory_size()}", end='\r')
        if isinstance(agent, NFSPAgent):
            agent.sample_episode_policy()
        trajectories, payoffs = env.run(is_training=False)

        # Reorganize the data to be state, action, reward, next_state, done
        trajectories = reorganize(trajectories, payoffs)

        # Feed transitions into agent memory, and train the agent
        # Here, we assume that DQN always plays the first position
        # and the other players play randomly (if any)
        for ts in trajectories[0]:
            agent.feed(ts)
    return memory

In [4]:
def transformer_input_to_mlp_input(transformer_input):
    return 

In [38]:
dqn_mlp_dir = "/home/kacperwyrwal/mlpractical-assignment4/rlcard-mlp/notebooks/experiments/random_agent/dqn/mlp/checkpoint_dqn.pt"
dqn_mlp_agent = load_agent(dqn_mlp_dir, 'dqn')


dqn_transformer_dir = "/home/kacperwyrwal/mlpractical-assignment4/rlcard-mlp/notebooks/experiments/random_agent/dqn/transformer/checkpoint_dqn.pt"
dqn_transformer_agent = load_agent(dqn_transformer_dir, 'dqn')


INFO - Restoring model from checkpoint...

INFO - Restoring model from checkpoint...


In [39]:
env = create_env()
memory = collect_trajectories(env, dqn_transformer_agent)

In [40]:
memory.sample()[0].shape

(32, 10, 36)

In [23]:
dqn_transformer_agent.memory.max_memory_size

49099

In [16]:
env = create_env()
memory = collect_trajectories(env, dqn_mlp_agent)

In [19]:
memory.max_memory_size

49099

In [37]:
dqn_transformer_dir_state = torch.load(dqn_transformer_dir)
dqn_transformer_dir_state['q_estimator']['memory_sequence_length'] = 128
dqn_transformer_dir_state['memory_sequence_length'] = 128
dqn_transformer_dir_state['q_estimator']['estimator_network'] = 'transformer'
dqn_transformer_dir_state['memory']['max_sequence_length'] = 128
torch.save(dqn_transformer_dir_state, dqn_transformer_dir)

In [36]:
dqn_transformer_dir_state['memory']

{'memory_size': 49099,
 'batch_size': 32,
 'memory': [Transition(state=array([1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.]), action=1, reward=0, next_state=array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0.]), done=False, legal_actions=[0, 2]),
  Transition(state=array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0.]), action=2, reward=-2.0, next_state=array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0.]), done=True, legal_actions=[2, 3]),
  Transition(state=array([1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [29]:
dqn_transformer_dir_state['memory']

{'memory_size': 49099,
 'batch_size': 32,
 'memory': [Transition(state=array([1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.]), action=1, reward=0, next_state=array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0.]), done=False, legal_actions=[0, 2]),
  Transition(state=array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0.]), action=2, reward=-2.0, next_state=array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0.]), done=True, legal_actions=[2, 3]),
  Transition(state=array([1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.