-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tools] introduce circle plus generator #12977
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Circle+ generator | ||
|
||
Circle+ is a circle file which contains training hyperparameters. <br/> | ||
This tool generates a circle+ by adding training hyperparameters to a circle file.<br/> | ||
It also helps handle circle+ file, such as checking whether the circle file contains training hyperparameters. <br/> | ||
|
||
## Requirements | ||
|
||
This tool tested on python3.8. | ||
|
||
1. (optional) Set python virtaul environment. | ||
|
||
``` | ||
python3 -m venv venv | ||
source ./venv/bin/activate | ||
``` | ||
|
||
2. Install required pakcages. | ||
|
||
Currently, only `flatbuffers` is needed. | ||
```bash | ||
python3 -m pip install -r requirements.txt | ||
``` | ||
|
||
## Inject training hpyerparameters using json file | ||
|
||
<!--to be updated --> | ||
|
||
## Print training hyperparameters in circle file | ||
|
||
You can check whether the circle file contains the training hyperparameters.</br> | ||
If you run the `main.py` without providing a json file, it will display training hyperparameters in the given circle file. | ||
|
||
Try this with the files in [example](./example/). | ||
```bash | ||
python3 main.py example/sample.circle | ||
|
||
# expected output | ||
# | ||
# check hyperparameters in example/sample.circle | ||
# No hyperparameters | ||
``` | ||
```bash | ||
python3 main.py example/sample_tparam.circle | ||
|
||
# expected output | ||
# | ||
# check hyperparameters in example/sample_tpram.circle | ||
# { | ||
# "optimizer": { | ||
# "type": "sgd", | ||
# "args": { | ||
# "learningRate": 0.0010000000474974513 | ||
# } | ||
# }, | ||
# "loss": { | ||
# "type": "sparse categorical crossentropy", | ||
# "args": { | ||
# "fromLogits": true, | ||
# "reduction": "SumOverBatchSize" | ||
# } | ||
# }, | ||
# "batchSize": 64 | ||
# } | ||
``` | ||
|
||
If it doesn't work well with example files, please check their md5sum to make sure they're not broken. | ||
|
||
```bash | ||
$ md5sum example/sample.circle example/sample_tpram.circle | ||
|
||
df287dea52cf5bf16bc9dc720e8bca04 example/sample.circle | ||
6e736e0544acc7ccb727cbc8f77add94 example/sample_tpram.circle | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import typing | ||
|
||
from schema import circle_schema_generated as cir_gen | ||
from lib.train_param import TrainParam | ||
|
||
|
||
class CirclePlus(): | ||
''' Wrapper class of circle_schema_generated.ModelT''' | ||
TINFO_META_TAG = "CIRCLE_TRAINING" | ||
|
||
def __init__(self): | ||
self.model: cir_gen.ModelT = cir_gen.ModelT() | ||
|
||
@classmethod | ||
def from_file(cls, circle_file: str): | ||
'''Create CirclePlus based on circle file''' | ||
new_circle_plus = cls() | ||
with open(circle_file, 'rb') as f: | ||
new_circle_plus.model = cir_gen.ModelT.InitFromPackedBuf(f.read()) | ||
return new_circle_plus | ||
|
||
def get_train_param(self) -> typing.Union[TrainParam, None]: | ||
'''Return TrainInfo, if it exists''' | ||
metadata = self.model.metadata | ||
buffers = self.model.buffers | ||
|
||
if metadata == None: | ||
return None | ||
|
||
for m in metadata: | ||
if m.name.decode("utf-8") == self.TINFO_META_TAG: | ||
buff: cir_gen.BufferT = buffers[m.buffer] | ||
tparam: TrainParam = TrainParam.from_buff(buff.data) | ||
return tparam | ||
|
||
return None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import json | ||
|
||
from lib.utils import * | ||
from schema import circle_traininfo_generated as ctr_gen | ||
|
||
|
||
class TrainParam(): | ||
'''Wrapper class of circle_traninfo_generated.ModelTrainingT''' | ||
|
||
def __init__(self): | ||
self.train_param = ctr_gen.ModelTrainingT() | ||
|
||
@classmethod | ||
def from_buff(cls, buff): | ||
'''Create TrainInfo from packed(serialized) buffer''' | ||
new_tparam = cls() | ||
new_tparam.train_param = ctr_gen.ModelTrainingT.InitFromPackedBuf(bytearray(buff)) | ||
return new_tparam | ||
|
||
def dump_as_json(self) -> str: | ||
'''Return JSON formmated string''' | ||
tparam = self.train_param | ||
name_opt = OptimizerNamer() | ||
name_loss = LossNamer() | ||
name_rdt = LossReductionNamer() | ||
|
||
json_form = {} | ||
json_form["optimizer"] = { | ||
"type": name_opt(tparam.optimizerOpt), | ||
"args": tparam.optimizerOpt.__dict__ | ||
} | ||
json_form["loss"] = { | ||
"type": name_loss(tparam.lossfnOpt), | ||
"args": tparam.lossfnOpt.__dict__, | ||
} | ||
json_form["loss"]["args"]["reduction"] = name_rdt(tparam.lossReductionType) | ||
json_form["batchSize"] = tparam.batchSize | ||
|
||
return json.dumps(json_form, indent=4) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from schema import circle_traininfo_generated as ctr_gen | ||
|
||
|
||
class OptimizerNamer: | ||
'''Return name(string) based on ModelTraining.OptimizerOpt''' | ||
names = {ctr_gen.SGDOptionsT: 'sgd', ctr_gen.AdamOptionsT: 'adam'} | ||
|
||
def __call__(cls, opt): | ||
try: | ||
name = cls.names[type(opt)] | ||
except: | ||
print(f"unknown optimizer {type(opt)}") | ||
return name | ||
|
||
|
||
class LossNamer: | ||
'''Return name(string) based on ModelTraining.LossfnOpt''' | ||
names = { | ||
ctr_gen.SparseCategoricalCrossentropyOptionsT: 'sparse categorical crossentropy', | ||
ctr_gen.CategoricalCrossentropyOptionsT: 'categorical crossentorpy', | ||
ctr_gen.MeanSquaredErrorOptionsT: 'mean squared error' | ||
} | ||
|
||
def __call__(cls, lossfn): | ||
try: | ||
name = cls.names[type(lossfn)] | ||
except: | ||
print(f"unknown lossfn {type(lossfn)}") | ||
return name | ||
|
||
|
||
class LossReductionNamer: | ||
'''Return name(string) based on ModelTraining.LossReductionType ''' | ||
names = { | ||
ctr_gen.LossReductionType.SumOverBatchSize: 'SumOverBatchSize', | ||
ctr_gen.LossReductionType.Sum: 'Sum', | ||
} | ||
|
||
def __call__(cls, rdt): | ||
try: | ||
name = cls.names[rdt] | ||
except: | ||
print(f"unknown loss reduction type {rdt}") | ||
return name |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import argparse | ||
import typing | ||
|
||
from lib.circle_plus import CirclePlus | ||
from lib.train_param import TrainParam | ||
|
||
|
||
def get_cmd_args(): | ||
parser = argparse.ArgumentParser( | ||
prog='circle_plus_gen', | ||
description='circle_plus_gen help handle circle file with training hyperparameters' | ||
) | ||
|
||
parser.add_argument('input', help='input circle file') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def print_training_hparameters(in_circle_file) -> typing.NoReturn: | ||
''' | ||
if in_circle_file has training hyperparameters, print it out | ||
''' | ||
print(f"check hyperparameters in {in_circle_file}") | ||
|
||
circle_model: CirclePlus = CirclePlus.from_file(in_circle_file) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any reason to write the type explicitly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No special reason. |
||
tinfo: typing.Union[TrainParam, None] = circle_model.get_train_param() | ||
|
||
if tinfo == None: | ||
print("No hyperparameters") | ||
else: | ||
print(tinfo.dump_as_json()) | ||
zetwhite marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
if __name__ == "__main__": | ||
args = get_cmd_args() | ||
|
||
print_training_hparameters(args.input) | ||
|
||
# TODO: add a function that injects training parameter into circle file |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
flatbuffers==24.3.25 |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,5 @@ | ||||||
# In this directory, *_generated.py is auto generated by flatc. | ||||||
# | ||||||
# * flatc version 23.5.26 is used | ||||||
# ./flatc --python --gen-onefile --gen-object-api ../../../nnpackage/schema/circle_schema.fbs | ||||||
# ./flatc --python --gen-onefile --one-object-api ../../../runtime/libs/circle-schema/include/circle_traininfo.fbs | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
typo again.. I'll fix it in next PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add this feature in next PR