In [None]:
import os
import argparse

import torch

import rlcard
from rlcard.utils import (
    get_device,
    set_seed,
    tournament,
    reorganize,
    Logger,
    plot_curve,
)
from rlcard.agents import (
    DQNAgent,
    RandomAgent,
    CDQNAgent
)

from rlcard.models.gin_rummy_rule_models import GinRummyNoviceRuleAgent


In [None]:
###Training DQN
device = get_device()
set_seed(42)
env = rlcard.make(
        'gin-rummy',
        config={
            'seed': 42,
        }
    )

agent = CDQNAgent(
                num_actions=env.num_actions,
                state_shape=env.state_shape[0],
                mlp_layers=[64,64],
                device=device,
                learning_rate = 0.00005,
                replay_memory_init_size = 100,
                number_discards = 3,
                execution_step = False,
                optimization_step = True,
            )

#path = 
#agent = CDQNAgent.from_checkpoint(checkpoint=torch.load(path))


In [None]:
###Training
from rlcard.models.gin_rummy_rule_models import GinRummyNoviceRuleAgent
agents = [agent, RandomAgent(num_actions = env.num_actions)]

env.set_agents(agents)


In [None]:
###Training
with Logger('final/cdqnv2/cdqn_3_False') as logger:
    for episode in range(20000):
        trajectories, payoffs, _, _ = env.run(is_training = True)

        trajectories = reorganize(trajectories, payoffs)

        # Feed transitions into agent memory, and train the agent
        # Here, we assume that DQN always plays the first position
        # and the other players play randomly (if any)
        for ts in trajectories[0]:
            agents[0].feed(ts)
            #agents[1].feed(ts)
            
        # Evaluate the performance. Play with random agents.
        if episode % 100 == 0:
            logger.log_performance(
                episode,
                tournament(
                    env,
                    10000,
                )
            )
            
        agent.save_checkpoint('final/cdqnv2/cdqn_3_False')

    # Get the paths
    csv_path, fig_path = logger.csv_path, logger.fig_path


In [None]:
#plot_curve(csv_path, fig_path, 'cdqn')

# Save model
save_path = os.path.join('final/cdqn/cdqn_results_reward_shaped_6500', 'model.pth')
torch.save(agent, save_path)
print('Model saved in', save_path)

In [None]:
device = get_device()

# Seed numpy, torch, random
set_seed(42)

# Make the environment with seed
env = rlcard.make('gin-rummy', config={'seed': 42})

import torch

from rlcard.agents import RandomAgent
random = RandomAgent(num_actions=env.num_actions)

from rlcard.models.gin_rummy_rule_models import GinRummyNoviceRuleAgent
rule = GinRummyNoviceRuleAgent()

dqnRs = torch.load('final/dqn/dqn_results_reward_shaped_10000/model.pth')
dqnRs.set_device(device)

dqn = torch.load('experiments/gin_rummy_dqn_result_no_constraints/model.pth')
dqn.set_device(device)

rs1 = torch.load('final/rs/1/model.pth')
rs1.set_device(device)

rs2 = torch.load('final/rs/2/model.pth')
rs2.set_device(device)

rs3 = torch.load('final/rs/3/model.pth')
rs3.set_device(device)

rs4 = torch.load('final/rs/4/model.pth')
rs4.set_device(device)

rs5 = torch.load('final/rs/5/model.pth')
rs5.set_device(device)

cdqn3 = torch.load('final/cdqnv2/3/model.pth')
cdqn3.set_device(device)

env.set_agents([cdqn3, dqn])


In [None]:
for i in range(10):
    payoffs, unsafe, wins = tournament(env, 10000)
    print(payoffs)
    print(unsafe)
    print(wins)
    print("="*30) 

In [None]:
def read_data(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
    episodes = list(range(0, 9901, 100))
    rewards = []
    for line in lines:
        if 'reward' in line:
            parts = line.split('|')
            reward_list = eval(parts[1])
            
            rewards.append(reward_list[0][0])

    return episodes, rewards


import matplotlib.pyplot as plt
file_paths = [
    'final/rs/1/log.txt',
    'final/rs/2/log.txt',
    'final/rs/3/log.txt',
    'final/rs/4/log.txt',
    'final/rs/5/log.txt',
    'final/dqn/dqn_results_reward_shaped/results_10k.txt',
]

fig, ax = plt.subplots()

for i, file_path in enumerate(file_paths, start=1):
    episodes, rewards = read_data(file_path)
    
    label = f'cdqn-{i}'
    print(i)
    ax.plot(episodes, rewards, label=label)

ax.set(xlabel='episode', ylabel='reward')
ax.legend()
ax.grid()

plt.show()

fig.savefig('final/testfig')

In [None]:
import numpy as np

from collections import OrderedDict

dqn = torch.load('experiments/gin_rummy_dqn_result_no_constraints/model.pth')
dqn.set_device(device)

dqnRs = torch.load('final/dqn/dqn_results_reward_shaped_10000/model.pth')
dqnRs.set_device(device)

cdqn3 = torch.load('final/cdqnv2/3/model.pth')
cdqn3.set_device(device)

state = {'obs': np.array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1,
        1, 1, 1, 1, 1, 1, 0, 0]]), 'legal_actions': OrderedDict([(6, None), (18, None), (20, None), (25, None), (26, None), (41, None), (44, None), (47, None), (48, None), (56, None), (57, None)]), 'raw_legal_actions': [6, 18, 20, 25, 26, 41, 44, 47, 48, 56, 57], 'raw_obs': np.array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1,
        0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1,
        1, 1, 1, 1, 1, 1, 0, 0]])}

cdqn3.eval_step(state)