In [None]:
import pickle
import math
import scipy
import scipy.interpolate
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras import Model
class PositionEncoding1D(Layer):#not using this rn
    def __init__(self,output_dim, input_dim):
        super(PositionEncoding1D,self).__init__()
        self.encoded_pos=tf.zeros([input_dim,output_dim])
        pos=tf.cast(tf.expand_dims(tf.range(input_dim),1),tf.dtypes.float32)
        div_term = tf.exp((tf.cast(tf.range(0, output_dim, 2),tf.dtypes.float32) *
                         -(np.log(10000.0) / output_dim)))
        #interleave the arrays with deep tensorflow magick
        #https://stackoverflow.com/questions/46431983/concatenate-two-tensors-in-alternate-fashion-tensorflow
        sin_term=tf.expand_dims(tf.sin(pos*div_term),1)
        cos_term=tf.expand_dims(tf.sin(pos*div_term),1)
        both_terms=tf.concat([sin_term, cos_term],2)
        self.encoded_pos = tf.reshape(both_terms, [-1,output_dim])
    def call(self, tensor):
        return tf.math.reduce_sum((tf.expand_dims(tensor,-1)*self.encoded_pos),axis=-2)
class GameState:
    def __init__(self,m,n,initial_state=None,num_shuffles=100):
        self.m=m
        self.n=n
        
        transition_map=[[1],[n],[-n],[-1]]
        #self.transition_map=[x+y for x in transition_map for y in transition_map if x[-1]!=-y[0]]
        #self.transition_map=[x+y for x in self.transition_map for y in transition_map if x[-1]!=-y[0]]
        self.transition_map=transition_map
        self.state=None
        if initial_state is None:
            L=np.zeros(m*n)
            ix=0
            for i in range(m):
                for j in range(n):
                    if ix<m*n-1:
                        L[i*n+j]=ix+1
                        ix+=1
            self.state=L
            self.scramble_state(num_shuffles)
        else:
            self.state=initial_state
    def scramble_state(self, num_shuffles):
        for random_move in range(num_shuffles):
            self.transition(random.randrange(len(self.transition_map)))
    def slide(self,action):
        place1=np.where(self.state==0)[0][0]
        place2=place1+action
        if place2<0 or place2>self.m*self.n-1:
            place2=place1-action
        self.state[[place1,place2]]=self.state[[place2,place1]]
    def transition(self,index):
        for action in self.transition_map[index]:
            self.slide(action)
    def score(self):
        error_by_cell=self.state[:-1]-np.arange(1,self.m*self.n)
        return len(np.where(error_by_cell==0)[0])
    def is_complete(self):
        return self.score()==len(self.state)-1
    def partial_score(self):
        i15=np.where(self.state==self.m*self.n-1)[0][0]
        i1=np.where(self.state==1)[0][0]
        return (i15==len(self.state)-2)  and (i1==0) 
        
def groupoid_basis(p):
    amn=scipy.linalg.circulant(np.eye(3)[0])
    return scipy.linalg.block_diag(*tuple(amn) * p ** 2)
    
class GroupoidDecompositionLayer(Layer):
    def __init__(self,m,n,name=None,trainable=True,dtype='float32'):
        super(GroupoidDecompositionLayer,self).__init__(name=name,trainable=True)
        self.basis=groupoid_basis(2)
        self.m=m
        self.n=n
    def call(self,tensor):
        return tf.linalg.trace(tf.matmul(tensor, self.basis))
    def get_config(self):
        config = super().get_config()
        config.update({
            "m": self.m,
            "n": self.n
        })
        return config
class PermutationEquivariantLayer(Layer):
    def __init__(self, output_dim, name=None,trainable=True,dtype='float32'):
        super(PermutationEquivariantLayer,self).__init__(name=name,trainable=True)
        w_init = tf.random_uniform_initializer(-1e-7,1e-7)
        self.w = tf.Variable(
            initial_value=w_init(shape=(output_dim,2), dtype=dtype),
            trainable=True,
        )
        self.output_dim=output_dim
    def call(self,tensor):
        x1=tf.expand_dims(tensor,-1)*tf.transpose(self.w[:,0])
        x2=tf.expand_dims(tensor,-1)*self.w[:,1]
        return tf.math.reduce_sum(tf.math.reduce_sum(tf.expand_dims(x1,-2)+tf.expand_dims(x2,-1),-1),1)
    def get_config(self):
        config = super().get_config()
        config.update({
            "output_dim": self.output_dim
        })
        return config

