Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
16c3a2a
commit fa89fb5
Showing
83 changed files
with
23,426 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Results to be ignored | ||
train_results/ | ||
saved_models/ | ||
logs/ | ||
|
||
# Mac gitignore | ||
**/.DS_Store | ||
**/changelog | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
The GAN Toolkit is dedicated to providing a harassment-free experience for everyone, regardless of gender, gender identity and expression, sexual orientation, disability, physical appearance, body size, age, race, or religion. We do not tolerate harassment of participants in any form. | ||
|
||
This code of conduct applies to all GAN Toolkit spaces, both online and off. Anyone who violates this code of conduct may be sanctioned or expelled from these spaces at the discretion of the IBM Research AI team. | ||
|
||
We may add additional rules over time, which will be made clearly available to participants. Participants are responsible for knowing and abiding by these rules. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Contributing | ||
|
||
This is an open source project, and we appreciate your help! | ||
|
||
We use the GitHub issue tracker to discuss new features and non-trivial bugs. | ||
|
||
In addition to the issue tracker, [#Discussion on | ||
Slack](https://ibm-gan-toolkit.slack.com) is the best way to get into contact with the | ||
project's maintainers. | ||
|
||
To contribute code, documentation, or tests, please submit a pull request to | ||
the GitHub repository. Generally, we expect two maintainers to review your pull | ||
request before it is approved for merging. For more details, see the | ||
[MAINTAINERS](MAINTAINERS.md) page. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Maintainers Guide | ||
|
||
This guide is intended for maintainers - anybody with commit access to one or | ||
more Code Pattern repositories. | ||
|
||
## Methodology | ||
|
||
This repository does not have a traditional release management cycle, but | ||
should instead be maintained as a useful, working, and polished reference at | ||
all times. While all work can therefore be focused on the master branch, the | ||
quality of this branch should never be compromised. | ||
|
||
The remainder of this document details how to merge pull requests to the | ||
repositories. | ||
|
||
## Merge approval | ||
|
||
The project maintainers use LGTM (Looks Good To Me) in comments on the pull | ||
request to indicate acceptance prior to merging. A change requires LGTMs from | ||
two project maintainers. If the code is written by a maintainer, the change | ||
only requires one additional LGTM. | ||
|
||
## Reviewing Pull Requests | ||
|
||
We recommend reviewing pull requests directly within GitHub. This allows a | ||
public commentary on changes, providing transparency for all users. When | ||
providing feedback be civil, courteous, and kind. Disagreement is fine, so long | ||
as the discourse is carried out politely. If we see a record of uncivil or | ||
abusive comments, we will revoke your commit privileges and invite you to leave | ||
the project. | ||
|
||
During your review, consider the following points: | ||
|
||
### Does the change have positive impact? | ||
|
||
Some proposed changes may not represent a positive impact to the project. Ask | ||
whether or not the change will make understanding the code easier, or if it | ||
could simply be a personal preference on the part of the author (see | ||
[bikeshedding](https://en.wiktionary.org/wiki/bikeshedding)). | ||
|
||
Pull requests that do not have a clear positive impact should be closed without | ||
merging. | ||
|
||
### Do the changes make sense? | ||
|
||
If you do not understand what the changes are or what they accomplish, ask the | ||
author for clarification. Ask the author to add comments and/or clarify test | ||
case names to make the intentions clear. | ||
|
||
At times, such clarification will reveal that the author may not be using the | ||
code correctly, or is unaware of features that accommodate their needs. If you | ||
feel this is the case, work up a code sample that would address the pull | ||
request for them, and feel free to close the pull request once they confirm. | ||
|
||
### Does the change introduce a new feature? | ||
|
||
For any given pull request, ask yourself "is this a new feature?" If so, does | ||
the pull request (or associated issue) contain narrative indicating the need | ||
for the feature? If not, ask them to provide that information. | ||
|
||
Are new unit tests in place that test all new behaviors introduced? If not, do | ||
not merge the feature until they are! Is documentation in place for the new | ||
feature? (See the documentation guidelines). If not do not merge the feature | ||
until it is! Is the feature necessary for general use cases? Try and keep the | ||
scope of any given component narrow. If a proposed feature does not fit that | ||
scope, recommend to the user that they maintain the feature on their own, and | ||
close the request. You may also recommend that they see if the feature gains | ||
traction among other users, and suggest they re-submit when they can show such | ||
support. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
web: python agant/main.py --config agant/configs/gan_gan.json |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import argparse | ||
import json | ||
|
||
def argument_parser(): | ||
"""Argument Parser Fucntion. | ||
Parameters | ||
---------- | ||
Returns | ||
------- | ||
conf_data: dict | ||
Dictionary containing all parameters and objects. | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--configuration', type=str, help='Model configuration file') | ||
parser.add_argument('--result_path', type=str, help='Path to save results at') | ||
parser.add_argument('--save_model_path', type=str, help='Path to save model at') | ||
parser.add_argument('--cuda', type=bool, help='Choice for selecting use of GPU') | ||
parser.add_argument('--performance_log', type=str, help='Path to file to record the model log while training and evauation') | ||
parser.add_argument('--epochs', type=int, help='Number of training epochs') | ||
parser.add_argument('--batch_size', type=int, help='Size of batch') | ||
parser.add_argument('--clip_value', type=int, help='Value for gragient clipping') | ||
parser.add_argument('--critic', type=int, help='Value of critic') | ||
parser.add_argument('--lambda_gp', type=int, help='Value of lambda for gradient penalty') | ||
parser.add_argument('--w_loss', type=int, help='Choosing using Wasserstein loss training process') | ||
parser.add_argument('--data_label', type=int, help='Whether data has lable values or not') | ||
parser.add_argument('--classes', type=int, help='Number of classes in data. Determine the use CGAN training process') | ||
parser.add_argument('--data_path', type=str, help='Path to data') | ||
parser.add_argument('--metric_evaluate', type=str, help='Choice of evalution metric') | ||
parser.add_argument('--sample_interval', type=int, help='Interval to sample during training') | ||
|
||
parser.add_argument('--g_choice', type=str, help='Choice of Generator') | ||
parser.add_argument('--g_pre_trained_path', type=str, help='Path to the pre-trained generator network') | ||
parser.add_argument('--g_input_shape', type=int, help='Input shape to the generator network') | ||
parser.add_argument('--g_latent_dim', type=int, help='Size of the noise vector') | ||
parser.add_argument('--g_channels', type=int, help='Number of channels in input image') | ||
parser.add_argument('--g_optimizer', type=str, help='Choice of optimizer for generator') | ||
parser.add_argument('--g_learning_rate', type=float, help='Value of learning rate for generator') | ||
parser.add_argument('--g_b1', type=float, help='Value of b1 for generator') | ||
parser.add_argument('--g_b2', type=float, help='Value of b2 for generator') | ||
parser.add_argument('--g_loss', type=str, help='Choice of loss for generator') | ||
|
||
parser.add_argument('--d_choice', type=str, help='Choice of Discriminator') | ||
parser.add_argument('--d_pre_trained_path', type=str, help='Path to the pre-trained discriminator network') | ||
parser.add_argument('--d_input_shape', type=int, help='Input shape to the discriminator network') | ||
parser.add_argument('--d_channels', type=int, help='Number of channels in input image') | ||
parser.add_argument('--d_optimizer', type=str, help='Choice of optimizer for discriminator') | ||
parser.add_argument('--d_learning_rate', type=float, help='Value of learning rate for discriminator') | ||
parser.add_argument('--d_b1', type=float, help='Value of b1 for discriminator') | ||
parser.add_argument('--d_b2', type=float, help='Value of b2 for discriminator') | ||
parser.add_argument('--d_loss', type=str, help='Choice of loss for discriminator') | ||
opt = parser.parse_args() | ||
|
||
if opt.configuration == None: | ||
conf_data = {} | ||
conf_data['GAN_model'] = {} | ||
conf_data['generator'] = {} | ||
conf_data['generator']['optimizer'] = {} | ||
conf_data['discriminator'] = {} | ||
conf_data['discriminator']['optimizer'] = {} | ||
|
||
conf_data['sample_interval'] = opt.sample_interval | ||
conf_data['result_path'] = opt.result_path | ||
conf_data['save_model_path'] = opt.save_model_path | ||
conf_data['cuda'] = opt.cuda | ||
conf_data['performance_log'] = opt.performance_log | ||
conf_data['data_path'] = opt.data_path | ||
conf_data['metric_evaluate'] = opt.metric_evaluate | ||
|
||
conf_data['GAN_model']['epochs'] = opt.epochs | ||
conf_data['GAN_model']['mini_batch_size'] = opt.batch_size | ||
conf_data['GAN_model']['clip_value'] = opt.clip_value | ||
conf_data['GAN_model']['n_critic'] = opt.critic | ||
conf_data['GAN_model']['lambda_gp'] = opt.lambda_gp | ||
conf_data['GAN_model']['w_loss'] = opt.w_loss | ||
conf_data['GAN_model']['data_label'] = opt.data_label | ||
conf_data['GAN_model']['classes'] = opt.classes | ||
|
||
conf_data['generator']['choice'] = opt.g_choice | ||
conf_data['generator']['pre_trained_path'] = opt.g_pre_trained_path | ||
conf_data['generator']['input_shape'] = opt.g_input_shape | ||
conf_data['generator']['latent_dim'] = opt.g_latent_dim | ||
conf_data['generator']['channels'] = opt.g_channels | ||
conf_data['generator']['optimizer']['choice'] = opt.g_optimizer | ||
conf_data['generator']['optimizer']['learning_rate'] = opt.g_learning_rate | ||
conf_data['generator']['optimizer']['b1'] = opt.g_b1 | ||
conf_data['generator']['optimizer']['b2'] = opt.g_b2 | ||
conf_data['generator']['loss'] = opt.g_loss | ||
|
||
conf_data['discriminator']['choice'] = opt.g_choice | ||
conf_data['discriminator']['pre_trained_path'] = opt.g_pre_trained_path | ||
conf_data['discriminator']['input_shape'] = opt.g_input_shape | ||
conf_data['discriminator']['channels'] = opt.g_channels | ||
conf_data['discriminator']['optimizer']['choice'] = opt.g_optimizer | ||
conf_data['discriminator']['optimizer']['learning_rate'] = opt.g_learning_rate | ||
conf_data['discriminator']['optimizer']['b1'] = opt.g_b1 | ||
conf_data['discriminator']['optimizer']['b2'] = opt.g_b2 | ||
conf_data['discriminator']['loss'] = opt.g_loss | ||
else: | ||
config_file = opt.configuration | ||
with open(config_file) as json_data_file: | ||
conf_data = json.load(json_data_file) | ||
return conf_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
from torch.autograd import Variable | ||
import torch.autograd as autograd | ||
|
||
def compute_gradient_penalty(conf_data): | ||
"""Calculate gradient penalty for WGAN-GP. | ||
Parameters | ||
---------- | ||
conf_data: dict | ||
Dictionary containing all parameters and objects. | ||
Returns | ||
------- | ||
conf_data: dict | ||
Dictionary containing all parameters and objects. | ||
""" | ||
D = conf_data['discriminator_model'] | ||
real_samples = conf_data['real_data_sample'] | ||
fake_samples = conf_data['fake_data_sample'] | ||
# Random weight term for interpolation between real and fake samples | ||
alpha = conf_data['Tensor'](np.random.random((real_samples.size(0), 1, 1, 1))) | ||
|
||
# Get random interpolation between real and fake samples | ||
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) | ||
d_interpolates = D(interpolates) | ||
fake = Variable(conf_data['Tensor'](real_samples.shape[0], 1).fill_(1.0), requires_grad=False) | ||
# Get gradient w.r.t. interpolates | ||
gradients = autograd.grad(outputs=d_interpolates, inputs=interpolates, | ||
grad_outputs=fake, create_graph=True, retain_graph=True, | ||
only_inputs=True)[0] | ||
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() | ||
|
||
conf_data['gradient_penalty'] = gradient_penalty | ||
return conf_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"result_path":"train_results/Comb/DCGAN/DCGAN", | ||
"save_model_path":"saved_models/Comb/DCGAN/DCGAN", | ||
"performance_log":"logs/Comb/DCGAN/DCGAN", | ||
"GAN_model":{ | ||
"epochs":"50" | ||
}, | ||
|
||
"generator":{ | ||
"choice":"dcgan" | ||
}, | ||
|
||
"discriminator":{ | ||
"choice":"dcgan" | ||
}, | ||
|
||
"data_path":"datasets/fmnist.p", | ||
"metric_evaluate":"MMD" | ||
} |
Oops, something went wrong.