Some comments from [@sokrypton](https://github.com/sokrypton/)

*   The 3rd track in Rosettafold is essentially doing what alphafold triangle
attention is doing. Adding a 3D bias to the attention. In Rosettafold this is done explicitly via evolving 3D coordinates.

*   In alphafold this is done implicitly by essentially learning triangle inequality. Often if your distance matrix satisfies the condition that all three distances form a triangle, it is possible to embed structure in 3D space.

*   The potential benefit of 3D track is that you remove the expensive L^3 operation and replace this with L^2.

*   The potential negative part is that you can get stuck in local minimum as explicit 3D structure maybe harder to refine iteratively compared to a 2D distance matrix, which is not explicitly forced to be a valid distance matrix at each step of the evoformer.

*   So it possible alphafold learned to not enforce triangle inequality at early layers, but then enforces them in latter layers.

See also [these AI conversation logs](https://poe.com/s/03iamy0dmmJF3023ccNp)

In [None]:
import torch
import torch.nn as nn

from typing import Dict, Optional, Callable , List, Any
from functools import partial
from collections.abc import Iterable

In [None]:
# Reference : https://github.com/dptech-corp/Uni-Core/blob/854b8890daa5722ba4e30eed1973564671611f6f/unicore/utils.py

def dict_map(fn, dic, leaf_type):
    new_dict = {}
    for k, v in dic.items():
        if type(v) is dict:
            new_dict[k] = dict_map(fn, v, leaf_type)
        else:
            new_dict[k] = tree_map(fn, v, leaf_type)
    return new_dict

def tree_map(fn, tree, leaf_type):
    if isinstance(tree, dict):
        return dict_map(fn, tree, leaf_type)
    elif isinstance(tree, list):
        return [tree_map(fn, x, leaf_type) for x in tree]
    elif isinstance(tree, tuple):
        return tuple([tree_map(fn, x, leaf_type) for x in tree])
    elif isinstance(tree, leaf_type):
        try:
            return fn(tree)
        except:
            raise ValueError(f"cannot apply {fn} on {tree}.")
    else:
        raise ValueError(f"{type(tree)} not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)

def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
    zero_index = -1 * len(inds)
    first_inds = list(range(len(tensor.shape[:zero_index])))
    return tensor.permute(first_inds + [zero_index + i for i in inds])

In [None]:
# Reference : https://github.com/dptech-corp/Uni-Fold/blob/1a301710392ecf97991aebc3276aad9d0f77178f/unifold/modules/common.py#L299

def chunk_layer(
    layer: Callable,
    inputs: Dict[str, Any],
    chunk_size: int,
    num_batch_dims: int,
) -> Any:
    # TODO: support inplace add to output
    if not (len(inputs) > 0):
        raise ValueError("Must provide at least one input")
    def _dict_get_shapes(input):
        shapes = []
        if type(input) is torch.Tensor:
            shapes.append(input.shape)
        elif type(input) is dict:
            for v in input.values():
                shapes.extend(_dict_get_shapes(v))
        elif isinstance(input, Iterable):
            for v in input:
                shapes.extend(_dict_get_shapes(v))
        else:
            raise ValueError("Not supported")
        return shapes
    inputs = {k: v for k, v in inputs.items() if v is not None}
    initial_dims = [shape[:num_batch_dims] for shape in _dict_get_shapes(inputs)]
    orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
    flat_batch_dim = 1
    for d in orig_batch_dims:
        flat_batch_dim *= d
    num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size
    def _flat_inputs(t):
        t = t.view(-1, *t.shape[num_batch_dims:])
        assert (
            t.shape[0] == flat_batch_dim or t.shape[0] == 1
        ), "batch dimension must be 1 or equal to the flat batch dimension"
        return t
    flat_inputs = tensor_tree_map(_flat_inputs, inputs)
    out = None
    for i in range(num_chunks):
        chunk_start = i * chunk_size
        chunk_end = min((i + 1) * chunk_size, flat_batch_dim)
        def select_chunk(t):
            if t.shape[0] == 1:
                return t[0:1]
            else:
                return t[chunk_start:chunk_end]
        chunkes = tensor_tree_map(select_chunk, flat_inputs)
        output_chunk = layer(**chunkes)
        if out is None:
            out = tensor_tree_map(
                lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk
            )
        out_type = type(output_chunk)
        if out_type is tuple:
            for x, y in zip(out, output_chunk):
                x[chunk_start:chunk_end] = y
        elif out_type is torch.Tensor:
            out[chunk_start:chunk_end] = output_chunk
        else:
            raise ValueError("Not supported")
    reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
    out = tensor_tree_map(reshape, out)
    return out

In [None]:
# Reference : https://github.com/dptech-corp/Uni-Fold/blob/1a301710392ecf97991aebc3276aad9d0f77178f/unifold/modules/common.py

class Linear(nn.Linear):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        bias: bool = True,
        init: str = "default",
    ):
        super(Linear, self).__init__(d_in, d_out, bias=bias)

        self.use_bias = bias

        if self.use_bias:
            with torch.no_grad():
                self.bias.fill_(0)

        if init == "default":
            self._trunc_normal_init(1.0)
        elif init == "relu":
            self._trunc_normal_init(2.0)
        elif init == "glorot":
            self._glorot_uniform_init()
        elif init == "gating":
            self._zero_init(self.use_bias)
        elif init == "normal":
            self._normal_init()
        elif init == "final":
            self._zero_init(False)
        else:
            raise ValueError("Invalid init method.")

    def _trunc_normal_init(self, scale=1.0):
        # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
        TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978
        _, fan_in = self.weight.shape
        scale = scale / max(1, fan_in)
        std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR
        nn.init.trunc_normal_(self.weight, mean=0.0, std=std)

    def _glorot_uniform_init(self):
        nn.init.xavier_uniform_(self.weight, gain=1)

    def _zero_init(self, use_bias=True):
        with torch.no_grad():
            self.weight.fill_(0.0)
            if use_bias:
                with torch.no_grad():
                    self.bias.fill_(1.0)

    def _normal_init(self):
        torch.nn.init.kaiming_normal_(self.weight, nonlinearity="linear")


In [None]:
# Reference : https://github.com/dptech-corp/Uni-Core/blob/44f6386f4dcd7137fc1e5d5e768117d635d64a26/unicore/modules/layer_norm.py

# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numbers
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import functional as F

"""
try:
    import unicore_fused_layernorm
    import unicore_fused_layernorm_backward_gamma_beta
    HAS_LAYER_NORM = True
except:
    print("fused_layer_norm is not installed corrected")
    HAS_LAYER_NORM = False

if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7:
    HAS_LAYER_NORM = False
"""

HAS_LAYER_NORM = False

class FusedLayerNormFastFunction(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input, weight, bias, normalized_shape, eps):
    ctx.normalized_shape = normalized_shape
    ctx.eps = eps
    input = input.contiguous()
    weight = weight.contiguous()
    bias = bias.contiguous()
    output, mean, invvar = unicore_fused_layernorm.forward(
        input, ctx.normalized_shape, weight, bias, ctx.eps)
    ctx.save_for_backward(input, weight, bias, mean, invvar)
    return output
  @staticmethod
  def backward(ctx, grad_output):
    input_, weight_, bias_, mean, invvar = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None
    grad_input = unicore_fused_layernorm.backward(
        grad_output.contiguous(), mean, invvar,
        input_, ctx.normalized_shape,
        weight_, bias_, ctx.eps)
    grad_weight, grad_bias = unicore_fused_layernorm_backward_gamma_beta.backward(
        grad_output.contiguous(), mean, invvar,
        input_, ctx.normalized_shape,
        weight_, bias_, ctx.eps)
    return grad_input, grad_weight, grad_bias, None, None

FUSED_LAYER_NORM_SUPPORT_DIM = set([64, 128, 192, 256, 320, 384, 512, 640, 768, 1024, 1280, 1536, 1792, 2048, 2560, 5120])

class LayerNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super(LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps
        assert elementwise_affine
        self.weight = Parameter(torch.Tensor(*normalized_shape))
        self.bias = Parameter(torch.Tensor(*normalized_shape))
        self.reset_parameters()
        def torch_layer_norm(input):
            return F.layer_norm(
                input, self.normalized_shape, self.weight.type(input.dtype), self.bias.type(input.dtype), self.eps)
        def fused_layer_norm(input):
            if input.is_cuda:
                return FusedLayerNormFastFunction.apply(
                    input, self.weight.type(input.dtype), self.bias.type(input.dtype), self.normalized_shape, self.eps)
            else:
                return F.layer_norm(
                    input, self.normalized_shape, self.weight.type(input.dtype), self.bias.type(input.dtype), self.eps)
        self.func = torch_layer_norm if (not HAS_LAYER_NORM or normalized_shape[0] not in FUSED_LAYER_NORM_SUPPORT_DIM) else fused_layer_norm

    def reset_parameters(self):
        init.ones_(self.weight)
        init.zeros_(self.bias)

    def forward(self, input):
        return self.func(input)

    def extra_repr(self):
        return '{normalized_shape}, eps={eps}, ' \
            'elementwise_affine=True'.format(**self.__dict__)

In [None]:
# Reference : https://github.com/dptech-corp/Uni-Fold/blob/1a301710392ecf97991aebc3276aad9d0f77178f/unifold/modules/triangle_multiplication.py

from functools import partialmethod

class TriangleMultiplication(nn.Module):
    def __init__(self, d_pair, d_hid, outgoing=True):
        super(TriangleMultiplication, self).__init__()
        self.outgoing = outgoing

        self.linear_ab_p = Linear(d_pair, d_hid * 2)
        self.linear_ab_g = Linear(d_pair, d_hid * 2, init="gating")

        self.linear_g = Linear(d_pair, d_pair, init="gating")
        self.linear_z = Linear(d_hid, d_pair, init="final")

        self.layer_norm_in = LayerNorm(d_pair)
        self.layer_norm_out = LayerNorm(d_hid)

        self._alphafold_original_mode = False

    def _chunk_2d(
        self,
        z: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        block_size: int = None,
    ) -> torch.Tensor:

        # avoid too small chunk size
        # block_size = max(block_size, 256)
        new_z = z.new_zeros(z.shape)
        dim1 = z.shape[-3]

        def _slice_linear(z, linear: Linear, a=True):
            d_hid = linear.bias.shape[0] // 2
            index = 0 if a else d_hid
            p = (
                nn.functional.linear(z, linear.weight[index : index + d_hid])
                + linear.bias[index : index + d_hid]
            )
            return p

        def _chunk_projection(z, mask, a=True):
            p = _slice_linear(z, self.linear_ab_p, a) * mask
            p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a))
            return p

        num_chunk = (dim1 + block_size - 1) // block_size
        for i in range(num_chunk):
            chunk_start = i * block_size
            chunk_end = min(chunk_start + block_size, dim1)
            if self.outgoing:
                a_chunk = _chunk_projection(
                    z[..., chunk_start:chunk_end, :, :],
                    mask[..., chunk_start:chunk_end, :, :],
                    a=True,
                )
                a_chunk = permute_final_dims(a_chunk, (2, 0, 1))
            else:
                a_chunk = _chunk_projection(
                    z[..., :, chunk_start:chunk_end, :],
                    mask[..., :, chunk_start:chunk_end, :],
                    a=True,
                )
                a_chunk = a_chunk.transpose(-1, -3)

            for j in range(num_chunk):
                j_chunk_start = j * block_size
                j_chunk_end = min(j_chunk_start + block_size, dim1)
                if self.outgoing:
                    b_chunk = _chunk_projection(
                        z[..., j_chunk_start:j_chunk_end, :, :],
                        mask[..., j_chunk_start:j_chunk_end, :, :],
                        a=False,
                    )
                    b_chunk = b_chunk.transpose(-1, -3)
                else:
                    b_chunk = _chunk_projection(
                        z[..., :, j_chunk_start:j_chunk_end, :],
                        mask[..., :, j_chunk_start:j_chunk_end, :],
                        a=False,
                    )
                    b_chunk = permute_final_dims(b_chunk, (2, 0, 1))
                x_chunk = torch.matmul(a_chunk, b_chunk)
                del b_chunk
                x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
                x_chunk = self.layer_norm_out(x_chunk)
                x_chunk = self.linear_z(x_chunk)
                x_chunk *= torch.sigmoid(
                    self.linear_g(
                        z[..., chunk_start:chunk_end, j_chunk_start:j_chunk_end, :]
                    )
                )
                new_z[
                    ..., chunk_start:chunk_end, j_chunk_start:j_chunk_end, :
                ] = x_chunk
                del x_chunk
            del a_chunk
        return new_z

    def forward(
        self,
        z: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        block_size=None,
    ) -> torch.Tensor:

        mask = mask.unsqueeze(-1)
        if not self._alphafold_original_mode:
            # divided by 1/sqrt(dim) for numerical stability
            mask = mask * (mask.shape[-2] ** -0.5)

        z = self.layer_norm_in(z)
        if not self.training and block_size is not None:
            return self._chunk_2d(z, mask, block_size=block_size)

        g = nn.functional.linear(z, self.linear_g.weight)
        if self.training:
            ab = self.linear_ab_p(z) * mask * torch.sigmoid(self.linear_ab_g(z))
        else:
            ab = self.linear_ab_p(z)
            ab *= mask
            ab *= torch.sigmoid(self.linear_ab_g(z))
        a, b = torch.chunk(ab, 2, dim=-1)
        del z, ab

        if self.outgoing:
            a = permute_final_dims(a, (2, 0, 1))
            b = b.transpose(-1, -3)
        else:
            b = permute_final_dims(b, (2, 0, 1))
            a = a.transpose(-1, -3)
        x = torch.matmul(a, b)
        del a, b

        x = permute_final_dims(x, (1, 2, 0))

        x = self.layer_norm_out(x)
        x = nn.functional.linear(x, self.linear_z.weight)
        return x, g

    def get_output_bias(self):
        return self.linear_z.bias, self.linear_g.bias


class TriangleMultiplicationOutgoing(TriangleMultiplication):
    __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True)


