In [1]:
# 8IJS, homo sapience nanobody sequence
nb_seq = "MDVQLVESGGGLVNPGGSLRLSCAASGRTFSSYSMGWFRQAPGKEREFVVAISKGGYKYDAVSLEGRFTISRDNAKNTVYLQINSLRPEDTAVYYCASSRAYGSSRLRLADTYEYWGQGTLVTVSS"

In [2]:
import torch
import numpy as np

## constants for sequence preprocessing

In [3]:
res_types = "ACDEFGHIKLMNPQRSTVWY" # residue types
res_to_n = {x: i for i, x in enumerate(res_types)} # create {res : idx} dictionary
# atom_types = ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "NZ", "OH"]
res_to_num = lambda x: res_to_n[x] if x in res_to_n else len(res_to_n) # res string to int

In [5]:
def get_one_hot(targets, nb_classes=21):
    print(targets)
    res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
    return res.reshape(list(targets.shape) + [nb_classes])

def get_encoding(sequence):
    one_hot_amino = get_one_hot(np.array([res_to_num(x) for x in sequence]))
    return one_hot_amino

# def encoding(sequence):
#     nb_classes = len(res_types)
#     res_to_num = np.array([res_to_num(x) for x in sequence])

#     res = np.eye(nb_classes)[np.array(res_to_num).reshape(-1)]
#     encoding = res.reshape(list(targets.shape) + [nb_classes])
#     return encoding

In [7]:
encoding = get_encoding(nb_seq)
encoding

[10  2 17 13  9 17  3 15  5  5  5  9 17 11 12  5  5 15  9 14  9 15  1  0
  0 15  5 14 16  4 15 15 19 15 10  5 18  4 14 13  0 12  5  8  3 14  3  4
 17 17  0  7 15  8  5  5 19  8 19  2  0 17 15  9  3  5 14  4 16  7 15 14
  2 11  0  8 11 16 17 19  9 13  7 11 15  9 14 12  3  2 16  0 17 19 19  1
  0 15 15 14  0 19  5 15 15 14  9 14  9  0  2 16 19  3 19 18  5 13  5 16
  9 17 16 17 15 15]


array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [16]:
seq_len = 126
origin = torch.zeros(seq_len, 3) # rigid origin, 3dim
rot = torch.eye(3).unsqueeze(0).expand(seq_len,-1,-1)

tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])

In [63]:
seq = "MDVQL"
dim = 64

In [79]:
def get_rel_pos(sequence, rel_pos_dim): # node의 상대 위치 반환
    # broadcasting에 의해 tensor 크기 맞춰짐 (seq_len, seq_len)
    rel_pos = torch.arange(len(sequence)).unsqueeze(-1) - torch.arange(len(sequence)).unsqueeze(0)

    # rel_pos_dim 의 두 배 만큼 clamp, 모든 값 양수 되도록 조정
    rel_pos = rel_pos.clamp(min=-rel_pos_dim, max=rel_pos_dim) # + rel_pos_dim
    return rel_pos

In [75]:
node_features = get_encoding(seq)
sequence_dict = {"H", seq}

encoding = torch.tensor(get_encoding(sequence_dict), dtype=torch.get_default_dtype())

def node_rel_pos(node_features, rel_pos_dim):
    relative_positions = (torch.arange(node_features.shape[-2])[None] - torch.arange(node_features.shape[-2])[:, None])
    relative_positions = relative_positions.clamp(min=-rel_pos_dim, max=rel_pos_dim) + rel_pos_dim
    return relative_positions

