In [329]:
from transformers import EsmForTokenClassification, EsmForMaskedLM 

In [363]:
import torch
import torch.nn as nn
from typing import Tuple
import torch.nn.functional as F
from typing import Union
from functools import lru_cache
import random
import numpy as np
import time

In [364]:
# Initialize parameters : 

batch_size = 1
seq_length = 1024
num_attention_heads = 20
hidden_size = 1280
attention_head_size = int(hidden_size / num_attention_heads)
all_head_size = num_attention_heads * attention_head_size

position_embedding_type = "rotary"

In [365]:
# Initialize layers : 

query_layer = nn.Linear(hidden_size, all_head_size)
key_layer = nn.Linear(hidden_size, all_head_size)
value_layer = nn.Linear(hidden_size, all_head_size)

dropout_layer = nn.Dropout(0.5)

In [366]:
query_layer.in_features, query_layer.out_features

(1280, 1280)

In [367]:
# Implement RoPE : 

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    cos = cos[:, :, : x.shape[-2], :]
    sin = sin[:, :, : x.shape[-2], :]
    return (x * cos) + (rotate_half(x) * sin)

class RotaryEmbedding(torch.nn.Module):
    """
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
    matrices which depend on their relative positions.
    """

    def __init__(self, dim: int):
        super().__init__()
        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
        inv_freq = inv_freq
        self.register_buffer("inv_freq", inv_freq)

        self._seq_len_cached = None
        self._cos_cached = None
        self._sin_cached = None

    def _update_cos_sin_tables(self, x, seq_dimension=2):
        seq_len = x.shape[seq_dimension]

        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            self._cos_cached = emb.cos()[None, None, :, :]
            self._sin_cached = emb.sin()[None, None, :, :]

        return self._cos_cached, self._sin_cached

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )


In [368]:
# Transpose for attentions :

def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

In [369]:
hidden_states = torch.randn((1, seq_length, hidden_size))

In [370]:
hidden_states.shape

torch.Size([1, 1024, 1280])

In [386]:
# Query :

mixed_query_layer = query_layer(hidden_states)
print(mixed_query_layer.shape)
query = transpose_for_scores(mixed_query_layer)
print(query.shape)


torch.Size([1, 1024, 1280])
torch.Size([1, 20, 1024, 64])


In [387]:
query = query * attention_head_size**-0.5

In [373]:
# Key, Value :

key = key_layer(hidden_states)
key = transpose_for_scores(key)

value = value_layer(hidden_states)
value = transpose_for_scores(value)

In [388]:
# Positional embedding : 

if position_embedding_type == "rotary":
    rotary_embeddings = RotaryEmbedding(attention_head_size)
    query, key = rotary_embeddings(query, key)

In [389]:
query.shape

torch.Size([1, 20, 1024, 64])

In [391]:
# Matrix multiplication for attention scores :

start_time = time.time()
attention_scores_full = torch.matmul(query, key.transpose(-1, -2))
print(f'Elapsed time for full matmul : {time.time()-start_time}')
attention_scores_full.shape

Elapsed time for full matmul : 0.07536983489990234


torch.Size([1, 20, 1024, 1024])

In [377]:
# Attention mask : 

attention_mask_full = None
if attention_mask_full is not None:
    # Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
    attention_mask_full = attention_mask_full
    attention_scores_full = attention_scores_full + attention_mask_full

In [378]:
# Transformations before Att*V : 

attention_probs_full = nn.functional.softmax(attention_scores_full, dim=-1)
attention_probs_full = dropout_layer(attention_probs_full)

In [393]:
# Context layer :

context = torch.matmul(attention_probs_full, value)
context_layer = context.permute(0, 2, 1, 3).contiguous()

new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
print(context_layer.shape)

torch.Size([1, 1024, 1280])


In [380]:
# Output
output_attentions = True
outputs = (context_layer, attention_probs_full) if output_attentions else (context_layer,)

In [381]:
outputs[0].shape, outputs[1].shape

(torch.Size([1, 1024, 1280]), torch.Size([1, 20, 1024, 513]))

In [None]:
##########

Longformer Attention (sliding window) : 

In [131]:
# Layers transformations methods : 

