In [3]:
import hydra
import wandb
import torch
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
import genova
import numpy as np
import pandas as pd
from omegaconf import OmegaConf, open_dict
from genova.utils.BasicClass import Residual_seq
from torch.utils.data import DataLoader

In [2]:
from itertools import combinations_with_replacement
aa_datablock_dict = {}
aalist = Residual_seq.output_aalist()
for num in range(1,7):
    for i in combinations_with_replacement(aalist,num):
        aa_datablock_dict[i] = Residual_seq(i).mass

In [3]:
hydra.initialize('configs')
cfg = hydra.compose('config.yaml')
with open_dict(cfg):
    cfg.task = 'optimum_path'
    cfg.wandb.project = 'optimum_path'

In [4]:
spec_header = pd.read_csv('/home/z37mao/genova_dataset_index.csv',low_memory=False,index_col='Spec Index')
spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']
#spec_header = spec_header[spec_header['Node Number']<=512]

In [5]:
spec_header.columns

Index(['PSMs Peptide ID', 'Annotated Sequence', 'Modifications',
       'Master Protein Accessions', 'Protein Accessions', 'Charge',
       'DeltaScore', 'DeltaCn', 'Rank', 'Search Engine Rank', 'm/z [Da]',
       'MH+ [Da]', 'Theo. MH+ [Da]', 'DeltaM [ppm]', 'Deltam/z [Da]',
       'Intensity', 'Activation Type', 'NCE [%]', 'MS Order',
       'Isolation Interference [%]', 'Ion Inject Time [ms]', 'RT [min]',
       'First Scan', 'Master Scan(s)', 'Spectrum File', 'File ID.1', 'XCorr',
       'Percolator q-Value', 'Percolator PEP', 'Percolator SVMScore',
       'MGFS Experiment Name', 'MGFS_Datablock_Pointer',
       'MGFS_Datablock_Length', 'Last Scan', 'Peptides Matched',
       'Identifying Node', 'PSM Ambiguity', 'Node Number', 'Relation Num',
       'Edge Num', 'MSGP File Name', 'MSGP Datablock Pointer',
       'MSGP Datablock Length', 'Experiment Name', 'Raw File ID',
       'Spectrum ID'],
      dtype='object')

In [5]:
task = genova.task.Task(cfg,'/home/z37mao/Genova/save', aa_datablock_dict=aa_datablock_dict, distributed=False)

