In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from ZW_utils import std_classes
from ZW_model import GPT
from ZW_Opt import *
from split_functions import bound_creation, layout_to_string_single_1d,equipments_to_strings
from thermo_validity import *
from tqdm.notebook import trange

In [2]:
from ZW_model import GPT
model = GPT(12,32,4,2,22,0.1)
model.load_state_dict(torch.load('GPT_NA_psitest/M1_model_10.pt'))
classes = std_classes
layouts = np.load("M2_data_300_8_augmented_layouts.npy", allow_pickle=True)
results = np.load("M2_data_300_8_augmented_results.npy", allow_pickle=True)
print(len(layouts), len(results))
layouts = equipments_to_strings(layouts, classes)
results = 1 - (results - 125) / 175
indices = np.argsort(results)
sorted_results = np.array(results)[indices]
sorted_layouts = np.array(layouts)[indices]
unique, indices = np.unique(sorted_layouts, return_index=True)
unique_results = sorted_results[indices]
unique_layouts = sorted_layouts[indices]
print(len(unique_layouts), len(unique_results))
layouts = unique_layouts.tolist()
results = unique_results
new_layouts = []
new_results = []

68026 68026
67345 67345


In [3]:
class Flowsheet:
    def __init__(self) -> None:
        self.column_count = 23
        self.action_size = 12

    def __repr__(self):
        return "Flowsheet"
    
    def get_initial_state(self):
        blank_state =np.ones(self.column_count)*-1
        blank_state[0] = 0
        return blank_state
    
    def get_next_state(self, state, action):
        column = np.where(state == -1)[0][0]
        state[column] = action
        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == -1).astype(np.uint8)
    
    def check_win(self,state,action):
        if action == None:
            return False
        if action == 11:
            return True
        return False
    
    def get_value_and_terminated(self,state,action):
        if self.check_win(state, action):
            value = evaluation(self.get_encoded_state(state))
            return value, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return -1, True
        return -1, False
    
    def get_encoded_state(self,state):
        '''"if the design is less than 23 equipment, return the state up to the last column with a -1
        else return the state as is"'''
        try:
            column = np.where(state == -1)[0][0]
        except:
            column = self.column_count
        encoded_state = state[:column]
        return encoded_state
    
def evaluation(layout):
    # 1. One hot encoding from integer
    layout = layout.astype(int)
    stringlist = [
        layout_to_string_single_1d(layout),
    ]
    valid_string = validity(stringlist)
    if len(valid_string) == 0:
        return -1
    # MEMORY check if we have it in the layouts
    if valid_string[0] in new_layouts:
        return new_results[new_layouts.index(valid_string[0])]
    if valid_string[0] in layouts:
        return results[layouts.index(valid_string[0])]
    ohe = np.zeros((len(layout), len(classes)), dtype=object)
    for i, l in enumerate(layout):
        ohe[i, l] = 1
    equipment, bounds, x, splitter = bound_creation(ohe)
    swarmsize_factor = 7
    nv = len(bounds)
    particle_size = swarmsize_factor * nv
    if 5 in equipment:
        particle_size += -1 * swarmsize_factor
    if 9 in equipment:
        particle_size += -2 * swarmsize_factor
    iterations = 30
    try:
        a = PSO(objective_function, bounds, particle_size, iterations, nv, equipment)
        if a.result < 300:
            # standardization between 125 and 300 to 1 and 0
            value = 1 - (a.result - 125) / 175
            print(valid_string, value)
            new_layouts.append(valid_string[0])
            new_results.append(value)
        else:
            value = -0.25
    except:
        value = -0.5
    return value

