In [368]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from einops import rearrange, repeat

from collections import deque

import sys
from datetime import datetime as dt

# Models

## feed-forward

In [136]:
class Dice_embedder(nn.Module):
    def __init__(self, d_emb):
        super().__init__()
        self.embed = nn.Embedding(6, d_emb)
    
    def forward(self, x):
        x = self.embed(x)
        #x = rearrange(x, '... i j -> ... (i j)')
        return x

In [137]:
class Table_embedder(nn.Module):
    def __init__(self, drop):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(24, 16),
            nn.LeakyReLU(), 
            nn.LayerNorm(16), 
            #nn.Dropout(drop),
            nn.Linear(16, 32),
            nn.LeakyReLU(), 
            nn.LayerNorm(32), 
            nn.Dropout(drop),
            nn.Linear(32, 24), 
            nn.LeakyReLU(),
            nn.LayerNorm(24))
    
    def forward(self, x):
        return self.net(x)

In [143]:
class Choose(nn.Module):
    def __init__(self, d_emb, lr, drop):
        super().__init__()
        self.dice = Dice_embedder(d_emb)
        self.table = Table_embedder(drop)
        self.out = nn.Sequential(
            nn.Linear(24 + d_emb * 2, 32),
            nn.LeakyReLU(), 
            nn.LayerNorm(32), 
            nn.Dropout(drop),
            nn.Linear(32, 16),
            nn.LeakyReLU(), 
            nn.LayerNorm(16),
            nn.Dropout(drop),
            nn.Linear(16, 8),
            nn.LeakyReLU(),
            nn.LayerNorm(8),
            nn.Linear(8, 2))
    
        self.weights_init()
        
        self.optimizer = torch.optim.AdamW(self.parameters(), lr = lr)
        self.loss = nn.HuberLoss()
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
    
    def weights_init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
                m.bias.data.fill_(0.01)
    
    def forward(self, t, d):
        t = self.table(t)
        d = self.dice(d)
        td = torch.cat((t, d[0], d[1]))
        out = self.out(td)
        return out

In [144]:
class Move(nn.Module):
    def __init__(self, d_emb, lr, drop):
        super().__init__()
        self.dice = Dice_embedder(d_emb)
        self.table = Table(drop)
        self.out = nn.Sequential(
            nn.Linear(24 + d_emb, 64),
            nn.LeakyReLU(), 
            nn.LayerNorm(64), 
            nn.Dropout(drop),
            nn.Linear(64, 64),
            nn.LeakyReLU(), 
            nn.LayerNorm(64),
            nn.Dropout(drop),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.LayerNorm(32),
            nn.Linear(32, 24))
                
        self.weights_init()
        
        self.optimizer = torch.optim.AdamW(self.parameters(), lr = lr)
        self.loss = nn.HuberLoss()
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        
    def weights_init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
                m.bias.data.fill_(0.01)
    
    def forward(self, t, d):
        t = self.table(t)
        d = self.dice(d)
        td = torch.cat((t, d[0]))
        out = self.out(td)
        return out

device(type='cuda')

In [149]:
m_choose_w = Choose(4, 0.0001, 0.1)
m_move1_w = Move(8, 0.0001, 0.1)
m_move2_w = Move(8, 0.0001, 0.1)

In [150]:
models = [m_choose_w, m_move1_w, m_move2_w]
for m in models:
    print('num of parameters:', sum(p.numel() for p in m.parameters() if p.requires_grad))

num of parameters: 3754
num of parameters: 11392
num of parameters: 11392


## Transformer

In [369]:
class Multi_Head_Attention(nn.Module):
    def __init__(self, d_emb, d_hid, heads):
        super().__init__()
        self.d_hid = d_hid
        self.heads = heads
        self.dim_per_head = self.d_hid // self.heads
        
        self.qkv = nn.Linear(d_emb, self.d_hid * 3, bias = False)
        
        
        self.unifyheads = nn.Linear(self.d_hid, d_emb)
    
    def self_attention(self, q, k, v):
        scores = torch.einsum('...ij,...kj->...ik', q, k) / np.sqrt(self.dim_per_head)
        scores = F.softmax(scores, dim = -1)
        return torch.einsum('...ij,...jk->...ik', scores, v)
    
    def forward(self, x):
        qkv = self.qkv(x)
        q = qkv[..., :self.d_hid]
        k = qkv[..., self.d_hid : self.d_hid * 2]
        v = qkv[..., self.d_hid * 2 :]
            
        q = rearrange(q, '... i (h j) -> ... h i j', h = self.heads)
        k = rearrange(k, '... i (h j) -> ... h i j', h = self.heads)
        v = rearrange(v, '... i (h j) -> ... h i j', h = self.heads)
                
        scores = self.self_attention(q, k, v)
        scores = rearrange(scores, '... h i j -> ... i (h j)').contiguous()
                
        return self.unifyheads(scores)