In [6]:
task.initialize(spec_header,'/home/z37mao/',spec_header,'/home/z37mao/')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgenova[0m (use `wandb login --relogin` to force relogin)


In [7]:
for loss_train, total_step in task.train():
    loss_eval, total_seq_len = task.eval()
    print(total_step, loss_train, loss_eval/total_seq_len)

1000 tensor(1.3872, device='cuda:0', grad_fn=<DivBackward0>) tensor(1.0219, device='cuda:0')
2000 tensor(0.9709, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.9210, device='cuda:0')
3000 tensor(0.9101, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.7885, device='cuda:0')
4000 tensor(0.8814, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.7286, device='cuda:0')


KeyboardInterrupt: 

In [None]:
loss_cum, total_seq_len = task.eval()

In [None]:
model = genova.models.Genova(cfg).to('cuda')
ds = genova.data.GenovaDataset(cfg,spec_header=spec_header,dataset_dir_path='/home/z37mao/', aa_datablock_dict=aa_datablock_dict)
sampler = genova.data.GenovaBatchSampler(cfg,'cuda',0.95,spec_header,[0,128,256,512], model)
collate_fn = genova.data.GenovaCollator(cfg)
dl = DataLoader(ds, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=True, num_workers=4, prefetch_factor=4)
dl = genova.data.DataPrefetcher(dl,'cuda')
loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(model.parameters(),lr=2e-4)
scaler = GradScaler()

In [None]:
optimizer.state_dict()

In [None]:
def train(dl,loss_fn,optimizer,scaler,model):
    total_step = 1
    for epoch in range(0, 40):
        print('new epoch')
        for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
            if total_step%100 == 1: loss_cum = 0
            elif total_step%100 == 0 and total_step != 0: yield loss_cum/100, total_step
            optimizer.zero_grad()
            with autocast():
                output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
                output = output.log_softmax(-1)
                loss = loss_fn(output[label_mask],label[label_mask])
            assert loss.item()!=float('nan')
            loss_cum += loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_step += 1

In [None]:
for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
    optimizer.zero_grad()
    with autocast():
        output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
        output = output.log_softmax(-1)
        loss = loss_fn(output[label_mask],label[label_mask])
    break

In [None]:
spec, graph_label, node_mass = ds[spec_header.index[0]]

In [None]:
node_num = node_mass.size
edge_mask = torch.zeros(node_num,node_num,dtype=bool)
for x,y in enumerate(node_mass.searchsorted(node_mass+max(aa_datablock_dict.values())+0.04)):
    edge_mask[x,y:] = True

In [None]:
edge_mask = torch.logical_or(edge_mask,spec['rel_input']['dist']!=0)

In [None]:
trans_mask=((graph_label@edge_mask.int())!=0).bool()

In [None]:
trans_mask = torch.where(trans_mask,0.0,-float('inf'))

In [None]:
trans_mask

In [None]:
decoder_input['trans_mask'].squeeze(-1)[label_mask]

In [None]:
label[label_mask]

In [None]:
output[label_mask]

In [None]:
loss

In [None]:
total_step = 1
for epoch in range(0, 40):
    print('new epoch')
    for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
        if total_step%100 == 1: loss_cum = 0
        elif total_step%100 == 0 and total_step != 0: yield loss_cum/100, total_step
        optimizer.zero_grad()
        with autocast():
            output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
            output = output.log_softmax(-1)
            loss = loss_fn(output[label_mask],label[label_mask])
        break
        assert loss.item()!=float('nan')
        loss_cum += loss
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_step += 1

In [None]:
a=train(dl,loss_fn,optimizer,scaler,model)

In [None]:
for loss_average, total_step in a:
    print(total_step, loss_average.item())

In [None]:
for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
    break

In [None]:
graph_probability.shape

In [None]:
decoder_input['trans_mask'].shape

In [None]:
label_mask.sum()

In [None]:
a=train(dl,loss_fn,optimizer,scaler,model)

In [None]:
for loss, total_step in a:
    #loss, total_step = next(a)
    print(loss, total_step)

In [None]:
model.state_dict()

In [None]:
for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
    break

In [None]:
with autocast():
    output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
    output = output.log_softmax(-1)
    loss = loss_fn(output[label_mask],label[label_mask])

In [None]:
loss

In [None]:
import os

In [None]:
if os.('/home/z37mao/genova/save'):
    print('kfjsadlkf')

In [None]:
os.path.exists(os.path.join('/home/z37mao/genova/save','fjklsfj.pt'))

In [None]:
torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()}, '/home/z37mao/genova/save/test.pt')

In [None]:
torch.load('/home/z37mao/genova/save/test.pt')

In [None]:
DDP(model,device_ids=[0])

In [None]:
for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
    break

In [None]:
encoder_input['path_input']['dist'][0]

In [None]:
ds[spec_header.index[0]][0]

In [None]:
from itertools import combinations_with_replacement
all_edge_mass = []
aalist = Residual_seq.output_aalist()
for num in range(1,7):
    for i in combinations_with_replacement(aalist,num):
        all_edge_mass.append(Residual_seq(i).mass)
all_edge_mass = np.unique(np.array(all_edge_mass))

In [None]:
from itertools import combinations_with_replacement
aa_datablock_dict = {}
aalist = Residual_seq.output_aalist()
for num in range(1,7):
    for i in combinations_with_replacement(aalist,num):
        aa_datablock_dict[i] = Residual_seq(i).mass

In [None]:
max(aa_datablock_dict.values())

In [None]:
spec = ds[spec_header.index[0]]

In [None]:
node_num = spec['node_mass'].size

In [None]:
edge_mask = torch.zeros(node_num,node_num,dtype=bool)
for x,y in enumerate(spec['node_mass'].searchsorted(spec['node_mass']+all_edge_mass[-1]+0.04)):
    edge_mask[x,y:] = 1

In [None]:
edge_mask = torch.logical_or(edge_mask,spec['rel_input']['dist']!=0)

In [None]:
edge_mask

