In [3]:
from transformers import EsmForTokenClassification, EsmForMaskedLM 

In [4]:
import torch
import torch.nn as nn
from typing import Tuple
import torch.nn.functional as F
from typing import Union
from functools import lru_cache
import random
import numpy as np
import time

In [5]:
# Initialize parameters : 

batch_size = 1
seq_length = 8192
num_attention_heads = 20
hidden_size = 1280
attention_head_size = int(hidden_size / num_attention_heads)
all_head_size = num_attention_heads * attention_head_size

position_embedding_type = "rotary"

In [6]:
# Initialize layers : 

query_layer = nn.Linear(hidden_size, all_head_size)
key_layer = nn.Linear(hidden_size, all_head_size)
value_layer = nn.Linear(hidden_size, all_head_size)

dropout_layer = nn.Dropout(0.5)

In [7]:
query_layer.in_features, query_layer.out_features

(1280, 1280)

In [8]:
# Implement RoPE : 

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    cos = cos[:, :, : x.shape[-2], :]
    sin = sin[:, :, : x.shape[-2], :]
    return (x * cos) + (rotate_half(x) * sin)

class RotaryEmbedding(torch.nn.Module):
    """
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
    matrices which depend on their relative positions.
    """

    def __init__(self, dim: int):
        super().__init__()
        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
        inv_freq = inv_freq
        self.register_buffer("inv_freq", inv_freq)

        self._seq_len_cached = None
        self._cos_cached = None
        self._sin_cached = None

    def _update_cos_sin_tables(self, x, seq_dimension=2):
        seq_len = x.shape[seq_dimension]

        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            self._cos_cached = emb.cos()[None, None, :, :]
            self._sin_cached = emb.sin()[None, None, :, :]

        return self._cos_cached, self._sin_cached

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )


In [9]:
# Transpose for attentions :

def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

In [10]:
hidden_states = torch.randn((1, seq_length, hidden_size))

In [11]:
hidden_states.shape

torch.Size([1, 8192, 1280])

In [12]:
# Query :

mixed_query_layer = query_layer(hidden_states)
print(mixed_query_layer.shape)
query = transpose_for_scores(mixed_query_layer)
print(query.shape)


torch.Size([1, 8192, 1280])
torch.Size([1, 20, 8192, 64])


In [13]:
query = query * attention_head_size**-0.5

In [14]:
# Key, Value :

key = key_layer(hidden_states)
key = transpose_for_scores(key)

value = value_layer(hidden_states)
value = transpose_for_scores(value)

In [15]:
# Positional embedding : 

if position_embedding_type == "rotary":
    rotary_embeddings = RotaryEmbedding(attention_head_size)
    query, key = rotary_embeddings(query, key)

In [14]:
query.shape

torch.Size([1, 20, 8192, 64])

In [15]:
# Matrix multiplication for attention scores :

start_time = time.time()
attention_scores_full = torch.matmul(query, key.transpose(-1, -2))
print(f'Elapsed time for full matmul : {time.time()-start_time}')
attention_scores_full.shape

Elapsed time for full matmul : 5.058045864105225


torch.Size([1, 20, 8192, 8192])

In [16]:
# Attention mask : 

attention_mask_full = None
if attention_mask_full is not None:
    # Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
    attention_mask_full = attention_mask_full
    attention_scores_full = attention_scores_full + attention_mask_full

In [17]:
# Transformations before Att*V : 

attention_probs_full = nn.functional.softmax(attention_scores_full, dim=-1)
attention_probs_full = dropout_layer(attention_probs_full)

In [18]:
# Context layer :

context = torch.matmul(attention_probs_full, value)
context_layer = context.permute(0, 2, 1, 3).contiguous()

new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
print(context_layer.shape)

torch.Size([1, 8192, 1280])


In [19]:
# Output
output_attentions = True
outputs = (context_layer, attention_probs_full) if output_attentions else (context_layer,)

In [20]:
outputs[0].shape, outputs[1].shape

(torch.Size([1, 8192, 1280]), torch.Size([1, 20, 8192, 8192]))