In [370]:
class GLU(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.out_size = out_size
        self.linear = nn.Linear(in_size, out_size * 2)
    def forward(self, x):
        x = self.linear(x)
        #x = x[..., : self.out_size] * x[..., self.out_size :].sigmoid()
        x = torch.einsum('...i, ...i->...i', [x[..., : self.out_size], x[..., self.out_size :].sigmoid()])

In [371]:
class Encoder_layer(nn.Module):
    def __init__(self, d_emb, d_hid, hidden_mult, heads, enc_drop):
        super().__init__()
        self.dropout = nn.Dropout(enc_drop)
        
        self.mha = Multi_Head_Attention(d_emb, d_hid, heads)
        self.norm_1 = nn.LayerNorm(d_emb)
        self.ff = nn.Sequential(
            nn.Linear(d_emb, hidden_mult * d_emb),
            #nn.ReLU(),
            #nn.GELU(),
            nn.LeakyReLU(),
            nn.Linear(hidden_mult * d_emb, d_emb)
        )
        #self.ff = GLU(d_emb, d_emb)
        
        self.norm_2 = nn.LayerNorm(d_emb)
        
    def forward(self, x):
        attended = self.mha(x)
        x = attended + x
        x = self.dropout(x)
        x = self.norm_1(x)
        fed_for = self.ff(x)
        x = fed_for + x
        x = self.dropout(x)
        x = self.norm_2(x)
        return x

In [387]:
class Transformer(nn.Module):
    def __init__(self, model_hp):
        super().__init__()
        self.epochs = 0
        self.losses = []
        
        d_emb, seq_length, hidden_mult, order, enc_drop, d_toks, m_toks = model_hp
        
        self.sr_d_emb = np.sqrt(d_emb)
        
        self.dice_emb = nn.Embedding(6, d_emb)
        self.table_emb = nn.Embedding(31, d_emb)
        
        self.cls_toks = d_toks + m_toks * 2
        self.pe = nn.Parameter(torch.rand(seq_length + self.cls_toks, d_emb))
        self.cls_token = nn.Parameter(torch.rand(self.cls_toks, d_emb))
        
        self.encoder = nn.ModuleList()
        for d_hid, heads in order:
            self.encoder.append(Encoder_layer(d_emb, d_hid, hidden_mult, heads, enc_drop))
        
        self.out = nn.Linear(self.cls_toks * d_emb, 50)
                
        self.weights_init()
        
        
    def weights_init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight, gain = nn.init.calculate_gain('relu'))
                #m.bias.data.fill_(0.01)
         
    def forward(self, x):
        d = x[..., :2]
        t = x[..., 2:] + 15
        d = self.dice_emb(d)
        t = self.table_emb(t)
        #x = rearrange(x, '... (s e) -> ... s e', e = self.d_emb)
        cls_toks = repeat(self.cls_token, 's e -> b s e', b = d.size(0))
        x = torch.cat((cls_toks, t, d), dim = 1)
        x = x * self.sr_d_emb + self.pe #[:x.size(1)]
        for enc in self.encoder:
            x = enc(x)
        #out = self.out(rearrange(x, 'i j k -> i (j k)'))
        x = x[:, :self.cls_toks]
        x = rearrange(x, '... s e -> ... (s e)')
        x = self.out(x)
        d = x[:, :2]
        m1 = x[:, 2 : -24]
        m2 = x[:, -24:]
        #return F.softmax(d, dim = -1), F.softmax(m1, dim = -1), F.softmax(m2, dim = -1)
        return d, m1, m2

# Agents