In [4]:
class Node:
    def __init__(self,game,args,state,parent=None,action_taken=None,prior = 0,visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior

        self.children = []

        self.visit_count = visit_count
        self.value_sum = -0.1
    
    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
        return best_child
    
    def get_ucb(self,child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = ((child.value_sum / child.visit_count)+1)/2
        return q_value + self.args["C"] * (np.sqrt(self.visit_count) / (child.visit_count+1))*child.prior
    
    def expand(self,policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state,action)

                child = Node(self.game,self.args,child_state,self,action,prob)
                self.children.append(child)

    def backpropagate(self,value):
        self.value_sum += value
        self.visit_count += 1
        if self.parent != None:
            self.parent.backpropagate(value)

class MCTS:
    def __init__(self,game,args,model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self,state):
        root = Node(self.game,self.args,state,visit_count=1)
        #noise addition
        policy = self.model(
            torch.tensor(self.game.get_encoded_state(state),dtype=torch.long).unsqueeze(0)
        )[:,-1,:]
        policy = torch.softmax(policy,axis=-1).squeeze(0).detach().numpy()
        
        policy = (1-self.args["dirichlet_epsilon"])*policy + self.args["dirichlet_epsilon"]\
            *np.random.dirichlet([self.args["dirichlet_alpha"]]*self.game.action_size)
        valid_moves = np.ones(self.game.action_size)
        valid_moves[0], valid_moves[6], valid_moves[8], valid_moves[10] = 0, 0, 0, 0
        policy *= valid_moves
        policy /= np.sum(policy)
        root.expand(policy)
        for search in range(self.args["num_searches"]):
            node = root
            while node.is_fully_expanded():
                node = node.select()
                
            value,is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            if not is_terminal:
                policy = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state),dtype=torch.long).unsqueeze(0)
                )[:,-1,:]
                policy = torch.softmax(policy,axis=-1).squeeze(0).detach().numpy()
                valid_moves = np.ones(self.game.action_size)
                valid_moves[0], valid_moves[6], valid_moves[8], valid_moves[10] = (
                    0,
                    0,
                    0,
                    0,
                )
                policy = policy * valid_moves
                policy /= np.sum(policy)
                value = -0.1
                #expansion
                node.expand(policy)
            #backpropagation
            node.backpropagate(value)
        
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs

In [5]:
# fw = Flowsheet()
# args = {
#     "C": 1.41,
#     "num_searches": 10,
#     "dirichlet_epsilon": 0.1,
#     "dirichlet_alpha": 0.3,
# }
# mcts = MCTS(fw, args, model)
# state = fw.get_initial_state()
# # policy = mcts.model(
# #             torch.tensor(mcts.game.get_encoded_state(state),dtype=torch.long).unsqueeze(0))
# # print(policy)
# # policy = F.softmax(policy,dim=-1).squeeze(0).detach().numpy()
# # print(policy)
# # valid_moves = fw.get_valid_moves(state)
# # print(valid_moves)
# while True:
#     print(state)
#     mcts_probs = mcts.search(state)
#     action = np.argmax(mcts_probs)
#     state = fw.get_next_state(state, action)
#     value,is_terminal = fw.get_value_and_terminated(state,action)

#     if is_terminal:
#         print(value,fw.get_encoded_state(state))
#         break

