Skip to content

Commit

Permalink
automatically download dsprite dataset if not exists
Browse files Browse the repository at this point in the history
  • Loading branch information
1Konny committed May 2, 2018
1 parent 699c054 commit 66fcd41
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
10 changes: 7 additions & 3 deletions dataset.py
Expand Up @@ -62,9 +62,13 @@ def return_data(args):
dset = CustomImageFolder

elif name.lower() == 'dsprites':
root = Path(dset_dir).joinpath('dsprites-dataset')
data = np.load(root.joinpath('dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'),
encoding='bytes')
root = Path(dset_dir).joinpath('dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
if not root.exists():
import subprocess
print('Now download dsprites-dataset')
subprocess.call(['./download_dsprites.sh'])
print('Finished')
data = np.load(root, encoding='bytes')
data = torch.from_numpy(data['imgs']).unsqueeze(1).float()
train_kwargs = {'data_tensor':data}
dset = CustomTensorDataset
Expand Down
7 changes: 7 additions & 0 deletions download_dsprites.sh
@@ -0,0 +1,7 @@
#! /bin/sh

mkdir data
cd data
git clone https://github.com/deepmind/dsprites-dataset.git
cd dsprites-dataset
rm -rf .git* *.md LICENSE *.ipynb *.gif *.hdf5

0 comments on commit 66fcd41

Please sign in to comment.