In [91]:
import torch
import torch.nn as nn
from typing import Tuple
import torch.nn.functional as F
from typing import Union
from functools import lru_cache
import random
import numpy as np
import time

In [295]:
# Initialize parameters : 

batch_size = 1
seq_length = 4096
num_attention_heads = 20
hidden_size = 1280
attention_head_size = int(hidden_size / num_attention_heads)
all_head_size = num_attention_heads * attention_head_size

position_embedding_type = "rotary"

In [296]:
# Initialize layers : 

query_layer = nn.Linear(hidden_size, all_head_size)
key_layer = nn.Linear(hidden_size, all_head_size)
value_layer = nn.Linear(hidden_size, all_head_size)

dropout_layer = nn.Dropout(0.5)

In [297]:
# Implement RoPE : 

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    cos = cos[:, :, : x.shape[-2], :]
    sin = sin[:, :, : x.shape[-2], :]
    return (x * cos) + (rotate_half(x) * sin)

class RotaryEmbedding(torch.nn.Module):
    """
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
    matrices which depend on their relative positions.
    """

    def __init__(self, dim: int):
        super().__init__()
        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
        inv_freq = inv_freq
        self.register_buffer("inv_freq", inv_freq)

        self._seq_len_cached = None
        self._cos_cached = None
        self._sin_cached = None

    def _update_cos_sin_tables(self, x, seq_dimension=2):
        seq_len = x.shape[seq_dimension]

        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            self._cos_cached = emb.cos()[None, None, :, :]
            self._sin_cached = emb.sin()[None, None, :, :]

        return self._cos_cached, self._sin_cached

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )


In [298]:
# Transpose for attentions :

def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

In [299]:
hidden_states = torch.randn((1, seq_length, hidden_size))
hidden_states.shape

torch.Size([1, 4096, 1280])

In [300]:
# Query :

mixed_query_layer = query_layer(hidden_states)
print(mixed_query_layer.shape)
query = transpose_for_scores(mixed_query_layer)
query = query * attention_head_size**-0.5
print(query.shape)

torch.Size([1, 4096, 1280])
torch.Size([1, 20, 4096, 64])


In [301]:
# Key, Value :

key = key_layer(hidden_states)
key = transpose_for_scores(key)

value = value_layer(hidden_states)
value = transpose_for_scores(value)

In [302]:
# Positional embedding : 

if position_embedding_type == "rotary":
    rotary_embeddings = RotaryEmbedding(attention_head_size)
    query, key = rotary_embeddings(query, key)

In [303]:
# Matrix multiplication for attention scores :
#tmean = []
#for _ in range(0, 1) : 
start_time = time.time()
attention_scores_full = torch.matmul(query, key.transpose(-1, -2))
print(f'Elapsed time for full matmul : {time.time()-start_time}')
#tmean.append(time.time()-start_time)
attention_scores_full.shape

#print(sum(tmean)/len(tmean))

Elapsed time for full matmul : 0.8946590423583984


torch.Size([1, 20, 4096, 4096])

In [304]:
# Attention mask : 

attention_mask_full = None
if attention_mask_full is not None:
    # Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
    attention_mask_full = attention_mask_full
    attention_scores_full = attention_scores_full + attention_mask_full

In [305]:
# Transformations before Att*V : 

attention_probs_full = nn.functional.softmax(attention_scores_full, dim=-1)
attention_probs_full = dropout_layer(attention_probs_full)

In [306]:
# Context layer :
t=time.time()
context = torch.matmul(attention_probs_full, value)
print(f'Elapsed time for context layer matmul : {time.time()-t}')
context_layer = context.permute(0, 2, 1, 3).contiguous()

new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
print(context_layer.shape)

Elapsed time for context layer matmul : 1.059563159942627
torch.Size([1, 4096, 1280])


In [307]:
# Output
output_attentions = True
outputs = (context_layer, attention_probs_full) if output_attentions else (context_layer,)

In [308]:
outputs[0].shape, outputs[1].shape

(torch.Size([1, 4096, 1280]), torch.Size([1, 20, 4096, 4096]))

Proteins blocks attentions : 

