In [1]:
import json
import torch
import genova
from datetime import datetime
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from genova.utils.BasicClass import Residual_seq
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler as GradScaler
import torch.nn as nn
import torch.optim as optim

In [2]:
cfg = OmegaConf.load('configs/genova_dda_light.yaml')
spec_header = pd.read_csv('/data/z37mao/genova/pretrain_data_sparse/genova_psm.csv',index_col='index')
spec_header = spec_header[np.logical_or(spec_header['Experiment Name']=='Cerebellum',spec_header['Experiment Name']=='HeLa')]
small_spec = spec_header[spec_header['Node Number']<=256]

In [3]:
import os
import gzip
import torch
import pickle
from torch.utils.data import Dataset
import numpy as np
from torch.nn.functional import pad

class GenovaCollator(object):
    def __init__(self,cfg):
        self.cfg = cfg
        
    def __call__(self,batch):
        encoder_records = [record[0] for record in batch]
        labels_ori = [record[1] for record in batch]
        encoder_input, node_mask = self.encoder_collate(encoder_records)
        max_node = max([label.shape[0] for label in labels_ori])
        labels = []
        for label_ori in labels_ori:
            labels.append(pad(label_ori,[0,max_node-label_ori.shape[0]]))
        labels = torch.stack(labels)
        
        return encoder_input, labels, node_mask
        
    def encoder_collate(self, encoder_records):
        node_shape = []
        for record in encoder_records: node_shape.append(np.array(record['node_sourceion'].shape))
        node_shape = np.array(node_shape).T
        max_node = node_shape[0].max()
        max_subgraph_node = node_shape[1].max()

        node_input = {}
        edge_input = {}
        rel_input = {}

        edge_input['rel_type'] = torch.concat([record['rel_type'] for record in encoder_records])
        edge_input['edge_pos'] = torch.concat([record['edge_pos'] for record in encoder_records])
        edge_input['rel_error'] = torch.concat([record['rel_error'] for record in encoder_records]).unsqueeze(-1)


        node_feat = []
        node_sourceion = []
        rel_mask = []
        dist = []
        charge = []
        rel_coor_cated = []
        node_mask = torch.zeros(len(encoder_records),max_node,dtype=bool)
        for i, record in enumerate(encoder_records):
            node_num, node_subgraph_node = record['node_sourceion'].shape
            node_feat.append(pad(record['node_feat'],[0,0,0,max_subgraph_node-node_subgraph_node,0,max_node-node_num]))
            node_sourceion.append(pad(record['node_sourceion'],[0,max_subgraph_node-node_subgraph_node,0,max_node-node_num]))
            rel_mask.append(pad(pad(record['rel_mask'],[0,max_node-node_num],value=-float('inf')),[0,0,0,max_node-node_num]))
            dist.append(pad(record['dist'],[0,max_node-node_num,0,max_node-node_num]))
            charge.append(record['charge'])
            rel_coor_cated.append(torch.stack([i*max_node**2+record['rel_coor'][0]*max_node+record['rel_coor'][1],
                                               record['rel_coor'][-2]*100+record['rel_coor'][-1]]))
            node_mask[i,node_num:] = True

        drctn = torch.zeros(max_node,max_node)+torch.tril(2*torch.ones(max_node,max_node),-1)+torch.triu(torch.ones(max_node,max_node),1)
        rel_input['drctn'] = drctn.int().unsqueeze(0)
        node_input['node_feat'] = torch.stack(node_feat)
        node_input['node_sourceion'] = torch.stack(node_sourceion)
        rel_input['rel_mask'] = torch.stack(rel_mask).unsqueeze(-1)
        edge_input['dist'] = torch.stack(dist)
        node_input['charge'] = torch.IntTensor(charge)
        edge_input['rel_coor_cated'] = torch.concat(rel_coor_cated,dim=1)
        edge_input['batch_num'] = len(encoder_records)
        edge_input['max_node'] = max_node
        
        encoder_input = {'node_input':node_input,'edge_input':edge_input,'rel_input':rel_input}

        return encoder_input, node_mask

class GenovaDataset(Dataset):
    def __init__(self, cfg, *, spec_header, dataset_dir_path):
        super().__init__()
        self.cfg = cfg
        #self.dictionary = dictionary
        self.spec_header = spec_header
        self.dataset_dir_path = dataset_dir_path
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx): idx = idx.tolist()
        spec_head = dict(self.spec_header.iloc[idx])
        with open(os.path.join(self.dataset_dir_path, spec_head['Serialized File Name']),'rb') as f:
            f.seek(spec_head['Serialized File Pointer'])
            spec = pickle.loads(gzip.decompress(f.read(spec_head['Serialized Data Length'])))
        
        spec['charge'] = spec_head['Charge']
        label = spec.pop('path_label')
        label = torch.any(label,-1).long()
        edge_type = spec.pop('edge_type')
        edge_error = spec.pop('edge_error')
        #edge_coor = torch.stack(torch.where(edge_type>0))
        #edge_error = edge_error[edge_type>0]
        #edge_type = edge_type[edge_type>0]
        return spec, label
        
    def __len__(self):
        return len(self.spec_header)