In [None]:
b=(graph_label@edge_mask.int()).bool()

In [None]:
b

In [None]:
torch.where(b,0.0,-float('inf'))

In [None]:
graph_label = spec['graph_label'].T

In [None]:
graph_label = graph_label[torch.any(graph_label,-1)]

In [None]:
graph_label

In [None]:
spec['graph_label'].bool().T

In [None]:
a='node'

In [None]:
c=None

In [None]:
assert c or a=='node'

In [6]:
import pandas as pd

In [7]:
spec_header = pd.read_csv('/home/z37mao/genova_dataset_index.csv',low_memory=False,index_col='Spec Index')

In [8]:
spec_header

Unnamed: 0_level_0,PSMs Peptide ID,Annotated Sequence,Modifications,Master Protein Accessions,Protein Accessions,Charge,DeltaScore,DeltaCn,Rank,Search Engine Rank,...,PSM Ambiguity,Node Number,Relation Num,Edge Num,MSGP File Name,MSGP Datablock Pointer,MSGP Datablock Length,Experiment Name,Raw File ID,Spectrum ID
Spec Index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Cerebellum:F13.7:37229,50988685,YNPENLATLER,,,Q9DBZ5,2,,,,,...,,143,36397,163962,15_0.msgp,0,832578,Cerebellum,F13.7,37229
Cerebellum:F13.13:9789,54248074,AASDPNPAEPAR,,,Q8R1C6,2,,,,,...,,56,4301,23019,15_0.msgp,832578,110004,Cerebellum,F13.13,9789
Cerebellum:F7.1:41688,286333,EALGGPAWDYR,,,Q9Z329,2,,,,,...,,117,22123,119537,15_0.msgp,942582,606525,Cerebellum,F7.1,41688
Cerebellum:F11.2:82491,31864240,DIVPGDIVEIAVGDK,,,O55143,3,,,,,...,,62,7281,24973,15_0.msgp,1549107,131192,Cerebellum,F11.2,82491
Cerebellum:F8.14:72416,22997133,EQLQDmGLIDLFSPEK,,,P32261,2,,,,,...,,71,15622,38849,15_0.msgp,1680299,184789,Cerebellum,F8.14,72416
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Cerebellum:F8.13:70773,22388500,mVVTLTFGDIVAVR,,,P51880,2,,,,,...,,123,31184,79596,20_15.msgp,2230317852,422790,Cerebellum,F8.13,70773
Cerebellum:F11.3:53756,32323483,GPADTGFLNQGDTWSSPR,,,Q5DTT2,2,,,,,...,,111,34391,121662,20_15.msgp,2230740642,612327,Cerebellum,F11.3,53756
Cerebellum:F8.23:78710,28813254,FPPFFTLQPNVDTR,,,Q9CQ80,2,,,,,...,,163,66467,240161,20_15.msgp,2231352969,1196705,Cerebellum,F8.23,78710
Cerebellum:F8.20:31518,26469489,mNINGQWEGEVNGR,,,P47941,2,,,,,...,,84,16832,57983,20_15.msgp,2232549674,287927,Cerebellum,F8.20,31518


In [13]:
a=spec_header.sample(frac=1,random_state=0)

In [11]:
spec_header.sample(frac=1,random_state=0)