In [6]:
class Alphazero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)

    def selfPlay(self):
        memory = []
        state = self.game.get_initial_state()

        while True:
            action_probs = self.mcts.search(state)
            memory.append((state, action_probs))
            # Temperature lim 0 exploiation, lim inf exploration (more randomness)
            temperature_action_probs = action_probs ** (1 / self.args["temperature"])
            temperature_action_probs /= np.sum(temperature_action_probs)

            action = np.random.choice(self.game.action_size, p=temperature_action_probs)
            state = self.game.get_next_state(state, action)
            value, is_terminal = self.game.get_value_and_terminated(state, action)
            if is_terminal:
                returnMemory = []
                for hist_neutral_state, hist_action_probs in memory:
                    hist_outcome =value
                    returnMemory.append(
                        (
                            self.game.get_encoded_state(hist_neutral_state),
                            hist_action_probs,
                            hist_outcome,
                        )
                    )
                return returnMemory

    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args["batch_size"]):
            sample = memory[
                batchIdx : min(len(memory) - 1, batchIdx + self.args["batch_size"])
            ]
            state, policy_targets, value_targets = zip(*sample)
            max_length = min(max([len(s) for s in state]), self.game.column_count)
            state = [s.tolist()+[11]*(max_length-len(s)) for s in state]
            #if last column is not 11, remove that instance from state and respective policy_targets, value_targets
            state, policy_targets, value_targets = (
                np.array(state),
                np.array(policy_targets),
                np.array(value_targets).reshape(-1, 1),
            )
            indices = np.where(state[:, -1] == 11)
            state = state[indices]
            policy_targets = policy_targets[indices]
            value_targets = value_targets[indices]
            state = torch.tensor(state).long()
            policy_targets = torch.tensor(policy_targets).long()
            value_targets = torch.tensor(value_targets).float()
            print(state)
            print(state.shape)
            out_policy = self.model(state)

            policy_loss = F.cross_entropy(out_policy, policy_targets)
            # value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def learn(self):
        for iteration in range(self.args["num_iterations"]):
            memory = []

            for selfPlay_iteration in trange(self.args["num_selfPlay_iterations"]):
                memory += self.selfPlay()
            self.model.train()
            for epoch in trange(self.args["num_epochs"]):
                self.train(memory)

            save_path = self.args["save_path"]
            torch.save(
                self.model.state_dict(), f"{save_path}/model_{iteration}_{self.game}.pt"
            )
            torch.save(
                self.optimizer.state_dict(),
                f"{save_path}/optimizer_{iteration}_{self.game}.pt",
            )

In [None]:
game = Flowsheet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
# args = {
#     "C":2,
#     "num_searches":60,
#     "num_iterations":3,
#     "num_selfPlay_iterations":500,
#     "num_epochs":4,
#     "batch_size":64,
#     "temperature":1.25,
#     "dirichlet_epsilon":0.25,
#     "dirichlet_alpha":0.3
# }
args = {
    "C": 2,
    "num_searches": 1000,
    "num_iterations": 4,
    "num_selfPlay_iterations": 50,
    "num_epochs": 5,
    "batch_size": 128,
    "temperature": 1.25,
    "dirichlet_epsilon": 0.10,
    "dirichlet_alpha": 0.3,
    "save_path": "./RL/policy_model_trials",
}
alphazero = Alphazero(model, optimizer, game, args)
alphazero.learn()

  0%|          | 0/50 [00:00<?, ?it/s]

['GCaAaTACHE'] 0.5604368088678346
['GCaAaCTACHE'] 0.52279113140133
['GCaAaHTACHE'] 0.6794793853626417
['GCaAaTACACHE'] 0.5634792805858594
['GCaATCaTACHE'] 0.5058928230311039
['GCaATaTACHE'] 0.47015235913015785
['GCaATaCTACHE'] 0.3250663040050634
['GCaATHTACaHE'] 0.285350469369848
['GCaATHTaACHE'] 0.24115967386291493
['GCaAT1HaT-1AC1HE'] 0.6437615555124303
['GCaAT1Ha-1TAC1HE'] 0.7732992717159927
['GCaAT1HaT1AC-1HE'] 0.7905289036849845
['GCaAT1Ha1TAC-1HE'] 0.7880765282791153
['GCaAT1HaT1ACH-1E'] 0.8114631933678725
['GCaAT1HaT1AC-1ACHE'] 0.5882970342130283
['GCaAT1HaT1AC-1CHE'] 0.7876724510122081
['GCaAT1HaT1AC-1AHE'] 0.7871710233199772
['GCaAT1HaT1ACA-1HE'] 0.6954849349281007
['GCaAT1HaT1ACH-1HE'] 0.8111654759343214
['GCaAT1HaT1ACH-1ACHE'] 0.7697891917278165
['GCaAT1HaT1ACAH-1E'] 0.7335905798466188
['GCaAT1HaT1ACHT-1E'] 0.8026948408095282
['GCaAT1HaT1ACHC-1E'] 0.8058351074810595
['GCaAT1HaT1AC-1CAHE'] 0.7865550023914307
['GCaAT1HaT1ACH-1AE'] 0.8054652361734704
['GCaAT1HaT1ACHA-1E'] 0.187