In [399]:
class Table():
    def __init__(self):
        self.phase = 0  # 0/1/2 -- game/house/done
        self.state = np.zeros(24, dtype = int)
        self.state[0] = 15
        self.state[12] = -15
        self.mine = {0}
        self.his = {12}
        self.home_cells = {6 : 18, 5 : 19, 4 : 20, 3 : 21, 2 : 22, 1 : 23}
    
    def check_phase(self):
        if self.phase == 0:
            condition = lambda x : x > 17
            if all(condition(x) for x in self.mine):
                self.phase = 1
        elif self.phase == 1:
            if not self.mine:
                self.phase = 2
    
    def get_legal_actions(self, die, flag_head):
        die += 1
        legal_actions = np.zeros(24)
        if self.phase == 1:
            flag_out = True
            tempo_actions = []
            for p in range(6, die - 1, -1):
                if self.home_cells[p] in self.mine:
                    tempo_actions.append(p)
                    flag_out = False
            if flag_out:
                legal_actions[list(self.mine)] = 1
            else:
                for p in tempo_actions:
                    if (self.home_cells[p] + die > 23) or (self.home_cells[p - die] not in self.his):
                        legal_actions[self.home_cells[p]] = 1
        elif self.phase == 0:
            for p in self.mine:
                destination = p + die
                if (destination < 24) and (destination not in self.his):
                    if destination < 12 and destination not in self.mine:
                        new_mine = self.mine | {destination}
                        if not self.check_mars(new_mine):
                            legal_actions[p] = 1
                    else:
                        legal_actions[p] = 1
            if flag_head:
                legal_actions[0] = 0
        return legal_actions.astype(bool)
    
    def check_mars(self, new_mine):
        n = 6
        mars_end = 0
        for i in range(7):
            if all(j in new_mine for j in range(i, i + n)):
                mars_end = i + n
        if mars_end:
            for i in range(mars_end, 12):
                if i in self.his:
                    return False
            return True
        return False

    def update(self, action, die, me):
        die += 1
        if me:
            destination = action + die
            self.state[action] -= 1
            if self.state[action] < 1:
                self.mine.remove(action)
            if destination < 24:
                self.state[destination] += 1
                self.mine.add(destination)
        else:
            destination = action + die
            action = (action + 12) % 24
            self.state[action] += 1
            if self.state[action] > -1:
                self.his.remove(action)
            if destination < 24:
                destination = (destination + 12) % 24
                self.state[destination] -= 1
                self.his.add(destination)
    
    def print_table(self):
        row_format = '{:>6}' * 13
        print(row_format.format('', *self.state[11::-1]))
        print(row_format.format('', *self.state[12:]))

In [451]:
class Agent():
    def __init__(self, model, model_hp, device):
        super().__init__()
        self.q_network = model(model_hp).to(device)
        self.t_network = model(model_hp).to(device)
        self.t_network.load_state_dict(self.q_network.state_dict())
        self.table = Table()
        self.flag_first_turn = True
        self.flag_done = False
        self.flag_winner = False
        self.memory_size = 2**13
        self.buffer_size = 2**7
        self.rep_memory = deque(maxlen = self.memory_size) #[2 dice,  24 table cells], size 26
        self.act_memory = deque(maxlen = self.memory_size) #[die, move1, move1]
        self.rew_memory = deque(maxlen = self.memory_size)
        self.dice = None
        self.epsilon = 0.5
    
    def refresh(self, reset = False):
        self.table.phase = 0
        self.flag_first_turn = True
        self.flag_done = False
        self.flag_winner = False
        if reset:
            self.rep_memory = []
            self.act_memory = []
            self.rew_memory = []
    
    def roll(self):
        self.dice = list(np.random.randint(6, size = 2))
    
    def update_rep_memory(self):
        self.rep_memory.append(np.concatenate((self.dice , self.table.state)))
    def update_act_memory(self, d, m1, m2):
        self.act_memory.append([d, m1, m2])
    def update_rew_memory(self):
        if not self.flag_first_turn:
            if self.flag_done:
                if self.flag_winner:
                    r = 20
                else:
                    r = -20
            else:
                r = -1
            self.rew_memory.append(r)
    
    def memory_not_full(self):
        return len(self.rew_memory) != self.memory_size
        
    def num_parameters(self):
        n = sum(p.numel() for p in self.q_network.parameters() if p.requires_grad)
        print('number of parameters:', n)
        
    def calculate(self):
        x = torch.LongTensor(self.rep_memory[-1]).to(device)
        x = x.unsqueeze(0)
        # d = x[..., :2]
        # t = x[..., 2:] + 15
        self.q_network.eval()
        d, m1, m2 = self.q_network(x)
        return d.detach().cpu()[0], m1.detach().cpu()[0], m2.detach().cpu()[0]
    
    def print_replay_memory(self):
        row_format = '{:>6}' * 13
        for i in range(len(self.rep_memory)):
            print(row_format.format('', *self.rep_memory[i][13:1:-1]))
            print(row_format.format('', *self.rep_memory[i][14:]))
            print(self.rep_memory[i][:2] + 1)
            print(self.act_memory[i])
            print()
    def train_the_agent(self, epochs, gamma, update_interval):
        self.q_network.train()
        self.t_network.eval()
        for epoch in range(epochs):
            indices = random.sample(range(self.memory_size - 1), self.buffer_size)
            # Retrieve corresponding elements using the sampled indices
            states = [self.rep_memory[i] for i in indices]
            states = torch.tensor(np.vstack(states)).to(device)
            next_states = [self.rep_memory[i + 1] for i in indices]
            next_states = torch.tensor(np.vstack(next_states)).to(device)
            actions = [self.act_memory[i] for i in indices]
            actions = torch.tensor(actions).to(device)
            rewards = [self.rew_memory[i] for i in indices]
            rewards = torch.tensor(rewards).to(device)

            d, m1, m2 = self.q_network(states)
            q_d_values = d.gather(1, actions[:, 0].unsqueeze(1)).squeeze(1)
            q_m1_values = m1.gather(1, actions[:, 1].unsqueeze(1)).squeeze(1)
            q_m2_values = m2.gather(1, actions[:, 1].unsqueeze(1)).squeeze(1)

            d, m1, m2 = self.t_network(states)
            t_d_values = rewards + gamma * d.max(1)[0]
            t_m1_values = rewards + gamma * m1.max(1)[0]
            t_m2_values = rewards + gamma * m2.max(1)[0]

            loss = self.criterion(torch.cat((q_d_values, q_m1_values, q_m2_values)), 
                                  torch.cat((t_d_values, t_m1_values, t_m2_values)))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if epoch % update_interval == 0:
                self.t_network.load_state_dict(self.q_network.state_dict())

In [401]:
def turn(agent, ag, flag_print = False):
    if flag_print:
        agent.table.print_table()
    if agent.table.phase == 2:
        if not agent.flag_done:
            agent.flag_done = True
            if not ag.flag_done:
                agent.flag_winner = True
            agent.update_rew_memory()
    else:
        agent.update_rew_memory()
        agent.roll() # roll the dice
        if flag_print:
            print(' ' * 36, agent.dice[0] + 1, agent.dice[1] + 1)
        agent.update_rep_memory()
        d, m1, m2 = agent.calculate()
        if agent.epsilon is not None:
            if np.random.random() < agent.epsilon:
                d = np.random.choice(dice_index)
            else:
                d = F.softmax(d, dim = -1).numpy()
                d = np.random.choice(dice_index, p = d)
        else:
            d = F.softmax(d, dim = -1).numpy()
            d = np.argmax(d)
        d1 = agent.dice.pop(d)
        d2 = agent.dice[0]
        flag_head = False
        mask = agent.table.get_legal_actions(d1, flag_head)
        if mask.sum() > 0:
            if agent.epsilon is not None:
                if np.random.random() < agent.epsilon:
                    m1 = np.random.choice(actions_index[mask])
                else:
                    m1 = F.softmax(m1[mask], dim = -1).numpy()
                    m1 = np.random.choice(actions_index[mask], p = m1)
            else:
                m1 = F.softmax(m1[mask], dim = -1).numpy()
                m1 = np.argmax(m1)
            agent.table.update(m1, d1, True)
            ag.table.update(m1, d1, False)
            agent.table.check_phase()
            if flag_print:
                print('1st', m1, d1 + 1)
                agent.table.print_table()
        else:
            m1 = -1
            if flag_print:
                print('no legal moves')
        if agent.flag_first_turn:
            agent.flag_first_turn = False
        elif m1 == 0:
            flag_head = True
        mask = agent.table.get_legal_actions(d2, flag_head)
        if mask.sum() > 0:
            if agent.epsilon is not None:
                if np.random.random() < agent.epsilon:
                    m2 = np.random.choice(actions_index[mask])
                else:
                    m2 = F.softmax(m2[mask], dim = -1).numpy()
                    m2 = np.random.choice(actions_index[mask], p = m2)
            else:
                m2 = F.softmax(m2[mask], dim = -1).numpy()
                m2 = np.argmax(m2)
            agent.table.update(m2, d2, True)
            ag.table.update(m2, d2, False)
            agent.table.check_phase()
            if flag_print:
                print('2nd', m2, d2 + 1)
                agent.table.print_table()
        else:
            m2 = -1
            if flag_print:
                print('no legal moves')
        agent.update_act_memory(d, m1, m2)

In [402]:
def session(ag_1, ag_2):
    ag_1.refresh(reset = False)
    ag_2.refresh(reset = False)
    while not (ag_1.flag_done and ag_2.flag_done):
        turn(ag_1, ag_2)
        turn(ag_2, ag_1)

In [None]:
def fill_the_replay_memory(ag_1, ag_2):
    while ag_1.memory_not_full() or ag_2.memory_not_full():
        session(ag_1, ag_2)
        session(ag_2, ag_1)

In [453]:
def collect_and_train(ag_1, ag_2, num_games, train_interval, epochs, gamma, update_interval):
    for i in range(num_games):
        if i % train_interval == 0:
            ag_1.train_the_agent(epochs, gamma, update_interval)
            ag_1.train_the_agent(epochs, gamma, update_interval)
        session(ag_1, ag_2)
        session(ag_2, ag_1)

# Create agents

In [403]:
cuda_core = 0
if torch.cuda.is_available():
    torch.cuda.set_device(cuda_core)
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


