In [None]:
import json
import torch
import genova
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from collections import OrderedDict
from torch.utils.data import DataLoader
from torch.nn.functional import pad

In [None]:
from genova.utils.BasicClass import Ion, Residual_seq

In [None]:
torch.cuda.set_device(1)

In [None]:
with open('genova/utils/dictionary') as f:
    dictionary = json.load(f)
reverse_dictionary = {idx:aa for aa, idx in dictionary.items()}

In [None]:
cfg = OmegaConf.load('configs/genova_dda_light.yaml')
spec_header = pd.read_csv('/data/z37mao/genova/pretrain_data_sparse/genova_psm.csv',index_col='index')
spec_header = spec_header[spec_header['Experiment Name']=='PXD008844']
small_spec = spec_header[spec_header['Node Number']<=256]
dataset = genova.data.GenovaDataset(cfg,dictionary=dictionary,spec_header=small_spec,dataset_dir_path='/data/z37mao/genova/pretrain_data_sparse/')
collate_fn  = genova.data.GenovaCollator(cfg,mode='eval')
dl = DataLoader(dataset,batch_size=1,collate_fn=collate_fn)
model = genova.models.Genova(cfg).cuda()
checkpoint = torch.load('/data/z37mao/save/Genova_model.pt',map_location = {'cuda:%d' % 0: 'cuda:%d' % 1})
model.load_state_dict(OrderedDict([(k[7:],v) for k,v in checkpoint['model_state_dict'].items()]))
model = model.eval()
for p in model.parameters():
    p.requires_grad = False

In [None]:
def encoder_input_cuda(encoder_input):
    for section_key in encoder_input:
        for key in encoder_input[section_key]:
            if isinstance(encoder_input[section_key][key],torch.Tensor):
                encoder_input[section_key][key] = encoder_input[section_key][key].cuda()
    return encoder_input

In [None]:
def beam_search(memory,i):
    k = beam_size
    tgt_index = torch.full((1,1),dictionary['<n_term>']).cuda()
    output = model.output_ffn(model.decoder(tgt_index=tgt_index, memory=memory))
    perplexity = torch.log_softmax(output[:,-1],-1)
    perplexity, aa_index = torch.topk(perplexity, k, dim=-1)
    perplexity = perplexity.view(-1,1)
    aa_index = aa_index.view(-1,1)
    tgt_index = torch.concat([tgt_index.repeat(k,1),aa_index],dim=-1)
    
    done_pep = []
    done_perplexity = []
    
    charge, precursor_mz = small_spec[['Charge','m/z [Da]']].iloc[i]

    seq_len = 1
    while k>0 and seq_len<100:
        probability = model.output_ffn(model.decoder(tgt_index=tgt_index, 
                                                     memory=torch.repeat_interleave(memory,tgt_index.size(0),dim=0)))[:,-1,:]
        probability = torch.log_softmax(probability,-1)
        perplexity = perplexity + probability
        perplexity = perplexity.view(1,-1)
        perplexity, aa_index = torch.topk(perplexity,k)
        aa = aa_index%30
        tgt_index = tgt_index[torch.div(aa_index, 30, rounding_mode='floor').view(-1)]
        perplexity = perplexity.view(-1,1)
        tgt_index = torch.concat([tgt_index,aa.view(-1,1)],dim=-1)
        done_mask = aa.squeeze(0)==dictionary['<c_term>']
        seq_len+=1
        if done_mask.sum()>0:
            k-=done_mask.sum()
            done_pep.append(tgt_index[done_mask].cpu())
            done_perplexity.append(perplexity[done_mask].cpu())
            perplexity = perplexity[~done_mask]
            tgt_index = tgt_index[~done_mask]
    if k>0:
        done_pep.append(pad(tgt_index,(0,1)).cpu())
        done_perplexity.append(perplexity.cpu())

    seqs = []
    for seq_group in done_pep:
        for seq_index in seq_group:
            seqs.append(''.join([reverse_dictionary[i] for i in seq_index[1:-1].numpy()]))
    done_perplexity = torch.concat(done_perplexity).squeeze(1).numpy()
    seqs = np.array(seqs)[np.argsort(-done_perplexity)]
    done_perplexity = -np.sort(-done_perplexity)
    precursor_mass = Ion.precursorion2mass(precursor_mz, charge)
    seq_mass = np.array([Residual_seq(seq).mass for seq in seqs])
    mass_mask = np.abs(seq_mass-precursor_mass)/precursor_mass<10e-6
    if mass_mask.sum()>0:
        done_perplexity = done_perplexity[mass_mask]
        seqs=seqs[mass_mask]
    return seqs, np.exp(done_perplexity)

