Skip to content

Commit

Permalink
Added base dataloader and experiment for pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmed-shariff committed Jul 19, 2019
1 parent 6f2986e commit 4f6441e
Show file tree
Hide file tree
Showing 2 changed files with 380 additions and 0 deletions.
280 changes: 280 additions & 0 deletions mlpipeline/pytorch/_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
import torch
import json
from mlpipeline.base import DataLoaderABC
from mlpipeline.utils import ExecutionModeKeys


class Datasets():
"""Class to store the datasets"""
# pylint disable:too-many-arguments

def __init__(self,
train_dataset_file_path,
test_dataset_file_path=None,
validation_dataset_file_path=None,
class_encoding=None,
train_data_load_function=lambda file_path: json.load(open(file_path, "r")),
test_data_load_function=None,
test_size=None,
# use_cache=True # Need to implementate this one?
validation_size=None):
"""
Keyword arguments:
train_dataset_file_path -- The path to the file containing the train dataset
test_dataset_file_path -- The path to the file containing the test dataset.
If this is None, a portion of the train dataset will be allocated as
the test dataset based on the `test_size`.
validation_dataset_file_path -- The path to the file containing the validation dataset.
If this is None, a portion of the train dataset will be allocated as
the validation dataset based on the `validation_size` after
allocating the test dataset.
class_encoding -- Dict. The index to class name mapping of the dataset. Will be logged.
train_data_load_function -- The function that will be used the content of the files passed above.
This is a callable, that takes the file path and return the dataset.
The returned value should allow selecting rows using python's slicing
(eg: pandas.DataFrame, python lists, numpy.array). Will be used to
load the file_passed through `train_daset_file_path`,
`validation_dataset_file_path`. Also will be used to load the
`test_dataset_file_path` if `test_data_load_function` is None.
test_data_load_function -- Similar to `train_data_load_function`. This parameter can be used to
define a seperate loading process for the test_dataset. If
`test_dataset_file_path` is not None, this callable will be used to
load the file's content. Also, if this parameter is set and
`test_dataset_file_path` is None, instead of allocating a portion of
the train_dataset as test_dataset, the files`train_dataset_file_path`
passed will be loaded using this callable. Note that it is the
callers responsibility to ensure there are no intersections between
train and test dataset when data is loaded using this parameter.
test_size -- Float between 0 and 1. The portion of the train dataset to allocate
as the test dataset based if `test_dataset_file_path` not given and
`test_data_load_function` is None.
validation_size -- Float between 0 and 1. The portion of the train dataset to allocate
as the validadtion dataset based if
`validation_dataset_file_path` not given.
"""
self._train_dataset = self._load_data(train_dataset_file_path,
train_data_load_function)
if test_dataset_file_path is None:
if test_data_load_function is not None:
self._test_dataset = self._load_data(train_dataset_file_path,
test_data_load_function)
else:
if test_size is None:
log("Datasetes: Using default 'test_size': 0.1")
test_size = 0.1
assert 0 <= test_size <= 1
train_size = round(len(self._train_dataset) * test_size)
self._test_dataset = self._train_dataset[:train_size]
self._train_dataset = self._train_dataset[train_size:]
else:
if test_size is not None:
log("Datasetes: Ignoring 'test_size'")
if test_data_load_function is None:
self._test_dataset = self._load_data(test_dataset_file_path,
train_data_load_function)
else:
self._test_dataset = self._load_data(test_dataset_file_path,
test_data_load_function)

if validation_dataset_file_path is None:
if validation_size is None:
log("Datasetes: Using default 'validation_size': 0.1")
validation_size = 0.1
assert 0 <= validation_size <= 1
train_size = round(len(self._train_dataset) * validation_size)
self._validation_dataset = self._train_dataset[:train_size]
self._train_dataset = self._train_dataset[train_size:]
else:
if validation_size is not None:
log("Datasetes: Ignoring 'validation_size'")
self._validation_dataset = self._load_data(validation_dataset_file_path,
train_data_load_function)

if class_encoding is not None:
assert isinstance(class_encoding, dict)
self.class_encoding = class_encoding
log("Datasets- Train dataset size: {}".format(len(self._train_dataset)))
log("Datasets- Test dataset size: {}".format(len(self._test_dataset)))
log("Datasets- Validation dataset size: {}".format(len(self._validation_dataset)))