In [404]:
#device = torch.device('cpu')

In [408]:
dice_index = np.arange(2)
actions_index = np.arange(24)

d_emb = 8
seq_length = 26
hidden_mult = 2
# order = [(128, 8), (128, 8), 
#          (64, 4), (64, 4), 
#          (64, 2), (64, 2), 
#          (32, 1), (32, 1)]
order = [(32, 4), (32, 4), (32, 2), (32, 2), (32, 1), (32, 1)]
#heads_order = [1] * 12
enc_drop = 0.05  # 0.005
d_toks = 1
m_toks = 2

lr = 0.0001

In [409]:
model1_hp = d_emb, seq_length, hidden_mult, order, enc_drop, d_toks, m_toks
model = Transformer
# optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
# criterion = nn.HuberLoss()
#print('num of parameters:', sum(p.numel() for p in model.parameters() if p.requires_grad))

In [445]:
ag_1 = Agent(model, model1_hp, device)
ag_1.optimizer = torch.optim.AdamW(ag_1.q_network.parameters(), lr = lr)
ag_1.criterion = nn.MSELoss()
ag_2 = Agent(model, model1_hp, device)
ag_2.optimizer = torch.optim.AdamW(ag_2.q_network.parameters(), lr = lr)
ag_2.criterion = nn.MSELoss()

In [411]:
# start
session(ag_1, ag_2)

In [None]:
num_games, train_interval, epochs, gamma, update_interval

In [412]:
# print('agent 1')
# turn(ag_1, ag_2, flag_print = True)
# print('\nagent 2')
# turn(ag_2, ag_1, flag_print = True)

In [413]:
# states = torch.tensor(np.vstack(ag_1.rep_memory)).to(device)
# next_states = torch.tensor(np.vstack(ag_1.rep_memory[1:11])).to(device)
# actions = torch.tensor(ag_1.act_memory[:10]).to(device)

In [414]:
indices = np.random.choice(range(20 - 1), 8, replace = False)
# Retrieve corresponding elements using the sampled indices
states = [ag_1.rep_memory[i] for i in indices]
states = torch.tensor(np.vstack(states)).to(device)
next_states = [ag_1.rep_memory[i + 1] for i in indices]
next_states = torch.tensor(np.vstack(next_states)).to(device)
actions = [ag_1.act_memory[i] for i in indices]
actions = torch.tensor(actions).to(device)
rewards = [ag_1.rew_memory[i] for i in indices]
rewards = torch.tensor(rewards).to(device)

In [431]:
d, m1, m2 = ag_1.q_network(states)
q_d_values = d.gather(1, actions[:, 0].unsqueeze(1)).squeeze(1)
q_m1_values = m1.gather(1, actions[:, 1].unsqueeze(1)).squeeze(1)
q_m2_values = m2.gather(1, actions[:, 1].unsqueeze(1)).squeeze(1)

In [436]:
d, m1, m2 = ag_1.t_network(next_states)
t_d_values = rewards + gamma * d.max(1)[0]
t_m1_values = rewards + gamma * m1.max(1)[0]
t_m2_values = rewards + gamma * m2.max(1)[0]

In [447]:
loss = ag_1.criterion(torch.cat((q_d_values, q_m1_values, q_m2_values)), 
                    torch.cat((t_d_values, t_m1_values, t_m2_values)))

In [448]:
loss

tensor(2.6969, device='cuda:0', grad_fn=<MseLossBackward0>)

In [441]:
t_m1_values

tensor([2.0230, 1.0306, 1.3290, 1.9921, 1.6633, 1.7826, 2.1213, 2.7031],
       device='cuda:0', grad_fn=<AddBackward0>)

In [435]:
gamma = 0.9

In [None]:
def train_the_agent(self, epochs, gamma, update_interval):
    self.q_network.train()
    self.t_network.eval()
    for epoch in range(epochs)
        indices = random.sample(range(self.memory_size - 1), self.buffer_size)
        # Retrieve corresponding elements using the sampled indices
        states = [self.rep_memory[i] for i in indices]
        states = torch.tensor(np.vstack(states)).to(device)
        next_states = [self.rep_memory[i + 1] for i in indices]
        next_states = torch.tensor(np.vstack(next_states)).to(device)
        actions = [self.act_memory[i] for i in indices]
        actions = torch.tensor(actions).to(device)
        rewards = [self.rew_memory[i] for i in indices]
        rewards = torch.tensor(rewards).to(device)

        d, m1, m2 = self.q_network(states)
        q_d_values = d.gather(1, actions[:, 0].unsqueeze(1)).squeeze(1)
        q_m1_values = m1.gather(1, actions[:, 1].unsqueeze(1)).squeeze(1)
        q_m2_values = m2.gather(1, actions[:, 1].unsqueeze(1)).squeeze(1)
        
        d, m1, m2 = self.t_network(states)
        t_d_values = rewards + gamma * d.max(1)[0]
        t_m1_values = rewards + gamma * m1.max(1)[0]
        t_m2_values = rewards + gamma * m2.max(1)[0]
        
        loss = self.criterion(torch.cat((q_d_values, q_m1_values, q_m2_values)), 
                              torch.cat((t_d_values, t_m1_values, t_m2_values)))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        if epoch % update_interval == 0:
            self.t_network.load_state_dict(self.q_network.state_dict())

In [None]:
self.rep_memory
self.act_memory
self.rew_memory

In [391]:
2**7

128

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import namedtuple

# Define the replay buffer
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state'))
replay_buffer = []

# Define your Q-network
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize the Q-network and target network
input_dim = ...  # Dimensionality of your state
output_dim = ...  # Number of possible actions
q_network = QNetwork(input_dim, output_dim)
target_network = QNetwork(input_dim, output_dim)
target_network.load_state_dict(q_network.state_dict())

# Define other hyperparameters
batch_size = 32
learning_rate = 0.001
discount_factor = 0.9
optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    # Sample a mini-batch from the replay buffer
    batch = random.sample(replay_buffer, batch_size)
    state_batch = torch.stack([transition.state for transition in batch])
    action_batch = torch.tensor([transition.action for transition in batch])
    reward_batch = torch.tensor([transition.reward for transition in batch])
    next_state_batch = torch.stack([transition.next_state for transition in batch])
    
    # Compute Q-values for the current state-action pairs using the Q-network
    q_values = q_network(state_batch)
    q_values = q_values.gather(1, action_batch.unsqueeze(1)).squeeze(1)
    
    # Compute the target Q-values for the next states using the target network
    target_q_values = target_network(next_state_batch)
    max_q_values = target_q_values.max(1)[0]
    
    # Compute the expected Q-values using the Q-learning update rule
    expected_q_values = reward_batch + discount_factor * max_q_values
    
    # Compute the loss (e.g., mean squared error) between the predicted Q-values and the expected Q-values
    loss = nn.MSELoss()(q_values, expected_q_values)
    
    # Update the Q-network parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Update the target network parameters periodically
    if epoch % target_update_interval == 0:
        target_network.load_state_dict(q_network.state_dict())


In [None]:
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

In [185]:
def get_reward(table, roll, action):
    reward = 0
    flag = False
    roll = roll + 1
    target = action + roll
    if table[action] < 1:
        reward -=2
    elif target > 23 or table[target] < 0:
        reward -= 1
    else:
        reward += table[action] / 6
        flag = True
    return reward, flag

def get_reward_end(table, roll, action):
    reward = 0
    flag = False
    roll = roll + 1
    target = action + roll
    if table[action] < 1:
        reward -= 2
    elif target > 23:
        reward += 1
        flag = True
    elif table[target] < 0:
        reward -= 1
    else:
        flag = True
    return reward, flag

In [None]:
def turn(models, table, roll):
    reward = 0
    roll = torch.randint(6, (2,))
    d = roll.to(device)
    t = table.to(device)
    chosen_actions = []
    outs = []
    out = models[0](t, d)
    outs.append(out)
    choice = out.argsort()[::-1]
    for i, ch in enumerate(choice):
        out = models[i + 1](t, d[ch : ch + 1])
        spots = (table > 0).nonzero().view(-1)
        actions = out.argsort()[::-1]
        if torch.all(spots > 17):
            for a in actions:
                r, flag = get_reward_end(table, roll[choice], a)
                reward += r
                if flag:
                    break
            a = -1
        else:
            for a in actions:
                r, flag = get_reward(table, roll[choice], a)
                reward += r
                if flag:
                    break
            a = -1
        outs.append(out)
        chosen_actions.append(a)

# WORKS

In [576]:
def check_mars(new_mine):
    flag_mars = False
    n = 6
    for i in range(7):
        if all(j in new_mine for j in range(i, i + n)):
            flag_mars = True
    return flag_mars

In [579]:
new_mine = {10, 8, 11, 6, 3, 1, 4, 5}
check_mars(new_mine)

False

In [44]:
emb = nn.Embedding(4, 2)

In [45]:
inp = torch.LongTensor([[0, 2], [1, 3]])

In [46]:
inp

tensor([[0, 2],
        [1, 3]])

In [47]:
emb(inp)

tensor([[[ 1.7558,  0.4006],
         [-1.3213, -1.0613]],

        [[ 0.2940, -1.1072],
         [-1.1620, -1.4370]]], grad_fn=<EmbeddingBackward0>)