In [21]:
##########

Longformer Attention (sliding window) : 

In [22]:
# Layers transformations methods : 

def _skew(x, direction, padding_value):
    '''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
    x_padded = F.pad(x, direction, value=padding_value)
    x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
    return x_padded

def _skew2(x, padding_value):
    '''shift every row 1 step to right converting columns into diagonals'''
    # X = B x C x M x L
    B, C, M, L = x.size()
    x = F.pad(x, (0, M + 1), value=padding_value)  # B x C x M x (L+M+1)
    x = x.view(B, C, -1)  # B x C x ML+MM+M
    x = x[:, :, :-M]  # B x C x ML+MM
    x = x.view(B, C, M, M + L)  # B x C, M x L+M
    x = x[:, :, :, :-1]
    return x

def _chunk(x, w):
    dim = int(x.size(1) // (w * 2))
    x = x.view(x.size(0), dim, int(w * 2), x.size(2))

    chunk_size = list(x.size())
    chunk_size[1] = chunk_size[1] * 2 - 1

    chunk_stride = list(x.stride())
    chunk_stride[1] = chunk_stride[1] // 2
    
    return x.as_strided(size=chunk_size, stride=chunk_stride)

@lru_cache()
def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
    if isinstance(d, int):
        affected_seq_len = w * d
        mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
        mask = mask[None, :, None, :]
    else:
        affected_seq_len = w * d.max()
        head_masks = []
        d_list = d.cpu().numpy().tolist()
        for d in d_list:
            one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
            head_masks.append(one_head_mask)
        mask = torch.stack(head_masks, dim=-2)
        mask = mask[None, :, :, :]

    ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device)
    return affected_seq_len, mask.bool().to(device), ending_mask

def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
    diagonals_list = []
    for j in range(-d * w, d, d):
        diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
        diagonal_mask[:-j] = 1
        diagonals_list.append(diagonal_mask)
    return torch.stack(diagonals_list, dim=-1)

def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor:
    affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device)
    seq_len = input_tensor.size(1)
    beginning_input = input_tensor[:, :affected_seq_len, :, :w+1]
    beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size())
    beginning_input.masked_fill_(beginning_mask, -float('inf'))
    if not autoregressive:
        ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):]
        ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size())
        ending_input.masked_fill_(ending_mask, -float('inf'))

def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):

    bsz, num_heads,seqlen, head_dim = q.size()

    assert seqlen % (w * 2) == 0
    assert q.size() == k.size()

    chunks_count = seqlen // w - 1

    # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
    q = q.reshape(bsz * num_heads, seqlen, head_dim)
    k = k.reshape(bsz * num_heads, seqlen, head_dim)

    chunk_q = _chunk(q, w)
    chunk_k = _chunk(k, w)
    chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k))  # multiply

    # convert diagonals into columns
    diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value)
    diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1))
    
    diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1]
    diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1]
    # - copying the lower triangle
    diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:]
    diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:]

    # separate bsz and num_heads dimensions again
    diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1)
    mask_invalid_locations(diagonal_attn, w, 1, True)

    return diagonal_attn

def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
    
    bsz, num_heads,seqlen, head_dim = v.size()
    assert seqlen % (w * 2) == 0
    assert prob.size()[:3] == v.size()[:3]
    assert prob.size(3) == 2 * w + 1
    chunks_count = seqlen // w - 1
    # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
    chunk_prob = prob.reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1)

    # group bsz and num_heads dimensions into one
    v = v.reshape(bsz * num_heads, seqlen, head_dim)

    # pad seqlen with w at the beginning of the sequence and another w at the end
    padded_v = F.pad(v, (0, 0, w, w), value=-1)

    # chunk padded_v into chunks of size 3w and an overlap of size w
    chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim)
    chunk_v_stride = padded_v.stride()
    chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
    chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)

    skewed_prob = _skew2(chunk_prob, padding_value=0)
    context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v))
    return context.view(bsz, num_heads, seqlen, head_dim)


In [23]:
# Layers and reshaping for longformer 

query_states = query_layer(hidden_states)
key_states = key_layer(hidden_states)
value_states = value_layer(hidden_states)

query_states = query_states.view(batch_size, seq_length, num_attention_heads, attention_head_size).transpose(1, 2)
key_states = key_states.view(batch_size, seq_length, num_attention_heads, attention_head_size).transpose(1, 2)
value_states = value_states.view(batch_size, seq_length, num_attention_heads, attention_head_size).transpose(1, 2)


In [24]:
# RoPE

if position_embedding_type == "rotary":
    rotary_embeddings = RotaryEmbedding(attention_head_size)
    query_states, key_states = rotary_embeddings(query_states, key_states)


In [25]:
# Attention scores

window_size = 256

attention_scores = sliding_chunks_matmul_qk(query_states, key_states, window_size, padding_value=0)*(1/(seq_length**0.5))
attention_scores.shape # torch.Size([batch_size, num_attention_heads, seq_length, w*2+1])

torch.Size([1, 20, 8192, 513])

In [26]:
# Attention mask : 

attention_mask = None

if attention_mask is not None:
    attention_scores = attention_scores + attention_mask

In [27]:
# Attention normalization and context layer

attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32)
attention_probs = dropout_layer(attention_probs)

context = sliding_chunks_matmul_pv(attention_probs, value_states, window_size)
context_layer = context.transpose(1, 2).contiguous()
output = context_layer.view(context_layer.size(0),context_layer.size(1),context_layer.size(2)*context_layer.size(3))

In [28]:
output.shape # torch.Size([batch_size, seq_length, hidden_size])

torch.Size([1, 8192, 1280])

Sparse Attention (ours)

In [18]:
# Input : hidden states + sequence information (proteins + interactions map)

# Random proteins information : 
def generate_list(n, Amin, Amax):
    ni_list = []
    remaining = n
    while remaining > Amin:
        ni = random.randint(Amin, min(Amax, remaining))
        ni_list.append(ni)
        remaining -= ni
    ni_list.append(remaining)
    return ni_list

# Random proteins interactions : 
def generate_couples(n_couples, n_len):
    couples = []
    while len(couples) < n_couples :
        i = random.randint(0, n_len - 1)
        j = random.randint(0, n_len - 1)
        if abs(i - j) > 1 and (i, j) not in couples and (j, i) not in couples:
            couples.append((i, j))
    return couples

proteins_list = generate_list(seq_length, Amin = 50, Amax = 200) 
proteins_cs =  [0]+list(np.cumsum(np.array(proteins_list)))
n = len(proteins_list)
print(n)
proteins_interactions = generate_couples(n_couples = 5 * n, n_len = n) # Max number of couples : n * (n - 3) / 2
print(len(proteins_interactions)), print(int(n * (n - 3) / 2))

66
330
2079


(None, None)

In [25]:
def rows_to_crows(rows, n): 
    rows = np.array(rows)
    counts = np.bincount(rows, minlength=n+1)
    
    # Si le tableau résultant est plus court que n+1, on ajoute des zéros à la fin
    if counts.size < n+1:
        counts = np.pad(counts, (0, n+1-counts.size), constant_values=0)
    counts_cs = np.cumsum(counts)
    return counts_cs.tolist()

In [26]:
def sparse_attention_matrix(query, key, proteins_interactions, proteins_cs, proteins_list):
    tt = time.time()
    batch_size, num_heads, seq_len, all_head_size = query.shape
    start_time = time.time()
    attentions = []
    block_positions = []
    block_sizes = []
    sorted_proteins_interactions = sorted(proteins_interactions, key=lambda x: x[0])
    max2 = 0
    max3 = 0
    tfor = time.time()
    spm = []
    for i, j in sorted_proteins_interactions :
        
        # Blocks
        query_block = query[:, :, proteins_cs[i]:proteins_cs[i+1], :]
        key_block = key[:, :, proteins_cs[j]:proteins_cs[j+1], :]
        
        # Compute attention matrix for the 2 blocks
        attention_block = torch.matmul(query_block, key_block.transpose(-1, -2)).squeeze()
        attentions.append(attention_block) 

        max2 = max(max2, attention_block.shape[-2])
        max3 = max(max3, attention_block.shape[-1])

    at = time.time()-start_time
    
    # Padding for bsr storage :
    padded_attentions = []
    for attention_block in attentions:
        s = attention_block.shape
        padded_attention_block = F.pad(attention_block, (0, max3 - s[-1], 0, max2 - s[-2]), "constant", 0)
        padded_attentions.append(padded_attention_block)
        
    # Create block sparse matrix with torch.sparse_bsr_tensor
    
    # columns and rows : 
    t2 = time.time()
    col_indices = [x[1] for x in sorted_proteins_interactions]
    rows = [x[0] for x in sorted_proteins_interactions]
    crow_indices = [0] + rows_to_crows(rows, len(proteins_cs) - 1)[:-1]
    crow_tensor = torch.stack([torch.tensor(crow_indices)] * 20)
    col_tensor = torch.stack([torch.tensor(col_indices)] * 20)

    # values :    
    concatenated_attentions = torch.stack(padded_attentions, dim=1)

    #sparse_matrix = create_sparse_coo_with_variable_blocks(attentions, block_positions, seq_len, seq_len)
    sparse_matrix = torch.sparse_bsr_tensor(crow_tensor, col_tensor, concatenated_attentions, size = [num_heads, len(proteins_list)*max2, len(proteins_list)*max3])
    mt = time.time() - t2
    ft = time.time()-tt
    print(at, mt, ft)
    return sparse_matrix

In [27]:
attentions_scores_sparse = sparse_attention_matrix(query, key, proteins_interactions, proteins_cs, proteins_list)
#print(attentions_scores_sparse.size())
#print(attentions_scores_sparse)

0.4056520462036133 0.8604848384857178 1.9431049823760986


In [None]:
dense_scores = attentions_scores_sparse.to_dense()
dense_scores.shape

In [None]:
total_elements = torch.prod(torch.tensor(attentions_scores_sparse.shape)).item()  # Nombre total d'éléments dans le tensor dense
non_zero_count = attentions_scores_sparse._nnz()*max(proteins_list)*max(proteins_list)*20  # Nombre d'éléments non nuls stockés dans le sparse tensor
zero_count = total_elements - non_zero_count  # Nombre d'éléments nuls
zero_percentage = (zero_count / total_elements) * 100  # Pourcentage de zéros

print(f"Total elements: {total_elements}")
print(f"Non-zero elements: {non_zero_count}")
print(f"Zero count: {zero_count}")
print(f"Percentage of zero elements: {zero_percentage:.2f}%")

In [217]:
def sparse_attention_matrix_subblock(query, key, proteins_interactions, proteins_cs, proteins_list, block_size):
    batch_size, num_heads, seq_len, all_head_size = query.shape
    start_time = time.time()
    sub_attentions = []
    block_positions = []
    block_sizes = []
    sorted_proteins_interactions = sorted(proteins_interactions, key=lambda x: x[0])
    max2 = 0
    max3 = 0

    num_blocks = [0]+[int(length//block_size) for length in proteins_list]
    num_blocks_cs = np.cumsum(np.array(num_blocks)).tolist()

    # Iterating over pairs to compute block-wise attention
    for i, j in sorted_proteins_interactions:
        #print(i, j)
        query_block = query[:, :, proteins_cs[i]:proteins_cs[i+1], :]
        key_block = key[:, :, proteins_cs[j]:proteins_cs[j+1], :]
        attention_block = torch.matmul(query_block, key_block.transpose(-1, -2)).squeeze()

        # Cut the attention blocks into smaller block_size x block_size blocks
        num_rows, num_cols = attention_block.shape[-2], attention_block.shape[-1]
        for start_row in range(0, num_rows, block_size):
            for start_col in range(0, num_cols, block_size):
                if start_row + block_size <= num_rows and start_col + block_size <= num_cols:
                    sub_block = attention_block[:, start_row:start_row+block_size, start_col:start_col+block_size]
                    sub_attentions.append(sub_block)
                    #print(start_row, start_col)
                    block_positions.append((num_blocks_cs[i] + start_row // 50, num_blocks_cs[j] + start_col // 50))

                    
    assert len(sub_attentions) == len(block_positions)
    print(f'Time for attention matrix computation: {time.time() - start_time}')
    
    # Create the BSR matrix
    col_indices = [x[1] for x in block_positions]
    rows = [x[0] for x in block_positions]
    crow_indices = [0] + rows_to_crows(rows, len(proteins_cs) - 1)[:-1]
    
    crow_tensor = torch.stack([torch.tensor(crow_indices)] * 20)
    col_tensor = torch.stack([torch.tensor(col_indices)] * 20)
    
    value_tensor = torch.stack(sub_attentions, dim=1)
    sparse_matrix = torch.sparse_bsr_tensor(crow_tensor, col_tensor, value_tensor, size=[num_heads, sum(num_blocks)*block_size, sum(num_blocks)*block_size])

    print(f'Time for sparse matrix creation: {time.time() - start_time}')

    return sparse_matrix

In [218]:
sparse_matrix_subblock = sparse_attention_matrix_subblock(query, key, proteins_interactions, proteins_cs, proteins_list, block_size=50)
# sparse_matrix_subblock

Time for attention matrix computation: 0.1936037540435791
Time for sparse matrix creation: 0.22324180603027344


In [219]:
dense_scores_sub = sparse_matrix_subblock.to_dense()
dense_scores_sub.shape

torch.Size([20, 3300, 3300])

In [220]:
block_size = 50
total_elements = torch.prod(torch.tensor(sparse_matrix_subblock.shape)).item()  # Nombre total d'éléments dans le tensor dense
non_zero_count = sparse_matrix_subblock._nnz()*block_size*block_size*20  # Nombre d'éléments non nuls stockés dans le sparse tensor
zero_count = total_elements - non_zero_count  # Nombre d'éléments nuls
zero_percentage = (zero_count / total_elements) * 100  # Pourcentage de zéros

print(f"Total elements: {total_elements}")
print(f"Non-zero elements: {non_zero_count}")
print(f"Zero count: {zero_count}")
print(f"Percentage of zero elements: {zero_percentage:.2f}%")

Total elements: 217800000
Non-zero elements: 30950000
Zero count: 186850000
Percentage of zero elements: 85.79%


In [221]:
def create_sparse_coo_with_variable_blocks(attentions, block_positions, max_rows, max_cols):
    # Listes pour stocker les données du tensor COO
    indices_i = []
    indices_j = []
    values = []
    
    # Parcourir chaque bloc et ses indices positionnels
    for (block, (block_i, block_j)) in zip(attentions, block_positions):
        # Aplatir le bloc
        flat_block = block.flatten()
        
        # Obtenir les indices relatifs à l'intérieur du bloc
        num_rows, num_cols = block.shape
        for idx, val in enumerate(flat_block):
            # Calculer l'indice relatif dans le bloc
            row_idx = idx // num_cols
            col_idx = idx % num_cols
            
            # Calculer les indices globaux
            global_row_idx = block_i + row_idx
            global_col_idx = block_j + col_idx
            
            # Assurer que les indices sont dans les limites
            if global_row_idx < max_rows and global_col_idx < max_cols:
                indices_i.append(global_row_idx)
                indices_j.append(global_col_idx)
                values.append(val)
    
    indices = torch.LongTensor([indices_i, indices_j])
    values = torch.FloatTensor(values)
    size = (max_rows, max_cols)
    
    sparse_tensor = torch.sparse_coo_tensor(indices, values, size)
    
    return sparse_tensor

# Exemple d'utilisation
blocks = [torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])]
positions = [(0, 0), (2, 1)]
max_rows, max_cols = 10000, 100000  # Taille globale de la matrice

sparse_matrix = create_sparse_coo_with_variable_blocks(blocks, positions, max_rows, max_cols)
dense_matrix = sparse_matrix.to_dense()