In [361]:
# Input : hidden states + sequence information (proteins + interactions map)

# Random proteins information : 
def generate_list(n, Amin, Amax):
    ni_list = []
    remaining = n
    while remaining > Amin:
        ni = random.randint(Amin, min(Amax, remaining))
        ni_list.append(ni)
        remaining -= ni
    ni_list.append(remaining)
    return ni_list

# Random proteins interactions : 
def generate_couples(n_couples, n_len):
    couples = []
    while len(couples) < n_couples :
        i = random.randint(0, n_len - 1)
        j = random.randint(0, n_len - 1)
        if abs(i - j) > 1 and (i, j) not in couples and (j, i) not in couples:
            couples.append((i, j))
    return couples

proteins_sizes = generate_list(seq_length, Amin = 65, Amax = 260) 
proteins_cs =  [0]+list(np.cumsum(np.array(proteins_sizes)))
print(proteins_cs)
n = len(proteins_sizes)
print(n)
proteins_interactions = generate_couples(n_couples = 5 * n, n_len = n) # Max number of couples : n * (n - 3) / 2
sorted_proteins_interactions = sorted(proteins_interactions, key=lambda x: x[0])
print(len(proteins_interactions)), print(int(n * (n - 3) / 2))

[0, 197, 445, 623, 825, 1075, 1243, 1403, 1509, 1605, 1729, 1917, 2147, 2241, 2321, 2566, 2636, 2841, 3049, 3158, 3246, 3392, 3514, 3638, 3723, 3821, 4077, 4096]
27
135
324


(None, None)

In [369]:
def reshape_tensor_padding(tensor, proteins_sizes, block_size):
    _, seq_len, hidden_dim = tensor.shape
    sub_blocks = []
    start_index = 0
    padding_storage = []

    for i, size in enumerate(proteins_sizes):
        num_full_blocks = size // block_size  # Nombre de sous-blocs entiers
        end_index = start_index + num_full_blocks * block_size  # Index de fin pour les sous-blocs entiers
        if num_full_blocks > 0:
            for block in range(num_full_blocks) : 
                sub_blocks.append(tensor[0][start_index + block * block_size : start_index + (block+1) * block_size].reshape(block_size, hidden_dim).unsqueeze(0))
        
        remainder = size % block_size
        padding_storage.append((i, num_full_blocks, remainder))
        
        if remainder > 0:
            padding = torch.zeros((block_size - remainder, hidden_dim), dtype=tensor.dtype, device=tensor.device)
            remainder_block = torch.cat([tensor[0][end_index:end_index + remainder], padding], dim=0)
            sub_blocks.append(remainder_block.unsqueeze(0))
        
        start_index += size
    if sub_blocks:
        result_tensor = torch.cat(sub_blocks, dim=0)
            
    return result_tensor.unsqueeze(0), padding_storage

