diff --git a/MANIFEST.in b/MANIFEST.in index cd797d8..56e7a5b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include LICENSE.txt -include README.md \ No newline at end of file +include README.md +recursive-include bootstrap/templates *py \ No newline at end of file diff --git a/bootstrap/new.py b/bootstrap/new.py new file mode 100644 index 0000000..ba37ea4 --- /dev/null +++ b/bootstrap/new.py @@ -0,0 +1,91 @@ +import os +from pathlib import Path +from argparse import ArgumentParser + + +file_dir = Path(__file__).parent + + +parser = ArgumentParser() +parser.add_argument("--project_name", type=str, help="Project name") + + +def get_template_file(filename, project_name): + parts = list(filename.parts) + project_index = parts.index(project_name.lower()) + if parts[-1] not in ["__init__.py", "factory.py"]: + parts[-1] = parts[-1].replace("my", "") + template_path = "/".join(parts[project_index + 1:]) + template_path = file_dir / Path("templates/default/project") / template_path + + return template_path + + +def get_file_content(filename, project_name): + template = get_template_file(filename, project_name) + + content = Path(template).read_text() + content = content.replace("{PROJECT_NAME}", project_name) + content = content.replace("{PROJECT_NAME_LOWER}", project_name.lower()) + content = content.replace("{PROJECT_NAME_UPPER}", project_name.upper()) + + return content + + +def write_files(files, project_name): + for f in files: + content = get_file_content(f, project_name) + f.write_text(content) + + +def get_files(directory): + dir_name = directory.stem + if dir_name == "options": + return [directory / "abstract.yaml"] + + to_ret = [] + if dir_name != "models": + to_ret.append(directory / "__init__.py") + + to_ret.append(directory / "factory.py") + custom_file = f"my{dir_name[:-1]}.py" + to_ret.append(directory / custom_file) + + return to_ret + + +if __name__ == "__main__": + args = parser.parse_args() + project_name = args.project_name + + path = Path(f"{project_name.lower()}.bootstrap.pytorch") + path.mkdir() + + print(f"Creating logs directory") + os.mkdir(path / "logs") + + print(f"Creating project directory and __init__.py file") + path = Path(f"{project_name.lower()}.bootstrap.pytorch/{project_name.lower()}") + path.mkdir() + Path(path / "__init__.py").touch() + + print("Creating models directory and __init__ file") + Path(path / "models").mkdir() + Path(path / "models/__init__.py").touch() + + directories = [ + "datasets", + "models/networks", + "models/criterions", + "models/metrics", + ] + + for directory in directories: + print(f"Creating {directory} folder and associated files") + new_dir = path / directory + if directory != "models": + new_dir.mkdir() + files = get_files(new_dir) + write_files(files, project_name) + + print("Project is ready !") diff --git a/bootstrap/templates/default/project/datasets/__init__.py b/bootstrap/templates/default/project/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bootstrap/templates/default/project/datasets/dataset.py b/bootstrap/templates/default/project/datasets/dataset.py new file mode 100644 index 0000000..3a3d641 --- /dev/null +++ b/bootstrap/templates/default/project/datasets/dataset.py @@ -0,0 +1,23 @@ +from bootstrap.datasets.dataset import Dataset + + +class {PROJECT_NAME}Dataset(Dataset): + """ Dataset of Wikipedia Comparable Article + + Parameters + ----------- + """ + def __init__(self, + dir_data, + split='train', + batch_size=4, + shuffle=False, + pin_memory=False, + nb_threads=4): + super({PROJECT_NAME}Dataset, self).__init__(dir_data, split, batch_size, shuffle, pin_memory, nb_threads) + + def __len__(self): + raise NotImplementedError + + def __getitem__(self, i): + raise NotImplementedError diff --git a/bootstrap/templates/default/project/datasets/factory.py b/bootstrap/templates/default/project/datasets/factory.py new file mode 100644 index 0000000..8f86d76 --- /dev/null +++ b/bootstrap/templates/default/project/datasets/factory.py @@ -0,0 +1,42 @@ +from bootstrap.lib.options import Options +from bootstrap.lib.logger import Logger + +from .mydataset import {PROJECT_NAME}Dataset + + +def factory(engine=None): + logger = Logger() + logger('Creating dataset...') + + opt = Options()["dataset"] + + dataset = {} + + if opt.get("train_split", None): + logger("Loading train data") + dataset["train"] = factory_split(opt["train_split"]) + logger(f"Train dataset length is {len(dataset['train'])}") + + if opt.get("eval_split", None): + logger("Loading test data") + dataset["eval"] = factory_split(opt["eval_split"]) + logger(f"Test dataset length is {len(dataset['eval'])}") + + logger("Dataset was created") + return dataset + + +def factory_split(split): + opt = Options()["dataset"] + + shuffle = ("train" in split) + + dataset = {PROJECT_NAME}Dataset( + dir_data=opt["dir"], + split=split, + batch_size=opt["batch_size"], + shuffle=shuffle, + nb_threads=opt["nb_threads"] + ) + + return dataset diff --git a/bootstrap/templates/default/project/models/__init__.py b/bootstrap/templates/default/project/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bootstrap/templates/default/project/models/criterions/__init__.py b/bootstrap/templates/default/project/models/criterions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bootstrap/templates/default/project/models/criterions/criterion.py b/bootstrap/templates/default/project/models/criterions/criterion.py new file mode 100644 index 0000000..e331f35 --- /dev/null +++ b/bootstrap/templates/default/project/models/criterions/criterion.py @@ -0,0 +1,12 @@ +import torch.nn as nn + + +class {PROJECT_NAME}Criterion(nn.Module): + + def __init__(self): + super({PROJECT_NAME}Criterion, self).__init__() + + def forward(self, net_out, batch): + # net_out : output of network + # batch : output of dataset (after collate function) + raise NotImplementedError diff --git a/bootstrap/templates/default/project/models/criterions/factory.py b/bootstrap/templates/default/project/models/criterions/factory.py new file mode 100644 index 0000000..f4560b9 --- /dev/null +++ b/bootstrap/templates/default/project/models/criterions/factory.py @@ -0,0 +1,16 @@ +from bootstrap.lib.options import Options +from bootstrap.lib.logger import Logger + +from .mycriterion import {PROJECT_NAME}Criterion + + +def factory(engine=None, mode=None): + logger = Logger() + logger('Creating criterion for {} mode...'.format(mode)) + + if Options()['model']['criterion'].get('import', False): + criterion = {PROJECT_NAME}Criterion() + else: + raise ValueError() + + return criterion diff --git a/bootstrap/templates/default/project/models/metrics/__init__.py b/bootstrap/templates/default/project/models/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bootstrap/templates/default/project/models/metrics/factory.py b/bootstrap/templates/default/project/models/metrics/factory.py new file mode 100644 index 0000000..6158428 --- /dev/null +++ b/bootstrap/templates/default/project/models/metrics/factory.py @@ -0,0 +1,14 @@ +from bootstrap.lib.options import Options + +from .mymetric import {PROJECT_NAME}Metric + + +def factory(engine=None, mode="train"): + opt = Options()['model.metric'] + + if opt['name'] == '{PROJECT_NAME_LOWER}': + metric = {PROJECT_NAME_LOWER}Metric() + else: + raise ValueError(opt['name']) + + return metric diff --git a/bootstrap/templates/default/project/models/metrics/metric.py b/bootstrap/templates/default/project/models/metrics/metric.py new file mode 100644 index 0000000..b6fd58a --- /dev/null +++ b/bootstrap/templates/default/project/models/metrics/metric.py @@ -0,0 +1,13 @@ +import torch.nn as nn +from bootstrap.lib.logger import Logger + + +class {PROJECT_NAME}Metric(nn.Module): + def __init__(self): + super(Accuracy, self).__init__() + + def forward(self, crit_out, net_out, batch): + # crit_out : output of criterion (dictionnary) + # net_out : output of network + # batch : output of dataset (after collate function) + raise NotImplementedError diff --git a/bootstrap/templates/default/project/models/networks/__init__.py b/bootstrap/templates/default/project/models/networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bootstrap/templates/default/project/models/networks/factory.py b/bootstrap/templates/default/project/models/networks/factory.py new file mode 100644 index 0000000..fa7e463 --- /dev/null +++ b/bootstrap/templates/default/project/models/networks/factory.py @@ -0,0 +1,22 @@ +from bootstrap.lib.options import Options +from bootstrap.lib.logger import Logger +from bootstrap.models.networks.data_parallel import DataParallel + +from .mynetwork import {PROJECT_NAME}Network + + +def factory(engine): + logger = Logger() + net_opt = Options()["model"]["network"] + logger("Creating Network...") + + if net_opt["name"] == "{PROJECT_NAME_LOWER}network": + # You can use any param to create your network + # You just have to write them in your option file from options/ folder + net = {PROJECT_NAME}Network(net_opt["param1"], net_opt["param2"]) + else: + raise ValueError(opt["name"]) + logger("Network was created") + if torch.cuda.device_count() > 1: + net = DataParallel(net) + return net diff --git a/bootstrap/templates/default/project/models/networks/network.py b/bootstrap/templates/default/project/models/networks/network.py new file mode 100644 index 0000000..f10c7b1 --- /dev/null +++ b/bootstrap/templates/default/project/models/networks/network.py @@ -0,0 +1,12 @@ +import torch.nn as nn + + +class {PROJECT_NAME}Network(nn.Module): + def __init__(self, *args, **kwargs): + super(MyNetwork, self).__init__() + # Assign args + + def forward(self, x): + # x is a dictionnary given by Dataset class + pred = self.net(x) + return pred # This is a tensor (or several tensors) \ No newline at end of file diff --git a/setup.py b/setup.py index b5a522b..baec16e 100644 --- a/setup.py +++ b/setup.py @@ -156,6 +156,7 @@ # package_data={ # Optional # 'sample': ['package_data.dat'], # }, + include_package_data=True, # Although 'package_data' is the preferred approach, in some case you may # need to place data files outside of your packages. See: