Skip to content

Commit

Permalink
Format the code.
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackHC committed Jun 21, 2019
1 parent 852efa0 commit 26e216c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 15 deletions.
16 changes: 3 additions & 13 deletions src/dataset_enum.py
Expand Up @@ -123,11 +123,7 @@ def apply_noise(idx, sample):

@property
def num_classes(self):
if self in (
DatasetEnum.mnist,
DatasetEnum.repeated_mnist_w_noise,
DatasetEnum.mnist_w_noise,
):
if self in (DatasetEnum.mnist, DatasetEnum.repeated_mnist_w_noise, DatasetEnum.mnist_w_noise):
return 10
elif self in (DatasetEnum.emnist, DatasetEnum.emnist_bymerge):
return 47
Expand All @@ -136,11 +132,7 @@ def num_classes(self):

def create_bayesian_model(self, device):
num_classes = self.num_classes
if self in (
DatasetEnum.mnist,
DatasetEnum.repeated_mnist_w_noise,
DatasetEnum.mnist_w_noise,
):
if self in (DatasetEnum.mnist, DatasetEnum.repeated_mnist_w_noise, DatasetEnum.mnist_w_noise):
return mnist_model.BayesianNet(num_classes=num_classes).to(device)
elif self in (DatasetEnum.emnist, DatasetEnum.emnist_bymerge):
return emnist_model.BayesianNet(num_classes=num_classes).to(device)
Expand Down Expand Up @@ -379,9 +371,7 @@ def get_targets(dataset):
if isinstance(dataset, data.ConcatDataset):
return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets])

if isinstance(
dataset, (datasets.MNIST,)
):
if isinstance(dataset, (datasets.MNIST,)):
return dataset.targets

raise NotImplementedError(f"Unknown dataset {dataset}!")
3 changes: 1 addition & 2 deletions src/run_experiment.py
Expand Up @@ -126,8 +126,7 @@ def create_experiment_config_argparser(parser):

def main():
parser = argparse.ArgumentParser(
description="BatchBALD",
formatter_class=functools.partial(argparse.ArgumentDefaultsHelpFormatter, width=120),
description="BatchBALD", formatter_class=functools.partial(argparse.ArgumentDefaultsHelpFormatter, width=120)
)
parser.add_argument("--experiment_task_id", type=str, default=None, help="experiment id")
parser.add_argument(
Expand Down

0 comments on commit 26e216c

Please sign in to comment.