def _load_data(self,
data_file_path,
data_load_function):
"""Helper function to load the data using the provided `data_load_function`"""
data, used_labels = data_load_function(data_file_path)
try:
self.used_labels.update(used_labels)
except AttributeError:
self.used_labels = set(used_labels)

# Cheap way of checking if slicing is supported
try:
data[0:2:2]
except Exception:
raise Exception("Check if the object returned by 'data_load_function' supports slicing!")
return data

@property
def train_dataset(self):
"""The pandas dataframe representing the training dataset"""
return self._train_dataset

@train_dataset.setter
def train_dataset(self, value):
self._train_dataset = value

@property
def test_dataset(self):
"""The pandas dataframe representing the training dataset"""
return self._test_dataset

@test_dataset.setter
def test_dataset(self, value):
self._test_dataset = value

@property
def validation_dataset(self):
"""The pandas dataframe representing the validation dataset"""
return self._validation_dataset

@validation_dataset.setter
def validation_dataset(self, value):
self._validation_dataset = value


class BaseTorchDataLoader(DataLoaderABC):
"""Base DataLoader implementation for using with pytoch"""

def __init__(self,
datasets,
pytorch_dataset_factory,
batch_size,
train_transforms=[],
test_transforms=[]):
super().__init__()
assert isinstance(datasets, Datasets)
assert isinstance(pytorch_dataset_factory, DatasetFactory)
self.datasets = datasets
self.pytorch_dataset_factory = pytorch_dataset_factory
self.batch_size = batch_size
self.train_transforms = train_transforms
self.test_transforms = test_transforms

def get_dataloader_summery(self, **kargs):
# This is set in the base class DataLoaderABC
return self.summery

def get_train_sample_count(self):
return len(self.datasets.train_dataset)

def get_test_sample_count(self):
"""Retruns the number of test samples"""
return len(self.datasets.test_dataset)

def get_train_input(self, mode, **kwargs):
self.log("batch size: {}, mode: {}".format(self.batch_size, mode))

if mode == ExecutionModeKeys.TRAIN:
dataset_class = self.pytorch_dataset_factory.create_instance(
current_data=self.datasets.train_dataset,
transform=self.train_transforms,
mode=ExecutionModeKeys.TRAIN)
else:
dataset_class = self.pytorch_dataset_factory.create_instance(
current_data=self.datasets.train_dataset,
transform=self.test_transforms,
mode=ExecutionModeKeys.TEST)
if mode == ExecutionModeKeys.TRAIN:
batch_size = self.batch_size
else:
batch_size = 1
dl = torch.utils.data.DataLoader(dataset_class,
batch_size=batch_size,
shuffle=True,
collate_fn=dataset_class.collate_fn)

return dl

def get_test_input(self, data=None, **kargs):
if data is None:
data = self.datasets.test_dataset
dataset_class = self.pytorch_dataset_factory.create_instance(
current_data=data,
transform=self.test_transforms,
mode=ExecutionModeKeys.TEST)
dl = torch.utils.data.DataLoader(dataset_class,
batch_size=1,
collate_fn=dataset_class.collate_fn)
return dl

# TODO: Temp placeholder function until the behaviour is replaced.
def set_validation_set(self, dataset):
pass


class DatasetFactory():
"""
This class will be used to create the dataset objects to be used by the different
stages in the DataLoader.
"""
def __init__(self, dataset_class, **args):
self.dataset_class = dataset_class
self.args = args

def create_instance(self, current_data, mode, transform_fn):
obj = self.dataset_class(**self.args)
obj._inject_params(current_data, mode, transform_fn)
obj.collate_fn = getattr(obj, 'collate_fn', torch.utils.data.dataloader.default_collate)
return obj


class DatasetBasic(torch.utils.data.Dataset):
"""The Base dataset class."""

def __init__(self, *args, **kwargs):
self._current_data = None
self._transform = None
self._mode = None
self.current_data = property(fget=self._get_current_data,
fset=self._set_current_data,
doc="The data currently being used by the dataset")
self.transform = property(fget=self._get_transform,
fset=self._set_transform,
doc="The transforms to be applied")
self.mode = property(fget=self._get_mode,
fset=self._set_mode,
doc="The current mode of the experiment (ExecutionModeKeys)")

def _set_mode(self, value):
self._mode = value

def _get_mode(self):
return self._mode

def _set_transform(self, value):
self._transform = value

def _get_transform(self):
return self._transform

def _set_current_data(self, value):
self._current_data = value

