In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from ZW_utils import std_classes
from ZW_model import GPT
from ZW_Opt import *
from split_functions import bound_creation, layout_to_string_single_1d
from thermo_validity import *
from tqdm.notebook import trange

classes = std_classes

In [None]:
def evaluation(layout):
    # 1. One hot encoding from integer
    layout = layout.astype(int)
    stringlist = [
        layout_to_string_single_1d(layout),
    ]
    valid_string = validity(stringlist)
    if len(valid_string) == 0:
        return -100
    ohe = np.zeros((len(layout), len(classes)))
    for i, l in enumerate(layout):
        ohe[i, l] = 1
    
    equipment, bounds, x, splitter = bound_creation(ohe)
    swarmsize_factor = 7
    nv = len(bounds)
    particle_size = swarmsize_factor * nv
    if 5 in equipment:
        particle_size += -1 * swarmsize_factor
    if 9 in equipment:
        particle_size += -2 * swarmsize_factor
    iterations = 30
    try:
        a = PSO(objective_function, bounds, particle_size, iterations, nv, equipment)
        if a.result<1e6:
            value = a.result
            print(valid_string, value)
        else:
            value = -5
    except:
        value = -10
    return value


class Flowsheet:
    def __init__(self) -> None:
        self.column_count = 23
        self.action_size = 12

    def __repr__(self):
        return "Flowsheet"

    def get_initial_state(self):
        blank_state = np.ones(self.column_count) * -1
        blank_state[0] = 0
        return blank_state

    def get_next_state(self, state, action, player=1):
        try:
            column = np.where(state == -1)[0][0]
        except:
            column = self.column_count
        state[column] = action
        return state

    def get_valid_moves(self, state):
        return state.reshape(-1) == -1

    def check_win(self, state, action):
        if action == None:
            return False
        if action == 11:
            return True
        return False

    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            value = evaluation(self.get_encoded_state(state))
            return value, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False

    def get_opponent(self, player):
        return player

    def get_opponent_value(self, value):
        return value

    def change_perspective(self, state, player):
        return state

    def get_encoded_state(self, state):
        try:
            column = np.where(state == -1)[0][0]
        except:
            column = self.column_count
        encoded_state = state[:column]
        return encoded_state

In [None]:
class Node:
    def __init__(self,game,args,state,parent=None,action_taken=None,prior = 0,visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior

        self.children = []

        self.visit_count = visit_count
        self.value_sum = 0
    
    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
        return best_child
    
    def get_ucb(self,child):
        if child.visit_count == 0:
            q_value = 0
        else:
            # 1 - because of switching player the child position is the opponent position
            q_value = ((child.value_sum / child.visit_count)+1)/2
        return q_value + self.args["C"] * (np.sqrt(self.visit_count) / (child.visit_count+1))*child.prior
    def expand(self,policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state,action,1)

                child = Node(self.game,self.args,child_state,self,action,prob)
                self.children.append(child)

    def backpropagate(self,value):
        self.value_sum += value
        self.visit_count += 1
        if self.parent != None:
            self.parent.backpropagate(value)

class MCTS:
    def __init__(self,game,args,model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self,state):
        root = Node(self.game,self.args,state,visit_count=1)
        #noise addition
        input = torch.tensor(self.game.get_encoded_state(state),dtype=torch.long).unsqueeze(0)
        lengths = torch.tensor([x for x in map(len, input)])
        policy,_ = self.model(
            input,lengths
        )
        policy = torch.softmax(policy,axis=-1).squeeze(0).detach().numpy()
        
        policy = (1-self.args["dirichlet_epsilon"])*policy + self.args["dirichlet_epsilon"]\
            *np.random.dirichlet([self.args["dirichlet_alpha"]]*self.game.action_size)
        
        # all moves are valid if we are not masking valid_moves = self.game.get_valid_moves(state)
        valid_moves = np.ones(self.game.action_size)
        valid_moves[0],valid_moves[6],valid_moves[8],valid_moves[10] = 0,0,0,0
        policy*=valid_moves
        policy /= np.sum(policy)
        root.expand(policy)
        for search in range(self.args["num_searches"]):
            #selection
            node = root

            while node.is_fully_expanded():
                node = node.select()
                #some noise to promote exploration

            value,is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)

            if not is_terminal:
                input = torch.tensor(self.game.get_encoded_state(node.state),dtype=torch.long).unsqueeze(0)
                lengths = torch.tensor([x for x in map(len, input)])
                policy,value = self.model(
                    input,lengths
                )
                policy = torch.softmax(policy,axis=-1).squeeze(0).detach().numpy()
                valid_moves = np.ones(self.game.action_size)
                valid_moves[0],valid_moves[6],valid_moves[8],valid_moves[10] = 0,0,0,0
                policy*=valid_moves
                policy /= np.sum(policy)

                value = value.item()
                #expansion
                node.expand(policy)
            #backpropagation
            node.backpropagate(value)
        
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs
        #return visit_counts

