Skip to content

Commit

Permalink
Merge pull request #31841 from Dr15Jones/addedPSetTemplate
Browse files Browse the repository at this point in the history
Added PSetTemplate to allow description of a PSet
  • Loading branch information
cmsbuild committed Oct 19, 2020
2 parents 1f87312 + bc71867 commit 9a90074
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 4 deletions.
6 changes: 4 additions & 2 deletions FWCore/ParameterSet/python/Config.py
Expand Up @@ -1990,7 +1990,9 @@ def __init__(self,*arg,**args):

def testProcessDumpPython(self):
self.assertEqual(Process("test").dumpPython(),
"""import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")
"""import FWCore.ParameterSet.Config as cms
process = cms.Process("test")
process.maxEvents = cms.untracked.PSet(
input = cms.optional.untracked.int32,
Expand All @@ -2011,7 +2013,7 @@ def testProcessDumpPython(self):
emptyRunLumiMode = cms.obsolete.untracked.string,
eventSetup = cms.untracked.PSet(
forceNumberOfConcurrentIOVs = cms.untracked.PSet(
allowAnyLabel_=cms.required.untracked.uint32
),
numberOfConcurrentIOVs = cms.untracked.uint32(1)
),
Expand Down
4 changes: 4 additions & 0 deletions FWCore/ParameterSet/python/Mixins.py
Expand Up @@ -364,6 +364,10 @@ def dumpPython(self, options=PrintOptions()):
# usings need to go first
resultList = usings
resultList.extend(others)
if self.__validator is not None:
options.indent()
resultList.append(options.indentation()+"allowAnyLabel_="+self.__validator.dumpPython(options))
options.unindent()
return ',\n'.join(resultList)+'\n'
def __repr__(self):
return self.dumpPython()
Expand Down
73 changes: 71 additions & 2 deletions FWCore/ParameterSet/python/Types.py
Expand Up @@ -59,6 +59,8 @@ def __setattr__(self,name, value):
if v is not None:
return setattr(v,name,value)
else:
if not name.startswith('_'):
raise AttributeError("%r object has no attribute %r" % (self.__class__.__name__, name))
return object.__setattr__(self, name, value)
def __bool__(self):
v = self.__dict__.get('_ProxyParameter__value',None)
Expand All @@ -71,7 +73,9 @@ def dumpPython(self, options=PrintOptions()):
v = "cms."+self._dumpPythonName()
if not _ParameterTypeBase.isTracked(self):
v+=".untracked"
return v+'.'+self.__type.__name__
if hasattr(self.__type, "__name__"):
return v+'.'+self.__type.__name__
return v+'.'+self.__type.dumpPython(options)
def validate_(self,value):
return isinstance(value,self.__type)
def convert_(self,value):
Expand Down Expand Up @@ -138,6 +142,19 @@ def __call__(self,value):
raise RuntimeError("Cannot convert "+str(value)+" to 'allowed' type")
return chosenType(value)

class _PSetTemplate(object):
def __init__(self, *args, **kargs):
self._pset = PSet(*args,**kargs)
self.__dict__['_PSetTemplate__value'] = None
def __call__(self, value):
self.__dict__
return self._pset.clone(**value)
def dumpPython(self, options=PrintOptions()):
v =self.__dict__.get('_ProxyParameter__value',None)
if v is not None:
return v.dumpPython(options)
return "PSetTemplate(\n"+_Parameterizable.dumpPython(self._pset, options)+options.indentation()+")"


class _ProxyParameterFactory(object):
"""Class type for ProxyParameter types to allow nice syntax"""
Expand All @@ -160,7 +177,17 @@ def __call__(self, *args):
return self.type(_AllowedParameterTypes(*args))

return _AllowedWrapper(self.__isUntracked, self.__type)

if name == 'PSetTemplate':
class _PSetTemplateWrapper(object):
def __init__(self, untracked, type):
self.untracked = untracked
self.type = type
def __call__(self,*args,**kargs):
if self.untracked:
return untracked(self.type(_PSetTemplate(*args,**kargs)))
return self.type(_PSetTemplate(*args,**kargs))
return _PSetTemplateWrapper(self.__isUntracked, self.__type)

type = globals()[name]
if not issubclass(type, _ParameterTypeBase):
raise AttributeError
Expand Down Expand Up @@ -1859,6 +1886,27 @@ def testRequired(self):
self.assertEqual(p1.foo.value(),3)
self.failIf(p1.foo.isTracked())
self.assertRaises(ValueError,setattr,p1, 'bar', 'bad')
#PSetTemplate use
p1 = PSet(aPSet = required.PSetTemplate())
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.required.PSetTemplate(\n\n )\n)')
p1.aPSet = dict()
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n\n )\n)')
p1 = PSet(aPSet=required.PSetTemplate(a=required.int32))
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.required.PSetTemplate(\n a = cms.required.int32\n )\n)')
p1.aPSet = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n a = cms.int32(5)\n )\n)')
self.assertEqual(p1.aPSet.a.value(), 5)
p1 = PSet(aPSet=required.untracked.PSetTemplate(a=required.int32))
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.required.untracked.PSetTemplate(\n a = cms.required.int32\n )\n)')
p1.aPSet = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.untracked.PSet(\n a = cms.int32(5)\n )\n)')
self.assertEqual(p1.aPSet.a.value(), 5)
p1 = PSet(allowAnyLabel_=required.PSetTemplate(a=required.int32))
self.assertEqual(p1.dumpPython(), 'cms.PSet(\n allowAnyLabel_=cms.required.PSetTemplate(\n a = cms.required.int32\n )\n)')
p1.foo = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n foo = cms.PSet(\n a = cms.int32(5)\n ),\n allowAnyLabel_=cms.required.PSetTemplate(\n a = cms.required.int32\n )\n)')
self.assertEqual(p1.foo.a.value(), 5)

def testOptional(self):
p1 = PSet(anInt = optional.int32)
self.assert_(hasattr(p1,"anInt"))
Expand Down Expand Up @@ -1887,6 +1935,27 @@ def testOptional(self):
self.failIf(p1.f)
p1.f.append(3)
self.assert_(p1.f)
#PSetTemplate use
p1 = PSet(aPSet = optional.PSetTemplate())
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.optional.PSetTemplate(\n\n )\n)')
p1.aPSet = dict()
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n\n )\n)')
p1 = PSet(aPSet=optional.PSetTemplate(a=optional.int32))
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.optional.PSetTemplate(\n a = cms.optional.int32\n )\n)')
p1.aPSet = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n a = cms.int32(5)\n )\n)')
self.assertEqual(p1.aPSet.a.value(), 5)
p1 = PSet(aPSet=optional.untracked.PSetTemplate(a=optional.int32))
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.optional.untracked.PSetTemplate(\n a = cms.optional.int32\n )\n)')
p1.aPSet = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.untracked.PSet(\n a = cms.int32(5)\n )\n)')
self.assertEqual(p1.aPSet.a.value(), 5)
p1 = PSet(allowAnyLabel_=optional.PSetTemplate(a=optional.int32))
self.assertEqual(p1.dumpPython(), 'cms.PSet(\n allowAnyLabel_=cms.optional.PSetTemplate(\n a = cms.optional.int32\n )\n)')
p1.foo = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n foo = cms.PSet(\n a = cms.int32(5)\n ),\n allowAnyLabel_=cms.optional.PSetTemplate(\n a = cms.optional.int32\n )\n)')
self.assertEqual(p1.foo.a.value(), 5)


def testAllowed(self):
p1 = PSet(aValue = required.allowed(int32, string))
Expand Down

0 comments on commit 9a90074

Please sign in to comment.