Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop_last option to method fit of PyTorchClassifier #1883

Merged
merged 7 commits into from
Nov 10, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions art/estimators/classification/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def fit( # pylint: disable=W0221
batch_size: int = 128,
nb_epochs: int = 10,
training_mode: bool = True,
drop_last: bool = False,
**kwargs,
) -> None:
"""
Expand All @@ -373,8 +374,11 @@ def fit( # pylint: disable=W0221
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
the last batch will be smaller. (default: ``False``)
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
and providing it takes no effect.
and providing it takes no effect.
"""
import torch # lgtm [py/repeated-import]

Expand All @@ -392,18 +396,25 @@ def fit( # pylint: disable=W0221
# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)

num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
num_batch = len(x_preprocessed) / float(batch_size)
if drop_last:
num_batch = int(np.floor(num_batch))
else:
num_batch = int(np.ceil(num_batch))
ind = np.arange(len(x_preprocessed))

x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)

# Start training
for _ in range(nb_epochs):
# Shuffle the examples
random.shuffle(ind)

# Train for one epoch
for m in range(num_batch):
i_batch = torch.from_numpy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]

# Zero the parameter gradients
self._optimizer.zero_grad()
Expand Down