Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tools] introduce circle plus generator #12977

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
74 changes: 74 additions & 0 deletions tools/circle_plus_gen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Circle+ generator

Circle+ is a circle file which contains training hyperparameters. <br/>
This tool generates a circle+ by adding training hyperparameters to a circle file.<br/>
It also helps handle circle+ file, such as checking whether the circle file contains training hyperparameters. <br/>

## Requirements

This tool tested on python3.8.

1. (optional) Set python virtaul environment.

```
python3 -m venv venv
source ./venv/bin/activate
```

2. Install required pakcages.

Currently, only `flatbuffers` is needed.
```bash
python3 -m pip install -r requirements.txt
```

## Inject training hpyerparameters using json file

<!--to be updated -->
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add this feature in next PR


## Print training hyperparameters in circle file

You can check whether the circle file contains the training hyperparameters.</br>
If you run the `main.py` without providing a json file, it will display training hyperparameters in the given circle file.

Try this with the files in [example](./example/).
```bash
python3 main.py example/sample.circle

# expected output
#
# check hyperparameters in example/sample.circle
# No hyperparameters
```
```bash
python3 main.py example/sample_tparam.circle

# expected output
#
# check hyperparameters in example/sample_tpram.circle
# {
# "optimizer": {
# "type": "sgd",
# "args": {
# "learningRate": 0.0010000000474974513
# }
# },
# "loss": {
# "type": "sparse categorical crossentropy",
# "args": {
# "fromLogits": true,
# "reduction": "SumOverBatchSize"
# }
# },
# "batchSize": 64
# }
```

If it doesn't work well with example files, please check their md5sum to make sure they're not broken.

```bash
$ md5sum example/sample.circle example/sample_tpram.circle

df287dea52cf5bf16bc9dc720e8bca04 example/sample.circle
6e736e0544acc7ccb727cbc8f77add94 example/sample_tpram.circle
```
Binary file added tools/circle_plus_gen/example/sample.circle
Binary file not shown.
Binary file added tools/circle_plus_gen/example/sample_tpram.circle
Binary file not shown.
Empty file.
36 changes: 36 additions & 0 deletions tools/circle_plus_gen/lib/circle_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import typing

from schema import circle_schema_generated as cir_gen
from lib.train_param import TrainParam


class CirclePlus():
''' Wrapper class of circle_schema_generated.ModelT'''
TINFO_META_TAG = "CIRCLE_TRAINING"

def __init__(self):
self.model: cir_gen.ModelT = cir_gen.ModelT()

@classmethod
def from_file(cls, circle_file: str):
'''Create CirclePlus based on circle file'''
new_circle_plus = cls()
with open(circle_file, 'rb') as f:
new_circle_plus.model = cir_gen.ModelT.InitFromPackedBuf(f.read())
return new_circle_plus

def get_train_param(self) -> typing.Union[TrainParam, None]:
'''Return TrainInfo, if it exists'''
metadata = self.model.metadata
buffers = self.model.buffers

if metadata == None:
return None

for m in metadata:
if m.name.decode("utf-8") == self.TINFO_META_TAG:
buff: cir_gen.BufferT = buffers[m.buffer]
tparam: TrainParam = TrainParam.from_buff(buff.data)
return tparam

return None
39 changes: 39 additions & 0 deletions tools/circle_plus_gen/lib/train_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json

from lib.utils import *
from schema import circle_traininfo_generated as ctr_gen


class TrainParam():
'''Wrapper class of circle_traninfo_generated.ModelTrainingT'''

def __init__(self):
self.train_param = ctr_gen.ModelTrainingT()

@classmethod
def from_buff(cls, buff):
'''Create TrainInfo from packed(serialized) buffer'''
new_tparam = cls()
new_tparam.train_param = ctr_gen.ModelTrainingT.InitFromPackedBuf(bytearray(buff))
return new_tparam

def dump_as_json(self) -> str:
'''Return JSON formmated string'''
tparam = self.train_param
name_opt = OptimizerNamer()
name_loss = LossNamer()
name_rdt = LossReductionNamer()