class TriangleMultiplicationIncoming(TriangleMultiplication):
    __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False)

In [None]:
# Modified from https://github.com/dptech-corp/Uni-Fold/blob/1a301710392ecf97991aebc3276aad9d0f77178f/unifold/modules/attentions.py#L349-L402

class TriangleAttention(nn.Module):
    def __init__(
        self,
        d_in,
        d_hid,
        num_heads,
        starting,
    ):
        super(TriangleAttention, self).__init__()
        self.starting = starting
        self.layer_norm = nn.LayerNorm(d_in)
        self.linear = nn.Linear(d_in, num_heads, bias=False)
        self.mha = nn.MultiheadAttention(d_hid, num_heads)
    @torch.jit.ignore
    def _chunk(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
        chunk_size: int = None,
    ) -> torch.Tensor:
        return chunk_layer(
            self.mha,
            {"q": x, "k": x, "v": x, "mask": mask, "bias": bias},
            chunk_size=chunk_size,
            num_batch_dims=len(x.shape[:-2]),
        )
    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        chunk_size: Optional[int] = None,
    ) -> torch.Tensor:
        if not self.starting:
            x = x.transpose(-2, -3)
        x = self.layer_norm(x)
        triangle_bias = (
            permute_final_dims(self.linear(x), (2, 0, 1)).unsqueeze(-4).contiguous()
        )
        if chunk_size is not None:
            x = self._chunk(x, attn_mask, triangle_bias, chunk_size)
        else:
            x = self.mha(query=x, key=x, value=x)

        if not self.starting:
            x = x.transpose(-2, -3)
        return x

    def get_output_bias(self):
        return self.mha.get_output_bias()