In [None]:
def match_AA_novor(target, predicted):
    """"""

    #~ print("".join(["="] * 80)) # section-separating line
    #~ print("WorkerTest._test_AA_match_novor()")

    num_match = 0
    target_len = len(target)
    predicted_len = len(predicted)
    target_mass = np.array([Residual_seq(aa).mass for aa in target])
    target_mass_cum = np.cumsum(target_mass)
    predicted_mass = np.array([Residual_seq(aa).mass for aa in predicted])
    predicted_mass_cum = np.cumsum(predicted_mass)

    i = 0
    j = 0
    while i < target_len and j < predicted_len:
        if abs(target_mass_cum[i] - predicted_mass_cum[j]) < 0.5:
            if abs(target_mass[i] - predicted_mass[j]) < 0.1:
            #~ if  decoder_input[index_aa] == output[index_aa]:
              num_match += 1
            i += 1
            j += 1
        elif target_mass_cum[i] < predicted_mass_cum[j]:
            i += 1
        else:
            j += 1

    return num_match

In [None]:
from tqdm import tqdm

In [None]:
beam_size = 30
right_aa = 0
right_peptide = 0
for i, (encoder_input, node_mask, labels) in enumerate(dl):
    encoder_input = encoder_input_cuda(encoder_input)
    label = small_spec['Annotated Sequence'].iloc[i]
    memory = model.encoder(**encoder_input)
    predicted_seqs, probability = beam_search(memory,i)
    if predicted_seqs[0]==label.replace('L','I'): right_peptide+=1
    #else: print(i)
    if i==15: break
    #break
    right_aa += match_AA_novor(label.replace('L','I'),predicted_seqs[0])

In [None]:
predicted_seqs[0]

In [None]:
Residual_seq('QQ').mass

In [None]:
Residual_seq('KE').mass

In [None]:
predicted_seqs[0]

In [None]:
label.replace('L','I')

In [None]:
k = beam_size

In [None]:
tgt_index = torch.full((1,1),dictionary['<n_term>']).cuda()
output = model.output_ffn(model.decoder(tgt_index=tgt_index, memory=memory))
perplexity = torch.log_softmax(output[:,-1],-1)
perplexity, aa_index = torch.topk(perplexity, k, dim=-1)
perplexity = perplexity.view(-1,1)
aa_index = aa_index.view(-1,1)
tgt_index = torch.concat([tgt_index.repeat(k,1),aa_index],dim=-1)

In [None]:
done_pep = []
done_perplexity = []
charge, precursor_mz = small_spec[['Charge','m/z [Da]']].iloc[i]
    
probability = model.output_ffn(model.decoder(tgt_index=tgt_index, 
                                                 memory=torch.repeat_interleave(memory,tgt_index.size(0),dim=0)))[:,-1,:]
