Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include LICENSE.txt
include README.md
include README.md
recursive-include bootstrap/templates *py
57 changes: 57 additions & 0 deletions bootstrap/new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pathlib import Path
from argparse import ArgumentParser


def replace_content(file_path, prj_name):
content = file_path.read_text()
content = content.replace('{PROJECT_NAME}', prj_name)
content = content.replace('{PROJECT_NAME_LOWER}', prj_name.lower())
content = content.replace(' # noqa: E999', '')
return content


def new_project(prj_name, prj_dir):
# will be rename into project_name.lower() + suffix
# ex: dataset.py -> myproject.py
files_to_rename = ['dataset.py', 'criterion.py', 'metric.py', 'network.py', 'options.yaml']

# will be rename into project_name.lower()
# ex: project/datasets -> myproject/datasets
dirs_to_rename = ['project']

path = Path(prj_dir)
path = path / f'{prj_name.lower()}.bootstrap.pytorch'
path.mkdir()
tpl_path = Path(__file__).parent / 'templates' / 'default'

print(f'Creating project {prj_name.lower()} in {path}')

# recursive iteration over directories and files
for p in tpl_path.rglob('*'):

# absolute path to local path
# ex: bootstrap.pytorch/templates/default/project -> project
tpl_local_path = p.relative_to(tpl_path)

# replace name of directories
local_path = p.relative_to(tpl_path)
for dir_name in dirs_to_rename:
local_path = Path(str(local_path).replace(dir_name, prj_name.lower()))

if p.is_dir():
Path(path / local_path).mkdir()