In [370]:
def chunk_proteins_padding(sorted_proteins_interactions, proteins_lengths, block_size):
    num_blocks = [(length + block_size - 1) // block_size for length in proteins_lengths]  # (length + block_size - 1) // block_size arrondit à l'entier supérieur

    num_blocks_cs = np.cumsum([0] + num_blocks).tolist()
    index = 0
    chunked_blocks = []
    
    for h in num_blocks:
        for j in range(h):
            chunked_blocks.append(index)
            index += 1

    block_interactions = []
    
    for i, j in sorted_proteins_interactions:
        for k in range(num_blocks[i]):
            for h in range(num_blocks[j]):
                block_interactions.append((num_blocks_cs[i] + k, num_blocks_cs[j] + h))

    return block_interactions, chunked_blocks

In [371]:
def transpose_for_scores_padding(x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(new_x_shape)
        x_permuted = x.permute(0, 3, 2, 1, 4)
        return x_permuted

In [372]:
def sparse_attention_matrix_padding(query, key, proteins_interactions, proteins_cs, proteins_list, block_size):
    batch_size, num_heads, bloc_size_exp, num_blocks, all_head_size = query.shape
    assert bloc_size_exp == block_size
    start_time = time.time()
    attentions = []
    block_positions = []
    sorted_proteins_interactions = sorted(proteins_interactions, key=lambda x: x[0])
    
    for i, j in sorted_proteins_interactions :
        # Blocks
        query_block = query[:, :, :,i:i+1, :]
        key_block = key[:, :, :,j:j+1, :]
        
        # Compute attention matrix for the 2 blocks
        attention_block = torch.matmul(query_block, key_block.transpose(-1, -2)).squeeze()
        attentions.append(attention_block) 

    amc = time.time()-start_time

    # Create block sparse matrix with torch.sparse_bsr_tensor
    
    # columns and rows : 
    t2 = time.time()
    
    col_indices = [x[1] for x in sorted_proteins_interactions]
    rows = [x[0] for x in sorted_proteins_interactions]
    crow_indices = [0] + rows_to_crows(rows, len(proteins_cs) - 1)[:-1] # CHECK CROW
    crow_tensor = torch.stack([torch.tensor(crow_indices)] * 20)
    col_tensor = torch.stack([torch.tensor(col_indices)] * 20)

    # values :    
    concatenated_attentions = torch.stack(attentions, dim=1)

    #sparse_matrix = create_sparse_coo_with_variable_blocks(attentions, block_positions, seq_len, seq_len)
    sparse_matrix = torch.sparse_bsr_tensor(crow_tensor, col_tensor, concatenated_attentions, size = [num_heads, num_blocks*block_size, num_blocks*block_size])
    smc = time.time() - t2
    
    return sparse_matrix, amc, smc

In [403]:
block_size = 64

chunked_hidden_states_padding, padding_storage = reshape_tensor_padding(hidden_states, proteins_sizes, block_size)
print(hidden_states.shape)
print(padding_storage)
print(chunked_hidden_states_padding.shape)
chunked_mixed_query_layer_padding = query_layer(chunked_hidden_states_padding)
chunked_query_padding = transpose_for_scores_padding(chunked_mixed_query_layer_padding)
chunked_query_padding = chunked_query_padding * attention_head_size**-0.5
print(chunked_query_padding.shape)

chunked_key_padding = key_layer(chunked_hidden_states_padding)
chunked_key_padding = transpose_for_scores_padding(chunked_key_padding)

chunked_value_padding = value_layer(chunked_hidden_states_padding)
chunked_value_padding = transpose_for_scores_padding(chunked_value_padding)
chunked_value_padding = chunked_value_padding.reshape(1, num_attention_heads , -1, attention_head_size).squeeze()

print(chunked_value_padding.shape)
print(len(padding_storage))

torch.Size([1, 4096, 1280])
[(0, 3, 5), (1, 3, 56), (2, 2, 50), (3, 3, 10), (4, 3, 58), (5, 2, 40), (6, 2, 32), (7, 1, 42), (8, 1, 32), (9, 1, 60), (10, 2, 60), (11, 3, 38), (12, 1, 30), (13, 1, 16), (14, 3, 53), (15, 1, 6), (16, 3, 13), (17, 3, 16), (18, 1, 45), (19, 1, 24), (20, 2, 18), (21, 1, 58), (22, 1, 60), (23, 1, 21), (24, 1, 34), (25, 4, 0), (26, 0, 19)]
torch.Size([1, 76, 64, 1280])
torch.Size([1, 20, 64, 76, 64])
torch.Size([20, 4864, 64])
27


In [398]:
chunked_interactions_padding, chunked_blocks_padding = chunk_proteins_padding(sorted_proteins_interactions, proteins_sizes, block_size)
proteins_chunked_sizes_padding = [block_size for _ in chunked_blocks_padding]
proteins_chunked_cs_padding =  [0]+list(np.cumsum(np.array(proteins_chunked_sizes_padding)))

In [399]:
tt = time.time()
spmd, amc, smc = sparse_attention_matrix_padding(chunked_query_padding, chunked_key_padding, chunked_interactions_padding, proteins_chunked_cs_padding, proteins_chunked_sizes_padding, block_size)
print(f'Elapsed time : {time.time()-tt}')

Elapsed time : 1.1163229942321777


In [400]:
amc, smc

(1.0970592498779297, 0.005131959915161133)

In [401]:
from torch.sparse._triton_ops import bsr_softmax, bsr_dense_mm # Only supported by CUDA and Triton ? 

# chunked_sparse_probs = bsr_softmax(chunked_sparse_matrix, dim=-1)
chunked_sparse_probs_pd = sparse_bsr_dropout(spmd, 0.1, True)
print(chunked_sparse_probs_pd.shape)

torch.Size([20, 4864, 4864])


In [402]:
tv = time.time()
batch_context = []
for batch in range(chunked_value_padding.shape[0]):
    chunked_context = torch.sparse._triton_ops.bsr_dense_mm(chunked_sparse_probs_pd[batch], chunked_value_padding[batch])
    batch_context.append(chunked_context)
    
chunked_context = torch.stack(batch_context, dim=0).unsqueeze(0)
print(f'Elapsed time for context layer matmul : {time.time()-tv}')
# Only way to perform batched matmul between sparse tensor and dense tensor (bmm sparse not implemented yet) - max : dim = 2 * dim = 2

print(chunked_context.shape)

chunked_context_layer = chunked_context.permute(0, 2, 1, 3).contiguous()
new_chunked_context_layer_shape = chunked_context_layer.size()[:-2] + (all_head_size,)
chunked_context_layer = chunked_context_layer.view(new_chunked_context_layer_shape)
print(chunked_context_layer.shape)

TypeError: 'NoneType' object is not callable

In [61]:
# Divide k, q, v based on block_size : 

def reshape_tensor(tensor, proteins_sizes, block_size):
    _, seq_len, hidden_dim = tensor.shape
    sub_blocks = []
    start_index = 0

        for size in proteins_sizes:
        num_full_blocks = size // block_size  # Nombre de sous-blocs entiers
        end_index = start_index + num_full_blocks * block_size  # Index de fin pour les sous-blocs entiers
        if num_full_blocks > 0:
            sub_blocks.append(tensor[0][start_index:end_index].reshape(num_full_blocks * block_size, hidden_dim))
        start_index += size

    if sub_blocks:
        result_tensor = torch.cat(sub_blocks, dim=0)
    print(result_tensor.shape)
    return result_tensor.unsqueeze(0)

In [76]:
# Divide interactions in blocks : 

block_size = 50

def chunk_proteins(sorted_proteins_interactions, proteins_lengths, block_size):
    # Calculer le nombre complet de blocs pour chaque protéine
    num_blocks = [length // block_size for length in proteins_lengths]

    # Cumulative sum pour obtenir les indices de début pour chaque protéine dans la grille globale
    num_blocks_cs = np.cumsum([0] + num_blocks).tolist()
    index = 0
    chunked_blocks = []
    
    for h in num_blocks : 
        for j in range(h) : 
            chunked_blocks.append(index)
            index +=1
    block_interactions = []
              
    for i, j in sorted_proteins_interactions:
        for k in range(num_blocks[i]):
            num_blocks_cs[i] + k
            for h in range(num_blocks[j]):
                block_interactions.append((num_blocks_cs[i] + k, num_blocks_cs[j] + h))

    return block_interactions, chunked_blocks


In [63]:
def rows_to_crows(rows, n): 
    rows = np.array(rows)
    counts = np.bincount(rows, minlength=n+1)
    
    if counts.size < n+1:
        counts = np.pad(counts, (0, n+1-counts.size), constant_values=0)
    counts_cs = np.cumsum(counts)
    
    return counts_cs.tolist()

In [64]:
def sparse_attention_matrix(query, key, proteins_interactions, proteins_cs, proteins_list, block_size):
    batch_size, num_heads, seq_len, all_head_size = query.shape
    start_time = time.time()
    attentions = []
    block_positions = []
    sorted_proteins_interactions = sorted(proteins_interactions, key=lambda x: x[0])
    
    for i, j in sorted_proteins_interactions :
        # Blocks
        query_block = query[:, :, proteins_cs[i]:proteins_cs[i+1], :]
        key_block = key[:, :, proteins_cs[j]:proteins_cs[j+1], :]
        
        # Compute attention matrix for the 2 blocks
        attention_block = torch.matmul(query_block, key_block.transpose(-1, -2)).squeeze()
        attentions.append(attention_block) 

    amc = time.time()-start_time

    # Create block sparse matrix with torch.sparse_bsr_tensor
    
    # columns and rows : 
    t2 = time.time()
    col_indices = [x[1] for x in sorted_proteins_interactions]
    rows = [x[0] for x in sorted_proteins_interactions]
    crow_indices = [0] + rows_to_crows(rows, len(proteins_cs) - 1)[:-1]
    crow_tensor = torch.stack([torch.tensor(crow_indices)] * 20)
    col_tensor = torch.stack([torch.tensor(col_indices)] * 20)

    # values :    
    concatenated_attentions = torch.stack(attentions, dim=1)

    #sparse_matrix = create_sparse_coo_with_variable_blocks(attentions, block_positions, seq_len, seq_len)
    sparse_matrix = torch.sparse_bsr_tensor(crow_tensor, col_tensor, concatenated_attentions, size = [num_heads, len(proteins_list)*block_size, len(proteins_list)*block_size])
    smc = time.time() - t2
    
    return sparse_matrix, amc, smc

In [134]:
# Query, Key, Value chunked by block_size
block_size = 50

chunked_hidden_states = reshape_tensor(hidden_states, proteins_sizes, block_size)
print(chunked_hidden_states.shape)
chunked_mixed_query_layer = query_layer(chunked_hidden_states)
chunked_query = transpose_for_scores_padding(chunked_mixed_query_layer)
chunked_query = chunked_query * attention_head_size**-0.5

chunked_key = key_layer(chunked_hidden_states)
chunked_key = transpose_for_scores_padding(chunked_key)

chunked_value = value_layer(chunked_hidden_states)
chunked_value = transpose_for_scores_padding(chunked_value).squeeze()
print(chunked_value.shape)

torch.Size([1, 48, 50, 1280])
torch.Size([20, 50, 48, 64])


In [133]:
chunked_interactions, chunked_blocks = chunk_proteins(sorted_proteins_interactions, proteins_sizes, block_size)
proteins_chunked_sizes = [block_size for _ in chunked_blocks]
proteins_chunked_cs =  [0]+list(np.cumsum(np.array(proteins_chunked_sizes)))

In [23]:
chunked_sparse_matrix, attention_time, matrix_time = sparse_attention_matrix(chunked_query, chunked_key, chunked_interactions, proteins_chunked_cs, proteins_chunked_sizes, block_size)
chunked_sparse_matrix.shape

  sparse_matrix = torch.sparse_bsr_tensor(crow_tensor, col_tensor, concatenated_attentions, size = [num_heads, len(proteins_list)*block_size, len(proteins_list)*block_size])


torch.Size([20, 1450, 1450])

In [24]:
def sparse_bsr_dropout(x, p, training):
    values = x.values()  
    dropped_values = F.dropout(values, p=p, training=training)  
    new_sparse_tensor = torch.sparse_bsr_tensor(x.crow_indices(), x.col_indices(), dropped_values, size=x.size())
    return new_sparse_tensor

In [25]:
from torch.sparse._triton_ops import bsr_softmax # Only supported by CUDA and Triton ? 

# chunked_sparse_probs = bsr_softmax(chunked_sparse_matrix, dim=-1)
chunked_sparse_probs = sparse_bsr_dropout(chunked_sparse_matrix, 0.1, True)

In [395]:
tv = time.time()
batch_context = []
for batch in range(chunked_value.shape[0]):
    chunked_context = torch.sparse.mm(chunked_sparse_probs[batch], chunked_value[batch])
    batch_context.append(chunked_context)
    
chunked_context = torch.stack(batch_context, dim=0).unsqueeze(0)
print(f'Elapsed time for context layer matmul : {time.time()-tv}')
# Only way to perform batched matmul between sparse tensor and dense tensor (bmm sparse not implemented yet) - max : dim = 2 * dim = 2

print(chunked_context.shape)

chunked_context_layer = chunked_context.permute(0, 2, 1, 3).contiguous()
new_chunked_context_layer_shape = chunked_context_layer.size()[:-2] + (all_head_size,)
chunked_context_layer = chunked_context_layer.view(new_chunked_context_layer_shape)
print(chunked_context_layer.shape)

RuntimeError: mat2 must be a matrix, got 3-D tensor

In [27]:
blocks_sizes_eval = [i for i in range(2, 101, 5)]

In [28]:
# TO RUN TO EVALUATE BLOCK_SIZE (very slow for 2 - start higher ?)

attention_times = []
sparse_matrix_creation_times = []
total_times = []
zero_values = []
for block_size in blocks_sizes_eval : 
    attention_times_mean = []
    sparse_matrix_creation_times_mean = []
    total_times_mean = []
    zero_values_mean = []
    for s in range(1, 5) : 
        tt = time.time()
        chunked_hidden_states = reshape_tensor(hidden_states, proteins_sizes, block_size)
        
        chunked_mixed_query_layer = query_layer(chunked_hidden_states)
        chunked_query = transpose_for_scores(chunked_mixed_query_layer)
        chunked_query = chunked_query * attention_head_size**-0.5
        
        chunked_key = key_layer(chunked_hidden_states)
        chunked_key = transpose_for_scores(chunked_key)
        
        chunked_value = value_layer(chunked_hidden_states)
        chunked_value = transpose_for_scores(chunked_value).squeeze()
    
        chunked_interactions, chunked_blocks = chunk_proteins(sorted_proteins_interactions, proteins_sizes, block_size)
        proteins_chunked_sizes = [block_size for _ in chunked_blocks]
        proteins_chunked_cs =  [0]+list(np.cumsum(np.array(proteins_chunked_sizes)))
    
        chunked_sparse_matrix, attention_time, matrix_time = sparse_attention_matrix(chunked_query, chunked_key, chunked_interactions, proteins_chunked_cs, proteins_chunked_sizes, block_size)

        
        chunked_sparse_probs = sparse_bsr_dropout(chunked_sparse_matrix, 0.1, True)
    
        batch_context = []
        for batch in range(chunked_value.shape[0]):
            chunked_context = torch.sparse.mm(chunked_sparse_probs[batch], chunked_value[batch])
            batch_context.append(chunked_context)
            
        chunked_context = torch.stack(batch_context, dim=0).unsqueeze(0)
    
        chunked_context_layer = chunked_context.permute(0, 2, 1, 3).contiguous()
        new_chunked_context_layer_shape = chunked_context_layer.size()[:-2] + (all_head_size,)
        chunked_context_layer = chunked_context_layer.view(new_chunked_context_layer_shape)
    
        output_attentions = True
        outputs = (chunked_context_layer, chunked_sparse_probs) if output_attentions else (chunked_context_layer,)
        ft = time.time()-tt
        nv = seq_length - chunked_hidden_states.shape[1]

        attention_times_mean.append(attention_time)
        sparse_matrix_creation_times_mean.append(matrix_time)
        total_times_mean.append(ft)
        zero_values_mean.append(nv)
        print(zero_values_mean)

    zero_values.append(sum(zero_values_mean)/len(zero_values_mean))
    attention_times.append(sum(attention_times_mean)/len(attention_times_mean))
    sparse_matrix_creation_times.append(sum(sparse_matrix_creation_times_mean)/len(sparse_matrix_creation_times_mean))
    total_times.append(sum(total_times_mean)/len(total_times_mean))

KeyboardInterrupt: 

In [None]:
import pickle 

with open('/home/thibaut/blocks_sizes_8192.pickle', 'wb') as fichier0:
    pickle.dump(blocks_sizes_eval, fichier0)

with open('/home/thibaut/attention_times_8192.pickle', 'wb') as fichier1:
    pickle.dump(attention_times, fichier1)

with open('/home/thibaut/matrix_times_8192.pickle', 'wb') as fichier2:
    pickle.dump(sparse_matrix_creation_times, fichier2)

with open('/home/thibaut/global_times_8192.pickle', 'wb') as fichier3:
    pickle.dump(total_times, fichier3)

with open('/home/thibaut/zero_values_8192.pickle', 'wb') as fichier4:
    pickle.dump(zero_values, fichier4)