Skip to content

Commit

Permalink
Merge pull request #2180 from f4str/torch-dataloaders
Browse files Browse the repository at this point in the history
Optimize PyTorch Classifiers and Object Detectors
  • Loading branch information
beat-buesser committed Jun 20, 2023
2 parents 011ab1e + f3fcf19 commit 4bfed67
Show file tree
Hide file tree
Showing 9 changed files with 541 additions and 559 deletions.
38 changes: 15 additions & 23 deletions art/estimators/certification/randomized_smoothing/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import List, Optional, Tuple, Union, Any, TYPE_CHECKING
from typing import List, Optional, Tuple, Union, TYPE_CHECKING

import warnings
import random
from tqdm import tqdm
import numpy as np

Expand Down Expand Up @@ -137,7 +136,7 @@ def fit( # pylint: disable=W0221
nb_epochs: int = 10,
training_mode: bool = True,
drop_last: bool = False,
scheduler: Optional[Any] = None,
scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None,
**kwargs,
) -> None:
"""
Expand All @@ -157,6 +156,7 @@ def fit( # pylint: disable=W0221
and providing it takes no effect.
"""
import torch
from torch.utils.data import TensorDataset, DataLoader

# Set model mode
self._model.train(mode=training_mode)
Expand All @@ -172,36 +172,28 @@ def fit( # pylint: disable=W0221
# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)

num_batch = len(x_preprocessed) / float(batch_size)
if drop_last:
num_batch = int(np.floor(num_batch))
else:
num_batch = int(np.ceil(num_batch))
ind = np.arange(len(x_preprocessed))
std = torch.tensor(self.scale).to(self._device)

x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
# Create dataloader
x_tensor = torch.from_numpy(x_preprocessed)
y_tensor = torch.from_numpy(y_preprocessed)
dataset = TensorDataset(x_tensor, y_tensor)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)

# Start training
for _ in tqdm(range(nb_epochs)):
# Shuffle the examples
random.shuffle(ind)

# Train for one epoch
for m in range(num_batch):
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
for x_batch, y_batch in dataloader:
# Move inputs to device
x_batch = x_batch.to(self._device)
y_batch = y_batch.to(self._device)

# Add random noise for randomized smoothing
i_batch = i_batch + torch.randn_like(i_batch, device=self._device) * std
x_batch += torch.randn_like(x_batch) * self.scale

# Zero the parameter gradients
self._optimizer.zero_grad()

# Perform prediction
try:
model_outputs = self._model(i_batch)
model_outputs = self._model(x_batch)
except ValueError as err:
if "Expected more than 1 value per channel when training" in str(err):
logger.exception(
Expand All @@ -211,7 +203,7 @@ def fit( # pylint: disable=W0221
raise err

# Form the loss function
loss = self._loss(model_outputs[-1], o_batch)
loss = self._loss(model_outputs[-1], y_batch)

# Do training
if self._use_amp: # pragma: no cover
Expand Down
55 changes: 24 additions & 31 deletions art/estimators/classification/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import copy
import logging
import os
import random
import time
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING

Expand Down Expand Up @@ -309,26 +308,27 @@ def predict( # pylint: disable=W0221
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
"""
import torch
from torch.utils.data import TensorDataset, DataLoader

# Set model mode
self._model.train(mode=training_mode)

# Apply preprocessing
x_preprocessed, _ = self._apply_preprocessing(x, y=None, fit=False)

results_list = []
# Create dataloader
x_tensor = torch.from_numpy(x_preprocessed)
dataset = TensorDataset(x_tensor)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

# Run prediction with batch processing
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
for m in range(num_batch):
# Batch indexes
begin, end = (
m * batch_size,
min((m + 1) * batch_size, x_preprocessed.shape[0]),
)
results_list = []
for (x_batch,) in dataloader:
# Move inputs to device
x_batch = x_batch.to(self._device)

# Run prediction
with torch.no_grad():
model_outputs = self._model(torch.from_numpy(x_preprocessed[begin:end]).to(self._device))
model_outputs = self._model(x_batch)
output = model_outputs[-1]
output = output.detach().cpu().numpy().astype(np.float32)
if len(output.shape) == 1:
Expand Down Expand Up @@ -373,7 +373,7 @@ def fit( # pylint: disable=W0221
nb_epochs: int = 10,
training_mode: bool = True,
drop_last: bool = False,
scheduler: Optional[Any] = None,
scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None,
**kwargs,
) -> None:
"""
Expand All @@ -393,6 +393,7 @@ def fit( # pylint: disable=W0221
and providing it takes no effect.
"""
import torch
from torch.utils.data import TensorDataset, DataLoader

# Set model mode
self._model.train(mode=training_mode)
Expand All @@ -408,32 +409,25 @@ def fit( # pylint: disable=W0221
# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)

num_batch = len(x_preprocessed) / float(batch_size)
if drop_last:
num_batch = int(np.floor(num_batch))
else:
num_batch = int(np.ceil(num_batch))
ind = np.arange(len(x_preprocessed))

x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
# Create dataloader
x_tensor = torch.from_numpy(x_preprocessed)
y_tensor = torch.from_numpy(y_preprocessed)
dataset = TensorDataset(x_tensor, y_tensor)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)

# Start training
for _ in range(nb_epochs):
# Shuffle the examples
random.shuffle(ind)

# Train for one epoch
for m in range(num_batch):
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
for x_batch, y_batch in dataloader:
# Move inputs to device
x_batch = x_batch.to(self._device)
y_batch = y_batch.to(self._device)

# Zero the parameter gradients
self._optimizer.zero_grad()

# Perform prediction
try:
model_outputs = self._model(i_batch)
model_outputs = self._model(x_batch)
except ValueError as err:
if "Expected more than 1 value per channel when training" in str(err):
logger.exception(
Expand All @@ -443,15 +437,14 @@ def fit( # pylint: disable=W0221
raise err

# Form the loss function
loss = self._loss(model_outputs[-1], o_batch)
loss = self._loss(model_outputs[-1], y_batch)

# Do training
if self._use_amp: # pragma: no cover
from apex import amp # pylint: disable=E0611

with amp.scale_loss(loss, self._optimizer) as scaled_loss:
scaled_loss.backward()

else:
loss.backward()

Expand Down

0 comments on commit 4bfed67

Please sign in to comment.