Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,23 @@ answer_preprocessing: none
# Those weights can be next used by weighted samplers (e.g. kFoldWeightedSampler)
export_sample_weights: ''

# Shuffle the indices of the input (source) files/samples.
# Leaving that to false will results in the original order of files samples,
# i.e. C1, then C2, then C3 etc.
shuffle_indices: False

# Generate and export (potentially shuffled) indices (LOADED)
# If not empty, will:
# * shuffle indices of all samples and export them to a file.
# * use those indices during sampling.
export_indices: ''

# Import (potentially shuffled) indices (LOADED)
# If not empty, will:
# * import them to a file.
# * use those indices during sampling.
import_indices: ''

streams:
####################################################################
# 2. Keymappings associated with INPUT and OUTPUT streams.
Expand Down
12 changes: 10 additions & 2 deletions configs/vqa_med_2019/default_vqa_med_2019.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ training:
split: training_validation
resize_image: &resize_image [224, 224]
batch_size: 64
# Generate and export shuffled indices.
shuffle_indices: True
export_indices: shuffled_indices.npy

# Default sampler during training.
sampler:
name: kFoldWeightedRandomSampler
folds: 5
folds: 10
epochs_per_fold: 20
# Use four workers for loading images.
dataloader:
num_workers: 4
Expand All @@ -35,10 +39,14 @@ validation:
split: training_validation
resize_image: *resize_image
batch_size: 64
# Import shuffled indices.
import_indices: shuffled_indices.npy

# Default sampler during validation.
sampler:
name: kFoldRandomSampler
folds: 5
folds: 10
epochs_per_fold: 20
# Use four workers for loading images.
dataloader:
num_workers: 4
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pipeline:
import_word_mappings_from_globals: True
streams:
inputs: predictions
outputs: predicted_categories
outputs: predicted_category_names
globals:
vocabulary_size: num_categories
word_mappings: category_word_mappings
Expand Down Expand Up @@ -72,6 +72,6 @@ pipeline:
viewer:
type: StreamViewer
priority: 100.4
input_streams: questions,category_names,predicted_categories
input_streams: questions,category_names,predicted_category_names

#: pipeline
8 changes: 6 additions & 2 deletions ptp/application/sampler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ def build(problem, config, problem_subset_name):
folds = config["folds"]
if folds < 2:
raise ConfigurationError("kFoldRandomSampler requires at least two 'folds'")
# Get epochs per fold (default: 1).
epochs_per_fold = config.get("epochs_per_fold", 1)

# Create the sampler object.
sampler = ptp_samplers.kFoldRandomSampler(len(problem), folds, problem_subset_name == 'training')
sampler = ptp_samplers.kFoldRandomSampler(len(problem), folds, epochs_per_fold, problem_subset_name == 'training')

###########################################################################
# Handle fourd special case: kFoldWeightedRandomSampler.
Expand All @@ -202,9 +204,11 @@ def build(problem, config, problem_subset_name):
folds = config["folds"]
if folds < 2:
raise ConfigurationError("kFoldRandomSampler requires at least two 'folds'")
# Get epochs per fold (default: 1).
epochs_per_fold = config.get("epochs_per_fold", 1)

# Create the sampler object.
sampler = ptp_samplers.kFoldWeightedRandomSampler(weights, len(problem), folds, problem_subset_name == 'training')
sampler = ptp_samplers.kFoldWeightedRandomSampler(weights, len(problem), folds, epochs_per_fold, problem_subset_name == 'training')

elif name in ['BatchSampler', 'DistributedSampler']:
# Sorry, don't support those. Yet;)
Expand Down
29 changes: 23 additions & 6 deletions ptp/components/problems/image_text_to_class/vqa_med_2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,29 @@ def __init__(self, name, config):
source_image_folder = os.path.join(split_folder, 'VQAMed2019_Test_Images')
self.dataset = self.load_testset(source_file, source_image_folder)