In [161]:
#tempo = table_w > 1
#torch.nonzero(table_w > 1, as_tuple = True)
(table_w >= 0).nonzero(as_tuple = True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22]),)

In [168]:
table_w

tensor([ 15.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
        -15.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.])

In [171]:
tens = torch.randint(15, (24,))
tens

tensor([11,  1,  4,  6,  4,  0,  3, 10, 12,  9,  2,  5,  3,  7,  8,  3, 10,  5,
        12,  4, 13, 12,  8,  4])

In [176]:
inds = (tens > 11).nonzero().view(-1)
inds

tensor([ 8, 18, 20, 21])

In [179]:
ch_inds = torch.arange(16, 24)
ch_inds

tensor([16, 17, 18, 19, 20, 21, 22, 23])

In [177]:
inds[0] = 17
inds

tensor([17, 18, 20, 21])

In [182]:
torch.all(inds > 18)

tensor(False)

In [114]:
class Dice_embedder(nn.Module):
    def __init__(self, d_emb):
        super().__init__()
        self.embed = nn.Embedding(6, d_emb)
    
    def forward(self, x):
        x = self.embed(x)
        #x = rearrange(x, '... i j -> ... (i j)')
        return x

In [115]:
model = Dice_embedder(4)

In [162]:
roll = torch.randint(6, (2,))
roll

tensor([3, 4])

In [165]:
3 in roll

True

In [153]:
roll[1:2]

tensor([1])

In [132]:
result = model(roll[:1])
result

tensor([[ 1.5480, -0.9881,  0.6946, -1.4169]], grad_fn=<EmbeddingBackward0>)

In [127]:
tens = torch.zeros(5,)
tens

tensor([0., 0., 0., 0., 0.])

In [134]:
torch.cat((tens, result[0]))

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.5480, -0.9881,  0.6946,
        -1.4169], grad_fn=<CatBackward0>)

In [45]:
class Multi_Head_Attention(nn.Module):
    def __init__(self, d_emb, d_hid, heads):
        super().__init__()
        self.d_hid = d_hid
        self.heads = heads
        self.dim_per_head = self.d_hid // self.heads
        
        self.qkv = nn.Linear(d_emb, self.d_hid * 3, bias = False)
        
        
        self.unifyheads = nn.Linear(self.d_hid, d_emb)
    
    def self_attention(self, q, k, v):
        scores = torch.einsum('...ij,...kj->...ik', q, k) / np.sqrt(self.dim_per_head)
        scores = F.softmax(scores, dim = -1)
        return torch.einsum('...ij,...jk->...ik', scores, v)
    
    def forward(self, x):
        qkv = self.qkv(x)
        q = qkv[..., :self.d_hid]
        k = qkv[..., self.d_hid : self.d_hid * 2]
        v = qkv[..., self.d_hid * 2 :]
            
        q = rearrange(q, '... i (h j) -> ... h i j', h = self.heads)
        k = rearrange(k, '... i (h j) -> ... h i j', h = self.heads)
        v = rearrange(v, '... i (h j) -> ... h i j', h = self.heads)
                
        scores = self.self_attention(q, k, v)
        scores = rearrange(scores, '... h i j -> ... i (h j)').contiguous()
                
        return self.unifyheads(scores)

