@@ -43,13 +43,16 @@ class ApplicationMixin(object):
4343 dim (int): dimension of embeddings
4444 gpus (list of int, optional): GPU ids, default is all GPUs
4545 cpu_per_gpu (int, optional): number of CPU threads per GPU, default is all CPUs
46+ gpu_memory_limit (int, optional): memory limit per GPU in bytes, default is all memory
4647 float_type (dtype, optional): type of parameters
4748 index_type (dtype, optional): type of graph indexes
4849 """
49- def __init__ (self , dim , gpus = [], cpu_per_gpu = auto , float_type = cfg .float_type , index_type = cfg .index_type ):
50+ def __init__ (self , dim , gpus = [], cpu_per_gpu = auto , gpu_memory_limit = auto ,
51+ float_type = cfg .float_type , index_type = cfg .index_type ):
5052 self .dim = dim
5153 self .gpus = gpus
5254 self .cpu_per_gpu = cpu_per_gpu
55+ self .gpu_memory_limit = gpu_memory_limit
5356 self .float_type = float_type
5457 self .index_type = index_type
5558 self .set_format ()
@@ -236,7 +239,8 @@ def get_solver(self, **kwargs):
236239 num_sampler_per_worker = auto
237240 else :
238241 num_sampler_per_worker = self .cpu_per_gpu - 1
239- return solver .GraphSolver (self .dim , self .float_type , self .index_type , self .gpus , num_sampler_per_worker )
242+ return solver .GraphSolver (self .dim , self .float_type , self .index_type , self .gpus , num_sampler_per_worker ,
243+ self .gpu_memory_limit )
240244
241245 def node_classification (self , X = None , Y = None , file_name = None , portions = (0.02 ,), normalization = False , times = 1 ,
242246 patience = 100 ):
@@ -513,7 +517,8 @@ def get_solver(self, **kwargs):
513517 num_sampler_per_worker = auto
514518 else :
515519 num_sampler_per_worker = self .cpu_per_gpu - 1
516- return solver .GraphSolver (self .dim , self .float_type , self .index_type , self .gpus , num_sampler_per_worker )
520+ return solver .GraphSolver (self .dim , self .float_type , self .index_type , self .gpus , num_sampler_per_worker ,
521+ self .gpu_memory_limit )
517522
518523
519524class KnowledgeGraphApplication (ApplicationMixin ):
@@ -573,7 +578,8 @@ def get_solver(self, **kwargs):
573578 num_sampler_per_worker = auto
574579 else :
575580 num_sampler_per_worker = self .cpu_per_gpu - 1
576- return solver .KnowledgeGraphSolver (self .dim , self .float_type , self .index_type , self .gpus , num_sampler_per_worker )
581+ return solver .KnowledgeGraphSolver (self .dim , self .float_type , self .index_type , self .gpus , num_sampler_per_worker ,
582+ self .gpu_memory_limit )
577583
578584 def entity_prediction (self , H = None , R = None , T = None , file_name = None , save_file = None , target = "tail" , k = 10 ,
579585 backend = cfg .backend ):
@@ -1032,7 +1038,8 @@ def get_solver(self, **kwargs):
10321038 else :
10331039 num_sampler_per_worker = self .cpu_per_gpu - 1
10341040
1035- return solver .VisualizationSolver (self .dim , self .float_type , self .index_type , self .gpus , num_sampler_per_worker )
1041+ return solver .VisualizationSolver (self .dim , self .float_type , self .index_type , self .gpus , num_sampler_per_worker ,
1042+ self .gpu_memory_limit )
10361043
10371044 def visualization (self , Y = None , file_name = None , save_file = None , figure_size = 10 , scale = 2 ):
10381045 """
0 commit comments