In [5]:
import sys, os
from omegaconf import OmegaConf
import torch
import numpy as np


sys.path.append("/home/jgershon/git/cleo")
from policy_utils import PolicyMPNN, alphabet


test_pdb = "/home/jgershon/projects/itopt/declan/TS1_trp_6_conformers_0008_000_10-atomized-bb-False_4_9_MPNN.pdb"


config = {
    "model_type": "protein_mpnn",
    "pdb": test_pdb,
    "batch_size": 6,
    "temperature": 1.0,
    "omit_AA": "CX",
    "lr": 1e-4,
    "fixed_residues": "A29 A30 A31 A32 A60 A61 A62 A63 A100 A101 A102 A103",
}

config = OmegaConf.create(config)

policy_mpnn = PolicyMPNN(config)


ModuleNotFoundError: No module named 'policy_utils'

In [38]:

# get features
feature_dict = policy_mpnn.featurize_pdb(config.pdb)

reward_history = [0]

# define dummy reward
aa_of_interest = "E"
aa_index_of_interest = alphabet.index(aa_of_interest)

# encode to get initial state
h_V, h_E, E_idx = policy_mpnn.encode_initial_state(feature_dict)

N_train = 100

for i in range(N_train):

    # clone initial state features
    h_V_in = h_V.clone()
    h_E_in = h_E.clone()
    E_idx_in = E_idx.clone()

    # set requires grad == True
    h_V_in.requires_grad = True
    h_E_in.requires_grad = True
 
    # run the policy
    out = policy_mpnn.rollout(feature_dict, h_V_in, h_E_in, E_idx_in)

    # mask for what was actually decoded in the sequence
    seq_mask = torch.nn.functional.one_hot(out["S"], num_classes=len(alphabet)).float()

    # apply mask and take sum over each seq in the batch
    batched_log_probs = (out["log_probs"] * seq_mask).sum(dim=(-1,-2))

    batched_reward = (out["S"] == aa_index_of_interest).sum(dim=-1).float()
    
    # get baseline first
    baseline = np.mean(reward_history).item()

    # store reward history
    reward_history.append(batched_reward.mean().item())

    # baseline subtracted reward
    baseline_subtracted_reward = batched_reward - baseline
    # baseline_subtracted_reward = torch.clamp(batched_reward-baseline, min=0)
    
    # compute loss
    loss = -(batched_log_probs * baseline_subtracted_reward).mean()

    # optimizer update
    policy_mpnn.optimizer.zero_grad()
    loss.backward()
    policy_mpnn.optimizer.step()

    if i%10 == 0:
        print(f"Step {i}: Reward = {batched_reward.mean().item():.2f}")

Step 0: Reward = 6.33
Step 10: Reward = 10.00
Step 20: Reward = 10.33
Step 30: Reward = 13.17
Step 40: Reward = 17.17
Step 50: Reward = 19.33
Step 60: Reward = 28.67
Step 70: Reward = 28.50
Step 80: Reward = 30.17
Step 90: Reward = 32.17


In [39]:
B = out["S"].shape[0]

for i in range(B):
    seq = out["S"][i]
    seq_str = "".join([alphabet[int(s)] for s in seq])
    print(f"Decoded sequence {i}: {seq_str}")

Decoded sequence 0: QTREEEIEVEIENTDWRAELYEEATEGETIATLKDSEHPEIGQAVEKFMQELLKLVREEVPALREEVEKLEEATRINLKEAEVELEPNEEKGKTRIRIEQGDKEIQGEMFAEEALAGLEPGETVRISMELLEPEG
Decoded sequence 1: FRESEEVRSEVEVTEIKAGLFKEITEKETIATLREEEHKEIAKAYKKFLEGYIELIREKVPALKEEIEKLEEAYGIDLAEAEIELEENEEEGETEVTIKQGDKELKGKIEAEEALAFLEEGKEVTLEWELWTELE
Decoded sequence 2: QKESEKVKVEIPVTQEIAELYRENTEEKTIATLEDPEHPEIAEAVKEMLEELMKAIREKVPALEPEIKEIEEATGINLRTAEVELEADPESGKTRVTIRQGDKELVGEITAEEALALLKPGENVTLERELWTEEE
Decoded sequence 3: VEESEEVVVEQEVSEEKAFLYREATPPETIATLEDEEHPNIAEAVEEMLKVIDELVRAKVPALEEEVTELKEATGIDLTEAEIKLEPEPEEGVTKVMIEQGDKELAGEVTAEEALSLLEPGETVTLEGEGKTPLE
Decoded sequence 4: EISEEEIEVEEKVTEERAELYRESIDEETIATHRDSKHEEIAKAVVEAEKEIEKLIRENVPALEEIVEKRIEASGIDMRTARVELKENPEKGKYEILYEQGDKELLGEKTAEEAEKMLEPGSLIEQRAELKEREE
Decoded sequence 5: MESEEEVEVLIEIDEERAEYYKSATEEETIATLENPEDKEIAEAVVKFLAPWRELVEEEVPALQEIVEEIEGATGIDLREAEVELKAEPEKGKAEEEVTQGDKRLEREISAEEALALLHPGEELRLRGVLKIEEE


In [24]:
loss

tensor(-0., grad_fn=<NegBackward0>)

In [14]:
for b in batched_reward:
    print(b)

tensor(124.)
tensor(124.)
