diff --git a/bnn_models/__init__.py b/bnn_models/__init__.py index 9d89716..ec0c844 100644 --- a/bnn_models/__init__.py +++ b/bnn_models/__init__.py @@ -1,10 +1,11 @@ from copy import deepcopy +from sacred import Ingredient +import torch.nn as nn + from bnn_models.lenet import BLeNet from bnn_models.resnet import BClassifier from datasets import datasets -from sacred import Ingredient -import torch.nn as nn from utils import load_config_from_checkpoint bnn_models = Ingredient("bnn_model", ingredients=(datasets,)) diff --git a/bnn_models/lenet.py b/bnn_models/lenet.py index ee3bcbc..5f7ec9a 100644 --- a/bnn_models/lenet.py +++ b/bnn_models/lenet.py @@ -1,10 +1,11 @@ from typing import Tuple import bnn -from cls_models.base import BaseClassifier import torch import torch.nn as nn +from cls_models.base import BaseClassifier + class BLeNet(BaseClassifier): def __init__( diff --git a/bnn_models/resnet.py b/bnn_models/resnet.py index 0cb3432..cc0bd54 100644 --- a/bnn_models/resnet.py +++ b/bnn_models/resnet.py @@ -1,9 +1,10 @@ from math import ceil, floor, log2 import bnn -from cls_models.base import BaseClassifier import torch import torch.nn as nn + +from cls_models.base import BaseClassifier from utils import init_weights N_FEATUREMAPS = 32 diff --git a/cae_models/__init__.py b/cae_models/__init__.py index 31f34de..76a7368 100644 --- a/cae_models/__init__.py +++ b/cae_models/__init__.py @@ -1,6 +1,7 @@ -from datasets import datasets from sacred import Ingredient import torch.nn as nn + +from datasets import datasets from utils import load_config_from_checkpoint from .medium import MediumCAE diff --git a/cae_models/base.py b/cae_models/base.py index 77847ee..57297a1 100644 --- a/cae_models/base.py +++ b/cae_models/base.py @@ -7,6 +7,7 @@ import torch.nn.functional as tf from torch.optim import Adam from torch.optim.lr_scheduler import MultiStepLR + from utils import save_sample_images diff --git a/cae_models/resnet.py b/cae_models/resnet.py index 17a8539..2f918c6 100644 --- a/cae_models/resnet.py +++ b/cae_models/resnet.py @@ -9,6 +9,7 @@ import torch from torch import Tensor import torch.nn as nn + from utils import init_weights from .base import BaseCAE diff --git a/cls_models/__init__.py b/cls_models/__init__.py index 9029080..6ba7c30 100644 --- a/cls_models/__init__.py +++ b/cls_models/__init__.py @@ -1,6 +1,8 @@ from copy import deepcopy import re +from sacred import Ingredient + from bnn_models.lenet import BLeNet from bnn_models.resnet import BClassifier from cls_models.lenet import LeNet @@ -9,7 +11,6 @@ from cls_models.small import SmallClassifier from cls_models.toy import ToyClassifier from datasets import datasets -from sacred import Ingredient from utils import load_config_from_checkpoint from .base import BaseClassifier diff --git a/cls_models/base.py b/cls_models/base.py index e03d05d..c0effc9 100644 --- a/cls_models/base.py +++ b/cls_models/base.py @@ -8,6 +8,7 @@ from torch import Tensor, optim import torch.nn.functional as tf from torch.utils.data import DataLoader + from utils import entropy from .utils import set_model_to_mode diff --git a/cls_models/resnet.py b/cls_models/resnet.py index d8b62e5..5ffe2d0 100644 --- a/cls_models/resnet.py +++ b/cls_models/resnet.py @@ -1,10 +1,11 @@ from math import floor, log2 -from gan_models.resnet import ResidualBlock import torch.nn as nn from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.conv import _ConvNd +from gan_models.resnet import ResidualBlock + from .base import BaseClassifier N_FEATUREMAPS = 32 diff --git a/cls_models/toy.py b/cls_models/toy.py index 18f4809..a436826 100644 --- a/cls_models/toy.py +++ b/cls_models/toy.py @@ -1,6 +1,7 @@ -from datasets.toy import ToyDataset import torch.nn as nn +from datasets.toy import ToyDataset + from .base import BaseClassifier diff --git a/confident_classifier.py b/confident_classifier.py index 33c0423..9259755 100644 --- a/confident_classifier.py +++ b/confident_classifier.py @@ -2,10 +2,6 @@ import logging from typing import Any, Dict -from cls_models import set_model_to_mode -from cls_models.base import BaseClassifier -from datasets import load_data -from eval_ood_detection import eval_classifier from pytorch_lightning import LightningModule from pytorch_lightning.utilities.types import STEP_OUTPUT import torch @@ -13,6 +9,11 @@ import torch.nn as nn import torch.nn.functional as tf +from cls_models import set_model_to_mode +from cls_models.base import BaseClassifier +from datasets import load_data +from eval_ood_detection import eval_classifier + class ConfidentClassifier(LightningModule): def __init__( diff --git a/cvae_models/__init__.py b/cvae_models/__init__.py index af87026..9984365 100644 --- a/cvae_models/__init__.py +++ b/cvae_models/__init__.py @@ -1,8 +1,9 @@ +from sacred import Ingredient +import torch.nn as nn + from cvae_models.medium import MediumCVAE from cvae_models.small import SmallCVAE from datasets import datasets -from sacred import Ingredient -import torch.nn as nn from utils import load_config_from_checkpoint cvae_models = Ingredient("cvae_model", ingredients=(datasets,)) diff --git a/datasets/__init__.py b/datasets/__init__.py index fcfcf84..a8d16e0 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -3,6 +3,14 @@ import logging from typing import Dict, Optional, Tuple +import numpy as np +from sacred import Ingredient +import torch +from torch.utils.data import Dataset, WeightedRandomSampler +import torchvision.transforms as trans +import torchvision.transforms.functional as ftrans +from typing_extensions import TypedDict + from datasets.celeba import CelebA from datasets.cifar10 import CIFAR10 from datasets.cifar100 import CIFAR100 @@ -18,13 +26,6 @@ from datasets.svhn import SVHN from datasets.tinyimagenet import TinyImageNet from datasets.toy import ToyDataset, ToyDataset2, ToyDataset3, ToyDataset4, ToyDataset5 -import numpy as np -from sacred import Ingredient -import torch -from torch.utils.data import Dataset, WeightedRandomSampler -import torchvision.transforms as trans -import torchvision.transforms.functional as ftrans -from typing_extensions import TypedDict from utils import IncompatibleRange, get_range datasets = Ingredient("dataset") diff --git a/datasets/clawa2.py b/datasets/clawa2.py index b83b180..79eb4ab 100644 --- a/datasets/clawa2.py +++ b/datasets/clawa2.py @@ -1,11 +1,12 @@ import logging from typing import Callable, Optional -from datasets.datahandlers.cl import AwA2 import numpy as np import torch from torch.utils.data import Dataset, random_split +from datasets.datahandlers.cl import AwA2 + EVAL_RATIO = 0.2 diff --git a/datasets/datahandlers/cl/awa2.py b/datasets/datahandlers/cl/awa2.py index ccfc4c1..8fc4a2c 100644 --- a/datasets/datahandlers/cl/awa2.py +++ b/datasets/datahandlers/cl/awa2.py @@ -1,13 +1,14 @@ # PyTorch Dataloader based on 'https://github.com/dfan/awa2-zero-shot-learning' -import numpy as np -import torch +import logging import os -from os.path import exists, join, splitext, isfile +from os.path import exists, isfile, join, splitext +from typing import Callable, Optional, Sequence, Union + from PIL import Image +import numpy as np +import torch from torch.utils.data import Dataset -import logging -from typing import Callable, Optional, Union, Sequence class AwA2(Dataset): diff --git a/datasets/datahandlers/cl/cub.py b/datasets/datahandlers/cl/cub.py index bc9fae3..a829bea 100644 --- a/datasets/datahandlers/cl/cub.py +++ b/datasets/datahandlers/cl/cub.py @@ -26,8 +26,8 @@ def __init__( Args: root: Root path of the dataset. split: Split to use. Valid options: ('train', 'test', 'all') - minimum_attribute_certainty: Minimum certainty of the annotated attribute as provided by the human - annotator. Valid options: (1, 2, 3, 4) + minimum_attribute_certainty: Minimum certainty of the annotated attribute + as provided by the human annotator. Valid options: (1, 2, 3, 4) transform: Image transforms. target_transform: Target transforms. target_type: Target types to return. Valid options: ('attr', 'class') diff --git a/datasets/tinyimagenet.py b/datasets/tinyimagenet.py index 07ce1a4..20acc4c 100644 --- a/datasets/tinyimagenet.py +++ b/datasets/tinyimagenet.py @@ -1,10 +1,11 @@ import logging from typing import Callable, Optional -from datasets.datahandlers.cl import TinyImageNet as myTinyImageNet import torch from torch.utils.data import Dataset, random_split +from datasets.datahandlers.cl import TinyImageNet as myTinyImageNet + EVAL_RATIO = 0.2 diff --git a/datasets/toy.py b/datasets/toy.py index fa500a0..2d2c145 100644 --- a/datasets/toy.py +++ b/datasets/toy.py @@ -279,7 +279,7 @@ def __len__(self) -> int: if __name__ == "__main__": - import matplotlib + import matplotlib # noqa: F401 # matplotlib.use("Qt5Agg") import matplotlib.pyplot as plt diff --git a/eval_ood_detection.py b/eval_ood_detection.py index 1591189..5c63cea 100644 --- a/eval_ood_detection.py +++ b/eval_ood_detection.py @@ -3,9 +3,6 @@ import os from os.path import abspath, exists, expanduser, join -from cls_models import cls_models, load_cls_model, set_model_to_mode -from cls_models.base import BaseClassifier -from datasets import datasets, load_data from eval.binary import aupr, auroc, ece, fprxtpr from logging_utils import log_config import numpy as np @@ -15,6 +12,10 @@ from tabulate import tabulate import torch from torch.utils.data import ConcatDataset, DataLoader + +from cls_models import cls_models, load_cls_model, set_model_to_mode +from cls_models.base import BaseClassifier +from datasets import datasets, load_data from utils import extract_exp_id_from_path, format_int_list, get_range, init_experiment ex = Experiment("evaluate OOD detection", ingredients=[datasets, cls_models]) diff --git a/eval_ood_detection_deep_ensembles.py b/eval_ood_detection_deep_ensembles.py index b54c31a..1ad2f68 100644 --- a/eval_ood_detection_deep_ensembles.py +++ b/eval_ood_detection_deep_ensembles.py @@ -1,8 +1,6 @@ import os from os.path import abspath, exists, expanduser, join -from cls_models import cls_models, load_cls_model -from datasets import datasets, load_data from eval.binary import aupr, auroc, ece, fprxtpr from logging_utils import log_config import numpy as np @@ -12,6 +10,9 @@ from tabulate import tabulate import torch from torch.utils.data import ConcatDataset, DataLoader + +from cls_models import cls_models, load_cls_model +from datasets import datasets, load_data from utils import ( entropy, extract_exp_id_from_path, diff --git a/gan_models/__init__.py b/gan_models/__init__.py index ecdef88..8c27e91 100644 --- a/gan_models/__init__.py +++ b/gan_models/__init__.py @@ -1,9 +1,10 @@ from copy import deepcopy from typing import Tuple -from datasets import datasets from sacred import Ingredient import torch.nn as nn + +from datasets import datasets from utils import load_config_from_checkpoint from .dcgan import Discriminator as DCGANDiscriminator diff --git a/gan_models/dcgan.py b/gan_models/dcgan.py index faa6376..1313799 100644 --- a/gan_models/dcgan.py +++ b/gan_models/dcgan.py @@ -2,6 +2,7 @@ import torch from torch.distributions.normal import Normal import torch.nn as nn + from utils import init_weights LATENT_DIM = 128 diff --git a/gan_models/resnet.py b/gan_models/resnet.py index 064de3d..fb55189 100644 --- a/gan_models/resnet.py +++ b/gan_models/resnet.py @@ -9,6 +9,7 @@ import torch from torch.distributions.uniform import Uniform import torch.nn as nn + from utils import init_weights LATENT_DIM = 128 diff --git a/gan_models/small.py b/gan_models/small.py index d9e1fd4..4f769bf 100644 --- a/gan_models/small.py +++ b/gan_models/small.py @@ -2,6 +2,7 @@ import torch from torch.distributions.uniform import Uniform import torch.nn as nn + from utils import init_weights LATENT_DIM = 128 diff --git a/gan_models/toy.py b/gan_models/toy.py index f98f9d2..73d1e4a 100644 --- a/gan_models/toy.py +++ b/gan_models/toy.py @@ -1,8 +1,9 @@ -from datasets.toy import ToyDataset from pytorch_lightning import LightningModule import torch from torch.distributions.uniform import Uniform import torch.nn as nn + +from datasets.toy import ToyDataset from utils import init_gan_weights diff --git a/gen.py b/gen.py index 0fa31a6..e481438 100644 --- a/gen.py +++ b/gen.py @@ -1,9 +1,6 @@ from copy import deepcopy from typing import Any, Dict -from cls_models.base import BaseClassifier, set_model_to_mode -from datasets import load_data -from eval_ood_detection import eval_classifier from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.accelerators.registry import AcceleratorRegistry import torch @@ -11,6 +8,10 @@ import torch.nn as nn import torch.nn.functional as tf +from cls_models.base import BaseClassifier, set_model_to_mode +from datasets import load_data +from eval_ood_detection import eval_classifier + class GEN(LightningModule): def __init__( diff --git a/meta_models/__init__.py b/meta_models/__init__.py index 729838d..8f07346 100644 --- a/meta_models/__init__.py +++ b/meta_models/__init__.py @@ -1,5 +1,6 @@ from sacred import Ingredient import torch.nn as nn + from utils import load_config_from_checkpoint from .fc import MetaClassifier diff --git a/meta_models/fc.py b/meta_models/fc.py index 4599698..036d21d 100644 --- a/meta_models/fc.py +++ b/meta_models/fc.py @@ -1,4 +1,5 @@ import torch.nn as nn + from utils import init_weights diff --git a/train_cae.py b/train_cae.py index a4a5fcf..87cf2e5 100644 --- a/train_cae.py +++ b/train_cae.py @@ -11,10 +11,11 @@ matplotlib.use("Agg") -from cae_models import cae_models, load_cae_model -from datasets import datasets, load_data from logging_utils import log_config from logging_utils.lightning_sacred import SacredLogger + +from cae_models import cae_models, load_cae_model +from datasets import datasets, load_data from utils import TimeEstimator, get_experiment_folder, init_experiment ex = Experiment("train_cae", ingredients=[datasets, cae_models]) diff --git a/train_classifier.py b/train_classifier.py index 305dbef..1363d2f 100644 --- a/train_classifier.py +++ b/train_classifier.py @@ -2,8 +2,6 @@ from os.path import exists import shutil -from cls_models import cls_models, load_cls_model -from datasets import datasets, load_data from logging_utils import log_config from logging_utils.lightning_sacred import SacredLogger from pytorch_lightning import Trainer @@ -11,6 +9,9 @@ from sacred import Experiment import torch from torch.utils.data import DataLoader + +from cls_models import cls_models, load_cls_model +from datasets import datasets, load_data from utils import TimeEstimator, get_experiment_folder, init_experiment ex = Experiment("train_classifier", ingredients=(cls_models, datasets)) diff --git a/train_confident_classifier.py b/train_confident_classifier.py index 1eb87d6..59b70b9 100644 --- a/train_confident_classifier.py +++ b/train_confident_classifier.py @@ -14,12 +14,13 @@ matplotlib.use("Agg") +from logging_utils import log_config +from logging_utils.lightning_sacred import SacredLogger + from cls_models import cls_models, load_cls_model from confident_classifier import ConfidentClassifier from datasets import datasets, load_data from gan_models import gan_models, load_gan_model -from logging_utils import log_config -from logging_utils.lightning_sacred import SacredLogger from uqgan import CustomCheckpointIO from utils import TimeEstimator, get_experiment_folder, init_experiment diff --git a/train_gen.py b/train_gen.py index d218e51..a1445e1 100644 --- a/train_gen.py +++ b/train_gen.py @@ -3,11 +3,6 @@ from os.path import exists import shutil -from cls_models import cls_models, load_cls_model -from cls_models.base import BaseClassifier -from datasets import datasets, load_data -from gan_models import gan_models, load_gan_model -from gen import GEN from logging_utils import log_config from logging_utils.lightning_sacred import SacredLogger from pytorch_lightning import Trainer @@ -17,8 +12,14 @@ from sacred.run import Run import torch from torch.utils.data import DataLoader + +from cls_models import cls_models, load_cls_model +from cls_models.base import BaseClassifier +from datasets import datasets, load_data +from gan_models import gan_models, load_gan_model +from gen import GEN from uqgan import CustomCheckpointIO -from utils import TimeEstimator, init_experiment, get_experiment_folder +from utils import TimeEstimator, get_experiment_folder, init_experiment from vae_models import load_vae_model, vae_models ex = Experiment("train_gen", ingredients=[datasets, gan_models, cls_models, vae_models]) diff --git a/train_meta_aggregator.py b/train_meta_aggregator.py index fbaf34e..f68ac40 100644 --- a/train_meta_aggregator.py +++ b/train_meta_aggregator.py @@ -3,9 +3,6 @@ from os.path import exists, join import shutil -from cls_models import cls_models, load_cls_model, set_model_to_mode -from cls_models.base import BaseClassifier -from datasets import datasets, load_data from logging_utils import log_config import numpy as np from sacred import Experiment @@ -14,6 +11,10 @@ from sklearn.metrics import confusion_matrix import torch from torch.utils.data import ConcatDataset, DataLoader, Subset + +from cls_models import cls_models, load_cls_model, set_model_to_mode +from cls_models.base import BaseClassifier +from datasets import datasets, load_data from utils import entropy, get_experiment_folder, get_range, init_experiment ex = Experiment("train_meta_aggregator", ingredients=(cls_models, datasets)) diff --git a/train_meta_classifier.py b/train_meta_classifier.py index 69469fe..0d36f80 100644 --- a/train_meta_classifier.py +++ b/train_meta_classifier.py @@ -3,8 +3,6 @@ from os.path import exists import shutil -from cls_models import cls_models, load_cls_model -from datasets import datasets, load_data from logging_utils import log_config from logging_utils.lightning_sacred import SacredLogger from pytorch_lightning import Trainer @@ -12,6 +10,9 @@ from sacred import Experiment import torch from torch.utils.data import ConcatDataset, DataLoader, WeightedRandomSampler + +from cls_models import cls_models, load_cls_model +from datasets import datasets, load_data from utils import TimeEstimator, get_experiment_folder, init_experiment ex = Experiment("train_meta_classifier", ingredients=(cls_models, datasets)) diff --git a/train_uqgan.py b/train_uqgan.py index be80029..b9a991c 100644 --- a/train_uqgan.py +++ b/train_uqgan.py @@ -14,15 +14,16 @@ matplotlib.use("Agg") from copy import deepcopy +from logging_utils import log_config +from logging_utils.lightning_sacred import SacredLogger +from pytorch_lightning.callbacks import ModelCheckpoint + from cae_models import cae_models, load_cae_model from cae_models.identity import IdentityCAE from cls_models import cls_models, load_cls_model from cls_models.base import BaseClassifier from datasets import datasets, load_data from gan_models import gan_models, load_gan_model -from logging_utils import log_config -from logging_utils.lightning_sacred import SacredLogger -from pytorch_lightning.callbacks import ModelCheckpoint from uqgan import UQGAN, CustomCheckpointIO from utils import TimeEstimator, get_experiment_folder, init_experiment diff --git a/uqgan.py b/uqgan.py index c518993..df15825 100644 --- a/uqgan.py +++ b/uqgan.py @@ -5,15 +5,15 @@ import re from typing import Any, Callable, Dict, Optional, Union +import matplotlib + from cls_models import set_model_to_mode from cls_models.base import BaseClassifier from datasets import load_data from datasets.toy import ToyDataset, ToyDataset2, ToyDataset3 -import matplotlib matplotlib.use("Agg") -from eval_ood_detection import eval_classifier import matplotlib.colors as colors import matplotlib.pyplot as plt import numpy as np @@ -29,6 +29,8 @@ import torch.autograd as autograd import torch.nn as nn import torch.nn.functional as tf + +from eval_ood_detection import eval_classifier from utils import cosine_loss_classwise, min_norm_loss, p_norm_loss diff --git a/vae_models/__init__.py b/vae_models/__init__.py index 7063292..c03a668 100644 --- a/vae_models/__init__.py +++ b/vae_models/__init__.py @@ -1,6 +1,7 @@ -from datasets import datasets from sacred import Ingredient import torch.nn as nn + +from datasets import datasets from utils import load_config_from_checkpoint from .sensoyetal2020 import AutoEncoder diff --git a/visualization/plot_GAN_samples.py b/visualization/plot_GAN_samples.py index f0d34dc..24bf519 100644 --- a/visualization/plot_GAN_samples.py +++ b/visualization/plot_GAN_samples.py @@ -1,12 +1,13 @@ -from cae_models import cae_models, load_cae_model -from datasets import datasets -from gan_models import gan_models, load_gan_model from logging_utils import log_config from sacred import Experiment import torch import torch.nn.functional as tf from torchvision.utils import make_grid, save_image +from cae_models import cae_models, load_cae_model +from datasets import datasets +from gan_models import gan_models, load_gan_model + ex = Experiment("Plot GAN OOD samples", ingredients=[datasets, gan_models, cae_models]) device = None diff --git a/visualization/plot_toy_example.py b/visualization/plot_toy_example.py index e9e4908..1a8b429 100644 --- a/visualization/plot_toy_example.py +++ b/visualization/plot_toy_example.py @@ -1,3 +1,11 @@ +from logging_utils import log_config +import matplotlib.colors as colors +import matplotlib.pyplot as plt +import numpy as np +from sacred import Experiment +from scipy.stats import gaussian_kde +import torch + from cls_models import ToyClassifier, cls_models, load_cls_model from datasets import ( # noqa ToyDataset2, @@ -7,13 +15,6 @@ datasets, ) from gan_models import ToyDiscriminator, ToyGenerator, gan_models, load_gan_model -from logging_utils import log_config -import matplotlib.colors as colors -import matplotlib.pyplot as plt -import numpy as np -from sacred import Experiment -from scipy.stats import gaussian_kde -import torch from utils import init_experiment ex = Experiment("Plot Toy Example", ingredients=[gan_models, cls_models, datasets])