In [None]:
class MyModelWithTriangleAttention(nn.Module):
    def __init__(self):
        super().__init__()

        vocab_size = 1024
        self.vocab_size = vocab_size
        embedding_dim = 512
        hidden_size = 512
        num_heads = 16
        num_classes = 4

        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.triangle_mutiplication_layer = TriangleMultiplicationOutgoing(
            d_pair=embedding_dim,
            d_hid=hidden_size,
        )

        self.triangle_attn_layer = TriangleAttention(
            d_in = embedding_dim,
            d_hid = hidden_size,
            num_heads = num_heads,
            starting = True
        )

        self.fc = nn.Linear(vocab_size+hidden_size, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        #print(f"type(x) = ", type(x))

        mask = torch.ones(x.size()).to(torch.bool)  # Create an all-true mask

        x = self.triangle_mutiplication_layer(x, mask)
        x = self.triangle_attn_layer(x, mask)

        #x = self.triangle_attn_layer(x)
        #print("len(x) = ", len(x))
        #assert x.shape[0] == x.shape[0], "Batch dimension must match"
        #assert x.shape[1] == 1 and x.shape[1] > 1, "Feature dimension must match"

        #print("x[0].shape = ", x[0].shape)
        #print("x[1].shape = ", x[1].shape)

        x = list(x)
        x[0] = x[0].reshape(1, 1024, 512)

        x = tuple(x)

        x = torch.cat((x[0], x[1]), dim=-1)

        x = self.fc(x)

        return x

model = MyModelWithTriangleAttention()

optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

num_epochs = 20

inputs = []

for epoch in range(num_epochs):

    # Training loop...

    optimizer.zero_grad()
    input = torch.randint(low=0, high=1, size=(1024,1)).to(torch.long)
    inputs.append(input)

    logits = model(input)
    labels = torch.randint(low=0, high=1, size=(1,4))
    loss = loss_fn(logits, labels)
    loss.backward()
    optimizer.step()

    print(f"epoch = {epoch} , loss = {loss}")

In [None]:
# Modified from code generated from AI chatbot

import matplotlib.pyplot as plt
import numpy as np

# Get sequence length
seq_len = model.vocab_size

# Initialize figure
fig, ax = plt.subplots()

# Plot heatmap
inputs = [input.to(torch.long).numpy() for input in inputs]
im = ax.imshow(inputs, cmap='RdBu',
   interpolation='nearest')

# Label axes
ax.set_xlabel('Query')
ax.set_ylabel('Key')
ax.set_xticks(np.arange(seq_len))
ax.set_yticks(np.arange(seq_len))

# Show triangle mask
for i in range(seq_len):
    for j in range(i):
        ax.axvspan(i, seq_len, ymin=j, ymax=j+1, color='grey', alpha=0.2)

plt.show()