In [1]:
import genova
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from genova.utils.BasicClass import Residual_seq

In [2]:
spec_header = pd.read_csv('/home/z37mao/genova_dataset_index.csv',low_memory=False,index_col='Spec Index')
spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']

In [3]:
cfg = OmegaConf.load('configs/genova_dda_light.yaml')

In [60]:
import os
import gzip
import torch
import pickle
from torch.utils.data import Dataset

class GenovaDataset(Dataset):
    def __init__(self, cfg, *, spec_header, dataset_dir_path, aa_datablock_dict = None):
        super().__init__()
        self.cfg = cfg
        self.spec_header = spec_header
        self.dataset_dir_path = dataset_dir_path
        if cfg.task == 'sequence_generation' or cfg.task == 'optimum_path_sequence': 
            assert aa_datablock_dict
            self.aa_datablock_dict = aa_datablock_dict

    def __getitem__(self, idx):
        if torch.is_tensor(idx): idx = idx.tolist()
        spec_head = dict(self.spec_header.loc[idx])
        with open(os.path.join(self.dataset_dir_path, spec_head['MSGP File Name']), 'rb') as f:
            f.seek(spec_head['MSGP Datablock Pointer'])
            spec = pickle.loads(gzip.decompress(f.read(spec_head['MSGP Datablock Length'])))

        spec['node_input']['charge'] = spec_head['Charge']
        graph_label = spec.pop('graph_label').T
        graph_label = graph_label[graph_label.any(-1)]
        node_mass = spec.pop('node_mass')
        seq = spec_head['Annotated Sequence']
        if self.cfg.task == 'node_classification':
            spec['graph_label'] = torch.any(graph_label, 0).long()
            return spec
        
        elif self.cfg.task == 'optimum_path_sequence':
            raise NotImplementedError
            
        elif self.cfg.task == 'sequence_generation':
            target = {}
            seq_blocks = self.seq2seqblock(seq, graph_label)
            target['tgt'] = seq_blocks
            target['trans_mask'] = self.trans_mask_sequence_generation(seq_blocks, node_mass)
            return spec, target
        
        elif self.cfg.task == 'optimum_path':
            trans_mask = torch.Tensor(self.trans_mask_optimum_path(graph_label))
            graph_probability = torch.Tensor(self.graph_probability_gen(graph_label))
            tgt = {}
            tgt['tgt'] = graph_probability[:-1]
            tgt['trans_mask'] = trans_mask
            return spec, tgt, graph_probability[1:]
            
    def graph_probability_gen(self, graph_label):
        graph_probability = graph_label/graph_label.sum(-1).unsqueeze(1)
        return graph_probability
    
    def trans_mask_sequence_generation(self,seq_blocks,node_mass):
        seq_mass = np.array([Residual_seq(seq_block.replace('L','I')).mass for seq_block in seq_blocks]).cumsum()
        trans_mask = torch.zeros((seq_mass.size,node_mass.size))
        trans_mask[0,0] = -float('inf')
        for i, board in enumerate(node_mass.searchsorted(seq_mass+0.02,side='right')[:-1],start=1):
            trans_mask[i,:board] = -float('inf')
        return trans_mask
    
    def trans_mask_optimum_path(self,graph_label):
        graph_label = graph_label[1:-1]
        trans_mask = torch.zeros((graph_label.shape[0]+1,graph_label.shape[1]))
        trans_mask[0,0] = -float('inf')
        for i, node_pos in enumerate(graph_label,start=1):
            trans_mask[i,:torch.where(node_pos)[0].max().item()] = -float('inf')
        return trans_mask
    
    def seq2seqblock(self, seq, graph_label):
        seq_block = []
        for i, combine_flag in enumerate(~graph_label.any(-1)[1:]):
            if combine_flag:
                if 'combine_start_index' not in locals():
                    combine_start_index = i
            else:
                try:
                    if i+1-combine_start_index>6: seq_block.append('X')
                    else:
                        seq_block.append(seq[combine_start_index:i+1])
                        del(combine_start_index)
                except:
                    seq_block.append(seq[i])
        return seq_block
    
    def __len__(self):
        return len(self.spec_header)

