Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include LICENSE.txt
include README.md
include README.md
recursive-include bootstrap/templates *py
91 changes: 91 additions & 0 deletions bootstrap/new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
from pathlib import Path
from argparse import ArgumentParser


file_dir = Path(__file__).parent


parser = ArgumentParser()
parser.add_argument("--project_name", type=str, help="Project name")


def get_template_file(filename, project_name):
parts = list(filename.parts)
project_index = parts.index(project_name.lower())
if parts[-1] not in ["__init__.py", "factory.py"]:
parts[-1] = parts[-1].replace("my", "")
template_path = "/".join(parts[project_index + 1:])
template_path = file_dir / Path("templates/default/project") / template_path

return template_path


def get_file_content(filename, project_name):
template = get_template_file(filename, project_name)

content = Path(template).read_text()
content = content.replace("{PROJECT_NAME}", project_name)
content = content.replace("{PROJECT_NAME_LOWER}", project_name.lower())
content = content.replace("{PROJECT_NAME_UPPER}", project_name.upper())

return content


def write_files(files, project_name):
for f in files:
content = get_file_content(f, project_name)
f.write_text(content)


def get_files(directory):
dir_name = directory.stem
if dir_name == "options":
return [directory / "abstract.yaml"]

to_ret = []
if dir_name != "models":
to_ret.append(directory / "__init__.py")

to_ret.append(directory / "factory.py")
custom_file = f"my{dir_name[:-1]}.py"
to_ret.append(directory / custom_file)

return to_ret


if __name__ == "__main__":
args = parser.parse_args()
project_name = args.project_name

path = Path(f"{project_name.lower()}.bootstrap.pytorch")
path.mkdir()

print(f"Creating logs directory")
os.mkdir(path / "logs")

print(f"Creating project directory and __init__.py file")
path = Path(f"{project_name.lower()}.bootstrap.pytorch/{project_name.lower()}")
path.mkdir()
Path(path / "__init__.py").touch()

print("Creating models directory and __init__ file")
Path(path / "models").mkdir()
Path(path / "models/__init__.py").touch()

directories = [
"datasets",
"models/networks",
"models/criterions",
"models/metrics",
]

for directory in directories:
print(f"Creating {directory} folder and associated files")
new_dir = path / directory
if directory != "models":
new_dir.mkdir()
files = get_files(new_dir)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As specified here, we would like to follow this path pattern bootstrap/templates/default/project/datasets/dataset.py instead of your current bootstrap/template/datasets/template_dataset.py.

write_files(files, project_name)

print("Project is ready !")
Empty file.
23 changes: 23 additions & 0 deletions bootstrap/templates/default/project/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from bootstrap.datasets.dataset import Dataset


class {PROJECT_NAME}Dataset(Dataset):
""" Dataset of Wikipedia Comparable Article

Parameters
-----------
"""
def __init__(self,
dir_data,
split='train',
batch_size=4,
shuffle=False,
pin_memory=False,
nb_threads=4):
super({PROJECT_NAME}Dataset, self).__init__(dir_data, split, batch_size, shuffle, pin_memory, nb_threads)

def __len__(self):
raise NotImplementedError

def __getitem__(self, i):
raise NotImplementedError
42 changes: 42 additions & 0 deletions bootstrap/templates/default/project/datasets/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from bootstrap.lib.options import Options
from bootstrap.lib.logger import Logger

from .mydataset import {PROJECT_NAME}Dataset


def factory(engine=None):
logger = Logger()
logger('Creating dataset...')

opt = Options()["dataset"]

dataset = {}

if opt.get("train_split", None):
logger("Loading train data")
dataset["train"] = factory_split(opt["train_split"])
logger(f"Train dataset length is {len(dataset['train'])}")

if opt.get("eval_split", None):
logger("Loading test data")
dataset["eval"] = factory_split(opt["eval_split"])
logger(f"Test dataset length is {len(dataset['eval'])}")

logger("Dataset was created")
return dataset


def factory_split(split):
opt = Options()["dataset"]

shuffle = ("train" in split)

dataset = {PROJECT_NAME}Dataset(
dir_data=opt["dir"],
split=split,
batch_size=opt["batch_size"],
shuffle=shuffle,
nb_threads=opt["nb_threads"]
)

return dataset
Empty file.
Empty file.
12 changes: 12 additions & 0 deletions bootstrap/templates/default/project/models/criterions/criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch.nn as nn


class {PROJECT_NAME}Criterion(nn.Module):

def __init__(self):
super({PROJECT_NAME}Criterion, self).__init__()

def forward(self, net_out, batch):
# net_out : output of network
# batch : output of dataset (after collate function)
raise NotImplementedError
16 changes: 16 additions & 0 deletions bootstrap/templates/default/project/models/criterions/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from bootstrap.lib.options import Options
from bootstrap.lib.logger import Logger

from .mycriterion import {PROJECT_NAME}Criterion


def factory(engine=None, mode=None):
logger = Logger()
logger('Creating criterion for {} mode...'.format(mode))

if Options()['model']['criterion'].get('import', False):
criterion = {PROJECT_NAME}Criterion()
else:
raise ValueError()

return criterion
Empty file.
14 changes: 14 additions & 0 deletions bootstrap/templates/default/project/models/metrics/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from bootstrap.lib.options import Options

from .mymetric import {PROJECT_NAME}Metric


def factory(engine=None, mode="train"):
opt = Options()['model.metric']

if opt['name'] == '{PROJECT_NAME_LOWER}':
metric = {PROJECT_NAME_LOWER}Metric()
else:
raise ValueError(opt['name'])

return metric
13 changes: 13 additions & 0 deletions bootstrap/templates/default/project/models/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch.nn as nn
from bootstrap.lib.logger import Logger


class {PROJECT_NAME}Metric(nn.Module):
def __init__(self):
super(Accuracy, self).__init__()

def forward(self, crit_out, net_out, batch):
# crit_out : output of criterion (dictionnary)
# net_out : output of network
# batch : output of dataset (after collate function)
raise NotImplementedError
Empty file.
22 changes: 22 additions & 0 deletions bootstrap/templates/default/project/models/networks/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from bootstrap.lib.options import Options
from bootstrap.lib.logger import Logger
from bootstrap.models.networks.data_parallel import DataParallel

from .mynetwork import {PROJECT_NAME}Network


def factory(engine):
logger = Logger()
net_opt = Options()["model"]["network"]
logger("Creating Network...")

if net_opt["name"] == "{PROJECT_NAME_LOWER}network":
# You can use any param to create your network
# You just have to write them in your option file from options/ folder
net = {PROJECT_NAME}Network(net_opt["param1"], net_opt["param2"])
else:
raise ValueError(opt["name"])
logger("Network was created")
if torch.cuda.device_count() > 1:
net = DataParallel(net)
return net
12 changes: 12 additions & 0 deletions bootstrap/templates/default/project/models/networks/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch.nn as nn


class {PROJECT_NAME}Network(nn.Module):
def __init__(self, *args, **kwargs):
super(MyNetwork, self).__init__()
# Assign args

def forward(self, x):
# x is a dictionnary given by Dataset class
pred = self.net(x)
return pred # This is a tensor (or several tensors)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
# package_data={ # Optional
# 'sample': ['package_data.dat'],
# },
include_package_data=True,

# Although 'package_data' is the preferred approach, in some case you may
# need to place data files outside of your packages. See:
Expand Down