probability = torch.log_softmax(probability,-1)
perplexity = perplexity + probability
perplexity = perplexity.view(1,-1)
perplexity, aa_index = torch.topk(perplexity,k)
aa = aa_index%30
tgt_index = tgt_index[torch.div(aa_index, 30, rounding_mode='floor').view(-1)]
perplexity = perplexity.view(-1,1)
tgt_index = torch.concat([tgt_index,aa.view(-1,1)],dim=-1)
done_mask = aa.squeeze(0)==dictionary['<c_term>']
print(tgt_index)
if done_mask.sum()>0:
    k-=done_mask.sum()
    done_pep.append(tgt_index[done_mask].cpu())
    done_perplexity.append(perplexity[done_mask].cpu())
    perplexity = perplexity[~done_mask]
    tgt_index = tgt_index[~done_mask]

In [None]:
k = beam_size
tgt_index = torch.full((1,1),dictionary['<n_term>']).cuda()
output = model.output_ffn(model.decoder(tgt_index=tgt_index, memory=memory))
perplexity = torch.log_softmax(output[:,-1],-1)
perplexity, aa_index = torch.topk(perplexity, k, dim=-1)
perplexity = perplexity.view(-1,1)
aa_index = aa_index.view(-1,1)
tgt_index = torch.concat([tgt_index.repeat(k,1),aa_index],dim=-1)

done_pep = []
done_perplexity = []

charge, precursor_mz = small_spec[['Charge','m/z [Da]']].iloc[i]

seq_len = 1
while k>0 and seq_len<100:
    probability = model.output_ffn(model.decoder(tgt_index=tgt_index, 
                                                 memory=torch.repeat_interleave(memory,tgt_index.size(0),dim=0)))[:,-1,:]
    probability = torch.log_softmax(probability,-1)
    perplexity = perplexity + probability
    perplexity = perplexity.view(1,-1)
    perplexity, aa_index = torch.topk(perplexity,k)
    aa = aa_index%30
    tgt_index = tgt_index[torch.div(aa_index, 30, rounding_mode='floor').view(-1)]
    perplexity = perplexity.view(-1,1)
    tgt_index = torch.concat([tgt_index,aa.view(-1,1)],dim=-1)
    done_mask = aa.squeeze(0)==dictionary['<c_term>']
    if done_mask.sum()>0:
        k-=done_mask.sum()
        done_pep.append(tgt_index[done_mask].cpu())
        done_perplexity.append(perplexity[done_mask].cpu())
        perplexity = perplexity[~done_mask]
        tgt_index = tgt_index[~done_mask]
    seq_len+=1
    if k!=0:
        done_pep.append(pad(tgt_index,(0,1)).cpu())
        done_perplexity.append(perplexity.cpu())

In [None]:
done_pep

In [None]:
tgt_index

In [None]:
pad(tgt_index,(0,1))

In [None]:
label

In [None]:
seqs = []
for seq_group in done_pep:
    for seq_index in seq_group:
        seqs.append(''.join([reverse_dictionary[i] for i in seq_index[1:-1].numpy()]))

In [None]:
seqs

In [None]:
done_pep

In [None]:
tgt_index

In [None]:
done_pep = []
done_perplexity = []

while k>0:
    probability = model.output_ffn(model.decoder(tgt_index=tgt_index, 
                                                 memory=torch.repeat_interleave(memory,tgt_index.size(0),dim=0)))[:,-1,:]
    probability = torch.log_softmax(probability,-1)
    perplexity = perplexity + probability
    perplexity = perplexity.view(1,-1)
    perplexity, aa_index = torch.topk(perplexity,k)
    aa = aa_index%30
    tgt_index = tgt_index[torch.div(aa_index, 30, rounding_mode='floor').view(-1)]
    perplexity = perplexity.view(-1,1)
    tgt_index = torch.concat([tgt_index,aa.view(-1,1)],dim=-1)
    done_mask = aa.squeeze(0)==dictionary['<c_term>']
    if done_mask.sum()>0:
        k-=done_mask.sum()
        done_pep.append(tgt_index[done_mask].cpu())
        done_perplexity.append(perplexity[done_mask].cpu())
        perplexity = perplexity[~done_mask]
        tgt_index = tgt_index[~done_mask]
        
