Skip to content

Commit

Permalink
Revert "Allow for partial module loads in compiler cache."
Browse files Browse the repository at this point in the history
This reverts commit 2782704.
  • Loading branch information
braxtonmckee committed Feb 16, 2023
1 parent be4ce9b commit 41361cd
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 260 deletions.
25 changes: 12 additions & 13 deletions typed_python/compiler/binary_shared_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@


class LoadedBinarySharedObject(LoadedModule):
def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlobalVariableDefinitions):
super().__init__(functionPointers, serializedGlobalVariableDefinitions)
def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariableDefinitions):
super().__init__(functionPointers, globalVariableDefinitions)

self.binarySharedObject = binarySharedObject
self.diskPath = diskPath
Expand All @@ -36,32 +36,30 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlo
class BinarySharedObject:
"""Models a shared object library (.so) loadable on linux systems."""

def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies):
def __init__(self, binaryForm, functionTypes, globalVariableDefinitions):
"""
Args:
binaryForm: a bytes object containing the actual compiled code for the module
serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition
globalDependencies: a dict from function linkname to the list of global variables it depends on
binaryForm - a bytes object containing the actual compiled code for the module
globalVariableDefinitions - a map from name to GlobalVariableDefinition
"""
self.binaryForm = binaryForm
self.functionTypes = functionTypes
self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions
self.globalDependencies = globalDependencies
self.globalVariableDefinitions = globalVariableDefinitions
self.hash = sha_hash(binaryForm)

@property
def definedSymbols(self):
return self.functionTypes.keys()

@staticmethod
def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
def fromDisk(path, globalVariableDefinitions, functionNameToType):
with open(path, "rb") as f:
binaryForm = f.read()

return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
return BinarySharedObject(binaryForm, functionNameToType, globalVariableDefinitions)

@staticmethod
def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
def fromModule(module, globalVariableDefinitions, functionNameToType):
target_triple = llvm.get_process_triple()
target = llvm.Target.from_triple(target_triple)
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default')
Expand All @@ -82,7 +80,7 @@ def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType,
)

with open(os.path.join(tf, "module.so"), "rb") as so_file:
return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
return BinarySharedObject(so_file.read(), functionNameToType, globalVariableDefinitions)

def load(self, storageDir):
"""Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer"""
Expand Down Expand Up @@ -129,7 +127,8 @@ def loadFromPath(self, modulePath):
self,
modulePath,
functionPointers,
self.serializedGlobalVariableDefinitions
self.globalVariableDefinitions
)
loadedModule.linkGlobalVariables()

return loadedModule
176 changes: 70 additions & 106 deletions typed_python/compiler/compiler_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@
import os
import uuid
import shutil
from typed_python.compiler.loaded_module import LoadedModule
from typed_python.compiler.binary_shared_object import BinarySharedObject

from typing import Optional, List

from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject
from typed_python.compiler.directed_graph import DirectedGraph
from typed_python.compiler.typed_call_target import TypedCallTarget
from typed_python.SerializationContext import SerializationContext
from typed_python import Dict, ListOf

Expand Down Expand Up @@ -55,173 +52,146 @@ def __init__(self, cacheDir):

ensureDirExists(cacheDir)

self.loadedBinarySharedObjects = Dict(str, LoadedBinarySharedObject)()
self.loadedModules = Dict(str, LoadedModule)()
self.nameToModuleHash = Dict(str, str)()

self.moduleManifestsLoaded = set()
self.modulesMarkedValid = set()
self.modulesMarkedInvalid = set()

for moduleHash in os.listdir(self.cacheDir):
if len(moduleHash) == 40:
self.loadNameManifestFromStoredModuleByHash(moduleHash)

# the set of functions with an associated module in loadedBinarySharedObjects
self.targetsLoaded: Dict[str, TypedCallTarget] = {}

# the set of functions with linked and validated globals (i.e. ready to be run).
self.targetsValidated = set()

self.function_dependency_graph = DirectedGraph()
# dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
self.global_dependencies = Dict(str, ListOf(str))()
self.targetsLoaded = {}

