Skip to content

Commit

Permalink
- removed load_dataset (moved functionality to utils_pytorch.py)
Browse files Browse the repository at this point in the history
  • Loading branch information
nasimrahaman committed Jul 9, 2019
1 parent 9c86d36 commit ca3eb03
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 5 deletions.
7 changes: 6 additions & 1 deletion evaluate.py
Expand Up @@ -39,6 +39,12 @@
from tensorflow.python.framework.errors_impl import NotFoundError


# Some more redundant code, but this allows us to not import utils_pytorch
def get_dataset_name():
"""Reads the name of the dataset from the environment variable `AICROWD_DATASET_NAME`."""
return os.getenv("AICROWD_DATASET_NAME", "cars3d")


def evaluate_with_gin(model_dir,
output_dir,
overwrite=False,
Expand Down Expand Up @@ -116,7 +122,6 @@ def evaluate(model_dir,
except NotFoundError:
# If we did not train with disentanglement_lib, there is no "previous step",
# so we'll have to rely on the environment variable.
from load_dataset import get_dataset_name
if gin.query_parameter("dataset.name") == "auto":
with gin.unlock_config():
gin.bind_parameter("dataset.name", get_dataset_name())
Expand Down
7 changes: 3 additions & 4 deletions pytorch/train_pytorch.py
Expand Up @@ -22,8 +22,8 @@
will be written out. To learn what tracing entails:
https://pytorch.org/docs/stable/jit.html#torch.jit.trace
You'll find a few more utility functions in utils_pytorch.py (for pytorch related stuff) and in
load_dataset.py (for data logistics).
You'll find a few more utility functions in utils_pytorch.py for pytorch related stuff and
for data logistics.
"""

import argparse
Expand All @@ -33,7 +33,6 @@
from torch.nn import functional as F

import utils_pytorch as pyu
import load_dataset as load

import aicrowd_helpers

Expand All @@ -57,7 +56,7 @@
device = torch.device("cuda" if args.cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = load.get_loader(batch_size=args.batch_size, **kwargs)
train_loader = pyu.get_loader(batch_size=args.batch_size, **kwargs)


class Encoder(nn.Module):
Expand Down
125 changes: 125 additions & 0 deletions utils_pytorch.py
Expand Up @@ -6,6 +6,19 @@
import torch
from torch.jit import trace

# ------ Data Loading ------
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

import os
if 'DISENTANGLEMENT_LIB_DATA' not in os.environ:
os.environ.update({'DISENTANGLEMENT_LIB_DATA': os.path.join(os.path.dirname(__file__),
'scratch',
'dataset')})
# noinspection PyUnresolvedReferences
from disentanglement_lib.data.ground_truth.named_data import get_named_ground_truth_data
# --------------------------


ExperimentConfig = namedtuple('ExperimentConfig',
('base_path', 'experiment_name', 'dataset_name'))
Expand All @@ -22,6 +35,11 @@ def get_config():
dataset_name=os.getenv("AICROWD_DATASET_NAME", "cars3d"))


def get_dataset_name():
"""Reads the name of the dataset from the environment variable `AICROWD_DATASET_NAME`."""
return os.getenv("AICROWD_DATASET_NAME", "cars3d")


def use_cuda():
"""
Whether to use CUDA for evaluation. Returns True if CUDA is available and
Expand Down Expand Up @@ -155,3 +173,110 @@ def _represent(x):
return y

return _represent


class DLIBDataset(Dataset):
"""
No-bullshit data-loading from Disentanglement Library, but with a few sharp edges.
Sharp edge:
Unlike a traditional Pytorch dataset, indexing with _any_ index fetches a random batch.
What this means is dataset[0] != dataset[0]. Also, you'll need to specify the size
of the dataset, which defines the length of one training epoch.
This is done to ensure compatibility with disentanglement_lib.
"""

def __init__(self, name, seed=0, iterator_len=50000):
"""
Parameters
----------
name : str
Name of the dataset use. You may use `get_dataset_name`.
seed : int
Random seed.
iterator_len : int
Length of the dataset. This defines the length of one training epoch.
"""
self.name = name
self.seed = seed
self.random_state = np.random.RandomState(seed)
self.iterator_len = iterator_len
self.dataset = self.load_dataset()

def load_dataset(self):
return get_named_ground_truth_data(self.name)

def __len__(self):
return self.iterator_len

def __getitem__(self, item):
assert item < self.iterator_len
output = self.dataset.sample_observations(1, random_state=self.random_state)[0]
# Convert output to CHW from HWC
return torch.from_numpy(np.moveaxis(output, 2, 0))


def get_dataset(name=None, seed=0, iterator_len=50000):
"""
Makes a dataset.
Parameters
----------
name : str
Name of the dataset use. Defaults to the output of `get_dataset_name`.
seed : int
Random seed.
iterator_len : int
Length of the dataset. This defines the length of one training epoch.
Returns
-------
DLIBDataset
"""
name = get_dataset_name() if name is None else name
return DLIBDataset(name, seed=seed, iterator_len=iterator_len)


def get_loader(name=None, batch_size=32, seed=0, iterator_len=50000, num_workers=0,
**dataloader_kwargs):
"""
Makes a dataset and a data-loader.
Parameters
----------
name : str
Name of the dataset use. Defaults to the output of `get_dataset_name`.
batch_size : int
Batch size.
seed : int
Random seed.
iterator_len : int
Length of the dataset. This defines the length of one training epoch.
num_workers : int
Number of processes to use for multiprocessed data-loading.
dataloader_kwargs : dict
Keyword arguments for the data-loader.
Returns
-------
DataLoader
"""
name = get_dataset_name() if name is None else name
dlib_dataset = DLIBDataset(name, seed=seed, iterator_len=iterator_len)
loader = DataLoader(dlib_dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers, **dataloader_kwargs)
return loader


def test_loader():
loader = get_loader(num_workers=2)
for count, b in enumerate(loader):
print(b.shape)
# ^ prints `torch.Size([32, 3, 64, 64])` and means that multiprocessing works
if count > 5:
break
print("Success!")


if __name__ == '__main__':
pass

0 comments on commit ca3eb03

Please sign in to comment.