In [2]:
import numpy as np
from copy import copy
from collections import deque
import torch
import torch.nn as nn
from torch.distributions import Categorical
from random import randint
import matplotlib.pyplot as plt

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1c8b2484d90>

In [3]:
# Environment (with also a random version)
SIZE = 3


class GOPS:
    def __init__(self):
        self.size = SIZE
        self.state, self.score = np.zeros((1, self.size * 3 + 1)), 0.0  # Placeholders
        self.reset()

    def reset(self):
        self.state = np.ones((1, self.size * 3 + 1))
        self.score = 0.0
        self.draw()
        return

    def draw(self):
        remaining = np.nonzero(self.state[0, 1:self.size + 1])[0]
        idx = np.random.choice(remaining, 1) + 1

        # Update the state
        self.state[0, 0] = idx - 1
        self.state[0, idx] = 0  # Note conventions with the index
        return

    def get_illegal_actions(self):
        illegal_actions = np.where(self.state[0, self.size + 1:self.size * 2 + 1] == 0)[0]
        return illegal_actions

    def step(self, action, action_opp):
        self.state[0, action + self.size + 1] = 0
        self.state[0, action_opp + self.size * 2 + 1] = 0

        # Update score
        if action > action_opp:
            self.score += self.state[0, 0]
        elif action < action_opp:
            self.score -= self.state[0, 0]

        # Game end conditions
        done = False
        if np.sum(self.state[0, 1:self.size + 1]) == 0:
            if self.score > 0:
                self.score += SIZE**2  # Best score for one player is nChoose2 so use n^2 to denote winning
            elif self.score < 0:
                self.score -= SIZE**2
            done = True
            self.state[0, 0] = -1  # This signifies no current value card. Probably unnecessary
        else:
            self.draw()
        return self.state.copy(), copy(self.score), done  # These copies are necessary I think



class RandomGOPS:
    def __init__(self):
        self.size = SIZE
        self.state, self.score = np.zeros((1, self.size * 3 + 1)), 0.0  # Placeholders
        self.reset()

    def reset(self):
        self.state = np.ones((1, self.size * 3 + 1))
        self.score = 0.0
        self.draw()
        return

    def draw(self):
        remaining = np.nonzero(self.state[0, 1:self.size + 1])[0]
        idx = np.random.choice(remaining, 1) + 1

        # Update the state
        self.state[0, 0] = idx - 1
        self.state[0, idx] = 0  # Note conventions with the index
        return

    def get_illegal_actions(self):
        illegal_actions = np.where(self.state[0, self.size + 1:self.size * 2 + 1] == 0)[0]
        #         legal_actions_mask = np.zeros((SIZE))
        #         legal_actions_mask[legal_actions] = 1
        return illegal_actions

    def step(self, action):
        # Update game with random opponent move
        action_opp = np.random.choice(np.nonzero(self.state[0, self.size * 2 + 1:])[0])
        self.state[0, action + self.size + 1] = 0
        self.state[0, action_opp + self.size * 2 + 1] = 0

        # Update score
        if action > action_opp:
            self.score += self.state[0, 0]
        elif action < action_opp:
            self.score -= self.state[0, 0]

        # Game end conditions
        done = False
        if np.sum(self.state[0, 1:self.size + 1]) == 0:
            if self.score > 0:
                self.score += SIZE**2  # Best score for one player is nChoose2 so use n^2 to denote winning
            elif self.score < 0:
                self.score -= SIZE**2
            done = True
            self.state[0, 0] = -1  # This signifies no current value card. Probably unnecessary
        else:
            self.draw()
        return self.state.copy(), copy(self.score), done  # These copies are necessary I think