def _get_current_data(self):
return self._current_data

def _inject_params(self, current_data, mode, transform=None):
self._set_current_data(current_data)
self._set_transform(transform)
self._set_mode(mode)

def pre_process(self):
raise NotImplementedError()

def __len__(self):
if self.current_data is None:
return 0
return len(self.current_data)

def __getitem__(self, idx):
raise NotImplementedError()
100 changes: 100 additions & 0 deletions mlpipeline/pytorch/_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import torch
from mlpipeline.base import ExperimentABC
from mlpipeline.utils import ExecutionModeKeys


class BaseTorchExperiment(ExperimentABC):
def __init__(self, versions, **args):
super().__init__(versions, **args)
self.model = None
self.topk_k = None
self.logging_iteration = None
self.criterion = None
self.optimizer = None
self.checkpoint_saving_per_epoc = None
self.use_cuda = None
self.save_history_checkpoints_count = None

def setup_model(self, version, experiment_dir):
self.history_file_name = "{}/model_params{}.tch".format(experiment_dir.rstrip("/"), "{}")
self.file_name = self.history_file_name.format(0)

def pre_execution_hook(self, version, experiment_dir, exec_mode=ExecutionModeKeys.TEST):
print("Version spec: ", version)
self.current_version = version
self.logging_iteration = 10
self.save_history_checkpoints_count = 10
if os.path.isfile(self.file_name):
self.log("Loading parameters from: {}".format(self.file_name))
self.load_history_checkpoint(self.file_name)
else:
self.epocs_params = 0
self.log("No checkpoint")

def get_current_version(self):
return self.current_version

def get_trained_step_count(self):
ret_val = (self.epocs_params
* self.dataloader.get_train_sample_count()
/ self.dataloader.batch_size)
self.log("steps_trained: {}".format(ret_val))
return ret_val

def save_checkpoint(self, epoch):
directory = os.path.dirname(self.file_name)
if not os.path.exists(directory):
os.makedirs(directory)

if self.save_history_checkpoints_count is not None:
if self.save_history_checkpoints_count < 1:
raise ValueError("save_history_checkpoints_count should be 1 or higher. "
"Else set it to None to completely disable this feature.")
for history_idx in range(self.save_history_checkpoints_count - 1, -1, -1):
history_file_name = self.history_file_name.format(history_idx)
if os.path.exists(history_file_name):
os.replace(history_file_name, self.history_file_name.format(history_idx + 1))
self.log("History checkpoints: {}".format(self.save_history_checkpoints_count))
torch.save({
'epoch': epoch,
'state_dict': self.model.state_dict(),
'optimizer' : None if self.optimizer is None else self.optimizer.state_dict(),
'validation': self.dataloader.datasets.validation_dataset,
'lr_scheduler': None if self.lr_scheduler is None else self.lr_scheduler.state_dict()
}, self.file_name)
self.log("Saved checkpoint for epoc: {} at {}".format(epoch + 1, self.file_name))

def load_history_checkpoint(self, checkpoint_file_name, load_optimizer=True, export_mode=False):
self.log("Loading: {}".format(checkpoint_file_name), log_to_file=True)
checkpoint = torch.load(checkpoint_file_name)
self.epocs_params = checkpoint['epoch']
self.model.load_state_dict(checkpoint['state_dict'])
if export_mode:
return

if load_optimizer:
self.optimizer.load_state_dict(checkpoint['optimizer'])
if checkpoint['lr_scheduler'] is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
if checkpoint['validation'] is not None:
self.dataloader.set_validation_set(checkpoint['validation'])

def get_ancient_checkpoint_file_name(self, epoc_from_last=None):
if epoc_from_last is None:
epoc_from_last = self.save_history_checkpoints_count
elif epoc_from_last == 0:
history_file_name = self.history_file_name.format(0)
if os.path.exists(history_file_name):
return history_file_name
elif epoc_from_last > self.save_history_checkpoints_count:
raise ValueError("`epoc_from_last` should be less than or equal "
"`self.save_history_checkpoints_count`.")

if self.save_history_checkpoints_count < 1:
raise ValueError("save_history_checkpoints_count should be 1 or higher. "
"Else set it to None to completely disable this feature.")
for history_idx in range(epoc_from_last, 0, -1):
history_file_name = self.history_file_name.format(history_idx)
if os.path.exists(history_file_name):
return history_file_name

0 comments on commit 4f6441e

Please sign in to comment.