if p.is_file():
content = replace_content(tpl_path / tpl_local_path, prj_name)
if p.name in files_to_rename:
local_path = Path(local_path.parent / f'{prj_name.lower()}{p.suffix}')
print(local_path)
Path(path / local_path).write_text(content)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--project_name', type=str, default='MyProject')
parser.add_argument('--project_dir', type=str, default='.')
args = parser.parse_args()
new_project(args.project_name, args.project_dir)
118 changes: 118 additions & 0 deletions bootstrap/templates/default/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Boostrap.pytorch
logs/*
data/*
!.gitkeep
docs/src
!logger.py

# Apple
.DS_Store
._.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
# lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
.static_storage/
.media/
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

*.swp
*.nfs*
29 changes: 29 additions & 0 deletions bootstrap/templates/default/LICENSE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) 2020+, {PROJECT_NAME}
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2 changes: 2 additions & 0 deletions bootstrap/templates/default/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include LICENSE.txt
include README.md
33 changes: 33 additions & 0 deletions bootstrap/templates/default/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# {PROJECT_NAME}

## Install

[Conda](https://docs.conda.io/en/latest/miniconda.html)

```bash
conda create --name {PROJECT_NAME_LOWER} python=3
source activate {PROJECT_NAME_LOWER}

cd $HOME
git clone --recursive https://github.com/{PROJECT_NAME}/{PROJECT_NAME_LOWER}.bootstrap.pytorch.git
cd {PROJECT_NAME_LOWER}.bootstrap.pytorch
pip install -r requirements.txt
```

## Reproducing results

Run experiment:
```bash
python -m bootstrap.run \
-o {PROJECT_NAME_LOWER}/options/{PROJECT_NAME_LOWER}.yaml \
--exp.dir logs/{PROJECT_NAME_LOWER}/1_exp
```

Display training and evaluation figures:
```bash
open logs/{PROJECT_NAME_LOWER}/1_exp/view.html
```

Display table of results:
```bash
python -m bootstrap.compare -o
1 change: 1 addition & 0 deletions bootstrap/templates/default/project/__version__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = '0.0.0'
Empty file.
54 changes: 54 additions & 0 deletions bootstrap/templates/default/project/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch.utils.data as tdata
from bootstrap.datasets import transforms as btf


class {PROJECT_NAME}Dataset(tdata.Dataset): # noqa: E999

def __init__(
self,
dir_data,
split='train',
batch_size=4,
shuffle=False,
pin_memory=False,
nb_threads=4,
*args,
**kwargs):
self.dir_data = dir_data
self.split = split
self.batch_size = batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.nb_threads = nb_threads
self.sampler = None

self.collate_fn = btf.Compose([
btf.ListDictsToDictLists(),
btf.StackTensors()
])

self.nb_items = kwargs['nb_items']
self.data = torch.randn(self.nb_items, 10)
self.target = torch.zeros(self.nb_items)
self.target[:int(self.nb_items / 2)].fill_(1)

def make_batch_loader(self, batch_size=None, shuffle=None):
batch_loader = tdata.DataLoader(
dataset=self,
batch_size=self.batch_size if batch_size is None else batch_size,
shuffle=self.shuffle if shuffle is None else shuffle,
pin_memory=self.pin_memory,
num_workers=self.nb_threads,
collate_fn=self.collate_fn,
sampler=self.sampler)
return batch_loader

def __len__(self):
return self.data.shape[0]

def __getitem__(self, idx):
item = {}
item['data'] = self.data[idx]
item['target'] = self.target[idx]
return item
40 changes: 40 additions & 0 deletions bootstrap/templates/default/project/datasets/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from bootstrap.lib.options import Options
from bootstrap.lib.logger import Logger
from .{PROJECT_NAME_LOWER} import {PROJECT_NAME}Dataset # noqa: E999


def factory(engine=None):
Logger()('Creating dataset...')

opt = Options()['dataset']

dataset = {}

if opt.get('train_split', None):
dataset['train'] = factory_split(opt['train_split'])

if opt.get('eval_split', None):
dataset['eval'] = factory_split(opt['eval_split'])

return dataset


def factory_split(split):
opt = Options()['dataset']

shuffle = ('train' in split)

dict_opt = opt.asdict()
dict_opt.pop('dir', None)
dict_opt.pop('batch_size', None)
dict_opt.pop('nb_threads', None)
dataset = {PROJECT_NAME}Dataset( # noqa: E999
dir_data=opt['dir'],
split=split,
batch_size=opt['batch_size'],
shuffle=shuffle,
nb_threads=opt['nb_threads'],
**dict_opt
)

return dataset
Empty file.
Empty file.
16 changes: 16 additions & 0 deletions bootstrap/templates/default/project/models/criterions/criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
import torch.nn as nn


class {PROJECT_NAME}Criterion(nn.Module): # noqa: E999

def __init__(self, *args, **kwargs):
super({PROJECT_NAME}Criterion, self).__init__() # noqa: E999
self.bce_loss = nn.BCELoss()

def forward(self, net_out, batch):
pred = net_out['pred'].squeeze(1)
target = batch['target']
loss = self.bce_loss(pred, target)
out = {'loss': loss}
return out
13 changes: 13 additions & 0 deletions bootstrap/templates/default/project/models/criterions/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from bootstrap.lib.options import Options
from .{PROJECT_NAME_LOWER} import {PROJECT_NAME}Criterion # noqa: E999


def factory(engine=None, mode=None):
opt = Options()['model.criterion']

if opt['name'] == '{PROJECT_NAME_LOWER}':
criterion = {PROJECT_NAME}Criterion(**opt) # noqa: E999
else:
raise ValueError(opt['name'])

return criterion
Empty file.
13 changes: 13 additions & 0 deletions bootstrap/templates/default/project/models/metrics/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from bootstrap.lib.options import Options
from .{PROJECT_NAME_LOWER} import {PROJECT_NAME}Metric # noqa: E999


def factory(engine=None, mode='train'):
opt = Options()['model.metric']

if opt['name'] == '{PROJECT_NAME_LOWER}':
metric = {PROJECT_NAME}Metric(**opt) # noqa: E999
else:
raise ValueError(opt['name'])

return metric
Loading