Skip to content

Commit

Permalink
[tools] introduce circle plus generator (#12977)
Browse files Browse the repository at this point in the history
This introduces a circle plus generator, which helps handle a circle file with training hyperparameters.
As a first feature, it checks whether a circle file contains training parameters or not.

ONE-DCO-1.0-Signed-off-by: seunghui youn <sseung.youn@samsung.com>
  • Loading branch information
zetwhite committed May 14, 2024
1 parent 7336cda commit 87506c7
Show file tree
Hide file tree
Showing 12 changed files with 20,963 additions and 0 deletions.
74 changes: 74 additions & 0 deletions tools/circle_plus_gen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Circle+ generator

Circle+ is a circle file which contains training hyperparameters. <br/>
This tool generates a circle+ by adding training hyperparameters to a circle file.<br/>
It also helps handle circle+ file, such as checking whether the circle file contains training hyperparameters. <br/>

## Requirements

This tool tested on python3.8.

1. (optional) Set python virtaul environment.

```
python3 -m venv venv
source ./venv/bin/activate
```

2. Install required pakcages.

Currently, only `flatbuffers` is needed.
```bash
python3 -m pip install -r requirements.txt
```

## Inject training hpyerparameters using json file

<!--to be updated -->

## Print training hyperparameters in circle file

You can check whether the circle file contains the training hyperparameters.</br>
If you run the `main.py` without providing a json file, it will display training hyperparameters in the given circle file.

Try this with the files in [example](./example/).
```bash
python3 main.py example/sample.circle

# expected output
#
# check hyperparameters in example/sample.circle
# No hyperparameters
```
```bash
python3 main.py example/sample_tparam.circle

# expected output
#
# check hyperparameters in example/sample_tpram.circle
# {
# "optimizer": {
# "type": "sgd",
# "args": {
# "learningRate": 0.0010000000474974513
# }
# },
# "loss": {
# "type": "sparse categorical crossentropy",
# "args": {
# "fromLogits": true,
# "reduction": "SumOverBatchSize"
# }
# },
# "batchSize": 64
# }
```

If it doesn't work well with example files, please check their md5sum to make sure they're not broken.

```bash
$ md5sum example/sample.circle example/sample_tpram.circle

df287dea52cf5bf16bc9dc720e8bca04 example/sample.circle
6e736e0544acc7ccb727cbc8f77add94 example/sample_tpram.circle
```
Binary file added tools/circle_plus_gen/example/sample.circle
Binary file not shown.
Binary file added tools/circle_plus_gen/example/sample_tpram.circle
Binary file not shown.
Empty file.
36 changes: 36 additions & 0 deletions tools/circle_plus_gen/lib/circle_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import typing

from schema import circle_schema_generated as cir_gen
from lib.train_param import TrainParam


class CirclePlus():
''' Wrapper class of circle_schema_generated.ModelT'''
TINFO_META_TAG = "CIRCLE_TRAINING"

def __init__(self):
self.model: cir_gen.ModelT = cir_gen.ModelT()

@classmethod
def from_file(cls, circle_file: str):
'''Create CirclePlus based on circle file'''
new_circle_plus = cls()
with open(circle_file, 'rb') as f:
new_circle_plus.model = cir_gen.ModelT.InitFromPackedBuf(f.read())
return new_circle_plus

def get_train_param(self) -> typing.Union[TrainParam, None]:
'''Return TrainInfo, if it exists'''
metadata = self.model.metadata
buffers = self.model.buffers

if metadata == None:
return None

for m in metadata:
if m.name.decode("utf-8") == self.TINFO_META_TAG:
buff: cir_gen.BufferT = buffers[m.buffer]
tparam: TrainParam = TrainParam.from_buff(buff.data)
return tparam

return None
39 changes: 39 additions & 0 deletions tools/circle_plus_gen/lib/train_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json

from lib.utils import *
from schema import circle_traininfo_generated as ctr_gen


class TrainParam():
'''Wrapper class of circle_traninfo_generated.ModelTrainingT'''

def __init__(self):
self.train_param = ctr_gen.ModelTrainingT()

@classmethod
def from_buff(cls, buff):
'''Create TrainInfo from packed(serialized) buffer'''
new_tparam = cls()
new_tparam.train_param = ctr_gen.ModelTrainingT.InitFromPackedBuf(bytearray(buff))
return new_tparam

def dump_as_json(self) -> str:
'''Return JSON formmated string'''
tparam = self.train_param
name_opt = OptimizerNamer()
name_loss = LossNamer()
name_rdt = LossReductionNamer()

json_form = {}
json_form["optimizer"] = {
"type": name_opt(tparam.optimizerOpt),
"args": tparam.optimizerOpt.__dict__
}
json_form["loss"] = {
"type": name_loss(tparam.lossfnOpt),
"args": tparam.lossfnOpt.__dict__,
}
json_form["loss"]["args"]["reduction"] = name_rdt(tparam.lossReductionType)
json_form["batchSize"] = tparam.batchSize

return json.dumps(json_form, indent=4)
44 changes: 44 additions & 0 deletions tools/circle_plus_gen/lib/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from schema import circle_traininfo_generated as ctr_gen


class OptimizerNamer:
'''Return name(string) based on ModelTraining.OptimizerOpt'''
names = {ctr_gen.SGDOptionsT: 'sgd', ctr_gen.AdamOptionsT: 'adam'}

def __call__(cls, opt):
try:
name = cls.names[type(opt)]
except:
print(f"unknown optimizer {type(opt)}")
return name


class LossNamer:
'''Return name(string) based on ModelTraining.LossfnOpt'''
names = {
ctr_gen.SparseCategoricalCrossentropyOptionsT: 'sparse categorical crossentropy',
ctr_gen.CategoricalCrossentropyOptionsT: 'categorical crossentorpy',
ctr_gen.MeanSquaredErrorOptionsT: 'mean squared error'
}

def __call__(cls, lossfn):
try:
name = cls.names[type(lossfn)]
except:
print(f"unknown lossfn {type(lossfn)}")
return name


class LossReductionNamer:
'''Return name(string) based on ModelTraining.LossReductionType '''
names = {
ctr_gen.LossReductionType.SumOverBatchSize: 'SumOverBatchSize',
ctr_gen.LossReductionType.Sum: 'Sum',
}

def __call__(cls, rdt):
try:
name = cls.names[rdt]
except:
print(f"unknown loss reduction type {rdt}")
return name
40 changes: 40 additions & 0 deletions tools/circle_plus_gen/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
import typing

from lib.circle_plus import CirclePlus
from lib.train_param import TrainParam


def get_cmd_args():
parser = argparse.ArgumentParser(
prog='circle_plus_gen',
description='circle_plus_gen help handle circle file with training hyperparameters'
)

parser.add_argument('input', help='input circle file')

args = parser.parse_args()
return args


def print_training_hparameters(in_circle_file) -> typing.NoReturn:
'''
if in_circle_file has training hyperparameters, print it out
'''
print(f"check hyperparameters in {in_circle_file}")

circle_model: CirclePlus = CirclePlus.from_file(in_circle_file)
tinfo: typing.Union[TrainParam, None] = circle_model.get_train_param()

if tinfo == None:
print("No hyperparameters")
else:
print(tinfo.dump_as_json())


if __name__ == "__main__":
args = get_cmd_args()

print_training_hparameters(args.input)

# TODO: add a function that injects training parameter into circle file
1 change: 1 addition & 0 deletions tools/circle_plus_gen/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
flatbuffers==24.3.25
5 changes: 5 additions & 0 deletions tools/circle_plus_gen/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# In this directory, *_generated.py is auto generated by flatc.
#
# * flatc version 23.5.26 is used
# ./flatc --python --gen-onefile --gen-object-api ../../../nnpackage/schema/circle_schema.fbs
# ./flatc --python --gen-onefile --one-object-api ../../../runtime/libs/circle-schema/include/circle_traininfo.fbs
Loading

0 comments on commit 87506c7

Please sign in to comment.