Unnamed: 0_level_0,PSMs Peptide ID,Annotated Sequence,Modifications,Master Protein Accessions,Protein Accessions,Charge,DeltaScore,DeltaCn,Rank,Search Engine Rank,...,PSM Ambiguity,Node Number,Relation Num,Edge Num,MSGP File Name,MSGP Datablock Pointer,MSGP Datablock Length,Experiment Name,Raw File ID,Spectrum ID
Spec Index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Cerebellum:F12.10:52794,39882814,TYSWDNAQVILAGNK,,,P62823,2,,,,,...,,267,158253,438260,19_7.msgp,4154829323,2359243,Cerebellum,F12.10,52794
Cerebellum:F8.16:33507,23898832,VQISPDSGGLPER,,,Q3U0V1,2,,,,,...,,121,25290,115681,12_7.msgp,2074650244,589351,Cerebellum,F8.16,33507
Cerebellum:F8.23:86350,28874403,LVQIEYALAAVAGGAPSVGIK,,,P49722,3,,,,,...,,493,658867,1165369,23_3.msgp,11111018265,6452749,Cerebellum,F8.23,86350
Cerebellum:F8.15:58395,23491358,LEPAFLSGLR,,,Q80YV3,2,,,,,...,,54,3498,11018,6_15.msgp,327975046,64944,Cerebellum,F8.15,58395
Cerebellum:F13.14:19880,54858015,EHDPVGQMVNNPK,,,Q9ERK4,2,,,,,...,,86,14294,59594,16_12.msgp,162623782,294135,Cerebellum,F13.14,19880
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Hela:F3.8:6524,3845409,VQSGNINAAK,,P49368,P49368,2,0.6176,0.0,1.0,1.0,...,Unambiguous,73,5570,15112,29_7.msgp,2554365801,85870,Hela,F3.8,6524
Hela:F4.5:18440,585438,TAVDSGIPLLTNFQVTK,,P31327,P31327,2,0.8398,0.0,1.0,1.0,...,Unambiguous,149,54945,78054,26_12.msgp,13952170504,451280,Hela,F4.5,18440
Hela:F1.9:9442,15067615,mPEMHFK,M1(Oxidation),Q09666,Q09666,3,0.3471,0.0,1.0,1.0,...,Unambiguous,95,7305,20692,31_14.msgp,302159330,122221,Hela,F1.9,9442
Cerebellum:F8.19:39970,25881401,ILLTEPPmNPTK,,,P61161,2,,,,,...,,127,27338,122274,11_6.msgp,302808274,618952,Cerebellum,F8.19,39970


In [14]:
a.sample(frac=1,random_state=0)

Unnamed: 0_level_0,PSMs Peptide ID,Annotated Sequence,Modifications,Master Protein Accessions,Protein Accessions,Charge,DeltaScore,DeltaCn,Rank,Search Engine Rank,...,PSM Ambiguity,Node Number,Relation Num,Edge Num,MSGP File Name,MSGP Datablock Pointer,MSGP Datablock Length,Experiment Name,Raw File ID,Spectrum ID
Spec Index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Hela:F1.22:17777,19393508,IIDVSTVK,,Q9Y3B8,Q9Y3B8,2,0.3274,0.0,1.0,1.0,...,Unambiguous,94,6360,11007,31_3.msgp,236025336,76872,Hela,F1.22,17777
Cerebellum:F8.17:19369,24436617,VFIAQSR,,,Q8BHN5,2,,,,,...,,74,4444,9399,6_1.msgp,132304570,61650,Cerebellum,F8.17,19369
Hela:F1.8:24129,14869521,MmLDDIVSR,M2(Oxidation),Q92945,Q92945,2,0.0286,0.0,1.0,1.0,...,Unambiguous,67,5430,30286,29_10.msgp,778607241,146200,Hela,F1.8,24129
Cerebellum:F7.7:59751,4082932,NATLLFPESIR,,,Q5SVR0,2,,,,,...,,55,5044,16154,6_14.msgp,692569555,86548,Cerebellum,F7.7,59751
Cerebellum:F7.21:68547,12493569,LLAESVTEVEIFGK,,,Q9JHU4,2,,,,,...,,141,39463,125817,20_2.msgp,4942360751,651098,Cerebellum,F7.21,68547
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Hela:F1.2:10519,12721102,TQPTSLPK,,P27816,P27816,2,0.3981,0.0,1.0,1.0,...,Unambiguous,75,3965,7306,32_10.msgp,587064553,51851,Hela,F1.2,10519
Cerebellum:F13.3:23343,49182889,TSPSEEYWR,,,P28571,2,,,,,...,,114,18184,90477,9_12.msgp,1246872730,463867,Cerebellum,F13.3,23343
Cerebellum:F7.23:36924,13474391,ImNTFSVmPSPK,,,Q7TMM9; Q9CWF2; Q922F4,2,,,,,...,,175,54922,247661,7_6.msgp,694514044,1247063,Cerebellum,F7.23,36924
Cerebellum:F10.3:25006,30046520,FEAPLFNAR,,,Q91VD9,2,,,,,...,,90,10415,38592,8_12.msgp,1326487701,196014,Cerebellum,F10.3,25006
