-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tools] introduce circle plus generator (#12977)
This introduces a circle plus generator, which helps handle a circle file with training hyperparameters. As a first feature, it checks whether a circle file contains training parameters or not. ONE-DCO-1.0-Signed-off-by: seunghui youn <sseung.youn@samsung.com>
- Loading branch information
Showing
12 changed files
with
20,963 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Circle+ generator | ||
|
||
Circle+ is a circle file which contains training hyperparameters. <br/> | ||
This tool generates a circle+ by adding training hyperparameters to a circle file.<br/> | ||
It also helps handle circle+ file, such as checking whether the circle file contains training hyperparameters. <br/> | ||
|
||
## Requirements | ||
|
||
This tool tested on python3.8. | ||
|
||
1. (optional) Set python virtaul environment. | ||
|
||
``` | ||
python3 -m venv venv | ||
source ./venv/bin/activate | ||
``` | ||
|
||
2. Install required pakcages. | ||
|
||
Currently, only `flatbuffers` is needed. | ||
```bash | ||
python3 -m pip install -r requirements.txt | ||
``` | ||
|
||
## Inject training hpyerparameters using json file | ||
|
||
<!--to be updated --> | ||
|
||
## Print training hyperparameters in circle file | ||
|
||
You can check whether the circle file contains the training hyperparameters.</br> | ||
If you run the `main.py` without providing a json file, it will display training hyperparameters in the given circle file. | ||
|
||
Try this with the files in [example](./example/). | ||
```bash | ||
python3 main.py example/sample.circle | ||
|
||
# expected output | ||
# | ||
# check hyperparameters in example/sample.circle | ||
# No hyperparameters | ||
``` | ||
```bash | ||
python3 main.py example/sample_tparam.circle | ||
|
||
# expected output | ||
# | ||
# check hyperparameters in example/sample_tpram.circle | ||
# { | ||
# "optimizer": { | ||
# "type": "sgd", | ||
# "args": { | ||
# "learningRate": 0.0010000000474974513 | ||
# } | ||
# }, | ||
# "loss": { | ||
# "type": "sparse categorical crossentropy", | ||
# "args": { | ||
# "fromLogits": true, | ||
# "reduction": "SumOverBatchSize" | ||
# } | ||
# }, | ||
# "batchSize": 64 | ||
# } | ||
``` | ||
|
||
If it doesn't work well with example files, please check their md5sum to make sure they're not broken. | ||
|
||
```bash | ||
$ md5sum example/sample.circle example/sample_tpram.circle | ||
|
||
df287dea52cf5bf16bc9dc720e8bca04 example/sample.circle | ||
6e736e0544acc7ccb727cbc8f77add94 example/sample_tpram.circle | ||
``` |
Binary file not shown.
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import typing | ||
|
||
from schema import circle_schema_generated as cir_gen | ||
from lib.train_param import TrainParam | ||
|
||
|
||
class CirclePlus(): | ||
''' Wrapper class of circle_schema_generated.ModelT''' | ||
TINFO_META_TAG = "CIRCLE_TRAINING" | ||
|
||
def __init__(self): | ||
self.model: cir_gen.ModelT = cir_gen.ModelT() | ||
|
||
@classmethod | ||
def from_file(cls, circle_file: str): | ||
'''Create CirclePlus based on circle file''' | ||
new_circle_plus = cls() | ||
with open(circle_file, 'rb') as f: | ||
new_circle_plus.model = cir_gen.ModelT.InitFromPackedBuf(f.read()) | ||
return new_circle_plus | ||
|
||
def get_train_param(self) -> typing.Union[TrainParam, None]: | ||
'''Return TrainInfo, if it exists''' | ||
metadata = self.model.metadata | ||
buffers = self.model.buffers | ||
|
||
if metadata == None: | ||
return None | ||
|
||
for m in metadata: | ||
if m.name.decode("utf-8") == self.TINFO_META_TAG: | ||
buff: cir_gen.BufferT = buffers[m.buffer] | ||
tparam: TrainParam = TrainParam.from_buff(buff.data) | ||
return tparam | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import json | ||
|
||
from lib.utils import * | ||
from schema import circle_traininfo_generated as ctr_gen | ||
|
||
|
||
class TrainParam(): | ||
'''Wrapper class of circle_traninfo_generated.ModelTrainingT''' | ||
|
||
def __init__(self): | ||
self.train_param = ctr_gen.ModelTrainingT() | ||
|
||
@classmethod | ||
def from_buff(cls, buff): | ||
'''Create TrainInfo from packed(serialized) buffer''' | ||
new_tparam = cls() | ||
new_tparam.train_param = ctr_gen.ModelTrainingT.InitFromPackedBuf(bytearray(buff)) | ||
return new_tparam | ||
|
||
def dump_as_json(self) -> str: | ||
'''Return JSON formmated string''' | ||
tparam = self.train_param | ||
name_opt = OptimizerNamer() | ||
name_loss = LossNamer() | ||
name_rdt = LossReductionNamer() | ||
|
||
json_form = {} | ||
json_form["optimizer"] = { | ||
"type": name_opt(tparam.optimizerOpt), | ||
"args": tparam.optimizerOpt.__dict__ | ||
} | ||
json_form["loss"] = { | ||
"type": name_loss(tparam.lossfnOpt), | ||
"args": tparam.lossfnOpt.__dict__, | ||
} | ||
json_form["loss"]["args"]["reduction"] = name_rdt(tparam.lossReductionType) | ||
json_form["batchSize"] = tparam.batchSize | ||
|
||
return json.dumps(json_form, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from schema import circle_traininfo_generated as ctr_gen | ||
|
||
|
||
class OptimizerNamer: | ||
'''Return name(string) based on ModelTraining.OptimizerOpt''' | ||
names = {ctr_gen.SGDOptionsT: 'sgd', ctr_gen.AdamOptionsT: 'adam'} | ||
|
||
def __call__(cls, opt): | ||
try: | ||
name = cls.names[type(opt)] | ||
except: | ||
print(f"unknown optimizer {type(opt)}") | ||
return name | ||
|
||
|
||
class LossNamer: | ||
'''Return name(string) based on ModelTraining.LossfnOpt''' | ||
names = { | ||
ctr_gen.SparseCategoricalCrossentropyOptionsT: 'sparse categorical crossentropy', | ||
ctr_gen.CategoricalCrossentropyOptionsT: 'categorical crossentorpy', | ||
ctr_gen.MeanSquaredErrorOptionsT: 'mean squared error' | ||
} | ||
|
||
def __call__(cls, lossfn): | ||
try: | ||
name = cls.names[type(lossfn)] | ||
except: | ||
print(f"unknown lossfn {type(lossfn)}") | ||
return name | ||
|
||
|
||
class LossReductionNamer: | ||
'''Return name(string) based on ModelTraining.LossReductionType ''' | ||
names = { | ||
ctr_gen.LossReductionType.SumOverBatchSize: 'SumOverBatchSize', | ||
ctr_gen.LossReductionType.Sum: 'Sum', | ||
} | ||
|
||
def __call__(cls, rdt): | ||
try: | ||
name = cls.names[rdt] | ||
except: | ||
print(f"unknown loss reduction type {rdt}") | ||
return name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import argparse | ||
import typing | ||
|
||
from lib.circle_plus import CirclePlus | ||
from lib.train_param import TrainParam | ||
|
||
|
||
def get_cmd_args(): | ||
parser = argparse.ArgumentParser( | ||
prog='circle_plus_gen', | ||
description='circle_plus_gen help handle circle file with training hyperparameters' | ||
) | ||
|
||
parser.add_argument('input', help='input circle file') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def print_training_hparameters(in_circle_file) -> typing.NoReturn: | ||
''' | ||
if in_circle_file has training hyperparameters, print it out | ||
''' | ||
print(f"check hyperparameters in {in_circle_file}") | ||
|
||
circle_model: CirclePlus = CirclePlus.from_file(in_circle_file) | ||
tinfo: typing.Union[TrainParam, None] = circle_model.get_train_param() | ||
|
||
if tinfo == None: | ||
print("No hyperparameters") | ||
else: | ||
print(tinfo.dump_as_json()) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_cmd_args() | ||
|
||
print_training_hparameters(args.input) | ||
|
||
# TODO: add a function that injects training parameter into circle file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
flatbuffers==24.3.25 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# In this directory, *_generated.py is auto generated by flatc. | ||
# | ||
# * flatc version 23.5.26 is used | ||
# ./flatc --python --gen-onefile --gen-object-api ../../../nnpackage/schema/circle_schema.fbs | ||
# ./flatc --python --gen-onefile --one-object-api ../../../runtime/libs/circle-schema/include/circle_traininfo.fbs |
Oops, something went wrong.