Skip to content

Commit

Permalink
[tools] introduce circle plus generator
Browse files Browse the repository at this point in the history
This introduces a circle plus generator, which helps handle a circle file with training hyperparameters.
As a first feature, Let's check whether a circle file contains training parameters or not.

ONE-DCO-1.0-Signed-off-by: seunghui youn <sseung.youn@samsung.com>
  • Loading branch information
zetwhite committed May 9, 2024
1 parent bcf4468 commit 483bf38
Show file tree
Hide file tree
Showing 9 changed files with 20,947 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tools/circle_plus_gen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Circle+ generator

If a circle file has training hyperparameters, we usually call it a 'circle+' file.<br/>
This tool generates a circle+ file by injecting training hyperparameters into a circle file.<br/>
It also helps handle circle+ file, such as checking whether the circle file contains training hyperparameters. <br/>

## Requirements

1. (optional) Set python virtaul environment.

This tool tested on python3.8.

```
python3 -m venv venv
source /venv/bin/activate
```

2. Install required pakcages.

Currently, only `flatbuffers==24.3.25` is needed.
```bash
pip install -r requirements.txt
```

## Inject training parameters using json file

<!--to be updated -->

## Check if the circle file contains training parameters

You can check whether the circle file contains the training parameters.</br>
If you run the `main.py` without providing a json file, it will check training parameters and display them.

Try this with the files in [example](./example/).
```bash
python3 main.py example/mnist.circle
```
```bash
python3 main.py example/mnist_with_tparam.circle
```
Empty file.
35 changes: 35 additions & 0 deletions tools/circle_plus_gen/lib/circle_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import typing

from schema import circle_schema_generated as cir_gen
from lib.train_info import TrainInfo


class CirclePlus(cir_gen.ModelT):
'''
Wrapper of auto generated circle_schema_generated.ModelT
'''
TINFO_META_TAG = "CIRCLE_TRAINING"

def __init__(self):
super().__init__()

@classmethod
def from_file(cls, circle_file: str):
'''Create CirclePlus based on circle file'''
with open(circle_file, 'rb') as f:
circle = super().InitFromPackedBuf(f.read())
circle.__class__ = CirclePlus
return circle

def get_train_info(self) -> typing.Union[TrainInfo, None]:
'''Return TrainInfo, if it exists in circle'''
if not self.metadata:
return None

for meta in self.metadata:
if meta.name.decode("utf-8") == self.TINFO_META_TAG:
buff: cir_gen.BufferT = self.buffers[meta.buffer]
tinfo: TrainInfo = TrainInfo.from_buff(buff.data)
return tinfo

return None
98 changes: 98 additions & 0 deletions tools/circle_plus_gen/lib/train_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json

from schema import circle_traininfo_generated as ctr_gen
from flatbuffers.builder import Builder
'''
Wrappers of the auto generated classes in circle_traininfo_generated
Wrapper classes provides additional interfaces(e.g. initializer) to the flatbuffer schema based auto-generated classes.
'''

# Optimizers


class SGD(ctr_gen.SGDOptionsT):
name = ['sgd', 'stocasticgradientdescent']


class Adam(ctr_gen.AdamOptionsT):
name = ['adam']


# Loss


class SparseCategoricalCrossEntropy(ctr_gen.SparseCategoricalCrossentropyOptionsT):
name = [
'sparse categorical crossentropy', 'sparsecategoricalcrossentropy', 'sparsecce'
]


class CategoricalCrossEntropy(ctr_gen.CategoricalCrossentropyOptionsT):
name = ['categorical crossentropy', 'categoricalcrossentropy', 'cce']


class MeanSqauaredError(ctr_gen.MeanSquaredErrorOptionsT):
name = ['mean squared error', 'mse']


# TrainInfo


class TrainInfo(ctr_gen.ModelTrainingT):
TRAINING_FILE_IDENTIFIER = b"CTR0"

def __init__(self):
super().__init__()

@classmethod
def from_buff(cls, buff):
'''Create TrainInfo from buffer(byte array)'''
tinfo = super().InitFromPackedBuf(bytearray(buff))
tinfo.__class__ = TrainInfo
return tinfo

def _to_dict(self) -> dict:
'''Convert TrainInfo to dictionary
The dictionary is usually for easy conversion to JSON format, later.
'''
ret = {}

# optimizer
opt_str = {
ctr_gen.Optimizer.SGD: SGD.name[0],
ctr_gen.Optimizer.ADAM: Adam.name[0]
}
ret["optimizer"] = {
"type": opt_str[self.optimizer],
"args": self.optimizerOpt.__dict__
}

# loss
loss_str = {
ctr_gen.LossFn.SPARSE_CATEGORICAL_CROSSENTROPY:
SparseCategoricalCrossEntropy.name[0],
ctr_gen.LossFn.CATEGORICAL_CROSSENTROPY:
CategoricalCrossEntropy.name[0],
ctr_gen.LossFn.MEAN_SQUARED_ERROR:
MeanSqauaredError.name[0]
}
ret["loss"] = {
"type": loss_str[self.lossfn],
"args": self.lossfnOpt.__dict__,
}

# reductions
reduction_str = {
ctr_gen.LossReductionType.SumOverBatchSize: "SumOverBatchSize",
ctr_gen.LossReductionType.Sum: "Sum",
}
ret["loss"]["args"]["reduction"] = reduction_str[self.lossReductionType]

# batchsize
ret["batchSize"] = self.batchSize
return ret

def dump(self) -> str:
dict = self._to_dict()
return json.dumps(dict, indent=4)
41 changes: 41 additions & 0 deletions tools/circle_plus_gen/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import argparse
import logging
import json
import typing

from lib.circle_plus import CirclePlus
from lib.train_info import TrainInfo


def get_cmd_args():
parser = argparse.ArgumentParser(
prog='circle plus generator',
description='help handle circle file with training hyper parameters')

parser.add_argument(
'input_circle_file', metavar="input.circle", type=str, help='input circle file')

args = parser.parse_args()
return args


def check(in_circle_file) -> typing.NoReturn:
'''
Check in_circle_file has training hyperparameters and print it.
'''
circle_model: CirclePlus = CirclePlus.from_file(in_circle_file)
tinfo = circle_model.get_train_info()

print(f"check hyperparameters in {in_circle_file}")
if tinfo == None:
print("No hyperparameters")
else:
print(tinfo.dump())


if __name__ == "__main__":
args = get_cmd_args()

check(args.input_circle_file)

# TODO: add a function that injects training parameter into circle file
1 change: 1 addition & 0 deletions tools/circle_plus_gen/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
flatbuffers==24.3.25
8 changes: 8 additions & 0 deletions tools/circle_plus_gen/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# In this directory, *_generated.py is auto generated by flatc.
#
# * flatc version 23.5.26 is used
# * /ONE/nnpackage/schema/circle_schema.fbs is used
# * /ONE/runtime/libs/circle-schema/include/circle_traininfo.fbs is used
#
# ./flatc --python --gen-onefile --gen-object-api circle_schema.fbs
# ./flatc --python --gen-onefile --one-object-api cricle_traininfo.fbs

0 comments on commit 483bf38

Please sign in to comment.