In [3]:
import hydra
import wandb
import torch
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
import genova
import numpy as np
import pandas as pd
from omegaconf import OmegaConf, open_dict
from genova.utils.BasicClass import Residual_seq
from torch.utils.data import DataLoader

In [2]:
from itertools import combinations_with_replacement
aa_datablock_dict = {}
aalist = Residual_seq.output_aalist()
for num in range(1,7):
    for i in combinations_with_replacement(aalist,num):
        aa_datablock_dict[i] = Residual_seq(i).mass

In [3]:
hydra.initialize('configs')
cfg = hydra.compose('config.yaml')
with open_dict(cfg):
    cfg.task = 'optimum_path'
    cfg.wandb.project = 'optimum_path'

In [4]:
spec_header = pd.read_csv('/home/z37mao/genova_dataset_index.csv',low_memory=False,index_col='Spec Index')
spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']
#spec_header = spec_header[spec_header['Node Number']<=512]

In [5]:
spec_header.columns

Index(['PSMs Peptide ID', 'Annotated Sequence', 'Modifications',
       'Master Protein Accessions', 'Protein Accessions', 'Charge',
       'DeltaScore', 'DeltaCn', 'Rank', 'Search Engine Rank', 'm/z [Da]',
       'MH+ [Da]', 'Theo. MH+ [Da]', 'DeltaM [ppm]', 'Deltam/z [Da]',
       'Intensity', 'Activation Type', 'NCE [%]', 'MS Order',
       'Isolation Interference [%]', 'Ion Inject Time [ms]', 'RT [min]',
       'First Scan', 'Master Scan(s)', 'Spectrum File', 'File ID.1', 'XCorr',
       'Percolator q-Value', 'Percolator PEP', 'Percolator SVMScore',
       'MGFS Experiment Name', 'MGFS_Datablock_Pointer',
       'MGFS_Datablock_Length', 'Last Scan', 'Peptides Matched',
       'Identifying Node', 'PSM Ambiguity', 'Node Number', 'Relation Num',
       'Edge Num', 'MSGP File Name', 'MSGP Datablock Pointer',
       'MSGP Datablock Length', 'Experiment Name', 'Raw File ID',
       'Spectrum ID'],
      dtype='object')

In [5]:
task = genova.task.Task(cfg,'/home/z37mao/Genova/save', aa_datablock_dict=aa_datablock_dict, distributed=False)

