In [89]:
pip install torch -U

[33mDEPRECATION: pytorch-lightning 1.6.5 has a non-standard dependency specifier torch>=1.8.*. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [90]:
import torch
import esm
from model import SparseForTokenClassification
from typing import Tuple, List


### Extracting Embedding for Both Protein Sequences and for Genome
#####  Extracting protien for a single protein 


In [91]:
_, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()


In [92]:
# Usign a short 10 amino acid sequence as a proof of concept.

data  = [('seq_1', 'MLKKLSVFLI')]

batch_labels, batch_strs, batch_tokens = batch_converter(data)

##### Recreate the model

In [93]:
version = '3C' # or '5B'
checkpoint = torch.load(f'./models/{version}/config_and_model.pth', map_location='cpu', weights_only=False)
# checkpoint = torch.load(f'/home/thibaut/mahdi/models/{version}/config_and_model.pth', map_location='cpu', weights_only=False)

In [94]:
device = "cpu"
config = checkpoint['config']
model_state_dict = checkpoint['model_state_dict']
sparse_model = SparseForTokenClassification(config=config)
sparse_model.load_state_dict(model_state_dict)
sparse_model = sparse_model.to(device)
sparse_model = sparse_model.eval()

In [95]:
output = sparse_model(input_ids=batch_tokens, 
                      output_attentions = True)


In [96]:
output.logits.size()

torch.Size([1, 12, 1280])

### We can extract the sequence embeddings using


In [97]:
output.logits[0,1:-1:].size()

torch.Size([10, 1280])

In [98]:
len(output.attentions)

16

In [99]:
[output.attentions[i].size() for i in range(len(output.attentions))]

[torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12]),
 torch.Size([1, 20, 12, 12])]

# Thibaut 

Could you include the code here that combines all the attention heads and layers into a single 10x10 matrix? It should be in a single function. For now, we will provide equal weights to each matrix, but the function should be able to take a different weighting scheme for each layer. See scaffold below


In [100]:
def compute_combined_attention(attentions: Tuple[torch.Tensor, ...], weight_scheme: List[float]) -> torch.Tensor:
    x,y,n1,n2 = len(attentions), attentions[0][0].shape[0], attentions[0][0].shape[1], attentions[0][0].shape[2]
    assert len(weight_scheme) == x * y
    
    weighted_sum = torch.zeros_like(attentions[0][0][0])

    stacked_attentions = torch.stack(attentions, dim=0)
    weights = torch.tensor(weight_scheme, dtype=stacked_attentions.dtype, device=stacked_attentions.device)

    flattened_attentions = stacked_attentions.view(x * y, n1, n2)
    weighted_sum = torch.sum(flattened_attentions * weights[:, None, None], dim=0)
    
    return weighted_sum

In [101]:
nb_layars = len(output.attentions)
nb_heads = output.attentions[0][0].shape[0]

In [102]:
len([1]*(nb_layars * nb_heads))

320

In [103]:
weight_scheme = [1]*(nb_layars * nb_heads)
len(weight_scheme), weight_scheme[0:10]

(320, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [104]:
compute_combined_attention(output.attentions, weight_scheme)

tensor([[16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000,
         16.0000, 16.0000, 16.0000, 16.0000],
        [16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000,
         16.0000, 16.0000, 16.0000, 16.0000],
        [16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000,
         16.0000, 16.0000, 16.0000, 16.0000],
        [16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000,
         16.0000, 16.0000, 16.0000, 16.0000],
        [16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000,
         16.0000, 16.0000, 16.0000, 16.0000],
        [16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000,
         16.0000, 16.0000, 16.0000, 16.0000],
        [16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000,
         16.0000, 16.0000, 16.0000, 16.0000],
        [16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000, 16.0000,
         16.0000, 16.0000, 16.

### Computing the Complete Genome Annotation

In [105]:
data  = [('seq_1', 'MLKKLSVFLI'), ('seq_2', 'MLSVFLI'), ('seq_3', 'LKKLSVFLI'), ('seq_4', 'MLKKLSVIMMMKV')]

# concat all the sequences in data
prots_concatenated = "".join([x[1] for x in data])
prots_lengths = [len(x[1]) for x in data]
genome = [("some_genome", prots_concatenated)]
genome , prots_lengths

([('some_genome', 'MLKKLSVFLIMLSVFLILKKLSVFLIMLKKLSVIMMMKV')], [10, 7, 9, 13])

In [106]:
len(genome[0][1])

39

In [107]:
batch_labels, batch_strs, batch_tokens = batch_converter(genome)

In [108]:
output = sparse_model(input_ids=batch_tokens, 
                      output_attentions = True, 
                      proteins_sizes = torch.tensor(prots_lengths))


In [109]:
output.keys()

odict_keys(['logits', 'attentions'])

In [110]:
output['logits'].size()

torch.Size([1, 41, 1280])

In [111]:
len(output['attentions'])

16

In [112]:
output['attentions'][0].size()

torch.Size([1, 20, 41, 41])

In [113]:
sum(prots_lengths)

39

In [114]:
# THIBAUT: my understanding is that two_step_selection= True determines whether to return all pairwise scores  or  not.
# Adding here, causes an error.


In [116]:
output = sparse_model(input_ids=batch_tokens, 
                      output_attentions = True, 
                      proteins_sizes = torch.tensor(prots_lengths),  two_step_selection= True)


### Ranking the pairwise interactions

1. I know that we need to run

In [2]:
from pair_ranking_script import get_pairs_dataset
# I am going to run the code below step by step we can use exactly what is needed
# output = get_pairs_dataset("../data/LR699048.fa", "./data/LR699048.pt", "./data/LR699048.txt", plot_fragment_heatmap = False)
# # generates the folowing
# all_pairs: attention between all pairs of proteins 
# all_nc_pairs:  attention between all pairs of non-consecutive proteins
# global_length: total number of all pairwise comparisons ()
# non_consecutive_length: total number of all non-consecutive pairwise comparisons
# database_dic: same as all_pairs but in dictionary format where key is genome name
# all_top_pairs_labeled: same as all_pairs but with protein labels instead of protein indexes. Uses the labels_file to get the labels
# all_top_pairs_labeled_nc: same as all_nc_pairs but with protein labels instead of protein indexes. Uses the labels_file to get the labels

1 assembled fragments
Number of sequences to process : 0
10 cpus availables
0 fragments processed


In [4]:
from pair_ranking_script import load_fasta_as_tuples
data_assembled = load_fasta_as_tuples("../data/LR699048.fa")
    global_length = []
    non_consecutive_length = []
    all_length_dict = []
    j = 0
    all_top_pairs = []
    all_top_pairs_nc = []
    all_top_pairs_labeled = []
    all_top_pairs_labeled_nc = []
    database_dic = {}
    mlp = []
    genomes = []

In [None]:
### THIBAUT