In [1]:
import os
import gzip
import torch
import pickle
import json
import genova
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from torch.utils.data import Dataset, DataLoader
from genova.utils.BasicClass import Residual_seq
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler as GradScaler
import torch.nn as nn
import torch.optim as optim
#from genova.data.sampler import GenovaSampler
from genova.data.prefetcher import DataPrefetcher
from torch.nn.functional import pad

import wandb

import collections
from torch._six import string_classes

In [2]:
from random import choices

In [3]:
cfg = OmegaConf.load('configs/genova_dda_light.yaml')

In [4]:
spec_header = pd.read_csv('/home/z37mao/genova_dataset_index.csv', index_col='Spec Index', low_memory=False)
spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']
spec_header = spec_header.rename(columns={'MSGP File Name':'Serialized File Name',
                                          'MSGP Datablock Pointer':'Serialized File Pointer',
                                          'MSGP Datablock Length':'Serialized Data Length'})

In [5]:
bin_boarders = [0,128,256,512]

In [7]:
import torch
import numpy as np
from random import choices
from torch.utils.data import Sampler

class GenovaBatchSampler(Sampler):
    """Wraps another sampler to yield a mini-batch of indices.

    Args:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, cfg, device, gpu_capacity_scaller, spec_header, bin_boarders, shuffle=True) -> None:
        super().__init__(data_source=None)
        self.cfg = cfg
        self.bin_boarders = bin_boarders
        self.gpu_capacity = torch.cuda.get_device_properties(device).total_memory*gpu_capacity_scaller
        self.shuffle = shuffle
        self.spec_header = spec_header
        
        self.hidden_size = self.cfg['hidden_size']
        self.d_relation = self.cfg['encoder']['d_relation']
        self.num_layers = self.cfg['encoder']['num_layers']
        self.d_node = self.cfg['encoder']['node_encoder']['d_node']
        self.d_node_expansion = self.cfg['encoder']['node_encoder']['expansion_factor']
        self.edge_expansion = self.cfg['encoder']['edge_encoder']['expansion_factor']
        self.edge_d_edge = self.cfg['encoder']['edge_encoder']['d_edge']
        self.path_expansion = self.cfg['encoder']['path_encoder']['expansion_factor']
        self.path_d_edge = self.cfg['encoder']['path_encoder']['d_edge']

        self.node_sparse = 4 * ((29 + self.d_node_expansion) * self.d_node)
        self.node = 4 * ((2 * self.d_node_expansion)*self.d_node + 4 * (self.d_node_expansion*self.d_node+self.hidden_size)/2)
        
        self.relation_matrix = 4 * 7 * self.d_relation * self.num_layers
        self.relation_ffn = 4 * (3 * self.d_relation + 13 * self.hidden_size) * self.num_layers + 4*8*self.hidden_size
        
        self.edge_matrix = 4 * (8*self.edge_d_edge + 2*self.d_relation + 2 * self.edge_expansion*self.edge_d_edge)
        self.edge_sparse = 4 * (4 + self.edge_expansion) * self.edge_d_edge
        
        self.path_matrix = 4 * (8*self.path_d_edge + 2*self.d_relation + 4*self.path_expansion*self.path_d_edge)
        self.path_sparse = 4 * (9 + self.path_expansion) * self.path_d_edge

    def __iter__(self):
        if self.shuffle: self.spec_header = self.spec_header.sample(frac=1)
        self.generate_bins()
        return self

    def __next__(self):
        if self.bins_readpointer.sum()==len(self.spec_header): raise StopIteration
        bin_index = choices([i for i in range(self.bin_len.size)], \
                            weights=self.bin_len-self.bins_readpointer)[0]
        bin = self.bins[bin_index]
        max_node = 0
        edge_num = 0
        path_num = 0
        for i in range(self.bins_readpointer[bin_index], len(bin)):
            spec_index=bin.iloc[i]
            if spec_index['Node Number']>max_node: max_node = spec_index['Node Number']
            batch_num = i-self.bins_readpointer[bin_index]+1
            edge_num += spec_index['Edge Num']
            path_num += spec_index['Relation Num']
            node_consumer = self.node_sparse*max_node*batch_num*30 + self.node*max_node*batch_num
            edge_consumer = self.edge_matrix*max_node**2*batch_num + self.edge_sparse*edge_num
            path_consumer = self.path_matrix*max_node**2*batch_num + self.path_sparse*path_num
            relation_cosumer = self.relation_matrix*max_node**2*batch_num + self.relation_ffn*max_node*batch_num
            theo = node_consumer+edge_consumer+path_consumer+relation_cosumer
            if theo>self.gpu_capacity:
                index = bin.iloc[self.bins_readpointer[bin_index]:i].index
                self.bins_readpointer[bin_index] = i
                return index
        index = bin.iloc[self.bins_readpointer[bin_index]:].index
        self.bins_readpointer[bin_index] = len(bin)
        return index
        

    def generate_bins(self):
        if self.shuffle: self.spec_header.sample(frac=1)
        self.bins = [self.spec_header[np.logical_and(self.spec_header['Node Number']>self.bin_boarders[i], \
            self.spec_header['Node Number']<=self.bin_boarders[i+1])] for i in range(len(self.bin_boarders)-1)]
        self.bin_len = np.array([len(bin_index) for bin_index in self.bins])
        self.bins_readpointer = np.zeros(len(self.bin_boarders)-1,dtype=int)

In [None]:
spec_header = spec_header[spec_header['Node Number']<=256]

In [8]:
class GenovaCollator(object):
    def __init__(self, cfg):
        self.cfg = cfg

    def __call__(self, batch):
        node_inputs = [record['node_input'] for record in batch]
        path_inputs = [record['rel_input'] for record in batch]
        edge_inputs = [record['edge_input'] for record in batch]
        node_labels = [record['graph_label'] for record in batch]
        
        node_shape = np.array([node_input['node_sourceion'].shape for node_input in node_inputs]).T
        max_node = node_shape[0].max()
        max_subgraph_node = node_shape[1].max()
        batch_num = len(batch)
        
        node_input = self.node_collate(node_inputs, max_node, max_subgraph_node)
        path_input = self.path_collate(path_inputs, max_node, node_shape)
        edge_input = self.edge_collate(edge_inputs, max_node)
        rel_mask = self.rel_collate(node_shape, max_node)
        node_labels, node_mask = self.nodelabel_collate(node_labels, max_node)
        
        encoder_input = {'node_input':node_input,'path_input':path_input,
                         'edge_input':edge_input,'rel_mask':rel_mask}
        labels = {'node_labels':node_labels, 'node_mask':node_mask}
        
        return encoder_input, labels

    def node_collate(self, node_inputs, max_node, max_subgraph_node):
        node_feat = []
        node_sourceion = []
        charge = torch.IntTensor([node_input['charge'] for node_input in node_inputs])
        for node_input in node_inputs:
            node_num, node_subgraph_node = node_input['node_sourceion'].shape
            node_feat.append(pad(node_input['node_feat'], 
                                 [0, 0, 0, max_subgraph_node - node_subgraph_node, 0, max_node - node_num]))
            node_sourceion.append(pad(node_input['node_sourceion'], 
                                      [0, max_subgraph_node - node_subgraph_node, 0, max_node - node_num]))
        return {'node_feat':torch.stack(node_feat),'node_sourceion':torch.stack(node_sourceion),'charge':charge}
    
    def path_collate(self, path_inputs, max_node, node_shape):
        rel_type = torch.concat([path_input['rel_type'] for path_input in path_inputs]).squeeze(-1)
        rel_error = torch.concat([path_input['rel_error'] for path_input in path_inputs])
        rel_coor = torch.concat([pad(path_input['rel_coor'],[1,0],value=i) for i, path_input in enumerate(path_inputs)]).T
        rel_coor_cated = torch.stack([rel_coor[0]*max_node**2+rel_coor[1]*max_node+rel_coor[2],
                                      rel_coor[-2]*self.cfg.preprocessing.edge_type_num+rel_coor[-1]])
        
        rel_pos = torch.concat([path_input['rel_coor'][:,-2] for path_input in path_inputs])
        dist = torch.stack([pad(path_input['dist'],[0,max_node-node_shape[0,i],0,max_node-node_shape[0,i]]) for i, path_input in enumerate(path_inputs)])
        
        return {'rel_type':rel_type,'rel_error':rel_error,
                'rel_pos':rel_pos,'dist':dist,
                'rel_coor_cated':rel_coor_cated,
                'max_node': max_node, 'batch_num': len(path_inputs)}
        
        
    def edge_collate(self, edge_inputs, max_node):
        rel_type = torch.concat([edge_input['edge_type'] for edge_input in edge_inputs]).squeeze(-1)
        rel_error = torch.concat([edge_input['edge_error'] for edge_input in edge_inputs])
        rel_coor = torch.concat([pad(edge_input['edge_coor'],[1,0],value=i) for i, edge_input in enumerate(edge_inputs)]).T
        rel_coor_cated = torch.stack([rel_coor[0]*max_node**2+rel_coor[1]*max_node+rel_coor[2],
                                      rel_coor[-1]])
        
        return {'rel_type':rel_type,'rel_error':rel_error,
                'rel_coor_cated':rel_coor_cated, 
                'max_node': max_node, 'batch_num': len(edge_inputs)}
        
    def rel_collate(self, node_shape, max_node):
        rel_masks = []
        for i in node_shape[0]:
            rel_mask = -np.inf * torch.ones(max_node,max_node,1)
            rel_mask[:,:i] = 0
            rel_masks.append(rel_mask)
        rel_masks = torch.stack(rel_masks)
        return rel_masks
    
    def nodelabel_collate(self, node_labels_temp, max_node):
        node_mask = torch.ones(len(node_labels_temp),max_node).bool()
        node_labels = []
        for i, node_label in enumerate(node_labels_temp):
            node_mask[i, node_label.shape[0]:] = 0
            node_labels.append(pad(node_label,[0,max_node-node_label.shape[0]]))
        node_labels = torch.stack(node_labels)
        return node_labels, node_mask
    

class GenovaDataset(Dataset):
    def __init__(self, cfg, *, spec_header, dataset_dir_path):
        super().__init__()
        self.cfg = cfg
        self.spec_header = spec_header
        self.dataset_dir_path = dataset_dir_path

    def __getitem__(self, idx):
        if torch.is_tensor(idx): idx = idx.tolist()
        spec_head = dict(self.spec_header.loc[idx])
        with open(os.path.join(self.dataset_dir_path, spec_head['Serialized File Name']), 'rb') as f:
            f.seek(spec_head['Serialized File Pointer'])
            spec = pickle.loads(gzip.decompress(f.read(spec_head['Serialized Data Length'])))

        spec['node_input']['charge'] = spec_head['Charge']
        spec.pop('node_mass')
        spec['graph_label'] = torch.any(spec['graph_label'], -1).long()
        return spec

    def __len__(self):
        return len(self.spec_header)

In [9]:
ds = GenovaDataset(cfg, spec_header=spec_header, dataset_dir_path='/home/z37mao/')

In [10]:
device = torch.device("cuda", 1)
torch.cuda.set_device(device)

In [11]:
collate_fn = GenovaCollator(cfg)
sampler = GenovaBatchSampler(cfg,device,0.9,spec_header,bin_boarders)
dl = DataLoader(ds,batch_sampler=sampler,collate_fn=collate_fn,num_workers=2,pin_memory=True)
dl = DataPrefetcher(dl,device,non_blocking=True)

In [12]:
for i in dl:
    break

In [None]:
model = genova.GenovaEncoder(cfg, bin_classification=True).cuda()

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-4)
scaler = GradScaler()
print(torch.cuda.memory_allocated())

In [None]:
loss_detect = 0
for i, (encoder_input, labels) in enumerate(dl,start=1):
    optimizer.zero_grad()
    with autocast():
        output = model(**encoder_input)
        #print(torch.cuda.memory_allocated())
        loss = loss_fn(output[labels['node_mask']], labels['node_labels'][labels['node_mask']])
    loss_detect += loss.item()
    max_node = encoder_input['edge_input']['max_node']
    batch_num = encoder_input['edge_input']['batch_num']
    edge_num = len(encoder_input['edge_input']['rel_type'])
    path_num = len(encoder_input['path_input']['rel_type'])
    
    node_consumer = a.node_sparse*max_node*batch_num*encoder_input['node_input']['node_feat'].shape[-2] + a.node*max_node*batch_num
    edge_consumer = a.edge_matrix*max_node**2*batch_num + a.edge_sparse*edge_num
    path_consumer = a.path_matrix*max_node**2*batch_num + a.path_sparse*path_num
    relation_cosumer = a.relation_matrix*max_node**2*batch_num + a.relation_ffn*max_node*batch_num
    theo = node_consumer+edge_consumer+path_consumer+relation_cosumer
    theo = (theo+47562752*4)*0.75
    real = torch.cuda.memory_allocated()
    print(theo/real)
    if theo/real<0.6 or theo/real>1.7: 
        print(theo/real, max_node, edge_num, path_num, encoder_input['node_input']['node_feat'].shape[-2])
    #break
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()


In [None]:
max_node = 128
batch_num = 32
edge_num = 1e6
path_num = 5e5

node_consumer = a.node_sparse*max_node*batch_num*30 + a.node*max_node*batch_num
edge_consumer = a.edge_matrix*max_node**2*batch_num + a.edge_sparse*edge_num
path_consumer = a.path_matrix*max_node**2*batch_num + a.path_sparse*path_num
relation_cosumer = a.relation_matrix*max_node**2*batch_num + a.relation_ffn*max_node*batch_num
theo = node_consumer+edge_consumer+path_consumer+relation_cosumer

In [None]:
theo/1024**3

In [None]:
cfg['encoder']['d_relation']

In [None]:
encoder_input['node_input']['node_feat'].shape[-2]

In [None]:
1527211008/4*3

In [None]:
1145408256+45208880*5

In [None]:
728486896/1213900288

In [None]:
18844*

In [None]:
encoder_input['edge_input']['rel_type'].shape

In [None]:
if encoder_input['path_input']['rel_pos']!=None:
    print('success')

In [None]:
test = nn.Embedding(50150,54)

In [None]:
encoder_input['edge_input']['rel_type'].squeeze(-1)

In [None]:
test(encoder_input['edge_input']['rel_type'].squeeze(-1)).shape

In [None]:
encoder_input['edge_input']

In [None]:
def encoder_input_cuda(encoder_input, device):
    for section_key in encoder_input:
        for key in encoder_input[section_key]:
            if isinstance(encoder_input[section_key][key], torch.Tensor):
                encoder_input[section_key][key] = encoder_input[section_key][key].to(device)
    return encoder_input


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

local_rank = int(os.environ['LOCAL_RANK'])

if local_rank == 0:
    wandb.init(project="Genova", entity="rxnatalie")

torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')  # nccl是GPU设备上最快、最推荐的后端

# 构造模型
device = torch.device("cuda", local_rank)

ds = GenovaDataset(cfg, spec_header=small_spec, dataset_dir_path='./pretrain_data_sparse/')
# num_train_samples = 2000
# ds = Subset(ds, np.arange(num_train_samples))

collate_fn = GenovaCollator(cfg)
sampler = GenovaSampler(ds, cfg, 13)
dl = DataLoader(ds, batch_sampler=sampler, collate_fn=collate_fn, num_workers=1)
# dl = DataLoader(ds,batch_size=4,collate_fn=collate_fn,num_workers=1,shuffle=True)
model = genova.GenovaEncoder(cfg, bin_classification=True).to(local_rank)
model = DDP(model, device_ids=[local_rank])
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-4)
scaler = GradScaler()

CHECKPOINT_PATH = './save/sampler_test/model_max.pt'
# #checkpoint = torch.load(CHECKPOINT_PATH,map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank})['model_state_dict']
# checkpoint = torch.load(CHECKPOINT_PATH,map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank})
# if list(model.state_dict().keys())[0].startswith('module'):
#     #model.load_state_dict(checkpoint)
#     model.load_state_dict(OrderedDict([('module.'+key, v) for key, v in checkpoint.items()]))
# else:
#     #model.load_state_dict(OrderedDict([(key[7:], v) for key, v in checkpoint.items()]))
#     model.load_state_dict(checkpoint)

loss_detect = 0
min_loss = 10000
detect_period = 50
accuracy = 0
recall = 0
precision = 0
for epoch in range(2):
    print('Epoch:', epoch)
    for i, (encoder_input, labels, node_mask) in enumerate(dl, start=1):
        if i % 50 == 0:
            print('Sample:', i)
        encoder_input = encoder_input_cuda(encoder_input, device)
        labels = labels.to(device)
        optimizer.zero_grad()
        with autocast():
            output = model(**encoder_input)
            loss = loss_fn(output[~node_mask], labels[~node_mask])
        if local_rank == 0:
            output = torch.argmax(output[~node_mask], -1)
            labels = labels[~node_mask]
            accuracy += (output == labels).sum() / labels.shape[0]
            recall += ((output == labels)[labels == 1]).sum() / (labels == 1).sum()
            precision += ((output == labels)[labels == 1]).sum() / (output == 1).sum()
            loss_detect += loss.item()
            if i % detect_period == 0:
                wandb.log({"loss": loss_detect / detect_period,
                           "accuracy": accuracy / detect_period,
                           "recall": recall / detect_period,
                           "precision": precision / detect_period}
                          )
                torch.save({'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict()}, CHECKPOINT_PATH)
                loss_detect, accuracy, recall, precision = 0, 0, 0, 0
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

In [None]:
label

In [None]:
label = spec.pop('graph_label')
print(label.shape)
label = torch.any(label, -1).long()

In [None]:
spec['rel_input']

In [None]:
ds[0]['node_input']

In [None]:
node_input[0]['node_feat'].shape