In [11]:
from utils.card_engine import Card_Game, Card_Env, random_agent

In [13]:
import math
import random
from collections import namedtuple, deque

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import matplotlib.pyplot as plt

from itertools import count

# Replay Memory

In [16]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        # save a transition
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

# Q network

In [19]:
class DQN(nn.Module):

    # n_input: the current state
    #  (1x52)    +  (56x52)       +       (1x52): the current state
    #    ^hand       ^who plays each card  ^cards not seen yet
    #                       + cards played
    # n_output: probability of playing each card
    #   (1x52)
    def __init__(self, n_input, n_output):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_input, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_output)

    def forward(self, x):
        if x.dtype == torch.float32:
            x = F.relu(self.layer1(x))
            x = F.relu(self.layer2(x))
        else:
            x=x.to(torch.float32)
            x = F.relu(self.layer1(x))
            x = F.relu(self.layer2(x))
        return self.layer3(x)

# Training

### The network agent
Selects a move according to epsilon-greedy policy:
sometimes uses the model to select move, sometimes just select one randomally

In [50]:
'''
A single step optimization of the model using Deep Q-Learning
1) samples a batch from memory, concatenates all the tensors into a single one
2) computes Q(s_t, a_t) and V(s_{t+1}) = max_a Q(s_{t+1}, a), where s_t --(a_t)--> s_{t+1}
3) computes the loss
4) updates the target network (which is computing V(s_{t+1})) at every step with soft update
'''
def optimize_model():
    transitions = []
    for turn, mem in memory.items():
        if len(mem) >= BATCH_SIZE:
            transitions += mem.sample(BATCH_SIZE)
    if transitions == []:
        return

    # transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    # mask the non final states and find the corresponding next states
    # We need an illegal move to be a non-final state
    # Right now, we are throwing out all the final states which include the case when
    # the agent ends the game prematurely after playing an illegal move
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    # print('non_final_mask is', non_final_mask)
    non_final_next_states = [s for s in batch.next_state if s is not None]
    if non_final_next_states == []:
        return
    non_final_next_states = torch.cat(non_final_next_states)
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    # print('reward_batch is', reward_batch)
    
    
    # compute Q(s_t, a)
    # for each state in the batch, find the value of the corresponding action
    state_action_values = policy_net(state_batch.to(torch.float)).gather(1, action_batch)
    # compute V(s_{t+1}) = max_a Q(s_{t+1}, a) for the next_states using the target_net
    next_state_values = torch.zeros(len(transitions), device=device)
    with torch.no_grad():
        # print(non_final_mask.shape, target_net(non_final_next_states).max(1).values
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values

    # R + \gamma max_a Q(s', a)
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # compute the Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # back propagate
    optimizer.zero_grad()
    loss.backward()

    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()


In [25]:


# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
if torch.cuda.is_available():
    print("CUDA is available. GPU can be used.")
    print(f"Device name: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available. Using CPU.")

env = Card_Env()

CUDA is available. GPU can be used.
Device name: NVIDIA GeForce GTX 1650


In [27]:
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000

# I am making batch_size small here so that we can test if this goes through in shorter time
BATCH_SIZE = 100
MEMORY_SIZE = 10000

# Learning rate of the optimizer
LR = 1e-4

# soft update rate
TAU = 0.005

# future discount
GAMMA = 1.0


state = env.game.get_network_input()

n_input = len(state)
n_actions = 52

policy_net = DQN(n_input, n_actions).to(device)
# use a target network to prevent oscillation or divergence
target_net = DQN(n_input, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())


optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = {turn: ReplayMemory(MEMORY_SIZE) for turn in range(13)}

steps_done = 0

'''
Given the game state, select an action by the epsilon-greedy policy
'''
def select_action(game):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1

    # epsilon-greedy choice
        
    if sample > eps_threshold:
        with torch.no_grad():
            # return the index of the card with highest probability
            # predicted from the policy net
            # print(policy_net(game.get_network_input().to(torch.float32).to(device)))
            return policy_net(game.get_network_input().to(torch.float32).to(device)).max(0).indices.view(1,1)
    else:
        # random select a legal action
        return torch.tensor([[game.sample_legal_move()]], device=device, dtype=torch.long) #changed from long



We load the newest traing Q function.

In [30]:
policy_net.load_state_dict(torch.load('latest_q_function.pth'))

<All keys matched successfully>

In [56]:

# fp =  DQN(n_input, 52).to(device)
# fp.load_state_dict(torch.load('latest_q_function.pth', map_location=torch.device('cpu')))

num_episodes = 1000

for i_episode in range(num_episodes):
    env.reset()
    state = torch.tensor(env.game.get_network_input(), dtype=torch.float32, device=device).unsqueeze(0)

    player_ind = random.randint(0, 3)
    while env.game.current_player != player_ind:
        move = env.game.sample_legal_move()
        env.game.play_card(move)
    
    for t in count():
        # Select action based on policy network
        
        with torch.no_grad():
            q_values = policy_net(state)
            action = q_values.max(1)[1].view(1, 1)

        # Perform action in the environment
        observation, reward, terminated = env.step(action.item(),fp=None)
        reward = torch.tensor([reward], device=device)
        done = terminated

        # Compute next state
        if not terminated:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        else:
            next_state = None

        # Store transition in memory
        # int(env.game.turn_counter / 4)
        memory[t].push(state, action, next_state, reward)

        # Move to next state
        state = next_state

        # Perform optimization step
        optimize_model()

        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            # print(f'Episode {i_episode} ended in {t} steps.')
            break
print('done')
# Save the Q-function model at the end of training
# torch.save(policy_net.state_dict(), 'latest_q_function.pth')

  state = torch.tensor(env.game.get_network_input(), dtype=torch.float32, device=device).unsqueeze(0)
  next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)


done


In [54]:
[len(mem) for i, mem in memory.items()]

[1458, 309, 181, 73, 27, 9, 2, 0, 0, 0, 0, 0, 0]

Look at the Q function for an initial state.


In [None]:
env.reset()
game = env.game
state = torch.tensor(game.get_network_input(), dtype=torch.long, device=device).unsqueeze(0)

torch.set_printoptions(profile="full")
policy_net.eval()  # Set the network to evaluation mode
q_values = policy_net(state)
print(state)
print("Initial Q-values:", q_values)

The maximal value corresponds to the right move.

In [None]:
torch.max(q_values)

The index of the maxima value.


In [None]:
torch.argmax(q_values)

Save the Q function for next training.

In [None]:
torch.save(policy_net.state_dict(), 'latest_q_function.pth')
policy_net.load_state_dict(torch.load('latest_q_function.pth'))