<h4>Invarian Point Attention Pytorch Implementation (IPA)</h4>

<p>This tutorial aims at clarifying the architecture of IPA piece by piece. The IPA module first used by Deepmind team in alphafold2 paper. It was used for coordinate refinement. But in this note book we only implement the attention part of the IPA module. We will see the rest of the module by noteebooks.</p>

<h5>Supplementary Figure</h5>
<img src='asset/IPA_Supplementary_Figure.PNG' />

<h5>Algorithm</h5>
<img src='asset/IPA_algorithm.png' />

In [1]:
import torch
import torch.nn as nn

<p>The definition below is all about the dimension of the attention blocks</p>
<ol>
<li>nh = number of heads, c = channel dimension, nqp = number of query points, npv = number of point values</li>
<li>embd = embedding dimension</li>
<li>nqs, nvs = number of query and value dimension for scalar q,k,v definition</li>
<li>bias in the linear projection of the q, k, v</li>
<li>rpr = require pair representation</li>
</ol>

In [5]:
emdb, nh, c, nqp, npv, nqs, nvs, bias, rpr = 768, 12, 16, 4, 4, 16, 16, False, True

<p>scalar qkv definition for algorithm line 1.</p>
<p>As I highlighted with the arrow in the supplementary figure. We need to define number of contribution of attention weights. Three (with pair representation) or 2 without pair representation.</p>
<p>nal - number of attention logits</p>

In [6]:
nal = 3 if rpr else 2

qs = nn.Linear(emdb, nqs*nh, bias=bias)
ks = nn.Linear(emdb, nqs*nh, bias=bias)
vs = nn.Linear(emdb, nvs*nh, bias=bias)

<p>qkv projection for point attention (coordinate and orientation aware) for algorithm line 2</p>

In [4]:
qp = nn.Linear(emdb, nqp*nh, bias=bias)
kp = nn.Linear(emdb, nqp*nh, bias=bias)
vp = nn.Linear(emdb, npv*nh, bias=bias)

<p>IPA docs from alphafold paper about weight initialization</p>
<img src='asset/IPA_docs.png' width=75%/>

<p>defining gama (softplus of learnable scalar) -> point weight initial value(pwiv) to define point weights (pw) </p>
<img src='asset/softplus_function.png' />

In [10]:
pwiv = torch.log(torch.exp(torch.full((12,), 1.))  + 1.)
pw = nn.Parameter(pwiv)

In [None]:
#scalar_attn_logits_scale
sals = (nal * nqs) ** -0.5
#point_attn_logits_scale
pals = ((nal * nqp) * (9 / 2)) ** -0.5

In [None]:
#implementing line 4
bij = nn.Linear()