In [8]:
class InvariantPointAttention(torch.nn.Module):
    def __init__(self, node_dim, edge_dim, heads=12, head_dim=16, n_query_points=4, n_value_points=8, **kwargs):
        super().__init__()
        self.heads = heads
        self.head_dim = head_dim
        self.n_query_points = n_query_points

        node_scalar_attention_inner_dim = heads * head_dim
        node_vector_attention_inner_dim = 3 * n_query_points * heads
        node_vector_attention_value_dim = 3 * n_value_points * heads
        after_final_cat_dim = heads * edge_dim + heads * head_dim + heads * n_value_points * 4

        point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.)) - 1.)
        self.point_weight = torch.nn.Parameter(point_weight_init_value)

        self.to_scalar_qkv = torch.nn.Linear(node_dim, 3 * node_scalar_attention_inner_dim, bias=False)
        self.to_vector_qk = torch.nn.Linear(node_dim, 2 * node_vector_attention_inner_dim, bias=False)
        self.to_vector_v = torch.nn.Linear(node_dim, node_vector_attention_value_dim, bias=False)
        self.to_scalar_edge_attention_bias = torch.nn.Linear(edge_dim, heads, bias=False)
        self.final_linear = torch.nn.Linear(after_final_cat_dim, node_dim)

        with torch.no_grad():
            self.final_linear.weight.fill_(0.0)
            self.final_linear.bias.fill_(0.0)

    def forward(self, node_features, edge_features, rigid):
        # Classic attention on nodes
        scalar_qkv = self.to_scalar_qkv(node_features).chunk(3, dim=-1)
        scalar_q, scalar_k, scalar_v = map(lambda t: rearrange(t, 'n (h d) -> h n d', h=self.heads), scalar_qkv)
        node_scalar = torch.einsum('h i d, h j d -> h i j', scalar_q, scalar_k) * self.head_dim ** (-1 / 2)

        # Linear bias on edges
        edge_bias = rearrange(self.to_scalar_edge_attention_bias(edge_features), 'i j h -> h i j')

        # Reference frame attention
        wc = (2 / self.n_query_points) ** (1 / 2) / 6
        vector_qk = self.to_vector_qk(node_features).chunk(2, dim=-1)
        vector_q, vector_k = map(lambda x: vec_from_tensor(rearrange(x, 'n (h p d) -> h n p d', h=self.heads, d=3)),
                                 vector_qk)
        rigid_ = rigid.unsqueeze(0).unsqueeze(-1)  # add head and point dimension to rigids

        global_vector_k = rigid_ @ vector_k
        global_vector_q = rigid_ @ vector_q
        global_frame_distance = wc * global_vector_q.unsqueeze(-2).dist(global_vector_k.unsqueeze(-3)).sum(
            -1) * rearrange(self.point_weight, "h -> h () ()")

        # Combining attentions
        attention_matrix = (3 ** (-1 / 2) * (node_scalar + edge_bias - global_frame_distance)).softmax(-1)

        # Obtaining outputs
        edge_output = (rearrange(attention_matrix, 'h i j -> i h () j') * rearrange(edge_features,
                                                                                    'i j d -> i () d j')).sum(-1)
        scalar_node_output = torch.einsum('h i j, h j d -> i h d', attention_matrix, scalar_v)

        vector_v = vec_from_tensor(
            rearrange(self.to_vector_v(node_features), 'n (h p d) -> h n p d', h=self.heads, d=3))
        global_vector_v = rigid_ @ vector_v
        attended_global_vector_v = global_vector_v.map(
            lambda x: torch.einsum('h i j, h j p -> h i p', attention_matrix, x))
        vector_node_output = rigid_.inv() @ attended_global_vector_v
        vector_node_output = torch.stack(
            [vector_node_output.norm(), vector_node_output.x, vector_node_output.y, vector_node_output.z], dim=-1)

        # Concatenate along heads and points
        edge_output = rearrange(edge_output, 'n h d -> n (h d)')
        scalar_node_output = rearrange(scalar_node_output, 'n h d -> n (h d)')
        vector_node_output = rearrange(vector_node_output, 'h n p d -> n (h p d)')

        combined = torch.cat([edge_output, scalar_node_output, vector_node_output], dim=-1)

        return node_features + self.final_linear(combined)