In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from ZW_utils import std_classes
from ZW_model import GPT
from ZW_Opt import *
from split_functions import bound_creation, layout_to_string_single_1d,equipments_to_strings
from thermo_validity import *
from tqdm.notebook import trange

classes = std_classes
layouts = np.load("M2_data_300_8_augmented_layouts.npy", allow_pickle=True)
results = np.load("M2_data_300_8_augmented_results.npy", allow_pickle=True)
print(len(layouts), len(results))
layouts = equipments_to_strings(layouts, classes)
results = 1 - (results - 125) / 175
indices = np.argsort(results)
sorted_results = np.array(results)[indices]
sorted_layouts = np.array(layouts)[indices]
unique, indices = np.unique(sorted_layouts, return_index=True)
unique_results = sorted_results[indices]
unique_layouts = sorted_layouts[indices]
print(len(unique_layouts), len(unique_results))
layouts = unique_layouts.tolist()
results = unique_results
new_layouts = []
new_results = []

68026 68026
67345 67345


In [3]:
def evaluation(layout):
    # 1. One hot encoding from integer
    layout = layout.astype(int)
    stringlist = [
        layout_to_string_single_1d(layout),
    ]
    valid_string = validity(stringlist)
    if len(valid_string) == 0:
        return -1
    if valid_string[0] in new_layouts:
        return new_results[new_layouts.index(valid_string[0])]
    if valid_string[0] in layouts:
        return results[layouts.index(valid_string[0])]
    ohe = np.zeros((len(layout), len(classes)))
    for i, l in enumerate(layout):
        ohe[i, l] = 1
    equipment, bounds, x, splitter = bound_creation(ohe)
    swarmsize_factor = 7
    nv = len(bounds)
    particle_size = swarmsize_factor * nv
    if 5 in equipment:
        particle_size += -1 * swarmsize_factor
    if 9 in equipment:
        particle_size += -2 * swarmsize_factor
    iterations = 30
    try:
        a = PSO(objective_function, bounds, particle_size, iterations, nv, equipment)
        if a.result < 300:
            # standardization between 125 and 300 to 1 and 0
            value = 1 - (a.result - 125) / 175
            new_layouts.append(valid_string[0])
            new_results.append(value)
            print(value, valid_string[0])
        else:
            value = -0.25
    except:
        value = -0.5
    return value


class Flowsheet:
    def __init__(self) -> None:
        self.column_count = 23
        self.action_size = 12

    def __repr__(self):
        return "Flowsheet"

    def get_initial_state(self):
        blank_state = np.ones(self.column_count) * -1
        blank_state[0] = 0
        return blank_state

    def get_next_state(self, state, action):
        column = np.where(state == -1)[0][0]
        state[column] = action
        return state

    def get_valid_moves(self, state):
        return state.reshape(-1) == -1

    def check_win(self, state, action):
        if action == 11:
            return True
        return False

    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            value = evaluation(self.get_encoded_state(state))
            return value, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return -1, True
        return -1, False

    def get_encoded_state(self, state):
        if len(state.shape) == 2:
            encoded_state = []
            for i in range(state.shape[0]):
                try:
                    column = np.where(state[i] == -1)[0][0]
                except:
                    column = self.column_count
                encoded_state.append(state[i][:column])
            return encoded_state
        try:
            column = np.where(state == -1)[0][0]
        except:
            column = self.column_count
        encoded_state = state[:column]
        return encoded_state

In [4]:
class SPG:
    def __init__(self,game):
        self.state =game.get_initial_state()
        self.memory = []
        self.root = None
        self.node = None