In [None]:
ds = GenovaDataset(cfg, spec_header=small_spec, dataset_dir_path='/data/z37mao/genova/pretrain_data_sparse/')
collate_fn = GenovaCollator(cfg)
dl = DataLoader(ds,batch_size=4,collate_fn=collate_fn,num_workers=4,shuffle=True)
model = genova.GenovaEncoder(cfg,bin_classification=True).cuda()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(),lr=1e-4)
scaler = GradScaler()

In [None]:
torch.cuda.empty_cache()

In [None]:
cfg = OmegaConf.load('configs/genova_dda_light.yaml')
spec_header = pd.read_csv('/data/z37mao/genova/pretrain_data_sparse/genova_psm.csv',index_col='index')
spec_header = spec_header[spec_header['Experiment Name']=='PXD008844']
small_spec = spec_header[spec_header['Node Number']<=256]

In [None]:
ds = GenovaDataset(cfg, spec_header=small_spec, dataset_dir_path='/data/z37mao/genova/pretrain_data_sparse/')
collate_fn = GenovaCollator(cfg)
dl = DataLoader(ds,batch_size=4,collate_fn=collate_fn,num_workers=4,shuffle=True)

In [None]:
from tqdm import tqdm
torch.cuda.empty_cache()

In [None]:
accuracy = 0
recall = 0
precision = 0
for i, (encoder_input, labels, node_mask) in enumerate(tqdm(dl),start=1):
    encoder_input = encoder_input_cuda(encoder_input)
    labels = labels.cuda()
    optimizer.zero_grad()
    with torch.no_grad():
        with autocast():
            output = model(**encoder_input)
    accuracy += (torch.argmax(output[~node_mask],-1)==labels[~node_mask]).sum()/labels[~node_mask].shape[0]
    recall += ((torch.argmax(output[~node_mask],-1)==labels[~node_mask])[labels[~node_mask]==1]).sum()/(labels[~node_mask]==1).sum()
    precision += ((torch.argmax(output[~node_mask],-1)==labels[~node_mask])[labels[~node_mask]==1]).sum()/(torch.argmax(output[~node_mask],-1)==1).sum()

In [None]:
print('accuracy: {}, recall: {}, precision: {}'.format(accuracy.item()/i, recall.item()/i, precision.item()/i))

In [4]:
def encoder_input_cuda(encoder_input):
    for section_key in encoder_input:
        for key in encoder_input[section_key]:
            if isinstance(encoder_input[section_key][key],torch.Tensor):
                encoder_input[section_key][key] = encoder_input[section_key][key].cuda()
    return encoder_input

In [None]:
loss_detect = 0
detect_period = 100
for i, (encoder_input, labels, node_mask) in enumerate(dl,start=1):
    encoder_input = encoder_input_cuda(encoder_input)
    labels = labels.cuda()
    optimizer.zero_grad()
    with autocast():
        output = model(**encoder_input)
        loss = loss_fn(output[~node_mask],labels[~node_mask])
    loss_detect+=loss.item()
    if i%detect_period==0:
        print(loss_detect/detect_period)
        loss_detect = 0
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

In [5]:
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

In [6]:
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

In [12]:
def train(rank, world_size):
    print(f"Running DDP checkpoint example on rank {rank}.")
    setup(rank, world_size)

    ds = GenovaDataset(cfg, spec_header=small_spec, dataset_dir_path='/data/z37mao/genova/pretrain_data_sparse/')
    collate_fn = GenovaCollator(cfg)
    dl = DataLoader(ds,batch_size=4,collate_fn=collate_fn,num_workers=4,shuffle=True)
    model = genova.GenovaEncoder(cfg,bin_classification=True).cuda()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(),lr=1e-5)
    scaler = GradScaler()

    ddp_model = DDP(model, device_ids=[rank])

    loss_detect = 0
    detect_period = 100
    for i, (encoder_input, labels, node_mask) in enumerate(dl,start=1):
        encoder_input = encoder_input_cuda(encoder_input)
        labels = labels.cuda()
        optimizer.zero_grad()
        with autocast():
            output = model(**encoder_input)
            loss = loss_fn(output[~node_mask],labels[~node_mask])
        loss_detect+=loss.item()
        if rank == 0:
            if i%detect_period==0:
                print(loss_detect/detect_period)
                loss_detect = 0
            if i%10000==0:
                torch.save(ddp_model.state_dict(), '/data/z37mao/save/model_checkpoint.pt')
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    cleanup()

In [10]:
def run(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)

In [13]:
train

<function __main__.train(rank, world_size)>

In [14]:
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run(train, world_size)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/z37mao/anaconda3/envs/genova_torch/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/z37mao/anaconda3/envs/genova_torch/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/z37mao/anaconda3/envs/genova_torch/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/z37mao/anaconda3/envs/genova_torch/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>


ProcessExitedException: process 2 terminated with exit code 1