In [4]:
# NN representing s->p(a) map
class PolicyNet(nn.Module):
    def __init__(self, widths=[8, 8], path=None):
        super().__init__()

        self.num_features = SIZE * 3 + 1  # Player hands and value cards, and the current card and score
        self.num_actions = SIZE

        self.layers = nn.Sequential(
            nn.Linear(self.num_features, widths[0]),
            nn.ReLU()
        )
        for i in range(len(widths) - 1):
            self.layers.append(nn.Linear(widths[i], widths[i + 1]))
            self.layers.append(nn.ReLU())
        self.layers.append(nn.Linear(widths[-1], self.num_actions))
        self.layers.append(nn.Softmax(dim=1))

        if path is not None:
            self.load_state_dict(torch.load(path))
        return

    def forward(self, state):
        state = torch.tensor(state, dtype=torch.float32)  # Environment is numpy-based; convert
        action_probs = self.layers(state)
        return action_probs

    def get_action(self, state, illegal_actions):
        # Use legal actions to mask
        action_probs = self.forward(state)
        action_probs[0, illegal_actions] = 0

        cat = Categorical(probs=action_probs)  # Constructs multinomial from the probs
        action = cat.sample()
        return action.item()
    
class QNet(nn.Module):
    def __init__(self):
        

In [5]:
# Collecting batches
# The reason we are temporarily storing in a deque is for future compatibility reasons
def get_batch(env, memory, num_batches, policy):
    for batch in range(num_batches):
        done = False
        state = env.state.copy()
        while not done:
            action = policy.get_action(state, env.get_illegal_actions())
            state_new, reward, done = env.step(action)

            memory.append([state, action, reward, done])

            state = state_new
        env.reset()
    return

# Processing batches into stacks
def stack_batch(memory):
    states = np.concatenate([s for (s, a, lp, r, done) in memory])
    actions = np.stack([a for (s, a, lp, r, done) in memory])
    rewards = np.stack([r for (s, a, lp, r, done) in memory])
    return states, actions, rewards

In [6]:
def test(env, model, count):
    rewards = []
    for t in range(count):
        state = env.state.copy()
        done = False
        while not done:
            action = model.get_action(state,env.get_illegal_actions())
            state_new, reward, done = env.step(action)
            state = state_new
        env.reset()
        rewards.append(reward)
    return sum(rewards)/(100*count)

In [8]:
E = RandomGOPS()
M = PolicyNet()
Mem = deque(maxlen=10000)
optim = torch.optim.SGD(M.parameters(), lr=0.01)
loss_fn = nn.KLDivLoss(log_target=True, reduction='sum')
normaliser = nn.LogSoftmax(dim=1)

get_batch(E, Mem, 1000, M)

train_losses = []
test_losses = []
for i in range(100):
    s, a, r, _ = Mem[i]
    logprobs = torch.log(M(s))    
    tweaked_logprobs = torch.clone(logprobs).detach()
    tweaked_logprobs[0, a] -= 0.001*r # Tweak policy
    norm_tweaked_logprobs = normaliser(tweaked_logprobs)
        
    loss = loss_fn(logprobs, norm_tweaked_logprobs)
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    print(i, loss.item())
    train_losses.append(loss.item())
    test_losses.append(test(E, M, 100))

0 0.0
1 -6.556675913316212e-08
2 1.1195545084774494e-05
3 0.0
4 4.803005140274763e-07
5 1.2064585462212563e-05
6 1.8015271052718163e-07
7 2.0631705410778522e-07
8 9.19075682759285e-06
9 1.267326297238469e-07
10 2.0397419575601816e-07
11 1.2220581993460655e-05
12 -2.519422359625878e-08
13 9.219365892931819e-08
14 1.1462834663689137e-05
15 5.812034942209721e-07
16 3.1324452720582485e-07
17 1.2984848581254482e-05
18 0.0
19 1.0387157090008259e-07
20 1.2148404493927956e-05
21 -6.212965075746979e-08
22 5.920883268117905e-07
23 9.164214134216309e-06
24 3.9217411540448666e-07
25 4.2532337829470634e-07
26 9.158742614090443e-06
27 -9.264986999824032e-08
28 1.5895057003945112e-07
29 1.3704528100788593e-05
30 3.8222060538828373e-07
31 4.4517219066619873e-07
32 1.3679498806595802e-05
33 1.1920928955078125e-07
34 5.094625521451235e-07
35 1.115433406084776e-05
36 -5.62691084837752e-08
37 0.0
38 -1.1920928955078125e-07
39 0.0
40 1.5451223589479923e-07
41 1.1290307156741619e-05
42 1.2782402336597443e-0

In [7]:
plt.scatter(range(len(train_losses)), train_losses)
plt.show()

NameError: name 'train_losses' is not defined

In [None]:
plt.scatter(range(len(test_losses)), test_losses)
plt.show()