def hasSymbol(self, linkName: str) -> bool:
"""NB this will return True even if the linkName is ultimately unretrievable."""
def hasSymbol(self, linkName):
return linkName in self.nameToModuleHash

def getTarget(self, linkName: str) -> TypedCallTarget:
if not self.hasSymbol(linkName):
raise ValueError(f'symbol not found for linkName {linkName}')
def getTarget(self, linkName):
assert self.hasSymbol(linkName)

self.loadForSymbol(linkName)

return self.targetsLoaded[linkName]

def dependencies(self, linkName: str) -> Optional[List[str]]:
"""Returns all the function names that `linkName` depends on"""
return list(self.function_dependency_graph.outgoing(linkName))
def markModuleHashInvalid(self, hashstr):
with open(os.path.join(self.cacheDir, hashstr, "marked_invalid"), "w"):
pass

def loadForSymbol(self, linkName: str) -> None:
"""Loads the whole module, and any submodules, into LoadedBinarySharedObjects"""
def loadForSymbol(self, linkName):
moduleHash = self.nameToModuleHash[linkName]

self.loadModuleByHash(moduleHash)

if linkName not in self.targetsValidated:
dependantFuncs = self.dependencies(linkName) + [linkName]
globalsToLink = {} # dict from modulehash to list of globals.
for funcName in dependantFuncs:
if funcName not in self.targetsValidated:
funcModuleHash = self.nameToModuleHash[funcName]
# append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
globalsToLink[funcModuleHash] = globalsToLink.get(funcModuleHash, []) + self.global_dependencies.get(funcName, [])

for moduleHash, globs in globalsToLink.items(): # this works because loadModuleByHash loads submodules too.
if globs:
definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x]
for x in globs
}
self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink)
if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink):
raise RuntimeError('failed to validate globals when loading:', linkName)

self.targetsValidated.update(dependantFuncs)

def loadModuleByHash(self, moduleHash: str) -> None:
def loadModuleByHash(self, moduleHash):
"""Load a module by name.
As we load, place all the newly imported typed call targets into
'nameToTypedCallTarget' so that the rest of the system knows what functions
have been uncovered.
"""
if moduleHash in self.loadedBinarySharedObjects:
return
if moduleHash in self.loadedModules:
return True

targetDir = os.path.join(self.cacheDir, moduleHash)

# TODO (Will) - store these names as module consts, use one .dat only
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
callTargets = SerializationContext().deserialize(f.read())

with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
serializedGlobalVarDefs = SerializationContext().deserialize(f.read())
try:
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
callTargets = SerializationContext().deserialize(f.read())

with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f:
functionNameToNativeType = SerializationContext().deserialize(f.read())
with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
globalVarDefs = SerializationContext().deserialize(f.read())

with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f:
functionNameToNativeType = SerializationContext().deserialize(f.read())

with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f:
dependency_edgelist = SerializationContext().deserialize(f.read())
with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
except Exception:
self.markModuleHashInvalid(moduleHash)
return False

with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f:
globalDependencies = SerializationContext().deserialize(f.read())
if not LoadedModule.validateGlobalVariables(globalVarDefs):
self.markModuleHashInvalid(moduleHash)
return False

# load the submodules first
for submodule in submodules:
self.loadModuleByHash(submodule)
if not self.loadModuleByHash(submodule):
return False

modulePath = os.path.join(targetDir, "module.so")

loaded = BinarySharedObject.fromDisk(
modulePath,
serializedGlobalVarDefs,
functionNameToNativeType,
globalDependencies

globalVarDefs,
functionNameToNativeType
).loadFromPath(modulePath)

self.loadedBinarySharedObjects[moduleHash] = loaded
self.loadedModules[moduleHash] = loaded

self.targetsLoaded.update(callTargets)

assert not any(key in self.global_dependencies for key in globalDependencies) # should only happen if there's a hash collision.
self.global_dependencies.update(globalDependencies)

# update the cache's dependency graph with our new edges.
for function_name, dependant_function_name in dependency_edgelist:
self.function_dependency_graph.addEdge(source=function_name, dest=dependant_function_name)
return True