In [193]:
class GenovaCollator(object):
    def __init__(self, cfg):
        self.cfg = cfg

    def __call__(self, batch):
        if self.cfg.task == 'optimum_path':
            spec = [record[0] for record in batch]
            tgt = [record[1] for record in batch]
            label = [record[2] for record in batch]
            encoder_input = self.encoder_collate(spec)
            decoder_input, graph_probability = self.decoder_collate(tgt)
            label, label_mask = self.label_collate(label)
            return encoder_input, decoder_input, graph_probability, label, label_mask
        
        elif self.cfg.task == 'sequence_generation':
            raise NotImplementedError
        
        elif self.cfg.task == 'node_classification':
            raise NotImplementedError
    
    def decoder_collate(self, decoder_input):
        if self.cfg.task == 'optimum_path':
            tgts_list = [record['tgt'] for record in decoder_input]
            trans_mask_list = [record['trans_mask'] for record in decoder_input]
            shape_list = np.array([tgt.shape for tgt in tgts_list])
            seqdblock_max = shape_list[:,0].max()
            node_max = shape_list[:,1].max()
            
            graph_probability = []
            trans_mask = []
            for i in range(len(tgts_list)):
                graph_probability.append(pad(tgts_list[i],[0,node_max-shape_list[i,1],
                                                           0,seqdblock_max-shape_list[i,0]]))
                trans_mask_temp = pad(trans_mask_list[i],[0,node_max-shape_list[i,1]],
                                      value=-float('inf'))
                trans_mask.append(pad(trans_mask_temp,[0,0,0,seqdblock_max-shape_list[i,0]]))
            graph_probability = torch.stack(graph_probability)
            decoder_input = {'trans_mask': torch.stack(trans_mask).unsqueeze(-1), 
                             'self_mask': (-float('inf')*torch.ones(seqdblock_max,seqdblock_max)) \
                             .triu(diagonal=1).unsqueeze(-1)}
            return decoder_input, graph_probability
            
    def label_collate(self, labels):
        if self.cfg.task == 'optimum_path':
            shape_list = np.array([label.shape for label in labels])
            seqdblock_max = shape_list[:,0].max()
            node_max = shape_list[:,1].max()
            result = []
            result_pading_mask = torch.ones(len(labels),seqdblock_max,dtype=bool)
            for i, label in enumerate(labels):
                result_pading_mask[i, label.shape[0]:] = 0
                label = pad(label,[0,node_max-label.shape[1],0,seqdblock_max-label.shape[0]])
                result.append(label)
            return torch.stack(result), result_pading_mask
    
    def encoder_collate(self, spec):
        node_inputs = [record['node_input'] for record in spec]
        path_inputs = [record['rel_input'] for record in spec]
        edge_inputs = [record['edge_input'] for record in spec]
        
        node_shape = np.array([node_input['node_sourceion'].shape for node_input in node_inputs]).T
        max_node = node_shape[0].max()
        max_subgraph_node = node_shape[1].max()
        
        node_input = self.node_collate(node_inputs, max_node, max_subgraph_node)
        path_input = self.path_collate(path_inputs, max_node, node_shape)
        edge_input = self.edge_collate(edge_inputs, max_node)
        rel_mask = self.rel_collate(node_shape, max_node)
        encoder_input = {'node_input':node_input,'path_input':path_input,
                         'edge_input':edge_input,'rel_mask':rel_mask}
        return encoder_input

    def node_collate(self, node_inputs, max_node, max_subgraph_node):
        node_feat = []
        node_sourceion = []
        charge = torch.IntTensor([node_input['charge'] for node_input in node_inputs])
        for node_input in node_inputs:
            node_num, node_subgraph_node = node_input['node_sourceion'].shape
            node_feat.append(pad(node_input['node_feat'], 
                                 [0, 0, 0, max_subgraph_node - node_subgraph_node, 0, max_node - node_num]))
            node_sourceion.append(pad(node_input['node_sourceion'], 
                                      [0, max_subgraph_node - node_subgraph_node, 0, max_node - node_num]))
        return {'node_feat':torch.stack(node_feat),'node_sourceion':torch.stack(node_sourceion),'charge':charge}
    
    def path_collate(self, path_inputs, max_node, node_shape):
        rel_type = torch.concat([path_input['rel_type'] for path_input in path_inputs]).squeeze(-1)
        rel_error = torch.concat([path_input['rel_error'] for path_input in path_inputs])
        rel_coor = torch.concat([pad(path_input['rel_coor'],[1,0],value=i) for i, path_input in enumerate(path_inputs)]).T
        rel_coor_cated = torch.stack([rel_coor[0]*max_node**2+rel_coor[1]*max_node+rel_coor[2],
                                      rel_coor[-2]*self.cfg.preprocessing.edge_type_num+rel_coor[-1]])
        
        rel_pos = torch.concat([path_input['rel_coor'][:,-2] for path_input in path_inputs])
        dist = torch.stack([pad(path_input['dist'],[0,max_node-node_shape[0,i],0,max_node-node_shape[0,i]]) for i, path_input in enumerate(path_inputs)])
        
        return {'rel_type':rel_type,'rel_error':rel_error,
                'rel_pos':rel_pos,'dist':dist,
                'rel_coor_cated':rel_coor_cated,
                'max_node': max_node, 'batch_num': len(path_inputs)}
        
        
    def edge_collate(self, edge_inputs, max_node):
        rel_type = torch.concat([edge_input['edge_type'] for edge_input in edge_inputs]).squeeze(-1)
        rel_error = torch.concat([edge_input['edge_error'] for edge_input in edge_inputs])
        rel_coor = torch.concat([pad(edge_input['edge_coor'],[1,0],value=i) for i, edge_input in enumerate(edge_inputs)]).T
        rel_coor_cated = torch.stack([rel_coor[0]*max_node**2+rel_coor[1]*max_node+rel_coor[2],
                                      rel_coor[-1]])
        
        return {'rel_type':rel_type,'rel_error':rel_error,
                'rel_coor_cated':rel_coor_cated, 
                'max_node': max_node, 'batch_num': len(edge_inputs)}
        
    def rel_collate(self, node_shape, max_node):
        rel_masks = []
        for i in node_shape[0]:
            rel_mask = -np.inf * torch.ones(max_node,max_node,1)
            rel_mask[:,:i] = 0
            rel_masks.append(rel_mask)
        rel_masks = torch.stack(rel_masks)
        return rel_masks
    
    def nodelabel_collate(self, node_labels_temp, max_node):
        node_mask = torch.ones(len(node_labels_temp),max_node).bool()
        node_labels = []
        for i, node_label in enumerate(node_labels_temp):
            node_mask[i, node_label.shape[0]:] = 0
            node_labels.append(pad(node_label,[0,max_node-node_label.shape[0]]))
        node_labels = torch.stack(node_labels)
        return node_labels, node_mask

