In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import gym
from gym import spaces
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torch.nn.functional import one_hot
# from rdkit import Chem
# from rdkit.Chem import Crippen
import random

In [2]:
state_size = 4
standard_amino_acids = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
action_space = np.arange(0, 20)

In [3]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_size, action_space):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc2 = nn.Linear(128, len(action_space))

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        action_probs = torch.softmax(self.fc2(x), dim=0)
        return action_probs

In [4]:
policy_net = PolicyNetwork(state_size, action_space)
policy_net #gives 20 probabilities for the 20 standard amino acids.

PolicyNetwork(
  (fc1): Linear(in_features=4, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=20, bias=True)
)

In [5]:
# s = [-1,-1,-1,0]
# se = []
# for i in range(3):
#     se.append(s)
#     for j in range(len(s)):
#         if s[j] == -1:
#             s[j] = 9
#             break
#     s[3] = s[3] + 1
    
# print(se)

In [6]:
learning_rate = 0.001
num_episodes = 500
len_peptide = 3
max_steps = len_peptide
gamma = 0.99

In [7]:
# Define your reward function
def compute_discounted_rewards(rewards, gamma):
    discounted_rewards = [rewards[-1]]
    for i in range(len(rewards) - 2, -1, -1):
        discounted_rewards.insert(0, rewards[i] + gamma * discounted_rewards[0])
    return discounted_rewards

In [8]:
# Define your training loop
def train():
    # Set up your optimizer
    optimizer = optim.Adam(policy_net.parameters(), lr= learning_rate) 

    # Training loop
    for episode in range(num_episodes):
        # Initialize the state
        state = [-1,-1,-1,0] 

        # Lists to store the trajectory
        states = []
        actions = []
        rewards = []

        # Collect trajectory by interacting with the environment
        for step in range(max_steps):
        # Convert the state to a PyTorch tensor
            
            state_tensor = torch.tensor(state, dtype=torch.float32)
            # Forward pass to get action probabilities
            
            action_probs = policy_net(state_tensor)

            # Sample an action from the action probabilities
            action = torch.multinomial(action_probs, 1).item()

            # Execute the action and observe the next state and reward 
            next_state = action
            reward = action_probs[action].item()
                    
            # Store the trajectory
            states.append(state)
            actions.append(action)
            rewards.append(reward)

            # Update the state
            for i in range(len(state)):
                if state[i] == -1:
                    state[i] = next_state
                    break
            state[3] = state[3] + 1

        # Compute discounted rewards
        discounted_rewards = compute_discounted_rewards(rewards, gamma)

        # Convert trajectory to tensors
        state_tensor = torch.tensor(states, dtype=torch.float32)
        action_tensor = torch.tensor(actions, dtype=torch.int64)
        reward_tensor = torch.tensor(discounted_rewards, dtype=torch.float32)

        # Compute the loss
        action_probs = policy_net(state_tensor)
        selected_action_probs = action_probs.gather(1, action_tensor.unsqueeze(1)).squeeze()
        loss = -torch.mean(torch.log(selected_action_probs) * reward_tensor)

        # Update the policy network
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print training progress
        print(f"Episode: {episode}, Loss: {loss}")

In [9]:
train()

Episode: 0, Loss: 0.2527870833873749
Episode: 1, Loss: 0.12812061607837677
Episode: 2, Loss: 0.4495612382888794
Episode: 3, Loss: 0.3946632444858551
Episode: 4, Loss: 0.31685808300971985
Episode: 5, Loss: 0.23426277935504913
Episode: 6, Loss: 0.2937659025192261
Episode: 7, Loss: 0.7505591511726379
Episode: 8, Loss: 0.10025864839553833
Episode: 9, Loss: 0.7924635410308838
Episode: 10, Loss: 0.47618845105171204
Episode: 11, Loss: 0.06063328683376312
Episode: 12, Loss: 0.8527063727378845
Episode: 13, Loss: 0.17512117326259613
Episode: 14, Loss: 0.08644694834947586
Episode: 15, Loss: 0.11897142976522446
Episode: 16, Loss: 0.9271313548088074
Episode: 17, Loss: 0.22364778816699982
Episode: 18, Loss: 0.24418038129806519
Episode: 19, Loss: 0.4341578781604767
Episode: 20, Loss: 0.1770927459001541
Episode: 21, Loss: 0.9450812935829163
Episode: 22, Loss: 0.11655717343091965
Episode: 23, Loss: 0.18812452256679535
Episode: 24, Loss: 0.1277303397655487
Episode: 25, Loss: 0.2960887849330902
Episode: 