In [46]:
class GLU(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.out_size = out_size
        self.linear = nn.Linear(in_size, out_size * 2)
    def forward(self, x):
        x = self.linear(x)
        #x = x[..., : self.out_size] * x[..., self.out_size :].sigmoid()
        x = torch.einsum('...i, ...i->...i', [x[..., : self.out_size], x[..., self.out_size :].sigmoid()])

In [47]:
class Encoder_layer(nn.Module):
    def __init__(self, d_emb, d_hid, hidden_mult, heads, enc_drop):
        super().__init__()
        self.dropout = nn.Dropout(enc_drop)
        
        self.mha = Multi_Head_Attention(d_emb, d_hid, heads)
        self.norm_1 = nn.LayerNorm(d_emb)
        self.ff = nn.Sequential(
            nn.Linear(d_emb, hidden_mult * d_emb),
            #nn.ReLU(),
            #nn.GELU(),
            nn.LeakyReLU(),
            nn.Linear(hidden_mult * d_emb, d_emb)
        )
        #self.ff = GLU(d_emb, d_emb)
        
        self.norm_2 = nn.LayerNorm(d_emb)
        
    def forward(self, x):
        attended = self.mha(x)
        x = attended + x
        x = self.dropout(x)
        x = self.norm_1(x)
        fed_for = self.ff(x)
        x = fed_for + x
        x = self.dropout(x)
        x = self.norm_2(x)
        return x

In [48]:
class Transformer(nn.Module):
    def __init__(self, model_hp):
        super().__init__()
        self.epochs = 0
        self.losses = []
        
        d_emb, seq_length, hidden_mult, order, enc_drop, d_toks, m1_toks, m2_toks = model_hp
        self.d_toks = d_toks
        self.m1_toks = m1_toks
        self.m2_toks = m2_toks
        self.d_emb = d_emb
                
        # self.emb = nn.Sequential(
        #     nn.Linear(in_dim, in_dim * d_emb),
        #     #nn.ReLU()
        #     #nn.GELU()
        #     nn.LeakyReLU()
        # )
        self.emb = GLU(seq_length, seq_length * d_emb)
        
        class_toks = d_toks + m1_toks + m2_toks
        self.pe = nn.Parameter(torch.rand(seq_length + class_toks, d_emb))
        self.cls_token = nn.Parameter(torch.rand(class_toks, d_emb))
        
        self.encoder = nn.ModuleList()
        for d_hid, heads in order:
            self.encoder.append(Encoder_layer(d_emb, d_hid, hidden_mult, heads, enc_drop))
        
        # self.out = nn.Sequential(
        #     nn.Linear(d_emb, d_emb * 2),
        #     #nn.ReLU(),
        #     #nn.GELU(),
        #     nn.LeakyReLU(),
        #     #GLU(d_emb *, d_emb),
        #     nn.Linear(d_emb * 2, num_classes)
        # )
        
        self.out_d = GLU(d_emb * d_toks, 2)
        self.out_m1 = GLU(d_emb * m1_toks, 24)
        self.out_m2 = GLU(d_emb * m2_toks, 24)
         
    def forward(self, x):
        x = self.emb(x)
        x = rearrange(x, '... (s e) -> ... s e', e = self.d_emb)
        if self.cls_token_dim:
            cls_toks = repeat(self.cls_token, 's e -> b s e', b = x.size(0))
            x = torch.cat((cls_toks, x), dim = 1)
        x = x + self.pe #[:x.size(1)]
        for enc in self.encoder:
            x = enc(x)
        #out = self.out(rearrange(x, 'i j k -> i (j k)'))
        if self.cls_token_dim:
            m = x[:, :3]
            d = x[:, 3]
        else:
            x = x.mean(dim = 1)
        x = self.out(x)
        return x

In [14]:
class Transformer(nn.Module):
    def __init__(self, model_hp):
        super().__init__()
        self.epochs = 0
        self.losses = []
        
        d_emb, seq_length, hidden_mult, order, enc_drop, n_d_toks, n_m1_toks, n_m2_toks, lr, cuda = model_hp
        self.n_d_toks = n_d_toks
        self.n_m1_toks = n_m1_toks
        self.n_m2_toks = n_m2_toks
        self.d_emb = d_emb
                
        self.emb = GLU(seq_length, seq_length * d_emb)
        
        self.n_cls_toks = n_d_toks + n_m1_toks + n_m2_toks
        self.pe = nn.Parameter(torch.rand(seq_length + self.n_cls_toks, d_emb))
        self.cls_token = nn.Parameter(torch.rand(self.n_cls_toks, d_emb))
        
        self.encoder = nn.ModuleList()
        for d_hid, heads in order:
            self.encoder.append(Encoder_layer(d_emb, d_hid, hidden_mult, heads, enc_drop))
        
        self.n_out_d = d_emb * n_d_toks
        self.out_d = GLU(self.n_out_d, 2)
        self.n_out_m1 = d_emb * n_m1_toks
        self.out_m1 = GLU(self.n_out_m1, 24)
        self.n_out_m2 = d_emb * n_m2_toks
        self.out_m2 = GLU(self.n_out_m2, 24)
        
        self.weights_init()
        
        self.optimizer = torch.optim.AdamW(self.parameters(), lr = lr)
        self.loss = nn.HuberLoss()
        if torch.cuda.is_available():
            torch.cuda.set_device(cuda)
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.to(self.device)
        
    def weights_init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
                #m.bias.data.fill_(0.01)
         
    def forward(self, x):
        x = self.emb(x)
        x = rearrange(x, '... (s e) -> ... s e', e = self.d_emb)
        cls_toks = repeat(self.cls_token, 's e -> b s e', b = x.size(0))
        x = torch.cat((cls_toks, x), dim = 1)
        x = x + self.pe #[:x.size(1)]
        for enc in self.encoder:
            x = enc(x)
        #out = self.out(rearrange(x, 'i j k -> i (j k)'))
        x = x[:, self.n_cls_toks]
        x = rearrange(x, '... s e -> ... (s e)')
        d = x[:, :self.n_out_d]
        m1 = x[:, self.n_out_d : -self.n_out_m2]
        m2 = x[:, -self.n_out_m2:]
                
        return self.out_d(d), self.out_m1(m1), self.out_m2(m2)