class SlideAgent:
    def __init__(self, 
                 m, 
                 n, 
                 exploration_probability=0.9,
                 exploration_decay=0.015,
                 gamma=0.9, 
                 num_trials_per_update=5, 
                 max_turns_per_game=50, 
                 max_cycle_penalty=0.00004,
                 experience_replay_sample_size=128,
                 learning_rate=0.001,
                 num_hidden_units=4,
                print_every_step=False):
        self.num_trials_per_update=num_trials_per_update
        self.max_turns_per_game=max_turns_per_game
        self.max_cycle_penalty=max_cycle_penalty
        self.print_every_step=print_every_step
        self.experience_replay_sample_size=experience_replay_sample_size
        self.m=m
        self.n=n
        self.gamma=gamma
        self.exploration_probability=exploration_probability
        self.exploration_decay=exploration_decay
        self.num_policies=len(GameState(m,n).transition_map)
        self.input_size=self.m*self.n*2
        input_layer=Input(shape=(self.input_size,))
        reshaped=Reshape([self.input_size,1])(input_layer)
        attention=Attention(use_scale=True)([reshaped,reshaped])
        attention=Reshape([self.input_size])(attention)
        concat=Concatenate()([attention, input_layer])
        hidden_layers=[]
        for layer in range(num_hidden_units):
            hidden_layers.append(LeakyReLU()(PermutationEquivariantLayer(self.input_size)(concat)))
            concat=sum(hidden_layers)
        representation_layer=PermutationEquivariantLayer(144)(concat)
        representation_layer=Reshape([12,12])(representation_layer)
        output_layer=GroupoidDecompositionLayer(self.m,self.n)(representation_layer)
        output_layer=output_layer
        self.experience_replay_buffer=[np.zeros([0,self.input_size],dtype=np.float32),np.array([],dtype=np.float32)]
        self.dqn=Model(inputs=[input_layer], outputs=[output_layer])
        self.dqn.compile(loss='huber', optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate))
        custom_objects={'PermutationEquivariantLayer': PermutationEquivariantLayer,
                        'GroupoidDecompositionLayer': GroupoidDecompositionLayer}
        tf.keras.utils.get_custom_objects().update(custom_objects)
        self.dqn_target=tf.keras.models.clone_model(self.dqn)
        self.dqn_target.trainable=False
        self.update_target()
    def update_target(self):
        self.dqn_target.set_weights(self.dqn.get_weights())
    def choose_policy(self,state,target=False):
        q_values=[]
        state=np.repeat(np.reshape(state,[1,self.m*self.n]),self.num_policies,axis=0)
        policies=np.arange(self.num_policies)
        q_values=self.predict_q(state,policies,target).numpy()
        choice=np.argmax(q_values.reshape([-1,self.num_policies]))
        return choice, q_values[choice]
    def predict_q(self,state,policy, target=False):
        model_input_vector=tf.convert_to_tensor(self.encode_input(state,policy).reshape([-1,self.input_size]))
        if target:
            return tf.stop_gradient(self.dqn_target(model_input_vector))
        else:
            return self.dqn(model_input_vector)
    def encode_input(self,state,policy):
        S=np.zeros([state.shape[0],state.shape[1]*2])
        for i in range(state.shape[0]):
            g=GameState(self.m,self.n,initial_state=state[i])
            g.transition(policy[i])
            state2=g.state
            S[i]=np.concatenate([state[i],state2])
        model_input_vector=S
        return model_input_vector
    def reward(self,game,replay,policy):
        score_adjusted= 0
        reward_lut=[0,0,0.1,0.1,0.5,0.9,1,1,1,1,10,10,80,100,175]
        lut_indices=np.linspace(0,self.m*self.n-1,len(reward_lut))
        reward_fn=scipy.interpolate.interp1d(lut_indices,reward_lut,kind='nearest')
        if game.is_complete():
            score_adjusted=1.0/(self.m*self.n)
        elif game.partial_score():
            score_adjusted+=reward_fn(game.score())/(self.m*self.n*10000)
        if not game.is_complete():
            for i,move in enumerate(replay[:-1]):
                if np.all(game.state == move[0]):
                    score_adjusted-=self.max_cycle_penalty*(0.1+1/(len(replay)-i)) #penalize periodic behavior, especially short cycles
        return score_adjusted
    def play_game(self,shuffles):
        replay=[]
        game=GameState(self.m,self.n,None,shuffles)
        while game.is_complete():
            game=GameState(self.m,self.n,None,shuffles) #just in case
        for iteration in range(self.max_turns_per_game):
            if np.random.rand()<self.exploration_probability:
                policy=random.randrange(0,self.num_policies)
            else:
                policy,Q=self.choose_policy(game.state,target=False)
            game.transition(policy)
            Q1=self.choose_policy(game.state.reshape([1,-1]),target=True)[1]
            replay.append([game.state,policy,self.reward(game,replay,policy)])
            if self.print_every_step:
                print(f'{np.reshape(game.state,[self.m,self.n])}'.replace(' 0',' _').replace('[0','[_'))
            if game.is_complete():
                break
            if len(replay)>3 and np.all(game.state==-replay[-3][0]): #we've just got ourselves stuck in a period-2 orbit
                break
        if self.print_every_step:
            print(replay[-1][-1])
        return replay,game.is_complete(), game.score()
    
    def update_weights(self,inputs,rewards):
        Q1=np.array([])
        for state in inputs[:,self.m*self.n:]:
            Q1=np.concatenate([Q1,self.choose_policy(state,target=True)[1]*np.ones(1)])
        loss=self.dqn.train_on_batch(tf.convert_to_tensor(inputs), tf.convert_to_tensor(rewards+self.gamma*Q1))
        return loss
            
    def train_once(self,epoch):
        transition_batch=[]
        state_batch=[]
        reward_batch=[]
        scores=[]
        win_rate=0
        difficulty=min(self.max_turns_per_game,int(np.sqrt(epoch)*0.05+5))
        difficulty0=min(self.max_turns_per_game,int(np.sqrt(epoch)*0.05+5))
        if difficulty0<difficulty:
            self.exploration_probability=max(self.exploration_probability,0.125)
        for trial in range(self.num_trials_per_update):
            replay_buffer, won, score=self.play_game(difficulty)
            state_batch+=[x[0] for x in replay_buffer]
            transition_batch+=[x[1] for x in replay_buffer]
            reward_batch+=[x[-1] for x in replay_buffer]
            scores.append(score)
            if won:
                win_rate+=1/self.num_trials_per_update
        batch=self.encode_input(np.array(state_batch),np.array(transition_batch, dtype=np.int32))
        self.experience_replay_buffer[0]=np.concatenate([self.experience_replay_buffer[0], batch],axis=0)
        self.experience_replay_buffer[1]=np.concatenate([self.experience_replay_buffer[1], np.array(reward_batch,dtype=np.float32)])
        loss=None
        if self.experience_replay_buffer[0].shape[0]>=self.experience_replay_sample_size:   
            experience_sample=np.random.choice(np.arange(self.experience_replay_buffer[0].shape[0]),self.experience_replay_sample_size)
            inputs=tf.convert_to_tensor(self.experience_replay_buffer[0][experience_sample], dtype=tf.float32)
            rewards=tf.convert_to_tensor(self.experience_replay_buffer[1][experience_sample], dtype=tf.float32)
            loss=self.update_weights(inputs,rewards)
        self.exploration_probability=(1-self.exploration_decay)*self.exploration_probability
        mean_score=np.mean(scores)
        return win_rate, np.array(reward_batch), loss, mean_score, difficulty
        
agent=SlideAgent(4,4,print_every_step=False)
print(agent.dqn.summary())
iteration=0
losses=[]
while True:
    win_rate, rewards, loss, mean_score, difficulty=agent.train_once(iteration+1)
    iteration+=1
    losses.append(loss)
    if iteration%5==0:
        print(f'win rate:{win_rate} on iteration {iteration}, loss: {loss}, mean score: {mean_score}, difficulty: {difficulty}')
        agent.update_target()
    if iteration%20==0:
        plt.subplot(1,2,1)
        plt.plot(np.arange(len(rewards)), rewards)
        plt.subplot(1,2,2)
        plt.scatter(np.arange(len(losses)), losses)
        plt.show()
    if iteration%100==0:
        agent.dqn.save('dqn')
        agent.dqn_target.save('dqn_target')
        #open(f'agent_{iteration}.pkl','wb+').write(pickle.dumps(agent))