In [1]:
from geometric_graph_transformer import GraphTransformer
import torch
from torch import nn

# Network Structure
## Encoder
- standard node attention with pair bias
- pair updates via node-wise outer product and triangle multiplication

## Decoder
- Node attention via IPA with distance-based attention for points
- IPA layers do not use shared weights
- No pair updates

In [15]:
node_dim = 32 # hidden dimension of node features
pair_dim = 32 # hidden dimension for pair features
encoder_depth = 2
decoder_depth = 1

"""Node Feature Updates for encoder

IPA not used in encoder for this example
"""
# defaults
encoder_node_kwargs = dict(
            dim_head= 32,
            heads = 8,
            bias= False,
)

"""Pair Feature Updates for Encoder

node-outer product + triangle multiplication
no triangle attention used (per empericism)

"""

encoder_pair_kwargs = dict(
    heads = 4,
    dim_head = 24,
    dropout = 0,
    tri_mul_dim = pair_dim,
    do_checkpoint = True,
    ff_mult = 2,
    do_tri_mul = True,
    do_tri_attn = False,
)



"""Node Feature Updates for decoder

IPA is used in decoder for this example
"""
# defaults
decoder_node_kwargs = dict(
    heads = 12,
    dim_query_scalar = 16,
    dim_query_point = 4,
    dim_value_scalar = 16,
    dim_value_point = 8,
    pre_norm = True,
    use_dist_attn = True,
)

""" No pair updates
"""
decoder_pair_kwargs = None


In [16]:
# encoder
encoder = GraphTransformer(
    node_dim = node_dim,
    pair_dim = pair_dim,
    depth = encoder_depth,
    use_ipa = False,
    node_update_kwargs = encoder_node_kwargs,
    pair_update_kwargs = encoder_pair_kwargs,
    share_weights = False,
)

# projections from encoder output to decoder input
decoder_node_proj = nn.Sequential(
    nn.LayerNorm(node_dim),
    nn.Linear(node_dim,node_dim)
)

decoder_pair_proj = nn.Sequential(
    nn.LayerNorm(pair_dim),
    nn.Linear(pair_dim,pair_dim)
)

# decoder
decoder = GraphTransformer(
    node_dim = node_dim,
    pair_dim = pair_dim,
    depth = decoder_depth,
    use_ipa = True, # use IPA in decoder
    node_update_kwargs = decoder_node_kwargs,
    pair_update_kwargs = decoder_pair_kwargs,
    share_weights = False,
)


In [17]:
# Example input
b,n = 1,30
node_feats = torch.randn(b,n,node_dim)
pair_feats = torch.randn(b,n,n,pair_dim)

# example forward pass
node_feats, pair_feats, *_ = encoder(
    node_feats = node_feats, 
    pair_feats = pair_feats
)

print(f"Encoder out shapes:\n    Node : {node_feats.shape}\n    Pair: {pair_feats.shape}\n")

node_feats, pair_feats, rigids, _ = decoder(
    node_feats = decoder_node_proj(node_feats),
    pair_feats = decoder_pair_proj(pair_feats)
)
print(f"Decoder out shapes:\n    Node : {node_feats.shape}\n    Pair: {pair_feats.shape}\n")

Encoder out shapes:
    Node : torch.Size([1, 30, 32])
    Pair: torch.Size([1, 30, 30, 32])

Decoder out shapes:
    Node : torch.Size([1, 30, 32])
    Pair: torch.Size([1, 30, 30, 32])



# Predict BB coordinates from decoder output
- Linearly project the node_feats and 
- apply rigids to place in local frame

In [5]:
from einops.layers.torch import Rearrange
point_dim = 4 # e.g. one dimension for each bb coord
CA_posn = 1 # position of CA atom in predicted coords

#predict points in local frame of each residue
to_points = nn.Sequential(
        nn.LayerNorm(node_dim),
        nn.Linear(node_dim, 3 * point_dim, bias=False),
        Rearrange("b n (d c) -> b n d c", c=3),
    )

# predict from node features
local_points = to_points(node_feats)

# replace predicted CA with rigid translation (helps empirically)
local_points[:, :, 1] = torch.zeros_like(local_points[:, :, 1])

# place points in global frame by applying rigids
global_points = rigids.apply(local_points)

print(f"point shape : {global_points.shape}")


point shape : torch.Size([1, 30, 4, 3])


# Use Decoder with pair features and shared IPA weights

In [6]:
decoder_pair_kwargs = encoder_pair_kwargs
decoder = GraphTransformer(
    node_dim = node_dim,
    pair_dim = pair_dim,
    depth = 2,
    use_ipa = True, # use IPA in decoder
    node_update_kwargs = decoder_node_kwargs,
    pair_update_kwargs = encoder_pair_kwargs, # add pair updates
    share_weights = True, # Share Weights
)

In [10]:
from rigids import Rigids

#Need to provide native rigids for weight sharing
native_bb = torch.randn(b,n,4,3)

# Example input
b,n = 1,30
node_feats = torch.randn(b,n,node_dim)
pair_feats = torch.randn(b,n,n,pair_dim)
true_rigids = Rigids.RigidFromBackbone(native_bb)

# example forward pass
node_feats, pair_feats, *_ = encoder(
    node_feats = node_feats, 
    pair_feats = pair_feats
)

print(f"Encoder out shapes:\n    Node : {node_feats.shape}\n    Pair: {pair_feats.shape}\n")

node_feats, pair_feats, rigids, fape_aux = decoder(
    node_feats = decoder_node_proj(node_feats),
    pair_feats = decoder_pair_proj(pair_feats),
    true_rigids = true_rigids
)
print(f"Decoder out shapes:\n    Node : {node_feats.shape}\n    Pair: {pair_feats.shape}\n")

print(f"Aux. FAPE Loss: {fape_aux}")

Encoder out shapes:
    Node : torch.Size([1, 30, 32])
    Pair: torch.Size([1, 30, 30, 32])

Decoder out shapes:
    Node : torch.Size([1, 30, 32])
    Pair: torch.Size([1, 30, 30, 32])

Aux. FAPE Loss: 0.9309613704681396


In [13]:
# initial rigids for decoy can also be passed to decoder 
decoy_coords = native_bb.clone()
node_feats, pair_feats, rigids, fape_aux = decoder(
    node_feats = decoder_node_proj(node_feats),
    pair_feats = decoder_pair_proj(pair_feats),
    true_rigids = true_rigids,
    rigids = Rigids.RigidFromBackbone(decoy_coords)
)
print(f"Decoder out shapes:\n    Node : {node_feats.shape}\n    Pair: {pair_feats.shape}\n")

print(f"Aux. FAPE Loss: {fape_aux}")

Decoder out shapes:
    Node : torch.Size([1, 30, 32])
    Pair: torch.Size([1, 30, 30, 32])

Aux. FAPE Loss: 0.93241286277771
