In [1]:
# 8IJS, homo sapience nanobody sequence
nb_seq = "MDVQLVESGGGLVNPGGSLRLSCAASGRTFSSYSMGWFRQAPGKEREFVVAISKGGYKYDAVSLEGRFTISRDNAKNTVYLQINSLRPEDTAVYYCASSRAYGSSRLRLADTYEYWGQGTLVTVSS"

In [8]:
import torch
import numpy as np

## constants for sequence preprocessing

In [3]:
res_types = "ACDEFGHIKLMNPQRSTVWY" # residue types
res_to_n = {x: i for i, x in enumerate(res_types)} # create {res : idx} dictionary
# atom_types = ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "NZ", "OH"]
res_to_num = lambda x: res_to_n[x] if x in res_to_n else len(res_to_n) # res string to int

In [4]:
def get_one_hot(targets, nb_classes=21):
    res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
    return res.reshape(list(targets.shape) + [nb_classes])

def get_encoding(sequence):
    one_hot_amino = get_one_hot(np.array([res_to_num(x) for x in sequence]))
    return one_hot_amino

# def encoding(sequence):
#     nb_classes = len(res_types)
#     res_to_num = np.array([res_to_num(x) for x in sequence])

#     res = np.eye(nb_classes)[np.array(res_to_num).reshape(-1)]
#     encoding = res.reshape(list(targets.shape) + [nb_classes])
#     return encoding

In [11]:
encoding = get_encoding(nb_seq)

torch.Size([126, 21])

In [16]:
seq_len = 126
origin = torch.zeros(seq_len, 3) # rigid origin, 3dim
rot = torch.eye(3).unsqueeze(0).expand(seq_len,-1,-1)

tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])

In [63]:
seq = "MDVQL"
dim = 64

In [79]:
def get_rel_pos(sequence, rel_pos_dim): # node의 상대 위치 반환
    # broadcasting에 의해 tensor 크기 맞춰짐 (seq_len, seq_len)
    rel_pos = torch.arange(len(sequence)).unsqueeze(-1) - torch.arange(len(sequence)).unsqueeze(0)

    # rel_pos_dim 의 두 배 만큼 clamp, 모든 값 양수 되도록 조정
    rel_pos = rel_pos.clamp(min=-rel_pos_dim, max=rel_pos_dim) # + rel_pos_dim
    return rel_pos

In [75]:
node_features = get_encoding(seq)
sequence_dict = {"H", seq}

encoding = torch.tensor(get_encoding(sequence_dict), dtype=torch.get_default_dtype())

def node_rel_pos(node_features, rel_pos_dim):
    relative_positions = (torch.arange(node_features.shape[-2])[None] - torch.arange(node_features.shape[-2])[:, None])
    relative_positions = relative_positions.clamp(min=-rel_pos_dim, max=rel_pos_dim) + rel_pos_dim
    return relative_positions

In [80]:
rp1 = get_rel_pos(seq, dim)
rp1

tensor([[ 0, -1, -2, -3, -4],
        [ 1,  0, -1, -2, -3],
        [ 2,  1,  0, -1, -2],
        [ 3,  2,  1,  0, -1],
        [ 4,  3,  2,  1,  0]])

In [74]:
rp2 = node_rel_pos(node_features, dim)
rp2

tensor([[64, 65, 66, 67, 68],
        [63, 64, 65, 66, 67],
        [62, 63, 64, 65, 66],
        [61, 62, 63, 64, 65],
        [60, 61, 62, 63, 64]])