In [1]:
import torch
import esm
from model import SparseForTokenClassification
from typing import Tuple, List
from pair_ranking_script import load_fasta_as_tuples

In [2]:
_, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()

In [3]:
data1 = load_fasta_as_tuples("NZ_ANAO01000007_products_concatenated_pros.fasta")
data2 = load_fasta_as_tuples("NZ_ANAO01000007_products_proteins.fasta")
print(len(data1[0][1]))

51986


In [4]:
batch_labels1, batch_strs1, batch_tokens1 = batch_converter(data1) # GENOME
batch_labels2, batch_strs2, batch_tokens2 = batch_converter(data2) # PROTEINS
print(batch_tokens1.shape)
print(batch_tokens2.shape)


torch.Size([1, 51988])
torch.Size([152, 1072])


In [5]:
version = '3C' # or '5B'
#checkpoint = torch.load(f'./models/{version}/config_and_model.pth', map_location='cpu', weights_only=False)
checkpoint = torch.load(f'/home/thibaut/mahdi/models/{version}/config_and_model.pth', map_location='cpu', weights_only=False)

In [6]:
device = "cuda:1"
config = checkpoint['config']
model_state_dict = checkpoint['model_state_dict']
sparse_model = SparseForTokenClassification(config=config)
sparse_model.load_state_dict(model_state_dict)
sparse_model = sparse_model.to(device)
sparse_model = sparse_model.eval()

# Complete genome embeddings 

In [7]:
with torch.no_grad():
    batch_tokens1 = batch_tokens1.to(device)
    output = sparse_model(input_ids=batch_tokens1, output_attentions = False)

# 1m53s for 52k aa sequence     

In [8]:
output.logits.shape # Genome embeddings
# torch.Size([1, 51988, 1280])


torch.Size([1, 51988, 1280])

# Non-contextualised proteins embeddings

In [9]:
with torch.no_grad():
    batch_tokens2 = batch_tokens2.to(device)
    output2 = sparse_model(input_ids=batch_tokens2, output_attentions = False)


In [10]:
output2.logits.shape # Padded protein embeddings
# torch.Size([number_proteins, longest_protein+2, 1280])

torch.Size([152, 1072, 1280])

In [11]:
# Function to remove padding from embeddings - returns a list of embeddings. 
def remove_padding_embeddings(data, output) : 
    out = output.logits
    res = []
    for i, prot in enumerate(data) : 
        size = len(prot[1])
        ex = out[i][1:size+1]
        res.append(ex)
    return(res)

In [12]:
A = remove_padding_embeddings(data2, output2) # List of non-padded non-contextualised protein embeddings. 

# Contextualised protein embeddings

In [9]:
def cumulative_sum(liste) : 
        r = 0
        res = [0]
        for l in liste : 
            r+=l
            res.append(int(r))
        return res

def extract_embeddding(embeddings, data, index):
    proteins_sizes = [len(prot[1]) for prot in data]
    embeddings = embeddings[:,1:-1,:]
    proteins_size_cs = cumulative_sum(proteins_sizes)
    start, end = proteins_size_cs[index], proteins_size_cs[index+1]
    subset = embeddings[:,start:end,:]
    return subset

In [13]:
B = extract_embeddding(output.logits, data2, 2)  
# output.logits is the embeddings of the genome 
# data2 is the data list ('prot_id', 'SEQ') for every protein in the genome
# 2 is the index of the desired protein in the genome
print(B.shape)
print(len(data2[2][1]))

torch.Size([1, 306, 1280])
306


# Proteins pairs ranking

In [7]:
with torch.no_grad():
    batch_tokens1 = batch_tokens1.to(device)
    prots_lengths = [len(prot[1]) for prot in data2]
    print(prots_lengths)
    output2 = sparse_model(input_ids=batch_tokens1, output_attentions = True, two_step_selection= True, proteins_sizes = torch.tensor(prots_lengths))

# Very long (approx 3h for 52k genome) - due to import number of pairs 
# Returns the dictionnary of protein pairs with score