In [194]:
ds = GenovaDataset(cfg,spec_header=spec_header,dataset_dir_path='/home/z37mao/')
collator = GenovaCollator(cfg)

In [196]:
result = collator(batch)

In [200]:
result[2].shape

torch.Size([5, 8, 184])

In [195]:
batch = [ds[index] for index in spec_header.index[0:5]]

In [116]:
decoder_input = [record[1] for record in batch]
labels = [record[2] for record in batch]

In [79]:
shape_list = np.array([label.shape for label in labels])

In [86]:
from torch.nn.functional import pad

In [93]:
pad(label,[0,node_max-label.shape[1],0,seqdblock_max-label.shape[0]])

torch.Size([8, 184])

In [91]:
node_max-label.shape[1]

149

In [88]:
seqdblock_max = shape_list[:,0].max()
node_max = shape_list[:,1].max()

In [94]:
result = []
for label in labels:
    label = pad(label,[0,node_max-label.shape[1],0,seqdblock_max-label.shape[0]])
    result.append(label)

In [97]:
torch.stack(result).shape

torch.Size([5, 8, 184])

In [69]:
for i in tgt[1]['tgt'].shape

torch.Size([6, 94])

In [70]:
tgt[0]['tgt'].shape

torch.Size([8, 184])

In [41]:
spec, tgt, graph_probability = ds['Cerebellum:F12.2:11100']

In [None]:
np.where(graph_label)[0][1:]-np.where(graph_label)[0][:-1]

In [None]:
np.where(graph_label)[0][1:]-np.where(graph_label)[0][:-1]