In [6]:
task.initialize(spec_header,'/home/z37mao/',spec_header,'/home/z37mao/')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgenova[0m (use `wandb login --relogin` to force relogin)


In [7]:
for loss_train, total_step in task.train():
    loss_eval, total_seq_len = task.eval()
    print(total_step, loss_train, loss_eval/total_seq_len)

1000 tensor(1.3872, device='cuda:0', grad_fn=<DivBackward0>) tensor(1.0219, device='cuda:0')
2000 tensor(0.9709, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.9210, device='cuda:0')
3000 tensor(0.9101, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.7885, device='cuda:0')
4000 tensor(0.8814, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.7286, device='cuda:0')


KeyboardInterrupt: 

In [None]:
loss_cum, total_seq_len = task.eval()

In [None]:
model = genova.models.Genova(cfg).to('cuda')
ds = genova.data.GenovaDataset(cfg,spec_header=spec_header,dataset_dir_path='/home/z37mao/', aa_datablock_dict=aa_datablock_dict)
sampler = genova.data.GenovaBatchSampler(cfg,'cuda',0.95,spec_header,[0,128,256,512], model)
collate_fn = genova.data.GenovaCollator(cfg)
dl = DataLoader(ds, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=True, num_workers=4, prefetch_factor=4)
dl = genova.data.DataPrefetcher(dl,'cuda')
loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(model.parameters(),lr=2e-4)
scaler = GradScaler()

In [None]:
optimizer.state_dict()

In [None]:
def train(dl,loss_fn,optimizer,scaler,model):
    total_step = 1
    for epoch in range(0, 40):
        print('new epoch')
        for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
            if total_step%100 == 1: loss_cum = 0
            elif total_step%100 == 0 and total_step != 0: yield loss_cum/100, total_step
            optimizer.zero_grad()
            with autocast():
                output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
                output = output.log_softmax(-1)
                loss = loss_fn(output[label_mask],label[label_mask])
            assert loss.item()!=float('nan')
            loss_cum += loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_step += 1

In [None]:
for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
    optimizer.zero_grad()
    with autocast():
        output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
        output = output.log_softmax(-1)
        loss = loss_fn(output[label_mask],label[label_mask])
    break

In [None]:
spec, graph_label, node_mass = ds[spec_header.index[0]]

In [None]:
node_num = node_mass.size
edge_mask = torch.zeros(node_num,node_num,dtype=bool)
for x,y in enumerate(node_mass.searchsorted(node_mass+max(aa_datablock_dict.values())+0.04)):
    edge_mask[x,y:] = True

In [None]:
edge_mask = torch.logical_or(edge_mask,spec['rel_input']['dist']!=0)

In [None]:
trans_mask=((graph_label@edge_mask.int())!=0).bool()

In [None]:
trans_mask = torch.where(trans_mask,0.0,-float('inf'))

In [None]:
trans_mask

In [None]:
decoder_input['trans_mask'].squeeze(-1)[label_mask]

In [None]:
label[label_mask]

In [None]:
output[label_mask]

In [None]:
loss

In [None]:
total_step = 1
for epoch in range(0, 40):
    print('new epoch')
    for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
        if total_step%100 == 1: loss_cum = 0
        elif total_step%100 == 0 and total_step != 0: yield loss_cum/100, total_step
        optimizer.zero_grad()
        with autocast():
            output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
            output = output.log_softmax(-1)
            loss = loss_fn(output[label_mask],label[label_mask])
        break
        assert loss.item()!=float('nan')
        loss_cum += loss
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_step += 1

In [None]:
a=train(dl,loss_fn,optimizer,scaler,model)

In [None]:
for loss_average, total_step in a:
    print(total_step, loss_average.item())

In [None]:
for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
    break

In [None]:
graph_probability.shape

In [None]:
decoder_input['trans_mask'].shape

In [None]:
label_mask.sum()

In [None]:
a=train(dl,loss_fn,optimizer,scaler,model)

In [None]:
for loss, total_step in a:
    #loss, total_step = next(a)
    print(loss, total_step)

In [None]:
model.state_dict()

In [None]:
for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
    break

In [None]:
with autocast():
    output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
    output = output.log_softmax(-1)
    loss = loss_fn(output[label_mask],label[label_mask])

In [None]:
loss

In [None]:
import os

In [None]:
if os.('/home/z37mao/genova/save'):
    print('kfjsadlkf')

In [None]:
os.path.exists(os.path.join('/home/z37mao/genova/save','fjklsfj.pt'))

In [None]:
torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()}, '/home/z37mao/genova/save/test.pt')

In [None]:
torch.load('/home/z37mao/genova/save/test.pt')

In [None]:
DDP(model,device_ids=[0])

In [None]:
for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
    break

In [None]:
encoder_input['path_input']['dist'][0]

In [None]:
ds[spec_header.index[0]][0]

In [None]:
from itertools import combinations_with_replacement
all_edge_mass = []
aalist = Residual_seq.output_aalist()
for num in range(1,7):
    for i in combinations_with_replacement(aalist,num):
        all_edge_mass.append(Residual_seq(i).mass)
all_edge_mass = np.unique(np.array(all_edge_mass))

In [None]:
from itertools import combinations_with_replacement
aa_datablock_dict = {}
aalist = Residual_seq.output_aalist()
for num in range(1,7):
    for i in combinations_with_replacement(aalist,num):
        aa_datablock_dict[i] = Residual_seq(i).mass

In [None]:
max(aa_datablock_dict.values())

In [None]:
spec = ds[spec_header.index[0]]

In [None]:
node_num = spec['node_mass'].size

In [None]:
edge_mask = torch.zeros(node_num,node_num,dtype=bool)
for x,y in enumerate(spec['node_mass'].searchsorted(spec['node_mass']+all_edge_mass[-1]+0.04)):
    edge_mask[x,y:] = 1

In [None]:
edge_mask = torch.logical_or(edge_mask,spec['rel_input']['dist']!=0)

In [None]:
edge_mask

In [None]:
b=(graph_label@edge_mask.int()).bool()

In [None]:
b

In [None]:
torch.where(b,0.0,-float('inf'))

In [None]:
graph_label = spec['graph_label'].T

In [None]:
graph_label = graph_label[torch.any(graph_label,-1)]

In [None]:
graph_label

In [None]:
spec['graph_label'].bool().T

In [None]:
a='node'

In [None]:
c=None

In [None]:
assert c or a=='node'