In [None]:
class Alphazero:
    def __init__(self,model,optimizer,game,args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game,args,model)

    def selfPlay(self):
        memory = []
        player = 1 
        state = self.game.get_initial_state()

        while True:
            neutral_state = self.game.change_perspective(state,player)
            action_probs = self.mcts.search(neutral_state)
            memory.append((neutral_state,action_probs,player))
            #Temperature lim 0 exploiation, lim inf exploration (more randomness)
            temperature_action_probs = action_probs**(1/self.args["temperature"])
            temperature_action_probs /= np.sum(temperature_action_probs)

            action = np.random.choice(self.game.action_size,p=temperature_action_probs)
            state = self.game.get_next_state(state,action,player)
            value,is_terminal = self.game.get_value_and_terminated(state,action)
            if is_terminal:
                returnMemory = []
                for hist_neutral_state,hist_action_probs,hist_player in memory:
                    hist_outcome = value 
                    returnMemory.append((
                        self.game.get_encoded_state(hist_neutral_state),
                        hist_action_probs,
                        hist_outcome
                    ))
                return returnMemory
            player = self.game.get_opponent(player)


    def train(self,memory):
        random.shuffle(memory)
        for batchIdx in range(0,len(memory),self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory)-1,batchIdx+self.args['batch_size'])]
            states,policy_targets,value_targets = zip(*sample)
            #padding necessicity
            lengths = torch.tensor([len(x) for x in states])
            max_length = max(lengths)
            #padding to the maximum length
            states = [x.tolist()+[12]*(max_length-len(x)) for x in states]
            states,policy_targets,value_targets = np.array(states),np.array(policy_targets),np.array(value_targets).reshape(-1,1)            
            states = torch.tensor(states).long()
            policy_targets = torch.tensor(policy_targets).float()
            value_targets = torch.tensor(value_targets).float()
            
            out_policy, out_value = self.model(states,lengths)

            policy_loss = F.cross_entropy(out_policy,policy_targets)
            value_loss = F.mse_loss(out_value,value_targets)
            loss = policy_loss + value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            

    def learn(self):
        for iteration in range(self.args["num_iterations"]):
            memory = []

            for selfPlay_iteration in trange(self.args["num_selfPlay_iterations"]):
                memory += self.selfPlay()
            self.model.train()
            for epoch in trange(self.args["num_epochs"]):
                self.train(memory)
            save_path = self.args["save_path"]
            torch.save(self.model.state_dict(),f"{save_path}/model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(),f"{save_path}/optimizer_{iteration}_{self.game}.pt")

In [None]:
class LSTM_packed(nn.Module):
    def __init__(self, embd_size,hidden_size):
        super(LSTM_packed, self).__init__()
        self.embedding = nn.Embedding(13, embd_size)
        self.lstm = nn.LSTM(embd_size, hidden_size, num_layers=2, batch_first=True, dropout=0.1)
        self.valuehead = nn.Linear(hidden_size, 1)
        self.policyhead = nn.Linear(hidden_size, 12)
    def forward(self, x, lengths):
        x = self.embedding(x.long())
        x = nn.utils.rnn.pack_padded_sequence(
            x, lengths, batch_first=True, enforce_sorted=False
        )
        output, (hidden, _) = self.lstm(x)
        value = self.valuehead(hidden[-1])
        policy = self.policyhead(hidden[-1])
        return policy, value

In [None]:
import matplotlib.pyplot as plt
ttt = Flowsheet()
state = ttt.get_initial_state()
state = ttt.get_next_state(state,1)
state = ttt.get_next_state(state,2)
state = ttt.get_next_state(state,3)
state = ttt.get_next_state(state,4)
state = ttt.get_next_state(state,11)
print(state)

encoded_state = ttt.get_encoded_state(state)
print(encoded_state)
tensor_state = torch.tensor(encoded_state).unsqueeze(0)
lengths = torch.tensor([x for x in map(len, tensor_state)])
#untrained model
model = LSTM_packed(64,256)
# #trained model
# model.load_state_dict(torch.load('model_2.pt'))
model.eval()
policy, value = model(tensor_state,lengths)
value = value.item()
policy = torch.softmax(policy,axis=1).squeeze(0).detach().numpy()
print(value)
plt.bar(range(ttt.action_size),policy)
plt.show()

In [None]:
game = Flowsheet()
model = LSTM_packed(64,256)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=1e-4)
args = {
    "C":2,
    "num_searches":500,
    "num_iterations":3,
    "num_selfPlay_iterations":500,
    "num_epochs":8,
    "batch_size":64,
    "temperature":1.25,
    "dirichlet_epsilon":0.25,
    "dirichlet_alpha":0.3,
    "save_path":"./RL/policy_value_model_trials"
}
alphazero = Alphazero(model,optimizer,game,args)
alphazero.learn()

In [None]:
# game = Flowsheet()
# player = 1
# args = {
#     "C": 2,
#     "num_searches": 100,
#     "dirichlet_epsilon":0.0,
#     "dirichlet_alpha":0.3
# }
# model = LSTM_packed(64,256)
# model.state_dict(torch.load('RL/policy_value_model_trials/model_1_Flowsheet.pt'))
# model.eval()
# mcts = MCTS(game, args,model)
# state = game.get_initial_state()
# while True:
#     neutral_state = game.change_perspective(state,player)
#     mcts_probs = mcts.search(neutral_state)
#     action = np.argmax(mcts_probs)
        
#     state = game.get_next_state(state, action, player)
#     value,is_terminal = game.get_value_and_terminated(state,action)

#     if is_terminal:
#         print(state)