json_form = {}
json_form["optimizer"] = {
"type": name_opt(tparam.optimizerOpt),
"args": tparam.optimizerOpt.__dict__
}
json_form["loss"] = {
"type": name_loss(tparam.lossfnOpt),
"args": tparam.lossfnOpt.__dict__,
}
json_form["loss"]["args"]["reduction"] = name_rdt(tparam.lossReductionType)
json_form["batchSize"] = tparam.batchSize

return json.dumps(json_form, indent=4)
44 changes: 44 additions & 0 deletions tools/circle_plus_gen/lib/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from schema import circle_traininfo_generated as ctr_gen


class OptimizerNamer:
'''Return name(string) based on ModelTraining.OptimizerOpt'''
names = {ctr_gen.SGDOptionsT: 'sgd', ctr_gen.AdamOptionsT: 'adam'}

def __call__(cls, opt):
try:
name = cls.names[type(opt)]
except:
print(f"unknown optimizer {type(opt)}")
return name


class LossNamer:
'''Return name(string) based on ModelTraining.LossfnOpt'''
names = {
ctr_gen.SparseCategoricalCrossentropyOptionsT: 'sparse categorical crossentropy',
ctr_gen.CategoricalCrossentropyOptionsT: 'categorical crossentorpy',
ctr_gen.MeanSquaredErrorOptionsT: 'mean squared error'
}

def __call__(cls, lossfn):
try:
name = cls.names[type(lossfn)]
except:
print(f"unknown lossfn {type(lossfn)}")
return name


class LossReductionNamer:
'''Return name(string) based on ModelTraining.LossReductionType '''
names = {
ctr_gen.LossReductionType.SumOverBatchSize: 'SumOverBatchSize',
ctr_gen.LossReductionType.Sum: 'Sum',
}

def __call__(cls, rdt):
try:
name = cls.names[rdt]
except:
print(f"unknown loss reduction type {rdt}")
return name
40 changes: 40 additions & 0 deletions tools/circle_plus_gen/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
import typing

from lib.circle_plus import CirclePlus
from lib.train_param import TrainParam


def get_cmd_args():
parser = argparse.ArgumentParser(
prog='circle_plus_gen',
description='circle_plus_gen help handle circle file with training hyperparameters'
)

parser.add_argument('input', help='input circle file')

args = parser.parse_args()
return args


def print_training_hparameters(in_circle_file) -> typing.NoReturn:
'''
if in_circle_file has training hyperparameters, print it out
'''
print(f"check hyperparameters in {in_circle_file}")

circle_model: CirclePlus = CirclePlus.from_file(in_circle_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to write the type explicitly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No special reason.
I thought that.. writing type here helps the reader not to read the inside of CirclePlus.from_file.

tinfo: typing.Union[TrainParam, None] = circle_model.get_train_param()

if tinfo == None:
print("No hyperparameters")
else:
print(tinfo.dump_as_json())
zetwhite marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
args = get_cmd_args()

print_training_hparameters(args.input)

# TODO: add a function that injects training parameter into circle file
1 change: 1 addition & 0 deletions tools/circle_plus_gen/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
flatbuffers==24.3.25
5 changes: 5 additions & 0 deletions tools/circle_plus_gen/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# In this directory, *_generated.py is auto generated by flatc.
#
# * flatc version 23.5.26 is used
# ./flatc --python --gen-onefile --gen-object-api ../../../nnpackage/schema/circle_schema.fbs
# ./flatc --python --gen-onefile --one-object-api ../../../runtime/libs/circle-schema/include/circle_traininfo.fbs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ./flatc --python --gen-onefile --one-object-api ../../../runtime/libs/circle-schema/include/circle_traininfo.fbs
# ./flatc --python --gen-onefile --gen-object-api ../../../runtime/libs/circle-schema/include/circle_traininfo.fbs

typo again.. I'll fix it in next PR.