In [108]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from ZW_utils import std_classes

In [110]:
from ZW_model import GPT
model = GPT(12,32,4,2,22,0.1)
model.load_state_dict(torch.load('GPT_NA_psitest/M1_model_10.pt'))
classes = std_classes

In [119]:
class Flowsheet:
    def __init__(self) -> None:
        self.column_count = 23
        self.action_size = 12

    def __repr__(self):
        return "Flowsheet"
    
    def get_initial_state(self):
        blank_state =np.ones(self.column_count)*-1
        blank_state[0] = 0
        return blank_state
    
    def get_next_state(self, state, action):
        column = np.where(state==-1)[0][0]
        state[column] = action
        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == -1)
    
    def check_win(self,state,action):
        if action == None:
            return False
        column = np.where(state==-1)[0][0]
        state[column] = action
        if action == 11:
            return True
        return False
    def get_value_and_terminated(self,state,action):
        if self.check_win(state,action):
            #check win is basically checking if it is completed the flowsheet
            #if it is completed, then we can put in optimizer to get the real value later
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self,player):
        return -player
    
    def get_opponent_value(self,value):
        return -value
    
    def change_perspective(self,state,player):
        return state * player
    
    def get_encoded_state(self,state):
        column = np.where(state==-1)[0][0]
        encoded_state = state[:column]
        return encoded_state

In [None]:
fw = Flowsheet()
s0 = fw.get_initial_state()
print(s0)
s1 = fw.get_next_state(s0,1)
s2 = fw.get_next_state(s1,2)
s3 = fw.get_next_state(s2,3)
s4 = fw.get_next_state(s3,4)
s5 = fw.get_next_state(s4,11)
print(s5)
print(fw.get_valid_moves(s5))
print(fw.get_encoded_state(s5))

[ 0. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
 -1. -1. -1. -1. -1.]
[ 0.  1.  2.  3.  4. 11. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
 -1. -1. -1. -1. -1.]
[False False False False False False  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True]
[ 0.  1.  2.  3.  4. 11.]


(0, False)

In [124]:
class Node:
    def __init__(self,game,args,state,parent=None,action_taken=None,prior = 0,visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior

        self.children = []

        self.visit_count = visit_count
        self.value_sum = 0
    
    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
        return best_child
    
    def get_ucb(self,child):
        if child.visit_count == 0:
            q_value = 0
        else:
            # 1 - because of switching player the child position is the opponent position
            q_value = 1-((child.value_sum / child.visit_count)+1)/2
        return q_value + self.args["C"] * (np.sqrt(self.visit_count) / (child.visit_count+1))*child.prior
    def expand(self,policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state,action,1)
                child_state = self.game.change_perspective(child_state,player = -1)

                child = Node(self.game,self.args,child_state,self,action,prob)
                self.children.append(child)

    def backpropagate(self,value):
        self.value_sum += value
        self.visit_count += 1
        value = self.game.get_opponent_value(value)
        if self.parent != None:
            self.parent.backpropagate(value)

class MCTS:
    def __init__(self,game,args,model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self,state):
        root = Node(self.game,self.args,state,visit_count=1)
        #noise addition
        policy,_ = self.model(
            torch.tensor(self.game.get_encoded_state(state)).unsqueeze(0)
        )
        policy = torch.softmax(policy,axis=1).squeeze(0).numpy()
        
        policy = (1-self.args["dirichlet_epsilon"])*policy + self.args["dirichlet_epsilon"]\
            *np.random.dirichlet([self.args["dirichlet_alpha"]]*self.game.action_size)
        
        valid_moves = self.game.get_valid_moves(state)
        policy*=valid_moves
        policy /= np.sum(policy)
        root.expand(policy)
        for search in range(self.args["num_searches"]):
            #selection
            node = root

            while node.is_fully_expanded():
                node = node.select()
                #some noise to promote exploration

            value,is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)

            if not is_terminal:
                policy,value = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state)).unsqueeze(0)
                )
                policy = torch.softmax(policy,axis=1).squeeze(0).numpy()
                valid_moves = self.game.get_valid_moves(node.state)
                policy = policy * valid_moves
                policy /= np.sum(policy)

                value = value.item()

                #expansion
                node.expand(policy)
            #backpropagation
            node.backpropagate(value)
        
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs
        #return visit_counts

In [None]:
game = Flowsheet()
args = {
    "C": 1,
    "dirichlet_epsilon": 0.25,
    "dirichlet_alpha": 0.3,
    "num_searches": 100
}
model.eval()
mcts = MCTS(game,args,model)
state = game.get_initial_state()
input = torch.tensor(game.get_encoded_state(state),dtype=torch.long).unsqueeze(0)
output = model(input
               )