seqs = []
for seq_group in done_pep:
    for seq_index in seq_group:
        seqs.append(''.join([reverse_dictionary[i] for i in seq_index[1:-1].numpy()]))



In [None]:
label

In [None]:
seqs

In [None]:
done_perplexity = torch.concat(done_perplexity).squeeze(1).numpy()

In [None]:
seqs = np.array(seqs)[np.argsort(-done_perplexity)]
done_perplexity = -np.sort(-done_perplexity)

In [None]:
precursor_mass = Ion.precursorion2mass(precursor_mz, charge)
seq_mass = np.array([Residual_seq(seq).mass for seq in seqs])

In [None]:
np.abs(seq_mass-precursor_mass)

In [None]:
precursor_mass

In [None]:
done_perplexity = done_perplexity[np.abs(seq_mass-precursor_mass)/precursor_mass<10e-6]
seqs=seqs[np.abs(seq_mass-precursor_mass)/precursor_mass<10e-6]

In [None]:
np.array([Residual_seq(aa).mass for aa in seqs[0]])

In [None]:
def match_AA_novor(target, predicted):
    """"""

    #~ print("".join(["="] * 80)) # section-separating line
    #~ print("WorkerTest._test_AA_match_novor()")

    num_match = 0
    target_len = len(target)
    predicted_len = len(predicted)
    target_mass = np.array([Residual_seq(aa).mass for aa in target])
    target_mass_cum = np.cumsum(target_mass)
    predicted_mass = np.array([Residual_seq(aa).mass for aa in predicted])
    predicted_mass_cum = np.cumsum(predicted_mass)

    i = 0
    j = 0
    while i < target_len and j < predicted_len:
        if abs(target_mass_cum[i] - predicted_mass_cum[j]) < 0.5:
            if abs(target_mass[i] - predicted_mass[j]) < 0.1:
            #~ if  decoder_input[index_aa] == output[index_aa]:
              num_match += 1
            i += 1
            j += 1
        elif target_mass_cum[i] < predicted_mass_cum[j]:
            i += 1
        else:
            j += 1

    return num_match

In [None]:
match_AA_novor('RIVAPPGGR','GVIVAPPGGR')

In [None]:
done_perplexity

In [None]:
charge, precursor_mz = small_spec[['Charge','m/z [Da]']].iloc[0]

In [None]:
Residual_seq('GV').mass

In [None]:
Residual_seq('R').mass

In [None]:
small_spec

In [None]:
precursor_mass = Ion.precursorion2mass(precursor_mz, charge)
seq_mass = np.array([Residual_seq(seq).mass for seq in seqs])
done_perplexity = done_perplexity[np.abs(seq_mass-precursor_mass)/precursor_mass<10e-6]
seqs=seqs[np.abs(seq_mass-precursor_mass)/precursor_mass<10e-6]

In [None]:
np.exp(done_perplexity)

In [None]:
seqs

In [None]:
done_perplexity

In [None]:
seqs

In [None]:
seq[list(np.argsort(-torch.concat(done_perplexity).squeeze(1).numpy()))]

In [None]:
list(np.argsort(-torch.concat(done_perplexity).squeeze(1).numpy()))

In [None]:
np.argsort(-torch.concat(done_perplexity).squeeze(1).numpy())

In [None]:
np.sort(done_perplexity,axis=0)

In [None]:
''.join([reverse_dictionary[i] for i in seq[1:-1].numpy()])

In [None]:
torch.concat(done_perplexity)

In [None]:
done_perplexity

In [None]:
done_id = np.argwhere((aa==dictionary['<c_term>']).cpu())

In [None]:
k-=done_num

In [None]:
k

In [None]:
if np.argwhere((aa_index==dictionary['<c_term>']).cpu()).nelement() != 0:
    k = 

