In [1]:
from pyfoldx.structure import Structure
import torch
from predict import *
import torch.optim as optim
import torch.nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

In [2]:
model = load_model('weights/HERN_dock.ckpt')

In [3]:
def reward(pdb, cdr3, model):
    dock(pdb, cdr3, model, relax=True)
    struct = Structure(code='', path='outputs/docked.pdb')
    interaction_energy = struct.getInterfaceEnergy()['Interaction Energy'][-1]
    
    return float(interaction_energy)

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [5]:
def generate_square_subsequent_mask(sz: int):
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [6]:
class CriticEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_model = 1
        self.embedding = nn.Embedding(20, 32)
        self.pos_encoder = PositionalEncoding(1)
        encoder_layers = TransformerEncoderLayer(32, 8, 64, 0.1)
        self.transformer_encoder = TransformerEncoder(encoder_layers, 1)
        
        self.init_weights()
        
    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, src, src_mask):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        
        return output

In [7]:
class StructureCritic(nn.Module):
    def __init__(self, seq_len):
        super().__init__()
        
        self.critic_encoder = CriticEncoder()
        
        self.dense_layers = nn.ModuleList([nn.Linear(32, 20) for _ in range(seq_len)])
        self.softmax_layers = nn.ModuleList([nn.Softmax(dim=1) for _ in range(seq_len)])
        
        self.linear_predict = nn.Linear(seq_len, 1)
    
    def forward(self, src, src_mask):
        dists = []
        sequence = []
        X = self.critic_encoder(src, src_mask)
        
        for residue, dense, softmax in zip(X, self.dense_layers, self.softmax_layers):
            out = dense(residue)
            policy_dist = softmax(out)
            dists.append(policy_dist)
        
        value = self.linear_predict(torch.tensor(src.t()).float())
        
        return dists, value

In [9]:
# Use CondRefineGNN to generate CDR3 sequences

sample_cdr3 = 'AAAAAAAAAAAAA'
sample_out = reward('1nca_imgt.pdb', sample_cdr3, model)



Computing complex energy for structure...
Energy computed.


In [10]:
X = torch.tensor([ALPHABET.index(a) for a in sample_cdr3]).unsqueeze(1).long()
Y = sample_out

In [88]:
structure_critic = StructureCritic(X.shape[0])
optimizer = optim.Adam(structure_critic.parameters())

'''
A2C https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f

advantage = Y - value
actor_loss = (-log_probs * advantage).mean()
critic_loss = 0.5 * advantage.pow(2).mean()

loss = actor_loss + critic_loss
loss.backward()

''' 

for i in range(50):
    dists, value = structure_critic(X, generate_square_subsequent_mask(X.shape[0]))
    log_probs = torch.log((torch.stack(dists))).sum(dim=0)
    advantage = Y - value

    log_loss = (-log_probs * advantage).mean()
    critic_loss = 0.5 * advantage.pow(2).mean()
    loss = log_loss + critic_loss
    print(loss)

    optimizer.zero_grad
    loss.backward()
    optimizer.step()

tensor(988.9739, grad_fn=<AddBackward0>)
tensor(991.3818, grad_fn=<AddBackward0>)
tensor(986.7454, grad_fn=<AddBackward0>)
tensor(984.9760, grad_fn=<AddBackward0>)
tensor(981.1136, grad_fn=<AddBackward0>)
tensor(981.6342, grad_fn=<AddBackward0>)
tensor(979.4203, grad_fn=<AddBackward0>)
tensor(978.9031, grad_fn=<AddBackward0>)
tensor(980.1447, grad_fn=<AddBackward0>)
tensor(973.5017, grad_fn=<AddBackward0>)
tensor(974.2394, grad_fn=<AddBackward0>)
tensor(982.7502, grad_fn=<AddBackward0>)
tensor(975.6188, grad_fn=<AddBackward0>)
tensor(974.9247, grad_fn=<AddBackward0>)
tensor(973.9822, grad_fn=<AddBackward0>)
tensor(974.0204, grad_fn=<AddBackward0>)
tensor(970.6610, grad_fn=<AddBackward0>)
tensor(976.1815, grad_fn=<AddBackward0>)
tensor(971.5253, grad_fn=<AddBackward0>)
tensor(971.1320, grad_fn=<AddBackward0>)
tensor(972.0148, grad_fn=<AddBackward0>)
tensor(970.5977, grad_fn=<AddBackward0>)
tensor(964.1851, grad_fn=<AddBackward0>)
tensor(968.2761, grad_fn=<AddBackward0>)
tensor(964.9650,

  value = self.linear_predict(torch.tensor(src.t()).float())


In [83]:
with torch.no_grad():
    X = torch.tensor([ALPHABET.index(a) for a in '#############']).unsqueeze(1).long()
    dists, value = structure_critic(X, generate_square_subsequent_mask(X.shape[0]))
    seq = [torch.argmax(dist).item() for dist in dists]
    print(''.join([ALPHABET[aa] for aa in seq]))

NPAPSRQAHSNGN


  value = self.linear_predict(torch.tensor(src.t()).float())