In [None]:
a=np.where(~graph_label)[0]
for i, flag in enumerate(np.where(~graph_label)[0]):
    if flag:
        min_index = i
    
    

In [None]:
spec

In [None]:
seq_blocks = target['tgt']
seq_mass = np.array([Residual_seq(seq_block.replace('L','I')).mass for seq_block in seq_blocks]).cumsum()

In [None]:
seq_mass

In [None]:
seq

In [None]:
np.array([1,1,1,1,1,1,0],dtype=bool)

In [None]:
def seq2seqblock(self, ):
    seq_block = []
    for i, combine_flag in enumerate(~spec['graph_label'].any(0)[1:]):
        if combine_flag:
            if 'combine_start_index' not in locals():
                combine_start_index = i
        else:
            try:
                seq_block.append(seq[combine_start_index:i+1])
                del(combine_start_index)
            except:
                seq_block.append(seq[i])

In [None]:
tgt

In [None]:
tgt = []
for i, combine_flag in enumerate(~spec['graph_label'].any(0)[1:]):
    if combine_flag: 
        try:
            combine_start_index
        except:
            combine_start_index = i
    else:
        try:
            tgt.append(seq[combine_start_index:i+1])
            del(combine_start_index)
        except:
            tgt.append(seq[i])

In [None]:
tgt

In [None]:
seq[4:8]

In [None]:
combine_start_index = 1

In [None]:
combine_start_index

In [None]:
seq['memory_mask']

In [None]:
graph_label = spec['graph_label'].T
graph_propobility = graph_label/torch.where(graph_label.sum(-1)==0,1,graph_label.sum(-1)).unsqueeze(1)
graph_propobility = graph_propobility[torch.any(graph_label,-1)]

In [None]:
graph_propobility

In [None]:
a={'fsdaf':'afds'}

In [None]:
'fsdaf' in a

In [None]:
seq_mass = genova.utils.BasicClass.Residual_seq(seq[:-1].replace('L','I')).step_mass-0.02
memory_mask = np.zeros((seq_mass.size+1,spec['node_mass'].size))
for i, board in enumerate(spec['node_mass'].searchsorted(seq_mass),start=1):
    memory_mask[i,:board] = -float('inf')

In [None]:
memory_mask = memory_mask[np.newaxis]

In [None]:
spec['node_mass']

In [None]:
np.repeat(memory_mask,cfg.decoder.num_heads,axis=0)

In [None]:
torch.where(spec['graph_label'][1])[0]

In [None]:
np.where(graph_label[0])[0]

In [None]:
spec['graph_label']/torch.where(spec['graph_label'].sum(-1)==0,1,spec['graph_label'].sum(-1)).unsqueeze(1)

In [None]:
spec['graph_label'].shape

In [None]:
for i, board in enumerate(spec['node_mass'].searchsorted(a),start=1):
    b[i,:board] = -float('inf')

In [None]:
import hydra

In [None]:
seq[:-1]

In [None]:
a=genova.utils.BasicClass.Residual_seq(seq[:-1].replace('L','I')).step_mass-0.02

In [None]:
genova.utils.BasicClass.Residual_seq(seq.replace('L','I')).step_mass-0.02

In [None]:
len('DLVILLYETALLSSGFSLEDPQTHANR')

In [None]:
genova.utils.seq

In [None]:
spec_header.loc['Cerebellum:F12.12:46468']['Annotated Sequence']

In [None]:
spec_header.index

In [None]:
len(set(spec_header.index))

In [None]:
genova.utils.BasicClass.Residual_seq('DLVLLLYDTALSSSGFSLFDPQTHNNR'.replace('L','I')).step_mass

In [None]:
genova.utils.BasicClass.Residual_seq('DLVILLYETALLSSGFSLEDPQTHANR'.replace('L','I')).step_mass

In [None]:
with open('/data/z37mao/genova_new/Cerebellum.mgfs', 'rb') as f:
    f.seek(spec_head['MSGP Datablock Pointer'])
    spec = pickle.loads(gzip.decompress(f.read(spec_head['MSGP Datablock Length'])))

In [None]:
spec_header

In [None]:
import json

In [None]:
with open('genova/utils/dictionary') as f:
    dictionary = json.load(f)