def _skew(x, direction, padding_value):
    '''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
    x_padded = F.pad(x, direction, value=padding_value)
    x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
    return x_padded

def _skew2(x, padding_value):
    '''shift every row 1 step to right converting columns into diagonals'''
    # X = B x C x M x L
    B, C, M, L = x.size()
    x = F.pad(x, (0, M + 1), value=padding_value)  # B x C x M x (L+M+1)
    x = x.view(B, C, -1)  # B x C x ML+MM+M
    x = x[:, :, :-M]  # B x C x ML+MM
    x = x.view(B, C, M, M + L)  # B x C, M x L+M
    x = x[:, :, :, :-1]
    return x

def _chunk(x, w):
    dim = int(x.size(1) // (w * 2))
    x = x.view(x.size(0), dim, int(w * 2), x.size(2))

    chunk_size = list(x.size())
    chunk_size[1] = chunk_size[1] * 2 - 1

    chunk_stride = list(x.stride())
    chunk_stride[1] = chunk_stride[1] // 2
    
    return x.as_strided(size=chunk_size, stride=chunk_stride)

@lru_cache()
def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
    if isinstance(d, int):
        affected_seq_len = w * d
        mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
        mask = mask[None, :, None, :]
    else:
        affected_seq_len = w * d.max()
        head_masks = []
        d_list = d.cpu().numpy().tolist()
        for d in d_list:
            one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
            head_masks.append(one_head_mask)
        mask = torch.stack(head_masks, dim=-2)
        mask = mask[None, :, :, :]

    ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device)
    return affected_seq_len, mask.bool().to(device), ending_mask

def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
    diagonals_list = []
    for j in range(-d * w, d, d):
        diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
        diagonal_mask[:-j] = 1
        diagonals_list.append(diagonal_mask)
    return torch.stack(diagonals_list, dim=-1)

def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor:
    affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device)
    seq_len = input_tensor.size(1)
    beginning_input = input_tensor[:, :affected_seq_len, :, :w+1]
    beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size())
    beginning_input.masked_fill_(beginning_mask, -float('inf'))
    if not autoregressive:
        ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):]
        ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size())
        ending_input.masked_fill_(ending_mask, -float('inf'))

def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):

    bsz, num_heads,seqlen, head_dim = q.size()

    assert seqlen % (w * 2) == 0
    assert q.size() == k.size()

    chunks_count = seqlen // w - 1

    # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
    q = q.reshape(bsz * num_heads, seqlen, head_dim)
    k = k.reshape(bsz * num_heads, seqlen, head_dim)

    chunk_q = _chunk(q, w)
    chunk_k = _chunk(k, w)
    chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k))  # multiply

    # convert diagonals into columns
    diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value)
    diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1))
    
    diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1]
    diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1]
    # - copying the lower triangle
    diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:]
    diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:]

    # separate bsz and num_heads dimensions again
    diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1)
    mask_invalid_locations(diagonal_attn, w, 1, True)

    return diagonal_attn

def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
    
    bsz, num_heads,seqlen, head_dim = v.size()
    assert seqlen % (w * 2) == 0
    assert prob.size()[:3] == v.size()[:3]
    assert prob.size(3) == 2 * w + 1
    chunks_count = seqlen // w - 1
    # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
    chunk_prob = prob.reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1)

    # group bsz and num_heads dimensions into one
    v = v.reshape(bsz * num_heads, seqlen, head_dim)

    # pad seqlen with w at the beginning of the sequence and another w at the end
    padded_v = F.pad(v, (0, 0, w, w), value=-1)

    # chunk padded_v into chunks of size 3w and an overlap of size w
    chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim)
    chunk_v_stride = padded_v.stride()
    chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
    chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)

    skewed_prob = _skew2(chunk_prob, padding_value=0)
    context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v))
    return context.view(bsz, num_heads, seqlen, head_dim)


In [132]:
# Layers and reshaping for longformer 

query_states = query_layer(hidden_states)
key_states = key_layer(hidden_states)
value_states = value_layer(hidden_states)

query_states = query_states.view(batch_size, seq_length, num_attention_heads, attention_head_size).transpose(1, 2)
key_states = key_states.view(batch_size, seq_length, num_attention_heads, attention_head_size).transpose(1, 2)
value_states = value_states.view(batch_size, seq_length, num_attention_heads, attention_head_size).transpose(1, 2)


In [133]:
# RoPE

if position_embedding_type == "rotary":
    rotary_embeddings = RotaryEmbedding(attention_head_size)
    query_states, key_states = rotary_embeddings(query_states, key_states)


In [134]:
# Attention scores

window_size = 256

attention_scores = sliding_chunks_matmul_qk(query_states, key_states, window_size, padding_value=0)*(1/(seq_length**0.5))
attention_scores.shape # torch.Size([batch_size, num_attention_heads, seq_length, w*2+1])

torch.Size([1, 20, 1024, 513])

In [135]:
# Attention mask : 

attention_mask = None

if attention_mask is not None:
    attention_scores = attention_scores + attention_mask

In [136]:
# Attention normalization and context layer

attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32)
attention_probs = dropout_layer(attention_probs)

context = sliding_chunks_matmul_pv(attention_probs, value_states, window_size)
context_layer = context.transpose(1, 2).contiguous()
output = context_layer.view(context_layer.size(0),context_layer.size(1),context_layer.size(2)*context_layer.size(3))

In [137]:
output.shape # torch.Size([batch_size, seq_length, hidden_size])

torch.Size([1, 1024, 1280])

Sparse Attention (ours)

In [138]:
# Input : hidden states + sequence information (proteins + interactions map)

# Random proteins information : 
def generate_list(n, Amin, Amax):
    ni_list = []
    remaining = n
    while remaining > Amin:
        ni = random.randint(Amin, min(Amax, remaining))
        ni_list.append(ni)
        remaining -= ni
    ni_list.append(remaining)
    return ni_list

# Random proteins interactions : 
def generate_couples(n_couples, n_len):
    couples = []
    while len(couples) < n_couples :
        i = random.randint(0, n_len - 1)
        j = random.randint(0, n_len - 1)
        if abs(i - j) > 1 and (i, j) not in couples and (j, i) not in couples:
            couples.append((i, j))
    return couples

proteins_list = generate_list(seq_length, Amin = 50, Amax = 200) 
proteins_cs =  [0]+list(np.cumsum(np.array(proteins_list)))
n = len(proteins_list)
proteins_interactions = generate_couples(n_couples = 10, n_len = n) # Max number of couples : n * (n - 3) / 2
proteins_list, proteins_cs, proteins_interactions

([116, 184, 134, 57, 169, 128, 191, 45],
 [0, 116, 300, 434, 491, 660, 788, 979, 1024],
 [(4, 2),
  (7, 1),
  (0, 3),
  (5, 7),
  (3, 7),
  (2, 0),
  (4, 7),
  (5, 0),
  (0, 4),
  (3, 6)])

In [325]:
def sparse_attention_matrix(query, key, proteins_interactions, proteins_cs):
    batch_size, num_heads, seq_len, all_head_size = query.shape
    start_time = time.time()
    attentions = []
    sorted_proteins_interactions_x = sorted(proteins_interactions, key=lambda x: x[0])
    print(sorted_proteins_interactions_x)
    max2 = 0
    max3 = 0
    for i, j in sorted_proteins_interactions :
        # Blocks
        query_block = query[:, :, proteins_cs[i]:proteins_cs[i+1], :]
        key_block = key[:, :, proteins_cs[j]:proteins_cs[j+1], :]
        
        # Compute attention matrix for the 2 blocks
        attention_block = torch.matmul(query_block, key_block.transpose(-1, -2)).squeeze()
        attentions.append(attention_block)
        
        if attention_block.shape[-2] > max2 :
            max2 = attention_block.shape[-2]
        if attention_block.shape[-1] > max3 :
            max3 = attention_block.shape[-1]

    # Padding for bsr storage :
    padded_attentions = []
    for block in attentions :
        s = block.shape
        padding = (0, max3 - s[-1], 0, max2 - s[-2])
        padded_attention_block = F.pad(block, padding, "constant", 0)
        padded_attentions.append(padded_attention_block)
        
    print(f'Time for attention matrix computation : {time.time()-start_time}')
    
    # Create block sparse matrix with torch.sparse_bsr_tensor
    # columns and rows : 
    columns = [x[1] for x in sorted_proteins_interactions_x]
    # values :
    concatenated_attentions = torch.stack(padded_attentions, dim=1)
    print(concatenated_attentions.shape)
    sparse_matrix = torch.sparse_bsr_tensor(indices_tensor, values_tensor, (seq_length*all_head_size, seq_length*all_head_size))

    return sparse_matrix

In [326]:
attentions_scores_sparse = sparse_attention_matrix(query, key, proteins_interactions, proteins_cs)

[(0, 3), (0, 4), (2, 0), (3, 7), (3, 6), (4, 2), (4, 7), (5, 7), (5, 0), (7, 1)]


NameError: name 'sorted_proteins_interactions' is not defined

In [327]:
crow_indices = [0, 2, 4, 8]
col_indices = [0, 1, 2, 0, 4, 5 , 6, 7]
values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[1, 2], [3, 4]], [[5, 6], [7, 8]], [[1, 2], [3, 4]], [[5, 6], [7, 8]],[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
bsr = torch.sparse_bsr_tensor(torch.tensor(crow_indices, dtype=torch.int64),torch.tensor(col_indices, dtype=torch.int64),torch.tensor(values), dtype=torch.int64)
bsr

tensor(crow_indices=tensor([0, 2, 4, 8]),
       col_indices=tensor([0, 1, 2, 0, 4, 5, 6, 7]),
       values=tensor([[[1, 2],
                       [3, 4]],

                      [[5, 6],
                       [7, 8]],

                      [[1, 2],
                       [3, 4]],

                      [[5, 6],
                       [7, 8]],

                      [[1, 2],
                       [3, 4]],

                      [[5, 6],
                       [7, 8]],

                      [[1, 2],
                       [3, 4]],

                      [[5, 6],
                       [7, 8]]]), size=(6, 16), nnz=8, layout=torch.sparse_bsr)

In [328]:
bsr.to_dense()

tensor([[1, 2, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 4, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [5, 6, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [7, 8, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 5, 6, 1, 2, 5, 6],
        [0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 7, 8, 3, 4, 7, 8]])

In [285]:
torch.tensor(values[0]).shape

torch.Size([2, 2])

In [248]:
import torch

# Supposons que vous ayez des valeurs avec des indices inégaux
indices = torch.tensor([[0, 1, 2, 2], [0, 1, 0, 1]])  # Les indices de chaque élément non-nul
values = torch.tensor([1, 2, 3, 4])  # Les valeurs correspondantes
size = (3, 2)  # La taille totale de la matrice

sparse_tensor = torch.sparse_coo_tensor(indices, values, size)
print(sparse_tensor.to_dense())

tensor([[1, 0],
        [0, 2],
        [3, 4]])
