In [1]:
import os
import gzip
import torch
import pickle
import json
import genova
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from torch.utils.data import Dataset, DataLoader
from genova.utils.BasicClass import Residual_seq
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler as GradScaler
import torch.nn as nn
import torch.optim as optim
from genova.data.sampler import GenovaSampler
from genova.data.prefetcher import DataPrefetcher
from torch.nn.functional import pad

import wandb

import collections
from torch._six import string_classes

In [2]:
cfg = OmegaConf.load('configs/genova_dda_light.yaml')

In [3]:
spec_header = pd.read_csv('/home/z37mao/genova_dataset_index.csv', index_col='Spec Index', low_memory=False)
spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']
spec_header = spec_header.rename(columns={'MSGP File Name':'Serialized File Name',
                                          'MSGP Datablock Pointer':'Serialized File Pointer',
                                          'MSGP Datablock Length':'Serialized Data Length'})

In [51]:
cfg['encoder']['d_relation'] = 512
cfg['hidden_size'] = 1024

In [52]:
from torch.utils.data import Sampler
class GenovaSampler(Sampler):
    def __init__(self, cfg):
        self.cfg = cfg
        self.hidden_size = self.cfg['hidden_size']
        self.d_relation = self.cfg['encoder']['d_relation']
        self.num_layers = self.cfg['encoder']['num_layers']
        self.d_node = self.cfg['encoder']['node_encoder']['d_node']
        self.d_node_expansion = self.cfg['encoder']['node_encoder']['expansion_factor']
        self.edge_expansion = self.cfg['encoder']['edge_encoder']['expansion_factor']
        self.edge_d_edge = self.cfg['encoder']['edge_encoder']['d_edge']
        self.path_expansion = self.cfg['encoder']['path_encoder']['expansion_factor']
        self.path_d_edge = self.cfg['encoder']['path_encoder']['d_edge']

        self.node_sparse = 4 * ((29+self.d_node_expansion)*self.d_node)
        self.node = 4 * ((2*self.d_node_expansion)*self.d_node+4*(self.d_node_expansion*self.d_node+self.hidden_size)/2)
        
        self.relation_matrix = 4 * 7 * self.d_relation * self.num_layers ## node_num**2
        self.relation_ffn = 4 * (3 * self.d_relation + 13 * self.hidden_size) * self.num_layers + 4*8*self.hidden_size
        
        self.edge_matrix = 4 * (8*self.edge_d_edge + 2*self.d_relation + 2*self.edge_expansion*self.edge_d_edge) ## node_num**2
        self.edge_sparse = 4 * (4 + self.edge_expansion) * self.edge_d_edge ## num_all_edges
        
        self.path_matrix = 4 * (8*self.path_d_edge + 2*self.d_relation + 4*self.path_expansion*self.path_d_edge) ## node_num**2
        self.path_sparse = 4 * (9 + self.path_expansion) * self.path_d_edge ## num_all_edges

In [53]:
a=test(cfg)

In [6]:
import random
from torch.utils.data import Sampler
from typing import Iterator, List