class Node:
    def __init__(self,game,args,state,parent=None,action_taken=None,prior = 0,visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior

        self.children = []

        self.visit_count = visit_count
        self.value_sum = 0
    
    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
        return best_child
    
    def get_ucb(self,child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = ((child.value_sum / child.visit_count)+1)/2
        return q_value + self.args["C"] * (np.sqrt(self.visit_count) / (child.visit_count+1))*child.prior
    def expand(self,policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state,action)
                child = Node(self.game,self.args,child_state,self,action,prob)
                self.children.append(child)

    def backpropagate(self,value):
        self.value_sum += value
        self.visit_count += 1
        if self.parent != None:
            self.parent.backpropagate(value)

class MCTS_Parallel:
    def __init__(self,game,args,model):
        self.game = game
        self.args = args
        self.model = model
        self.valid_moves = np.ones(self.game.action_size)
        self.valid_moves[0],self.valid_moves[6],self.valid_moves[8],self.valid_moves[10] = 0,0,0,0

    @torch.no_grad()
    def search(self,states,spGames):
        input = torch.tensor(np.array(self.game.get_encoded_state(states)),dtype=torch.long)
        lengths = torch.tensor([x for x in map(len, input)])
        policy,_ = self.model(
            input,lengths
        )
        policy = torch.softmax(policy,axis=-1).squeeze(0).detach().numpy()
        
        policy = (1-self.args["dirichlet_epsilon"])*policy + self.args["dirichlet_epsilon"]\
            *np.random.dirichlet([self.args["dirichlet_alpha"]]*self.game.action_size)
        
        for i,spg in enumerate(spGames):
            spg_policy = policy[i]
            spg_policy*=self.valid_moves
            spg_policy /= np.sum(spg_policy)
            spg.root = Node(self.game,self.args,states[i],visit_count=1)
            spg.root.expand(spg_policy)

        for search in range(self.args["num_searches"]):
            for spg in spGames:
                spg.node = None
                node = spg.root
                while node.is_fully_expanded():
                    node = node.select()
                value,is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
                if is_terminal:
                    node.backpropagate(value)
                else:
                    spg.node = node
                
                expandable_spGames = [mappingIdx for mappingIdx in range(len(spGames)) if spGames[mappingIdx].node != None]
            
                if len(expandable_spGames) > 0:
                  states = np.stack([spGames[mappingIdx].node.state for mappingIdx in expandable_spGames])
                  encoded_states = self.game.get_encoded_state(states)
                  lengths = torch.tensor([x for x in map(len, encoded_states)])
                  max_length = max(lengths)
                  encoded_states = [x.tolist() + [12]*(max_length-len(x)) for x in encoded_states]
                  input = torch.tensor(encoded_states,dtype=torch.long)
                  policy,value = self.model(input,lengths)
                  policy = torch.softmax(policy,axis=-1).squeeze(0).detach().numpy()

                for i,mappingIdx in enumerate(expandable_spGames):
                    node = spGames[mappingIdx].node
                    spg_policy,spg_value = policy[i],value[i]
                    spg_policy*= self.valid_moves
                    spg_policy /= np.sum(spg_policy)
                    node.expand(spg_policy)
                    node.backpropagate(spg_value)

In [5]:
class Alphazero_Parallel:
    def __init__(self,model,optimizer,game,args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS_Parallel(game,args,model)

    def selfPlay(self):
        retun_memory = []
        sgGames = [SPG(self.game) for spg in range(self.args["num_parallel_games"])]

        while len(sgGames) > 0:
            states = np.stack([spg.state for spg in sgGames])
            self.mcts.search(states,sgGames)

            for i in range(len(sgGames))[::-1]:
                spg = sgGames[i]
                action_probs = np.zeros(self.game.action_size)
                for child in spg.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs /= np.sum(action_probs)
                spg.memory.append((spg.root.state,action_probs))
                #Temperature lim 0 exploiation, lim inf exploration (more randomness)
                temperature_action_probs = action_probs**(1/self.args["temperature"])
                temperature_action_probs /= np.sum(temperature_action_probs)

                action = np.random.choice(self.game.action_size,p=temperature_action_probs)
                spg.state = self.game.get_next_state(spg.state,action)
                value,is_terminal = self.game.get_value_and_terminated(spg.state,action)
                if is_terminal:
                    for hist_state,hist_action_probs in spg.memory:
                        hist_outcome = value 
                        retun_memory.append((
                            self.game.get_encoded_state(hist_state),
                            hist_action_probs,
                            hist_outcome
                        ))
                    del sgGames[i]
        return retun_memory
    
    def train(self,memory):
        random.shuffle(memory)
        for batchIdx in range(0,len(memory),self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory)-1,batchIdx+self.args['batch_size'])]
            states,policy_targets,value_targets = zip(*sample)
            #padding necessicity
            lengths = torch.tensor([len(x) for x in states])
            max_length = max(lengths)
            #padding to the maximum length
            states = [x.tolist()+[12]*(max_length-len(x)) for x in states]
            states,policy_targets,value_targets = np.array(states),np.array(policy_targets),np.array(value_targets).reshape(-1,1)            
            states = torch.tensor(states).long()
            policy_targets = torch.tensor(policy_targets).float()
            value_targets = torch.tensor(value_targets).float()
            
            out_policy, out_value = self.model(states,lengths)

            policy_loss = F.cross_entropy(out_policy,policy_targets)
            value_loss = F.mse_loss(out_value,value_targets)
            loss = policy_loss + value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            

    def learn(self):
        for iteration in range(self.args["num_iterations"]):
            memory = []
            for selfPlay_iteration in trange(self.args["num_selfPlay_iterations"]//self.args["num_parallel_games"]):
                memory += self.selfPlay()
            self.model.train()
            for epoch in trange(self.args["num_epochs"]):
                self.train(memory)
            save_path = self.args["save_path"]
            torch.save(self.model.state_dict(),f"{save_path}/model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(),f"{save_path}/optimizer_{iteration}_{self.game}.pt")

In [6]:
class LSTM_packed(nn.Module):
    def __init__(self, embd_size,hidden_size):
        super(LSTM_packed, self).__init__()
        self.embedding = nn.Embedding(13, embd_size, padding_idx=12)
        self.lstm = nn.LSTM(embd_size, hidden_size, num_layers=2, batch_first=True, dropout=0.1)
        self.valuehead = nn.Linear(hidden_size, 1)
        self.policyhead = nn.Linear(hidden_size, 12)
    def forward(self, x, lengths):
        
        x = self.embedding(x.long())
        x = nn.utils.rnn.pack_padded_sequence(
            x, lengths, batch_first=True, enforce_sorted=False
        )
        output, (hidden, _) = self.lstm(x)
        value = self.valuehead(hidden[-1])
        policy = self.policyhead(hidden[-1])
        return policy, value

In [7]:
game = Flowsheet()
model = LSTM_packed(64,256)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=1e-4)
args = {
    "C":2,
    "num_searches":100,
    "num_iterations":4,
    "num_selfPlay_iterations":1000,
    "num_parallel_games":20,
    "num_epochs":5,
    "batch_size":100,
    "temperature":1,
    "dirichlet_epsilon":0.1,
    "dirichlet_alpha":0.3,
    "save_path":"./RL/policy_value_model_parallel"
}
alphazero = Alphazero_Parallel(model,optimizer,game,args)
alphazero.learn()

  0%|          | 0/50 [00:00<?, ?it/s]

0.4839339949786815 GCACATHTE
0.5798630443947534 GACAHCTE
0.33712997869804795 GC1CA1C-1HCTAE
0.5627772137172904 GHCaCHAaHCTE
0.7897215147759601 GCHCTAE


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

0.8408012079708432 GaHTCaCAE
0.8064632418801045 GTATCHE
0.7379157128928046 GaCTHa1T1TAC-1E
0.885214687026118 GACTAaTHaE
0.7522312423620507 GA1A1CHC-1TE
0.15698779417476116 G-1HC1TAC1E
0.7295440252050646 GCTA1AT1AC-1HE


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

0.8062616709485732 GTHTACE
0.8065767520431921 GHTACE
0.6945383000402641 GAHTAT1C1AC-1E
0.6011197260283179 GAHTAT1C1AC-1CE
0.8065768927948451 GCHTAE
0.588086497738305 GCAHTAE


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

0.4708603842202572 GC-1AHTCA1H1E
0.7296436674537348 GAHTACE
0.6772171210430198 GHTHA-1CHTaA1CA1CaE
0.2651330597435957 GHATHTHACE


  0%|          | 0/5 [00:00<?, ?it/s]

In [8]:
# game = Flowsheet()
# args = {
#     "C":2,
#     "num_searches":100,
#     "num_iterations":6,
#     "num_selfPlay_iterations":500,
#     "num_epochs":30,
#     "batch_size":100,
    # "temperature":1,
#     "dirichlet_epsilon":0.1,
#     "dirichlet_alpha":0.3,
#     "save_path":"./RL/policy_value_model_parallel"
# }
# model = LSTM_packed(64,256)
# model.state_dict(torch.load('RL/policy_value_model_parallel/model_2_Flowsheet.pt'))
# model.eval()
# mcts = MCTS_Parallel(game, args,model)
# state = game.get_initial_state()
# while True:
#     mcts_probs = mcts.search(state)
#     action = np.argmax(mcts_probs)
        
#     state = game.get_next_state(state, action)
#     value,is_terminal = game.get_value_and_terminated(state,action)

#     if is_terminal:
#         print(value, state)
#         break