-
Notifications
You must be signed in to change notification settings - Fork 153
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tools] introduce circle plus generator
This introduces a circle plus generator, which helps handle a circle file with training hyperparameters. As a first feature, it checks whether a circle file contains training parameters or not. ONE-DCO-1.0-Signed-off-by: seunghui youn <sseung.youn@samsung.com>
- Loading branch information
Showing
11 changed files
with
20,940 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Circle+ generator | ||
|
||
If a circle file has training hyperparameters, we usually call it a 'circle+' file.<br/> | ||
This tool generates a circle+ file by injecting training hyperparameters into a circle file.<br/> | ||
It also helps handle circle+ file, such as checking whether the circle file contains training hyperparameters. <br/> | ||
|
||
## Requirements | ||
|
||
1. (optional) Set python virtaul environment. | ||
|
||
This tool tested on python3.8. | ||
|
||
``` | ||
python3 -m venv venv | ||
source /venv/bin/activate | ||
``` | ||
|
||
2. Install required pakcages. | ||
|
||
Currently, only `flatbuffers==24.3.25` is needed. | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Inject training parameters using json file | ||
|
||
<!--to be updated --> | ||
|
||
## Check if the circle file contains training parameters | ||
|
||
You can check whether the circle file contains the training parameters.</br> | ||
If you run the `main.py` without providing a json file, it will check training parameters and display them. | ||
|
||
Try this with the files in [example](./example/). | ||
```bash | ||
python3 main.py example/mnist.circle | ||
``` | ||
```bash | ||
python3 main.py example/mnist_with_tparam.circle | ||
``` |
Binary file not shown.
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import typing | ||
|
||
from schema import circle_schema_generated as cir_gen | ||
from lib.train_info import TrainInfo | ||
|
||
|
||
class CirclePlus(cir_gen.ModelT): | ||
''' | ||
Wrapper of auto generated circle_schema_generated.ModelT | ||
''' | ||
TINFO_META_TAG = "CIRCLE_TRAINING" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
@classmethod | ||
def from_file(cls, circle_file: str): | ||
'''Create CirclePlus based on circle file''' | ||
with open(circle_file, 'rb') as f: | ||
circle = super().InitFromPackedBuf(f.read()) | ||
circle.__class__ = CirclePlus | ||
return circle | ||
|
||
def get_train_info(self) -> typing.Union[TrainInfo, None]: | ||
'''Return TrainInfo, if it exists''' | ||
if not self.metadata: | ||
return None | ||
|
||
for meta in self.metadata: | ||
if meta.name.decode("utf-8") == self.TINFO_META_TAG: | ||
buff: cir_gen.BufferT = self.buffers[meta.buffer] | ||
tinfo: TrainInfo = TrainInfo.from_buff(buff.data) | ||
return tinfo | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import json | ||
|
||
from schema import circle_traininfo_generated as ctr_gen | ||
''' | ||
Wrappers of the auto generated classes in circle_traininfo_generated | ||
Wrapper classes provides additional interfaces(e.g. initializer) to the auto generated classes. | ||
''' | ||
|
||
# Optimizers | ||
|
||
|
||
class SGD(ctr_gen.SGDOptionsT): | ||
name = ['sgd', 'stocasticgradientdescent'] | ||
|
||
|
||
class Adam(ctr_gen.AdamOptionsT): | ||
name = ['adam'] | ||
|
||
|
||
# Loss | ||
|
||
|
||
class SparseCategoricalCrossEntropy(ctr_gen.SparseCategoricalCrossentropyOptionsT): | ||
name = [ | ||
'sparse categorical crossentropy', 'sparsecategoricalcrossentropy', 'sparsecce' | ||
] | ||
|
||
|
||
class CategoricalCrossEntropy(ctr_gen.CategoricalCrossentropyOptionsT): | ||
name = ['categorical crossentropy', 'categoricalcrossentropy', 'cce'] | ||
|
||
|
||
class MeanSqauaredError(ctr_gen.MeanSquaredErrorOptionsT): | ||
name = ['mean squared error', 'mse'] | ||
|
||
|
||
# TrainInfo | ||
|
||
|
||
class TrainInfo(ctr_gen.ModelTrainingT): | ||
TRAINING_FILE_IDENTIFIER = b"CTR0" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
@classmethod | ||
def from_buff(cls, buff): | ||
'''Create TrainInfo from buffer(byte array)''' | ||
tinfo = super().InitFromPackedBuf(bytearray(buff)) | ||
tinfo.__class__ = TrainInfo | ||
|
||
# convert ModelTrainingT.optimizerOpt to wrapped class | ||
if tinfo.optimizer == ctr_gen.Optimizer.SGD: | ||
tinfo.optimizerOpt.__class__ = SGD | ||
elif tinfo.optimizer == ctr_gen.Optimizer.ADAM: | ||
tinfo.optimizerOpt.__class__ = Adam | ||
else: | ||
raise RuntimeError(f"Unknown optimizer {tinfo.optimizer}") | ||
|
||
# convert ModelTrainingT.lossfnOpt to wrapped class | ||
if tinfo.lossfn == ctr_gen.LossFn.SPARSE_CATEGORICAL_CROSSENTROPY: | ||
tinfo.lossfnOpt.__class__ = SparseCategoricalCrossEntropy | ||
elif tinfo.lossfn == ctr_gen.LossFn.CATEGORICAL_CROSSENTROPY: | ||
tinfo.lossfnOpt.__class__ = CategoricalCrossEntropy | ||
elif tinfo.lossfn == ctr_gen.LossFn.MEAN_SQUARED_ERROR: | ||
tinfo.lossfnOpt.__class__ = MeanSqauaredError | ||
else: | ||
raise RuntimeError(f"Unknown lossfn {tinfo.lossfn}") | ||
|
||
return tinfo | ||
|
||
def dump_as_json(self) -> str: | ||
'''Return JSON frommated string''' | ||
json_form = {} | ||
json_form["optimizer"] = { | ||
"type": self.optimizerOpt.name[0], | ||
"args": self.optimizerOpt.__dict__ | ||
} | ||
json_form["loss"] = { | ||
"type": self.lossfnOpt.name[0], | ||
"args": self.lossfnOpt.__dict__, | ||
} | ||
reduction_str = { | ||
ctr_gen.LossReductionType.SumOverBatchSize: "SumOverBatchSize", | ||
ctr_gen.LossReductionType.Sum: "Sum", | ||
} | ||
json_form["loss"]["args"]["reduction"] = reduction_str[self.lossReductionType] | ||
json_form["batchSize"] = self.batchSize | ||
|
||
return json.dumps(json_form, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import argparse | ||
import logging | ||
import json | ||
import typing | ||
|
||
from lib.circle_plus import CirclePlus | ||
from lib.train_info import TrainInfo | ||
|
||
|
||
def get_cmd_args(): | ||
parser = argparse.ArgumentParser( | ||
prog='circle plus generator', | ||
description='help handle circle file with training hyper parameters') | ||
|
||
parser.add_argument( | ||
'input_circle_file', metavar="input.circle", type=str, help='input circle file') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def check(in_circle_file) -> typing.NoReturn: | ||
''' | ||
Check in_circle_file has training hyperparameters and print it. | ||
''' | ||
circle_model: CirclePlus = CirclePlus.from_file(in_circle_file) | ||
tinfo = circle_model.get_train_info() | ||
|
||
print(f"check hyperparameters in {in_circle_file}") | ||
if tinfo == None: | ||
print("No hyperparameters") | ||
else: | ||
print(tinfo.dump_as_json()) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_cmd_args() | ||
|
||
check(args.input_circle_file) | ||
|
||
# TODO: add a function that injects training parameter into circle file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
flatbuffers==24.3.25 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# In this directory, *_generated.py is auto generated by flatc. | ||
# | ||
# * flatc version 23.5.26 is used | ||
# * /ONE/nnpackage/schema/circle_schema.fbs is used | ||
# * /ONE/runtime/libs/circle-schema/include/circle_traininfo.fbs is used | ||
# | ||
# ./flatc --python --gen-onefile --gen-object-api circle_schema.fbs | ||
# ./flatc --python --gen-onefile --one-object-api cricle_traininfo.fbs |
Oops, something went wrong.