Skip to content

Commit

Permalink
Refactor serialize & deserialize of models.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancern committed Sep 9, 2019
1 parent 92ec094 commit 6e975f0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 20 deletions.
9 changes: 9 additions & 0 deletions asm2vec/internal/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ def to_dict(self) -> Dict[str, Any]:
'jobs': self.jobs
}

def populate(self, rep: Dict[bytes, Any]) -> None:
self.d: int = rep.get(b'd', 200)
self.initial_alpha: float = rep.get(b'alpha', 0.0025)
self.alpha_update_interval: int = rep.get(b'alpha_update_interval', 10000)
self.num_of_rnd_walks: int = rep.get(b'rnd_walks', 3)
self.neg_samples: int = rep.get(b'neg_samples', 25)
self.iteration: int = rep.get(b'iteration', 1)
self.jobs: int = rep.get(b'jobs', 4)


class SequenceWindow:
def __init__(self, sequence: List[Instruction], vocabulary: Dict[str, Token]):
Expand Down
7 changes: 4 additions & 3 deletions asm2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def serialize(self) -> Dict[str, Any]:
'vocab': asm2vec.repo.serialize_vocabulary(self.vocab)
}

def populate(self, rep: Dict[str, Any]) -> None:
self.params = asm2vec.internal.training.Asm2VecParams(**rep['params'])
self.vocab = asm2vec.repo.deserialize_vocabulary(rep['vocab'])
def populate(self, rep: Dict[bytes, Any]) -> None:
self.params = asm2vec.internal.training.Asm2VecParams()
self.params.populate(rep[b'params'])
self.vocab = asm2vec.repo.deserialize_vocabulary(rep[b'vocab'])


class Asm2Vec:
Expand Down
35 changes: 18 additions & 17 deletions asm2vec/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def _serialize_token(token: Token) -> Dict[str, Any]:
}


def _deserialize_token(rep: Dict[str, Any]) -> Token:
name = rep['name']
v = np.array(rep['v'])
v_pred = np.array(rep['v_pred'])
count = rep['count']
frequency = rep['frequency']
def _deserialize_token(rep: Dict[bytes, Any]) -> Token:
name = rep[b'name'].decode('utf-8')
v = np.array(rep[b'v'])
v_pred = np.array(rep[b'v_pred'])
count = rep[b'count']
frequency = rep[b'frequency']

token = Token(VectorizedToken(name, v, v_pred))
token.count = count
Expand All @@ -106,16 +106,17 @@ def serialize_vocabulary(vocab: Dict[str, Token]) -> Dict[str, Any]:
return dict(zip(vocab.keys(), map(_serialize_token, vocab.values())))


def deserialize_vocabulary(rep: Dict[str, Any]) -> Dict[str, Token]:
return dict(zip(rep.keys(), map(_deserialize_token, rep.values())))
def deserialize_vocabulary(rep: Dict[bytes, Any]) -> Dict[str, Token]:
return dict(zip(map(lambda b: b.decode('utf-8'), rep.keys()), map(_deserialize_token, rep.values())))


def _serialize_sequence(seq: List[asm2vec.asm.Instruction]) -> List[Any]:
return list(map(lambda instr: [instr.op(), instr.args()], seq))


def _deserialize_sequence(rep: List[Any]) -> List[asm2vec.asm.Instruction]:
return list(map(lambda instr_rep: asm2vec.asm.Instruction(instr_rep[0], instr_rep[1]), rep))
return list(map(
lambda instr_rep: asm2vec.asm.Instruction(instr_rep[0].decode('utf-8'), instr_rep[1].decode('utf-8')), rep))


def _serialize_vectorized_function(func: VectorizedFunction, include_sequences: bool) -> Dict[str, Any]:
Expand All @@ -131,11 +132,11 @@ def _serialize_vectorized_function(func: VectorizedFunction, include_sequences:
return data


def _deserialize_vectorized_function(rep: Dict[str, Any]) -> VectorizedFunction:
name = rep['name']
fid = rep['id']
v = np.array(rep['v'])
sequences = list(map(_deserialize_sequence, rep.get('sequences', [])))
def _deserialize_vectorized_function(rep: Dict[bytes, Any]) -> VectorizedFunction:
name = rep[b'name'].decode('utf-8')
fid = rep[b'id']
v = np.array(rep[b'v'])
sequences = list(map(_deserialize_sequence, rep.get(b'sequences', [])))
return VectorizedFunction(SequentialFunction(fid, name, sequences), v)


Expand All @@ -158,7 +159,7 @@ def serialize_function_repo(repo: FunctionRepository, flags: int) -> Dict[str, A
return data


def deserialize_function_repo(rep: Dict[str, Any]) -> FunctionRepository:
funcs = list(map(_deserialize_vectorized_function, rep.get('funcs', [])))
vocab = deserialize_vocabulary(rep.get('vocab', dict()))
def deserialize_function_repo(rep: Dict[bytes, Any]) -> FunctionRepository:
funcs = list(map(_deserialize_vectorized_function, rep.get(b'funcs', [])))
vocab = deserialize_vocabulary(rep.get(b'vocab', dict()))
return FunctionRepository(funcs, vocab)

0 comments on commit 6e975f0

Please sign in to comment.