# Ok, now we got the whole dataset (for given "split").
self.ix = np.arange(len(self.dataset))
if self.config["import_indices"] != '':
# Try to load indices from the file.
self.ix = np.load(os.path.join(self.app_state.log_dir, self.config["import_indices"]))
self.logger.info("Imported indices from '{}'".format(os.path.join(self.app_state.log_dir, self.config["export_indices"])))
else:
# Ok, check whether we want to shuffle.
if self.config["shuffle_indices"]:
np.random.shuffle(self.ix)
# Export if required.
if self.config["export_indices"] != '':
# export indices to file.
np.save(os.path.join(self.app_state.log_dir, self.config["export_indices"]), self.ix)
self.logger.info("Exported indices to '{}'".format(os.path.join(self.app_state.log_dir, self.config["export_indices"])))

# Display exemplary sample.
self.logger.info("Exemplary sample:\n [ category: {}\t image_ids: {}\t question: {}\t answer: {} ]".format(
self.dataset[0][self.key_category_ids],
self.dataset[0][self.key_image_ids],
self.dataset[0][self.key_questions],
self.dataset[0][self.key_answers]
self.logger.info("Exemplary sample 0 ({}):\n [ category: {}\t image_ids: {}\t question: {}\t answer: {} ]".format(
self.ix[0],
self.category_idx_to_word[self.dataset[self.ix[0]][self.key_category_ids]],
self.dataset[self.ix[0]][self.key_image_ids],
self.dataset[self.ix[0]][self.key_questions],
self.dataset[self.ix[0]][self.key_answers]
))

# Check if we want the problem to calculate and export the weights.
Expand Down Expand Up @@ -703,7 +720,7 @@ def __getitem__(self, index):
:return: DataDict({'indices', 'images', 'images_ids','questions', 'answers', 'category_ids', 'image_sizes'})
"""
# Get item.
item = self.dataset[index]
item = self.dataset[self.ix[index]]

# Create the resulting sample (data dict).
data_dict = self.create_data_dict(index)
Expand Down
83 changes: 57 additions & 26 deletions ptp/utils/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ class kFoldRandomSampler(Sampler):
Every time __iter__() method is called, it moves to next fold/set of folds.
"""

def __init__(self, num_samples, num_folds, all_but_current_fold = True):
def __init__(self, num_samples, num_folds, epochs_per_fold = 1, all_but_current_fold = True):
"""
Initializes the sampler by generating the indices associated with the fold(s) that are to be used.

:param num_samples: Size of the dataset
:param num_samples: Size of the dataset

:param num_folds: Number of folds
:param all_but_current_fold: Operation mode (DEFAULT: True):
When True, generates indices for all-but-one folds (for training). \
When False, generates indices for only one fold (for validation). \

:param epochs_per_fold: Number of epochs that need to pass before sampler moves to next fold(s) (DEFAULT: 1)

:param all_but_current_fold: Operation mode (DEFAULT: True): \
When True, generates indices for all-but-one folds (for training) \
When False, generates indices for only one fold (for validation)
"""
# Get number of samples (size of "whole dataset").
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
Expand All @@ -55,12 +59,23 @@ def __init__(self, num_samples, num_folds, all_but_current_fold = True):
num_folds <= 0:
raise ValueError("num_folds should be a positive integeral "
"value, but got num_folds={}".format(num_folds))
self.num_folds = num_folds

# Get number epochs per fold.
if not isinstance(epochs_per_fold, _int_classes) or isinstance(epochs_per_fold, bool) or \
epochs_per_fold <= 0:
raise ValueError("epochs_per_fold should be a positive integeral "
"value, but got num_folds={}".format(epochs_per_fold))

# Store fold-related parameres.
self.all_but_current_fold = all_but_current_fold
# Initialize current "fold" as -1, so then dataloder will call next() for the first time
# it will return samples for 0-th fold/all-but-0th fold.
self.current_fold = -1
self.num_folds = num_folds
self.epochs_per_fold = epochs_per_fold

# Initialize current "fold" so it will return samples for 0-th fold/all-but-0th fold.
self.current_fold = 0
# "Left epochs": +1 is related to "initial", additional generation of indices - below.
self.epochs_left = self.epochs_per_fold +1

# Generate "initial" indices.
self.indices = self.regenerate_indices()

Expand All @@ -73,9 +88,7 @@ def regenerate_indices(self):
# Fold size and indices.
all_indices = range(self.num_samples)
fold_size = ceil(self.num_samples / self.num_folds)

# Modulo current fold number by total number of folds.
fold = self.current_fold % self.num_folds
fold = self.current_fold

# Generate indices associated with the given fold / all except the given fold.
if self.all_but_current_fold:
Expand Down Expand Up @@ -106,11 +119,17 @@ def __iter__(self):
"""
Return "shuffled" indices.
"""
# Next fold.
self.current_fold += 1
# "Decrease" the number of epochs with this fold.
self.epochs_left = self.epochs_left - 1
if self.epochs_left <= 0:
# Next fold, modulo by the total number of folds.
self.current_fold = (self.current_fold + 1) % self.num_folds

# Regenerate indices.
self.indices = self.regenerate_indices()
# Regenerate indices.
self.indices = self.regenerate_indices()

# Reset epochs counter.
self.epochs_left = self.epochs_per_fold

# Return permutated indices.
return (self.indices[i] for i in torch.randperm(len(self.indices)))
Expand All @@ -132,23 +151,30 @@ class kFoldWeightedRandomSampler(kFoldRandomSampler):
Every time __iter__() method is called, it moves to next fold/set of folds.
"""

def __init__(self, weights, num_samples, num_folds, all_but_current_fold = True, replacement=True):
def __init__(self, weights, num_samples, num_folds, epochs_per_fold = 1, all_but_current_fold = True, replacement=True):
"""
Initializes the sampler by generating the indices associated with the fold(s) that are to be used.

:param num_samples: Size of the dataset
:param num_samples: Size of the dataset

:param num_folds: Number of folds
:param all_but_current_fold: Operation mode (DEFAULT: True):
When True, generates indices for all-but-one folds (for training). \
When False, generates indices for only one fold (for validation). \

:param epochs_per_fold: Number of epochs that need to pass before sampler moves to next fold(s) (DEFAULT: 1)

:param all_but_current_fold: Operation mode (DEFAULT: True): \
When True, generates indices for all-but-one folds (for training) \
When False, generates indices for only one fold (for validation)

:params weights: a sequence of weights, not necessary summing up to one

:param num_samples: number of samples to draw

:param replacement: if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
"""
# Call k-fold base class constructor.
super().__init__(num_samples, num_folds, all_but_current_fold)
super().__init__(num_samples, num_folds, epochs_per_fold, all_but_current_fold)
# Get replacement flag.
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
Expand All @@ -159,12 +185,17 @@ def __init__(self, weights, num_samples, num_folds, all_but_current_fold = True,
self.weights = torch.tensor(weights, dtype=torch.double)

def __iter__(self):
# Next fold.
self.current_fold += 1
# "Decrease" the number of epochs with this fold.
self.epochs_left = self.epochs_left - 1
if self.epochs_left <= 0:
# Next fold, modulo by the total number of folds.
self.current_fold = (self.current_fold + 1) % self.num_folds

# Regenerate indices.
self.indices = self.regenerate_indices()
# Regenerate indices.
self.indices = self.regenerate_indices()

# Reset epochs counter.
self.epochs_left = self.epochs_per_fold

# Select the corresponging weights.
weights = torch.take(self.weights, torch.tensor(self.indices))
Expand Down
27 changes: 27 additions & 0 deletions tests/samplers_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,33 @@ def test_kfold_random_sampler_current_fold(self):
self.assertIn(ix, indices)


def test_kfold_random_sampler_current_fold_10epochs(self):
""" Tests the k-fold sampler current_fold mode. """

# Create the sampler.
sampler = kFoldRandomSampler(20, 3, 10, all_but_current_fold=False)

# First 10 epochs - the same indices from 0-7 range.
for _ in range(10):
# Test zero-th fold.
indices = list(iter(sampler))
# Check number of samples.
self.assertEqual(len(indices), 7)
# Check presence of all indices.
for ix in range(0,7):
self.assertIn(ix, indices)

# Next 10 epochs - the same indices from 7-14 range.
for _ in range(10):
# Test zero-th fold.
indices = list(iter(sampler))
# Check number of samples.
self.assertEqual(len(indices), 7)
# Check presence of all indices.
for ix in range(7,14):
self.assertIn(ix, indices)


def test_kfold_random_sampler_all_but_current_fold(self):
""" Tests the k-fold sampler all_but_current_fold mode. """

Expand Down