[1014, 174, 306, 296, 299, 85, 240, 466, 388, 374, 492, 490, 65, 391, 236, 471, 291, 644, 200, 103, 609, 85, 145, 142, 66, 430, 463, 515, 573, 433, 461, 210, 157, 72, 420, 314, 270, 67, 383, 437, 390, 338, 283, 114, 350, 390, 113, 799, 215, 308, 311, 494, 131, 298, 330, 537, 152, 449, 339, 308, 322, 301, 593, 41, 140, 483, 416, 225, 463, 173, 326, 477, 185, 144, 125, 475, 214, 218, 276, 242, 486, 323, 228, 490, 477, 432, 128, 392, 338, 420, 90, 172, 182, 566, 43, 423, 487, 788, 125, 471, 377, 623, 160, 978, 268, 168, 40, 223, 487, 170, 466, 308, 174, 399, 448, 235, 534, 660, 238, 273, 416, 223, 234, 458, 273, 310, 365, 188, 289, 167, 84, 434, 230, 534, 658, 526, 182, 965, 516, 288, 1070, 391, 631, 689, 162, 295, 380, 217, 258, 153, 239, 106]


51988
init
done
51988
init
done
51988
init
done
51988
init
done
51988
init
done
51988
init
done
51988
init


# Protein Pair Attention

In [8]:
with torch.no_grad():
    batch_tokens1 = batch_tokens1.to(device)
    prots_lengths = [len(prot[1]) for prot in data2]
    prot_int = [1, 2]
    print(prots_lengths)
    output3 = sparse_model(input_ids=batch_tokens1, output_attentions = True, two_step_selection= False, proteins_sizes = torch.tensor(prots_lengths), proteins_interactions = torch.tensor(prot_int))

# Approx 1m50 to run for the 52k genome

[1014, 174, 306, 296, 299, 85, 240, 466, 388, 374, 492, 490, 65, 391, 236, 471, 291, 644, 200, 103, 609, 85, 145, 142, 66, 430, 463, 515, 573, 433, 461, 210, 157, 72, 420, 314, 270, 67, 383, 437, 390, 338, 283, 114, 350, 390, 113, 799, 215, 308, 311, 494, 131, 298, 330, 537, 152, 449, 339, 308, 322, 301, 593, 41, 140, 483, 416, 225, 463, 173, 326, 477, 185, 144, 125, 475, 214, 218, 276, 242, 486, 323, 228, 490, 477, 432, 128, 392, 338, 420, 90, 172, 182, 566, 43, 423, 487, 788, 125, 471, 377, 623, 160, 978, 268, 168, 40, 223, 487, 170, 466, 308, 174, 399, 448, 235, 534, 660, 238, 273, 416, 223, 234, 458, 273, 310, 365, 188, 289, 167, 84, 434, 230, 534, 658, 526, 182, 965, 516, 288, 1070, 391, 631, 689, 162, 295, 380, 217, 258, 153, 239, 106]


In [12]:
def compute_combined_attention(attentions: Tuple[torch.Tensor, ...], weight_scheme: List[float]) -> torch.Tensor:
    x,y,n1,n2 = len(attentions), attentions[0][0].shape[0], attentions[0][0].shape[1], attentions[0][0].shape[2]
    assert len(weight_scheme) == x * y
    
    weighted_sum = torch.zeros_like(attentions[0][0][0])

    stacked_attentions = torch.stack(attentions, dim=0)
    weights = torch.tensor(weight_scheme, dtype=stacked_attentions.dtype, device=stacked_attentions.device)

    flattened_attentions = stacked_attentions.view(x * y, n1, n2)
    weighted_sum = torch.sum(flattened_attentions * weights[:, None, None], dim=0)
    
    return weighted_sum

In [15]:
print(len(output3.attentions), output3.attentions[0].shape)
weight_scheme = [1]*(len(output3.attentions)*output3.attentions[0][0].shape[0]) # num_layers * num_heads
print(compute_combined_attention(output3.attentions,weight_scheme).shape)

16 torch.Size([1, 20, 174, 306])
torch.Size([174, 306])