def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, dependencyEdgelist):
def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies):
"""Add new code to the compiler cache.
Args:
binarySharedObject: a BinarySharedObject containing the actual assembler
we've compiled.
nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
the formal python types for all the objects.
linkDependencies: a set of linknames we depend on directly.
dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
module.
binarySharedObject - a BinarySharedObject containing the actual assembler
we've compiled
nameToTypedCallTarget - a dict from linkname to TypedCallTarget telling us
the formal python types for all the objects
linkDependencies - a set of linknames we depend on directly.
"""
dependentHashes = set()

for name in linkDependencies:
dependentHashes.add(self.nameToModuleHash[name])

path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes, dependencyEdgelist)
path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes)

self.loadedBinarySharedObjects[hashToUse] = (
self.loadedModules[hashToUse] = (
binarySharedObject.loadFromPath(os.path.join(path, "module.so"))
)

for n in binarySharedObject.definedSymbols:
self.nameToModuleHash[n] = hashToUse

# link & validate all globals for the new module
self.loadedBinarySharedObjects[hashToUse].linkGlobalVariables()
if not self.loadedBinarySharedObjects[hashToUse].validateGlobalVariables(
self.loadedBinarySharedObjects[hashToUse].serializedGlobalVariableDefinitions):
raise RuntimeError('failed to validate globals in new module:', hashToUse)

def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None:
if moduleHash in self.moduleManifestsLoaded:
return
def loadNameManifestFromStoredModuleByHash(self, moduleHash):
if moduleHash in self.modulesMarkedValid:
return True

targetDir = os.path.join(self.cacheDir, moduleHash)

# ignore 'marked invalid'
if os.path.exists(os.path.join(targetDir, "marked_invalid")):
# just bail - don't try to read it now

# for the moment, we don't try to clean up the cache, because
# we can't be sure that some process is not still reading the
# old files.
self.modulesMarkedInvalid.add(moduleHash)
return False

with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
submodules = SerializationContext().deserialize(f.read(), ListOf(str))

for subHash in submodules:
self.loadNameManifestFromStoredModuleByHash(subHash)
if not self.loadNameManifestFromStoredModuleByHash(subHash):
self.markModuleHashInvalid(subHash)
return False

with open(os.path.join(targetDir, "name_manifest.dat"), "rb") as f:
self.nameToModuleHash.update(
SerializationContext().deserialize(f.read(), Dict(str, str))
)

self.moduleManifestsLoaded.add(moduleHash)
self.modulesMarkedValid.add(moduleHash)

return True

def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules, dependencyEdgelist):
def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules):
"""Write out a disk representation of this module.
This includes writing both the shared object, a manifest of the function names
Expand Down Expand Up @@ -274,17 +244,11 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule

# write the type manifest
with open(os.path.join(tempTargetDir, "globals_manifest.dat"), "wb") as f:
f.write(SerializationContext().serialize(binarySharedObject.serializedGlobalVariableDefinitions))
f.write(SerializationContext().serialize(binarySharedObject.globalVariableDefinitions))

with open(os.path.join(tempTargetDir, "submodules.dat"), "wb") as f:
f.write(SerializationContext().serialize(ListOf(str)(submodules), ListOf(str)))

with open(os.path.join(tempTargetDir, "function_dependencies.dat"), "wb") as f:
f.write(SerializationContext().serialize(dependencyEdgelist)) # might need a listof

with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f:
f.write(SerializationContext().serialize(binarySharedObject.globalDependencies))

try:
os.rename(tempTargetDir, targetDir)
except IOError:
Expand All @@ -300,7 +264,7 @@ def function_pointer_by_name(self, linkName):
if moduleHash is None:
raise Exception("Can't find a module for " + linkName)

if moduleHash not in self.loadedBinarySharedObjects:
if moduleHash not in self.loadedModules:
self.loadForSymbol(linkName)

return self.loadedBinarySharedObjects[moduleHash].functionPointers[linkName]
return self.loadedModules[moduleHash].functionPointers[linkName]
Loading

0 comments on commit 41361cd

Please sign in to comment.