Skip to content

Commit fabd989

Browse files
committed
tiny imagenet support added
1 parent e505b3c commit fabd989

File tree

6 files changed

+118
-20
lines changed

6 files changed

+118
-20
lines changed

README.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ The codebase currently only supports single-machine single-gpu training. We will
1313

1414
Please see [`GETTING_STARTED`](docs/GETTING_STARTED.md) for brief installation instructions and basic usage examples.
1515

16-
## Model Zoo
17-
18-
We provide a large set of baseline results as proof of repository's efficiency. (coming soon)
19-
2016
## Active Learning Methods Supported
2117
* Uncertainty Sampling
2218
* Least Confidence
@@ -32,11 +28,16 @@ We provide a large set of baseline results as proof of repository's efficiency.
3228

3329

3430
## Datasets Supported
35-
* CIFAR10
36-
* CIFAR100
37-
* MNIST
38-
* SVHN
39-
* TinyImageNet (coming soon)
31+
* [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html)
32+
* [MNIST](http://yann.lecun.com/exdb/mnist/)
33+
* [SVHN](http://ufldl.stanford.edu/housenumbers/)
34+
* [TinyImageNet](https://www.kaggle.com/c/tiny-imagenet) (Download the zip file [here](http://cs231n.stanford.edu/tiny-imagenet-200.zip))
35+
36+
37+
## Model Zoo
38+
39+
We provide a large set of baseline results as proof of repository's efficiency. (coming soon)
40+
4041

4142
## Citing this repository
4243

pycls/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@
217217
# ---------------------------------------------------------------------------- #
218218
_C.DATASET = CN()
219219
_C.DATASET.NAME = None
220+
# For Tiny ImageNet dataset, ROOT_DIR must be set to the dataset folder ("data/tiny-imagenet-200/"). For others, the outder "data" folder where all datasets can be stored is expected.
220221
_C.DATASET.ROOT_DIR = None
221222
# Specifies the proportion of data in train set that should be considered as the validation data
222223
_C.DATASET.VAL_RATIO = 0.1

pycls/datasets/data.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .utils import helpers
1717
import pycls.utils.logging as lu
1818
from pycls.datasets.sampler import IndexedSequentialSampler
19+
from pycls.datasets.tiny_imagenet import TinyImageNet
1920

2021
logger = lu.get_logger(__name__)
2122

@@ -148,7 +149,7 @@ def getDataset(self, save_dir, isTrain=True, isDownload=False):
148149
if self.dataset == "MNIST":
149150
mnist = datasets.MNIST(save_dir, train=isTrain, transform=preprocess_steps, download=isDownload)
150151
return mnist, len(mnist)
151-
152+
152153
elif self.dataset == "CIFAR10":
153154
cifar10 = datasets.CIFAR10(save_dir, train=isTrain, transform=preprocess_steps, download=isDownload)
154155
return cifar10, len(cifar10)
@@ -159,11 +160,18 @@ def getDataset(self, save_dir, isTrain=True, isDownload=False):
159160

160161
elif self.dataset == "SVHN":
161162
if isTrain:
162-
svhn = SVHN(save_dir,split='train', transform=preprocess_steps, download=isDownload)
163+
svhn = datasets.SVHN(save_dir, split='train', transform=preprocess_steps, download=isDownload)
163164
else:
164-
svhn = SVHN(save_dir, split='test', transform=preprocess_steps, download=isDownload)
165+
svhn = datasets.SVHN(save_dir, split='test', transform=preprocess_steps, download=isDownload)
165166
return svhn, len(svhn)
166-
# TinyImageNet Implementation Needed
167+
168+
elif self.dataset == "TINYIMAGENET":
169+
if isTrain:
170+
tiny = TinyImageNet(save_dir, split='train', transform=preprocess_steps)
171+
else:
172+
tiny = TinyImageNet(save_dir, split='val', transform=preprocess_steps)
173+
return tiny, len(tiny)
174+
167175
else:
168176
print("Either the specified {} dataset is not added or there is no if condition in getDataset function of Data class".format(self.dataset))
169177
logger.info("Either the specified {} dataset is not added or there is no if condition in getDataset function of Data class".format(self.dataset))
@@ -200,7 +208,7 @@ def makeLUVSets(self, train_split_ratio, val_split_ratio, data, seed_id, save_di
200208

201209
assert isinstance(train_split_ratio, float),"Train split ratio is of {} datatype instead of float".format(type(train_split_ratio))
202210
assert isinstance(val_split_ratio, float),"Val split ratio is of {} datatype instead of float".format(type(val_split_ratio))
203-
assert self.dataset in ["MNIST","CIFAR10","CIFAR100", "SVHN"], "Sorry the dataset {} is not supported. Currently we support ['MNIST','CIFAR10', 'CIFAR100', 'SVHN']".format(self.dataset)
211+
assert self.dataset in ["MNIST","CIFAR10","CIFAR100", "SVHN", "TINYIMAGENET"], "Sorry the dataset {} is not supported. Currently we support ['MNIST','CIFAR10', 'CIFAR100', 'SVHN', 'TINYIMAGENET']".format(self.dataset)
204212

205213
lSet = []
206214
uSet = []
@@ -262,7 +270,7 @@ def makeTVSets(self, train_split_ratio, val_split_ratio, data, seed_id, save_dir
262270

263271
assert isinstance(train_split_ratio, float),"Train split ratio is of {} datatype instead of float".format(type(train_split_ratio))
264272
assert isinstance(val_split_ratio, float),"Val split ratio is of {} datatype instead of float".format(type(val_split_ratio))
265-
assert self.dataset in ["MNIST","CIFAR10","CIFAR100", "SVHN"], "Sorry the dataset {} is not supported. Currently we support ['MNIST','CIFAR10', 'CIFAR100', 'SVHN']".format(self.dataset)
273+
assert self.dataset in ["MNIST","CIFAR10","CIFAR100", "SVHN", "TINYIMAGENET"], "Sorry the dataset {} is not supported. Currently we support ['MNIST','CIFAR10', 'CIFAR100', 'SVHN', 'TINYIMAGENET']".format(self.dataset)
266274

267275
trainSet = []
268276
valSet = []
@@ -377,7 +385,7 @@ def getTestLoader(self, data, test_batch_size, seed_id=0):
377385
torch.manual_seed(seed_id)
378386
np.random.seed(seed_id)
379387

380-
if self.dataset in ["MNIST","CIFAR10","CIFAR100"]:
388+
if self.dataset in ["MNIST","CIFAR10","CIFAR100", "TINYIMAGENET"]:
381389
n_datapts = len(data)
382390
idx = [i for i in range(n_datapts)]
383391
#np.random.shuffle(idx)

pycls/datasets/tiny_imagenet.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import os
2+
import numpy as np
3+
4+
import torch
5+
import torchvision.datasets as datasets
6+
7+
from typing import Any
8+
9+
10+
class TinyImageNet(datasets.ImageFolder):
11+
"""`Tiny ImageNet Classification Dataset.
12+
13+
Args:
14+
root (string): Root directory of the ImageNet Dataset.
15+
split (string, optional): The dataset split, supports ``train``, or ``val``.
16+
transform (callable, optional): A function/transform that takes in an PIL image
17+
and returns a transformed version. E.g, ``transforms.RandomCrop``
18+
target_transform (callable, optional): A function/transform that takes in the
19+
target and transforms it.
20+
loader (callable, optional): A function to load an image given its path.
21+
22+
Attributes:
23+
classes (list): List of the class name tuples.
24+
class_to_idx (dict): Dict with items (class_name, class_index).
25+
wnids (list): List of the WordNet IDs.
26+
wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
27+
samples (list): List of (image path, class_index) tuples
28+
targets (list): The class_index value for each image in the dataset
29+
"""
30+
def __init__(self, root: str, split: str = 'train', **kwargs: Any) -> None:
31+
self.root = root
32+
assert self.check_root(), "Something is wrong with the Tiny ImageNet dataset. Download the official dataset zip from http://cs231n.stanford.edu/tiny-imagenet-200.zip and unzip it inside {}.".format(self.root)
33+
self.split = datasets.utils.verify_str_arg(split, "split", ("train", "val"))
34+
35+
wnid_to_classes = self.load_wnid_to_classes()
36+
37+
super(TinyImageNet, self).__init__(self.split_folder, **kwargs)
38+
self.wnids = self.classes
39+
self.wnid_to_idx = self.class_to_idx
40+
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
41+
self.class_to_idx = {cls: idx
42+
for idx, clss in enumerate(self.classes)
43+
for cls in clss}
44+
# Tiny ImageNet val directory structure is not similar to that of train's
45+
# So a custom loading function is necessary
46+
if self.split == 'val':
47+
self.root = root
48+
self.imgs, self.target = self.load_val_data()
49+
self.samples = [(self.imgs[idx],self.targets[idx]) for idx in range(len(self.imgs))]
50+
self.root = os.path.join(self.root, 'val')
51+
52+
53+
# Split folder is used for the 'super' call. Since val directory is not structured like the train,
54+
# we simply use train's structure to get all classes and other stuff
55+
@property
56+
def split_folder(self) -> str:
57+
return os.path.join(self.root, 'train')
58+
59+
60+
def load_val_data(self):
61+
imgs, targets = [], []
62+
with open(os.path.join(self.root, 'val', 'val_annotations.txt'), 'r') as file:
63+
for line in file:
64+
if line.split()[1] in self.wnids:
65+
img_file, wnid = line.split('\t')[:2]
66+
imgs.append(os.path.join(self.root, 'val', 'images', img_file))
67+
targets.append(wnid)
68+
targets = np.array([self.wnid_to_idx[wnid] for wnid in targets])
69+
return imgs, targets
70+
71+
72+
def load_wnid_to_classes(self):
73+
wnid_to_classes = {}
74+
with open(os.path.join(self.root, 'words.txt'), 'r') as file:
75+
lines = file.readlines()
76+
lines = [x.split('\t') for x in lines]
77+
wnid_to_classes = {x[0]:x[1].strip() for x in lines}
78+
return wnid_to_classes
79+
80+
def check_root(self):
81+
tinyim_set = ['words.txt', 'wnids.txt', 'train', 'val', 'test']
82+
for x in os.scandir(self.root):
83+
if x.name not in tinyim_set:
84+
return False
85+
return True

tools/ensemble_al.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def main(cfg):
110110
if not os.path.exists(cfg.OUT_DIR):
111111
os.mkdir(cfg.OUT_DIR)
112112
# Create "DATASET" specific directory
113-
dataset_out_dir = os.path.join(cfg.OUT_DIR, cfg.DATASET.NAME)
113+
dataset_out_dir = os.path.join(cfg.OUT_DIR, cfg.DATASET.NAME, cfg.MODEL.TYPE)
114114
if not os.path.exists(dataset_out_dir):
115115
os.mkdir(dataset_out_dir)
116116
# Creating the experiment directory inside the dataset specific directory
117117
# all logs, labeled, unlabeled, validation sets are stroed here
118-
# E.g., output/CIFAR10/{timestamp or cfg.EXP_NAME based on arguments passed}
118+
# E.g., output/CIFAR10/resnet18/{timestamp or cfg.EXP_NAME based on arguments passed}
119119
if cfg.EXP_NAME == 'auto':
120120
now = datetime.now()
121121
exp_dir = f'{now.year}_{now.month}_{now.day}_{now.hour}{now.minute}{now.second}'
@@ -246,7 +246,10 @@ def main(cfg):
246246
save_plot_values([plot_episode_xvalues, plot_episode_yvalues], \
247247
["plot_episode_xvalues", "plot_episode_yvalues"], out_dir=cfg.EXP_DIR, saveInTextFormat=True)
248248

249-
249+
250+
# No need to perform active sampling in the last episode iteration
251+
if cur_episode == cfg.ACTIVE_LEARNING.MAX_ITER:
252+
break
250253

251254
# Active Sample
252255
print("======== ENSEMBLE ACTIVE SAMPLING ========\n")

tools/train_al.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def main(cfg):
113113
# Create "DATASET" specific directory
114114
dataset_out_dir = os.path.join(cfg.OUT_DIR, cfg.DATASET.NAME, cfg.MODEL.TYPE)
115115
if not os.path.exists(dataset_out_dir):
116-
os.mkdir(dataset_out_dir)
116+
os.makedirs(dataset_out_dir)
117117
# Creating the experiment directory inside the dataset specific directory
118118
# all logs, labeled, unlabeled, validation sets are stroed here
119119
# E.g., output/CIFAR10/resnet18/{timestamp or cfg.EXP_NAME based on arguments passed}

0 commit comments

Comments
 (0)