In [None]:
tgt_index = tgt_index[torch.div(aa_index, 30, rounding_mode='floor').view(-1)]
perplexity = perplexity.view(-1,1)
tgt_index = torch.concat([tgt_index,aa.view(-1,1)],dim=-1)
print(tgt_index)
print(perplexity)

In [None]:
tgt_index

In [None]:
labels

In [None]:
aa_index==dictionary['<c_term>']

In [None]:
if np.argwhere((aa_index==dictionary['<c_term>']).cpu()).nelement() != 0

In [None]:
dictionary['<c_term>']

In [None]:
perplexity = perplexity + probability
perplexity = perplexity.view(1,-1)
perplexity, aa_index = torch.topk(perplexity,k)
perplexity = perplexity.view(-1,1)
aa_index = aa_index.view(-1)
tgt_index = torch.concat([tgt_index[torch.div(aa_index, 30, rounding_mode='floor').view(-1)+batch_id],(aa_index%30).view(-1,1)],dim=-1)

In [None]:
batch_size = 1
perplexity = torch.zeros((batch_size*beam_size,1)).cuda()
tgt_index = torch.full((batch_size*beam_size,1),21).cuda()
batch_id = torch.repeat_interleave(torch.LongTensor([i*beam_size for i in range(batch_size)]),beam_size).cuda()

In [None]:
output = model.output_ffn(model.decoder(tgt_index=tgt_index, memory=memory))
perplexity = perplexity + torch.log_softmax(output[:,-1],-1)
perplexity, aa_index = torch.topk(perplexity, beam_size, dim=-1)
perplexity = perplexity[::beam_size].reshape((-1,1))
aa_index = aa_index[::beam_size].reshape((-1,1))
tgt_index = torch.concat([tgt_index,aa_index],dim=-1)

In [None]:
tgt_index

In [None]:
for i in range(2,11):
    output = model.output_ffn(model.decoder(tgt_index=tgt_index, memory=memory))[:,-1,:]
    perplexity = perplexity + torch.log_softmax(output,-1)
    perplexity = perplexity.view(batch_size,-1)
    perplexity, aa_index = torch.topk(perplexity,beam_size)
    perplexity = perplexity.view(-1,1)
    aa_index = aa_index.view(-1)
    tgt_index = torch.concat([tgt_index[torch.div(aa_index, 30, rounding_mode='floor').view(-1)+batch_id],(aa_index%30).view(-1,1)],dim=-1)
print(tgt_index)
print(torch.exp(perplexity))

In [None]:
labels

In [None]:
output = model.output_ffn(model.decoder(memory_key_padding_mask=node_mask, tgt_index=tgt_index, memory=memory))[:,-1]
perplexity = perplexity + torch.log_softmax(output,-1)
perplexity = perplexity.view(batch_size,-1)
perplexity, aa_index = torch.topk(perplexity,beam_size)
perplexity = perplexity.view(-1,1)
aa_index = aa_index.view(-1)
tgt_index = torch.concat([tgt_index[torch.div(aa_index, 30, rounding_mode='floor').view(-1)+batch_id],(aa_index%30).view(-1,1)],dim=-1)
print(tgt_index)
print(perplexity)

In [None]:
labels

In [1]:
import torch
import torch.nn as nn

In [None]:
a=nn.LayerNorm(10)

In [None]:
a.bias.requires_grad = False

In [None]:
for i in a.parameters():
    print(i)

In [None]:
a=nn.Sequential(nn.LayerNorm(10),nn.Linear(10,20))

In [None]:
for i in a.modules():
    break

In [3]:
a=torch.rand(10,20)

In [6]:
b=torch.ones(20)

In [9]:
b.requires_grad

False

In [40]:
(2*6)**0.25

1.8612097182041991

In [22]:
2**-0.5

0.7071067811865476

In [41]:
import math

In [42]:
math.sqrt(6/(512+512))

0.07654655446197431

In [52]:
(2*6)**-0.25

0.537284965911771

In [47]:
1/math.sqrt(2)

0.7071067811865475