In [None]:
from itertools import combinations_with_replacement
from genova.utils.BasicClass import Residual_seq

all_edge_mass = []
aalist = Residual_seq.output_aalist()
for num in range(1,7):
    for i in combinations_with_replacement(aalist,num):
        all_edge_mass.append(Residual_seq(i).mass)
all_edge_mass = np.unique(np.array(all_edge_mass))
print(len(all_edge_mass))

In [None]:
from collections import OrderedDict
from itertools import combinations_with_replacement
from genova.utils.BasicClass import Residual_seq

aa_candidate_datablock_temp = {}
aalist = Residual_seq.output_aalist()
for num in range(1,7):
    for i in combinations_with_replacement(aalist,num):
        mass = Residual_seq(i).mass
        if len(i)==1: aa_db = i[0]
        else: aa_db = '[{}]'.format(''.join(i))
        if mass in aa_candidate_datablock_temp: aa_candidate_datablock_temp[mass].append(aa_db)
        else: aa_candidate_datablock_temp[mass] = [aa_db]

aa_candidate_datablock = OrderedDict()
for aa_db_mass in sorted(list(aa_candidate_datablock_temp.keys())):
    aa_candidate_datablock[aa_db_mass] = aa_candidate_datablock_temp[aa_db_mass]

In [None]:
np.array(list(aa_candidate_datablock.keys()))

In [None]:
aa_candidate_datablock = OrderedDict()
for aa_db_mass in sorted(list(aa_candidate_datablock_temp.keys())):
    aa_candidate_datablock[aa_db_mass] = aa_candidate_datablock_temp[aa_db_mass]

In [None]:
aa_candidate_datablock

In [None]:
from collections import OrderedDict

In [None]:
for aa_db_mass in aa_candidate_datablock.keys():
    print(aa_db_mass)
    break

In [None]:
a=list(aa_candidate_datablock.items())

In [None]:
a[0]

In [None]:
aa_candidate_datablock

In [None]:
20214.1205 in aa_datablock

In [None]:
aa_datablock

In [None]:
'[ab][cd]'.split('[')

In [None]:
1/3*np.log(1/3)-1/3*np.log(0.98)

In [None]:
from genova.utils.BasicClass import Composition

In [None]:
Composition(' ').mass

In [37]:
a=torch.rand([32,32,32,256])+(-float('inf')*torch.ones((32,32,32))).triu(diagonal=1).unsqueeze(-1)

In [29]:
a=(-float('inf')*torch.ones((32,32))).triu(diagonal=1)

In [128]:
torch.rand(4,32,32,25)+torch.rand(32,32).unsqueeze(-1)

tensor([[[[0.8755, 1.3568, 0.9401,  ..., 1.3450, 0.8885, 1.3301],
          [1.1024, 1.3881, 0.9311,  ..., 1.0911, 1.3286, 1.4917],
          [1.3135, 1.3447, 1.3518,  ..., 0.9081, 1.0507, 0.8434],
          ...,
          [1.1560, 1.2221, 1.7016,  ..., 1.8448, 1.6369, 0.9617],
          [0.5528, 1.0723, 1.2672,  ..., 0.8729, 0.9173, 0.9895],
          [1.5915, 0.8110, 1.1644,  ..., 1.1764, 1.2189, 1.5209]],

         [[1.1561, 1.1784, 1.0456,  ..., 1.0298, 0.2042, 0.3168],
          [1.0195, 1.7352, 0.7770,  ..., 0.9642, 1.5835, 0.8905],
          [1.6254, 1.2550, 0.8563,  ..., 1.1306, 0.8263, 0.7572],
          ...,
          [1.1828, 1.8997, 1.9298,  ..., 0.9694, 1.0890, 1.1368],
          [1.6457, 1.9019, 1.4826,  ..., 1.2657, 1.3327, 1.9615],
          [1.7724, 1.0165, 1.8038,  ..., 1.9391, 1.2788, 1.5405]],

         [[0.9806, 0.5712, 0.3364,  ..., 0.9172, 0.2433, 0.6554],
          [1.0414, 1.3599, 0.9872,  ..., 1.2863, 0.8959, 1.4692],
          [1.5346, 0.9124, 1.4485,  ..., 1