In [11]:
import numpy as np
from matplotlib import pyplot as plt

In [88]:
"""Actions/State


        | (0,0) | U/1|  (2,0) |
        |L/4    | X  | R/2    |
        | (0,2) | D/3| (2,2)  |
""" 



class BPFEnv():
    def __init__(self, n_residues_max, residues):
        assert len(residues) < n_residues_max
        # self.hyperparameters = ? may need some here
        self.state = (np.zeros((n_residues_max*2+1, n_residues_max*2+1)), residues, 
                     (n_residues_max,n_residues_max)) # state first entry is the lattice state of the molecule
                                                      # second entry is the string of the molecule (i.e. {"HPPHHPHPHH..."}, written as [-1,1,1,-1,-1,1,1,-1,1,1...)
                                                      # third entry is the current position i.e. where we are placing the next
                                                      # residue from
    def increment_lattice(self, action):
        # double check the move makes sense
        assert action in [1,2,3,4], "Action " +str(action) + " is invalid"
        
        state_prime = np.copy(self.state)
        board = state_prime[0]
        residues = state_prime[1]
        position = state_prime[2]
        
        position[0] += (action % 2 ==0) * (np.sign(3-action)) # i.e. if the action is divisible by 2 (right or left) with appropriate sign
        position[1] += (not action%2 == 0) * (np.sign(action-2)) # these lines just move where our cursor is
    
        board[position] = residues[0] # change board at cursor to leading residues
        residues = [residues[i] for i in range(1, len(residues))] # remove first residue
        
        return (board, residues, new_position)
        
    def compute_energy(self, state):
        # first attempt at the energy functional
        side_of_board_reward = 0
        energy = 0
        board = state[0]
        for i in range(len(board)):
            for j in range(len(board[i])):
                ep = 0
                if(j < len(board[i])-1):
                    ep += np.abs(board[i][j+1]+board[i][j])
                else:
                    ep += side_of_board_reward
                if(j > 0):
                    ep += np.abs(board[i][j-1]+board[i][j])
                else:
                    ep += side_of_board_reward
                if(i > 0):
                    ep += np.abs(board[i-1][j] + board[i][j])
                else:
                    ep += side_of_board_reward
                if(i < len(board[j]-1)):
                    ep += np.abs(board[i+1][j] + board[i][j])
                else:
                    ep += side_of_board_reward
                # if any of the residues match, they are awarded 2 points. If they are blank, they are awarded 1pt
                # if they are different they are awarded 0 points. We then
                energy += ep
                    
    def step(self, action):
        assert action in [1,2,3,4], "Action " +str(action) + " is invalid"
        
        reward = 0
        new_state = self.increment(lattice, action)
        
        #now we need to implement the rewards
        
        #SELF - AVOIDING
        if(not np.count_nonzero(new_state[0]) > np.count_nonzero(state[0])):
            # this means the action overlapped with a previous residue
            reward += 0.01
            done = True
        
        #ENERGY 
        if(not done):
            reward += self.compute_energy(new_state)
        
        
        return new_state, reward, done, {}

    def show(self):
        # display a picture of the current state
        x = self.state[0]
        green = [[i,j] for i in range(len(x)) for j in range(len(x[i])) if x[i][j] == -1]
        blue =  [[i,j] for i in range(len(x)) for j in range(len(x[i])) if x[i][j] == 1]
        green = np.array(green)
        blue = np.array(blue)
        plt.scatter(green[:,0],green[:,1],color="green", lw=5)
        plt.scatter(blue[:, 0], blue[:, 1],color="teal", lw=5)