class GenovaSampler(Sampler[List[int]]):
    """Wraps another sampler to yield a mini-batch of indices.

    Args:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source, cfg, gpu_capacity, scale_factor=1.37, error_tol=0.2) -> None:
        self.data_source = data_source
        self.bins = self.generate_bins(data_source)
        self.cfg = cfg
        self.gpu_capacity = gpu_capacity / scale_factor
        self.error_tol = error_tol
        self.hidden_size = self.cfg['hidden_size']
        self.d_relation = self.cfg['encoder']['d_relation']
        self.num_layers = self.cfg['encoder']['num_layers']
        self.expansion = self.cfg['encoder']['edge_encoder']['expansion_factor']
        self.d_edge = self.cfg['encoder']['edge_encoder']['d_edge']

        self.relation_matrix = 4 * 3 * self.d_relation * self.num_layers / 1024 ** 3 ## node_num**2
        self.relation_gpu2 = 4 * (3 * self.d_relation + 11 * self.hidden_size) * \
                             self.num_layers / 1024 ** 3  ## node_num
        
        self.edge_matrix = 4 * 2 * (self.d_relation+self.expansion*self.d_edge) / 1024**3 ## node_num**2
        self.edge_sparse = 4 * (9 + self.expansion) * self.d_edge / 1024**3 ## num_all_edges
        
        self.edge_matrix = 4 * 2 * (self.d_relation+self.expansion*self.d_edge) / 1024**3 ## node_num**2
        self.edge_sparse = 4 * (9 + self.expansion) * self.d_edge / 1024**3 ## num_all_edges


    def __iter__(self) -> Iterator[List[int]]:
        edge_gpu_used1 = 0
        edge_gpu_used2 = 0
        relation_gpu_used1 = 0
        relation_gpu_used2 = 0
        batch = []
        max_node = 1

        counter = 0

        bin_tracker = {0: 0, 1: 0, 2: 0}
        bin_len = {0: len(self.bins[0]), 1: len(self.bins[1]), 2: len(self.bins[2])}
        bin_len_left = {l: bin_len[l] - bin_tracker[l] for l in [0, 1, 2]}

        while bin_len_left[0] or bin_len_left[1] or bin_len_left[2]:
            # print(bin_len_left)
            which_bin = random.choices([0, 1, 2], weights=[bin_len_left[l] for l in bin_len_left])[0]
            bin_index = bin_tracker[which_bin]

            while bin_index < bin_len[which_bin]:
                i = self.bins[which_bin][bin_index] ## index in data_source
                d = self.data_source[i]
                num_all_edges = d[0]['rel_type'].shape[0]
                node_num = d[0]['node_feat'].shape[0]

                # if num_all_edges > 200000:
                #     continue

                gpu_used_scale = 1.0 * max(max_node, node_num) / max_node
                max_node = max(max_node, node_num)

                relation_gpu1 = self.relation_gpu1 * max_node * max_node
                relation_gpu2 = self.relation_gpu2 * max_node
                edge_gpu1 = self.edge_gpu1 * max_node * max_node
                edge_gpu2 = self.edge_gpu2 * num_all_edges

                relation_gpu_used1 = relation_gpu_used1 * gpu_used_scale**2
                relation_gpu_used2 = relation_gpu_used2 * gpu_used_scale
                edge_gpu_used1 = edge_gpu_used1 * gpu_used_scale**2

                if relation_gpu_used1 + relation_gpu_used2 + edge_gpu_used1 + edge_gpu_used2 + \
                        relation_gpu1 + relation_gpu2 + edge_gpu1 + edge_gpu2 > self.gpu_capacity - self.error_tol:
                    counter += 1
                    if counter % 20 == 0:
                        print('which bin: ', which_bin, batch)
                    yield batch
                    # batch = [i]
                    # relation_gpu_used1 = self.relation_gpu1 * node_num * node_num
                    # relation_gpu_used2 = self.relation_gpu2 * node_num
                    # edge_gpu_used1 = self.edge_gpu1 * node_num * node_num
                    # edge_gpu_used2 = edge_gpu2
                    # max_node = node_num

                    edge_gpu_used1 = 0
                    edge_gpu_used2 = 0
                    relation_gpu_used1 = 0
                    relation_gpu_used2 = 0
                    batch = []
                    max_node = 1

                    bin_tracker[which_bin] = bin_index
                    bin_len_left = {l: bin_len[l] - bin_tracker[l] for l in [0, 1, 2]}
                    break
                else:
                    edge_gpu_used1 += edge_gpu1
                    edge_gpu_used2 += edge_gpu2
                    relation_gpu_used1 += relation_gpu1
                    relation_gpu_used2 += relation_gpu2
                    batch.append(i)

                bin_index += 1

            if len(batch) > 0:
                counter += 1
                print('which bin: ', which_bin, batch)
                yield batch

                edge_gpu_used1 = 0
                edge_gpu_used2 = 0
                relation_gpu_used1 = 0
                relation_gpu_used2 = 0
                batch = []
                max_node = 1

                bin_tracker[which_bin] = bin_index
                bin_len_left = {l: bin_len[l] - bin_tracker[l] for l in [0, 1, 2]}


    # def __len__(self) -> int:
    #     # Can only be called if self.sampler has __len__ implemented
    #     # We cannot enforce this condition, so we turn off typechecking for the
    #     # implementation below.
    #     # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]

    def generate_bins(self, data_source):
        bins = [[] for _ in range(3)]
        for i, d in enumerate(data_source):
            if d[0]['node_feat'].shape[0] < 128:
                bins[0].append(i)
            elif d[0]['node_feat'].shape[0] < 256:
                bins[1].append(i)
            elif d[0]['node_feat'].shape[0] < 512:
                bins[2].append(i)

        # for bin in bins:
        #     random.shuffle(bin)
        return bins

In [7]:
spec_header = spec_header[spec_header['Node Number']<=256]

In [8]:
class GenovaCollator(object):
    def __init__(self, cfg):
        self.cfg = cfg

    def __call__(self, batch):
        node_inputs = [record['node_input'] for record in batch]
        path_inputs = [record['rel_input'] for record in batch]
        edge_inputs = [record['edge_input'] for record in batch]
        node_labels = [record['graph_label'] for record in batch]
        
        node_shape = np.array([node_input['node_sourceion'].shape for node_input in node_inputs]).T
        max_node = node_shape[0].max()
        max_subgraph_node = node_shape[1].max()
        batch_num = len(batch)
        
        node_input = self.node_collate(node_inputs, max_node, max_subgraph_node)
        path_input = self.path_collate(path_inputs, max_node, node_shape)
        edge_input = self.edge_collate(edge_inputs, max_node)
        rel_mask = self.rel_collate(node_shape, max_node)
        node_labels, node_mask = self.nodelabel_collate(node_labels, max_node)
        
        encoder_input = {'node_input':node_input,'path_input':path_input,
                         'edge_input':edge_input,'rel_mask':rel_mask}
        labels = {'node_labels':node_labels, 'node_mask':node_mask}
        
        return encoder_input, labels

    def node_collate(self, node_inputs, max_node, max_subgraph_node):
        node_feat = []
        node_sourceion = []
        charge = torch.IntTensor([node_input['charge'] for node_input in node_inputs])
        for node_input in node_inputs:
            node_num, node_subgraph_node = node_input['node_sourceion'].shape
            node_feat.append(pad(node_input['node_feat'], 
                                 [0, 0, 0, max_subgraph_node - node_subgraph_node, 0, max_node - node_num]))
            node_sourceion.append(pad(node_input['node_sourceion'], 
                                      [0, max_subgraph_node - node_subgraph_node, 0, max_node - node_num]))
        return {'node_feat':torch.stack(node_feat),'node_sourceion':torch.stack(node_sourceion),'charge':charge}
    
    def path_collate(self, path_inputs, max_node, node_shape):
        rel_type = torch.concat([path_input['rel_type'] for path_input in path_inputs]).squeeze(-1)
        rel_error = torch.concat([path_input['rel_error'] for path_input in path_inputs])
        rel_coor = torch.concat([pad(path_input['rel_coor'],[1,0],value=i) for i, path_input in enumerate(path_inputs)]).T
        rel_coor_cated = torch.stack([rel_coor[0]*max_node**2+rel_coor[1]*max_node+rel_coor[2],
                                      rel_coor[-2]*self.cfg.preprocessing.edge_type_num+rel_coor[-1]])
        
        rel_pos = torch.concat([path_input['rel_coor'][:,-2] for path_input in path_inputs])
        dist = torch.stack([pad(path_input['dist'],[0,max_node-node_shape[0,i],0,max_node-node_shape[0,i]]) for i, path_input in enumerate(path_inputs)])
        
        return {'rel_type':rel_type,'rel_error':rel_error,
                'rel_pos':rel_pos,'dist':dist,
                'rel_coor_cated':rel_coor_cated,
                'max_node': max_node, 'batch_num': len(path_inputs)}
        
        
    def edge_collate(self, edge_inputs, max_node):
        rel_type = torch.concat([edge_input['edge_type'] for edge_input in edge_inputs]).squeeze(-1)
        rel_error = torch.concat([edge_input['edge_error'] for edge_input in edge_inputs])
        rel_coor = torch.concat([pad(edge_input['edge_coor'],[1,0],value=i) for i, edge_input in enumerate(edge_inputs)]).T
        rel_coor_cated = torch.stack([rel_coor[0]*max_node**2+rel_coor[1]*max_node+rel_coor[2],
                                      rel_coor[-1]])
        
        return {'rel_type':rel_type,'rel_error':rel_error,
                'rel_coor_cated':rel_coor_cated, 
                'max_node': max_node, 'batch_num': len(edge_inputs)}
        
    def rel_collate(self, node_shape, max_node):
        rel_masks = []
        for i in node_shape[0]:
            rel_mask = -np.inf * torch.ones(max_node,max_node,1)
            rel_mask[:,:i] = 0
            rel_masks.append(rel_mask)
        rel_masks = torch.stack(rel_masks)
        return rel_masks
    
    def nodelabel_collate(self, node_labels_temp, max_node):
        node_mask = torch.ones(len(node_labels_temp),max_node).bool()
        node_labels = []
        for i, node_label in enumerate(node_labels_temp):
            node_mask[i, node_label.shape[0]:] = 0
            node_labels.append(pad(node_label,[0,max_node-node_label.shape[0]]))
        node_labels = torch.stack(node_labels)
        return node_labels, node_mask
    

class GenovaDataset(Dataset):
    def __init__(self, cfg, *, spec_header, dataset_dir_path):
        super().__init__()
        self.cfg = cfg
        self.spec_header = spec_header
        self.dataset_dir_path = dataset_dir_path

    def __getitem__(self, idx):
        if torch.is_tensor(idx): idx = idx.tolist()
        spec_head = dict(self.spec_header.loc[idx])
        with open(os.path.join(self.dataset_dir_path, spec_head['Serialized File Name']), 'rb') as f:
            f.seek(spec_head['Serialized File Pointer'])
            spec = pickle.loads(gzip.decompress(f.read(spec_head['Serialized Data Length'])))

        spec['node_input']['charge'] = spec_head['Charge']
        spec.pop('node_mass')
        spec['graph_label'] = torch.any(spec['graph_label'], -1).long()
        return spec

    def __len__(self):
        return len(self.spec_header)

In [9]:
ds = GenovaDataset(cfg, spec_header=spec_header, dataset_dir_path='/home/z37mao/')

In [10]:
device = torch.device("cuda", 1)
torch.cuda.set_device(device)

In [11]:
collate_fn = GenovaCollator(cfg)
dl = DataLoader(ds,batch_size=2,collate_fn=collate_fn,num_workers=2,pin_memory=True,shuffle=True)
dl = DataPrefetcher(dl,device,non_blocking=True)

In [12]:
model = genova.GenovaEncoder(cfg, bin_classification=True).cuda()

In [13]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-4)
scaler = GradScaler()
print(torch.cuda.memory_allocated())

49141760


In [14]:
loss_detect = 0
for i, (encoder_input, labels) in enumerate(dl,start=1):
    optimizer.zero_grad()
    with autocast():
        output = model(**encoder_input)
        #print(torch.cuda.memory_allocated())
        loss = loss_fn(output[labels['node_mask']], labels['node_labels'][labels['node_mask']])
    loss_detect += loss.item()
    max_node = encoder_input['edge_input']['max_node']
    batch_num = encoder_input['edge_input']['batch_num']
    edge_num = len(encoder_input['edge_input']['rel_type'])
    path_num = len(encoder_input['path_input']['rel_type'])
    
    node_consumer = a.node_sparse*max_node*batch_num*encoder_input['node_input']['node_feat'].shape[-2] + a.node*max_node*batch_num
    edge_consumer = a.edge_matrix*max_node**2*batch_num + a.edge_sparse*edge_num
    path_consumer = a.path_matrix*max_node**2*batch_num + a.path_sparse*path_num
    relation_cosumer = a.relation_matrix*max_node**2*batch_num + a.relation_ffn*max_node*batch_num
    theo = node_consumer+edge_consumer+path_consumer+relation_cosumer
    theo = (theo+47562752*4)*0.75
    real = torch.cuda.memory_allocated()
    print(theo/real)
    if theo/real<0.6 or theo/real>1.7: 
        print(theo/real, max_node, edge_num, path_num, encoder_input['node_input']['node_feat'].shape[-2])
    #break
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()


1.1758629789036574
1.1062283793546832
1.184342462925717
1.1694907760531084
1.111275617800659
0.9530165178015136
1.1162969288166287
1.013237634354432
1.067330823676462
1.1021279067609544
0.9155439790548126
1.0512562298407386
1.0768710402923782
1.0018907895558191
1.0622090865403924
1.093791103050106
1.0874644502303223
1.1888044866748761
1.051291796099107
1.0694206540412683
1.0569850070399587
1.0687628326138294
1.0672270279636393
1.0401522960288367
0.918915103915548
1.09941596824614
0.9548673889922893
1.0888620705032441
1.0639840298527226
1.101571302655665
1.1136648415923973
1.0943862790846755
1.120265955028626
1.0208712571018181
1.0237185134878826
1.0165160836048261
1.0458986441919296
1.1365850940601048
1.0509711974333302
1.086606576960875
1.1206292992506237
1.1631902328044557
0.9304607617090035
1.0525930451115229
1.0967673187365399
0.9996852095982668
1.071223234283682
1.1538481461527215
1.0706229302858175
1.0001047949073443
1.099256515433375
1.1939748892484328
1.0532840815439373
1.12313

1.0477064510514407
0.9774399494457855
0.973765316862146
1.1698660395903517
1.052127849002849
1.0338940776809378
1.0985642427757398
0.9401924864265491
1.0308089653820365
1.0407902488442389
0.9922100552708162
0.9415730073381712
1.0364000650279912
1.0614682900105243
1.1336783096022898
1.08428641232564
1.143862789739175
1.1066450828307248
0.9445559836357758
1.088352219731672
0.9873542107404196
1.0661215375185715
1.0780956598427738
1.1128104539319572
0.8909329963873986
1.0786975090537154
1.1206033555578019
0.9818801086153205
1.0381801912821116
1.086201116763538
1.1911462594085502
0.9209358971645466
1.1252307771205854
1.0858086954738189
1.0691283188535239
1.1257219482536829
1.0929468525262769
1.0201862715823864
1.0878213252823423
1.0529633943637142
1.100402300763617
1.0835159669324148
1.1120366936555661
1.0516975962247859
1.0166518292494584
0.920388660219008
1.0249653767202553
0.8672988923684517
1.1395224429994164
1.0469488160831488
1.0241942866144818
1.1202992357082908
1.097913057136492
1.0

1.2114537900605848
1.103132353666626
0.8574084723484043
1.015070830563519
1.0266186829330965
1.0954126991398805
1.1284366897624218
1.0361410217040288
0.9826885110916186
1.0792586915759887
1.1098142403211595
1.0612897744483107
0.9064042579175033
1.1032062114990198
1.1209557545017346
1.1809757767343474
1.078439173581725
1.055521564831472
0.9984470927695068
0.995252743361917
1.099973410033046
1.1195459306438973
1.057015511559333
1.1109031330878163
1.05549240631741
0.9817089187837948
1.1081115373376693
1.1042236813184445
1.1026076582616091
1.0793788828986297
1.1305396117721778
1.0343690194856503
1.149461735879104
1.0591598610755346
1.0970474001535468
1.040276890802275
0.8954497968273121
0.9451904850613115
1.0928055630459423
1.0050025796407511
1.1178521088548214
1.1197127870642472
1.1208284486897766
1.1367105940721762
1.0359275278430224
1.0838093416950945
1.074516961460939
1.0399603607895898
1.1053189868955413
0.8301064896951478
1.1584313705379787
1.0771159512367394
1.0055582308216509
1.097

1.0991000030127223
1.078446070231364
1.0347493481411563
1.136235198502757
1.0564405236634447
1.1658219097017146
1.0986765914524805
1.0076062342078917
1.1762938421431641
1.1184204025357953
1.043732325354821
1.05716204873653
1.1449691937312583
1.1119476095574379
1.0753434886823825
1.202254725389947
1.1487100495900053
1.0758838244456628
1.0552020940512044
1.1077433092828657
1.1373886485346585
0.981593242452324
1.0494452707499766
1.0247294653311283
1.0538744605842922
1.0372146355617946
1.1318330152221614
1.0351102194556823
1.1810656536674191
1.059095004159809
0.9786539646736215
1.1775098476210222
0.9311615544207476
1.029438011844723
0.998980918739672
1.1579697960901527
0.9162497128657047
0.9854656958574028
1.0652101827514215
1.0859334876281108
1.1941792705142047
1.1153416998315306
1.0705750362011284
1.0242094096929486
1.110065922534836
1.020799511704916
1.1441660990179052
1.095089418611277
0.9176527379412481
1.0807871031799876
1.1804247994826826
1.0340795903438187
1.0571946833424108
1.1154

1.0951983081797858
0.9428190387473404
1.0944151862077776
0.9527298713316752
1.0983283174986458
1.047538158613473
1.1362300191249735
1.0485819827911005
1.1067850789536362
1.051407792732825
0.9213759536388187
1.0030474058002723
1.1629330150638857
1.1510157575007611
0.9595038530148414
1.1118218708266816
1.104388867201781
1.143635361336795
1.0883614734646372
0.9833711721701329
1.0072068068694482
0.9211400506455959
1.1532189519731728
1.1341791461806436
0.9686110989150057
1.06690337657843
0.9023259504695254
1.1101714614099805
1.0442713542708542
1.0372137256928735
0.9950525864649378
1.0247006210925795
1.09588299511912
1.1271216023960426
1.037142780601701
1.0017366787882707
1.1280519641451356
0.9266522597948607
1.1432113612492898
1.079014008884137
1.133550832183517
1.0381113071416028
1.046024621146372
1.1373911346867935
1.1155237124236475
1.020611172412582
1.1001237712558383
1.132808320672758
1.1097300746326033
1.1377008769401744
0.9272166888570258
1.0641810458075436
1.1718276976130555
0.92860

In [72]:
max_node = 128
batch_num = 32
edge_num = 1e6
path_num = 5e5

node_consumer = a.node_sparse*max_node*batch_num*30 + a.node*max_node*batch_num
edge_consumer = a.edge_matrix*max_node**2*batch_num + a.edge_sparse*edge_num
path_consumer = a.path_matrix*max_node**2*batch_num + a.path_sparse*path_num
relation_cosumer = a.relation_matrix*max_node**2*batch_num + a.relation_ffn*max_node*batch_num
theo = node_consumer+edge_consumer+path_consumer+relation_cosumer

In [73]:
theo/1024**3

57.87671661376953

In [49]:
cfg['encoder']['d_relation']

256

In [15]:
encoder_input['node_input']['node_feat'].shape[-2]

9

In [16]:
1527211008/4*3

1145408256.0

In [17]:
1145408256+45208880*5

1371452656

In [18]:
728486896/1213900288

0.6001208692356781

In [19]:
18844*

SyntaxError: invalid syntax (1716935593.py, line 1)

In [None]:
encoder_input['edge_input']['rel_type'].shape

In [None]:
if encoder_input['path_input']['rel_pos']!=None:
    print('success')

In [None]:
test = nn.Embedding(50150,54)

In [None]:
encoder_input['edge_input']['rel_type'].squeeze(-1)

In [None]:
test(encoder_input['edge_input']['rel_type'].squeeze(-1)).shape

In [None]:
encoder_input['edge_input']

In [None]:
def encoder_input_cuda(encoder_input, device):
    for section_key in encoder_input:
        for key in encoder_input[section_key]:
            if isinstance(encoder_input[section_key][key], torch.Tensor):
                encoder_input[section_key][key] = encoder_input[section_key][key].to(device)
    return encoder_input


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

local_rank = int(os.environ['LOCAL_RANK'])

if local_rank == 0:
    wandb.init(project="Genova", entity="rxnatalie")

torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')  # nccl是GPU设备上最快、最推荐的后端

# 构造模型
device = torch.device("cuda", local_rank)

ds = GenovaDataset(cfg, spec_header=small_spec, dataset_dir_path='./pretrain_data_sparse/')
# num_train_samples = 2000
# ds = Subset(ds, np.arange(num_train_samples))

collate_fn = GenovaCollator(cfg)
sampler = GenovaSampler(ds, cfg, 13)
dl = DataLoader(ds, batch_sampler=sampler, collate_fn=collate_fn, num_workers=1)
# dl = DataLoader(ds,batch_size=4,collate_fn=collate_fn,num_workers=1,shuffle=True)
model = genova.GenovaEncoder(cfg, bin_classification=True).to(local_rank)
model = DDP(model, device_ids=[local_rank])
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-4)
scaler = GradScaler()

CHECKPOINT_PATH = './save/sampler_test/model_max.pt'
# #checkpoint = torch.load(CHECKPOINT_PATH,map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank})['model_state_dict']
# checkpoint = torch.load(CHECKPOINT_PATH,map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank})
# if list(model.state_dict().keys())[0].startswith('module'):
#     #model.load_state_dict(checkpoint)
#     model.load_state_dict(OrderedDict([('module.'+key, v) for key, v in checkpoint.items()]))
# else:
#     #model.load_state_dict(OrderedDict([(key[7:], v) for key, v in checkpoint.items()]))
#     model.load_state_dict(checkpoint)

loss_detect = 0
min_loss = 10000
detect_period = 50
accuracy = 0
recall = 0
precision = 0
for epoch in range(2):
    print('Epoch:', epoch)
    for i, (encoder_input, labels, node_mask) in enumerate(dl, start=1):
        if i % 50 == 0:
            print('Sample:', i)
        encoder_input = encoder_input_cuda(encoder_input, device)
        labels = labels.to(device)
        optimizer.zero_grad()
        with autocast():
            output = model(**encoder_input)
            loss = loss_fn(output[~node_mask], labels[~node_mask])
        if local_rank == 0:
            output = torch.argmax(output[~node_mask], -1)
            labels = labels[~node_mask]
            accuracy += (output == labels).sum() / labels.shape[0]
            recall += ((output == labels)[labels == 1]).sum() / (labels == 1).sum()
            precision += ((output == labels)[labels == 1]).sum() / (output == 1).sum()
            loss_detect += loss.item()
            if i % detect_period == 0:
                wandb.log({"loss": loss_detect / detect_period,
                           "accuracy": accuracy / detect_period,
                           "recall": recall / detect_period,
                           "precision": precision / detect_period}
                          )
                torch.save({'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict()}, CHECKPOINT_PATH)
                loss_detect, accuracy, recall, precision = 0, 0, 0, 0
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

In [None]:
label

In [None]:
label = spec.pop('graph_label')
print(label.shape)
label = torch.any(label, -1).long()

In [None]:
spec['rel_input']

In [None]:
ds[0]['node_input']

In [None]:
node_input[0]['node_feat'].shape