Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize PyTorch Classifiers and Object Detectors #2180

Merged
merged 13 commits into from
Jun 20, 2023
38 changes: 15 additions & 23 deletions art/estimators/certification/randomized_smoothing/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import List, Optional, Tuple, Union, Any, TYPE_CHECKING
from typing import List, Optional, Tuple, Union, TYPE_CHECKING

import warnings
import random
from tqdm import tqdm
import numpy as np

Expand Down Expand Up @@ -137,7 +136,7 @@ def fit( # pylint: disable=W0221
nb_epochs: int = 10,
training_mode: bool = True,
drop_last: bool = False,
scheduler: Optional[Any] = None,
scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None,
**kwargs,
) -> None:
"""
Expand All @@ -157,6 +156,7 @@ def fit( # pylint: disable=W0221
and providing it takes no effect.
"""
import torch
from torch.utils.data import TensorDataset, DataLoader

# Set model mode
self._model.train(mode=training_mode)
Expand All @@ -172,36 +172,28 @@ def fit( # pylint: disable=W0221
# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)

num_batch = len(x_preprocessed) / float(batch_size)
if drop_last:
num_batch = int(np.floor(num_batch))
else:
num_batch = int(np.ceil(num_batch))
ind = np.arange(len(x_preprocessed))
std = torch.tensor(self.scale).to(self._device)

x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
# Create dataloader
x_tensor = torch.from_numpy(x_preprocessed)
y_tensor = torch.from_numpy(y_preprocessed)
dataset = TensorDataset(x_tensor, y_tensor)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)

# Start training
for _ in tqdm(range(nb_epochs)):
# Shuffle the examples
random.shuffle(ind)

# Train for one epoch
for m in range(num_batch):
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
for x_batch, y_batch in dataloader:
# Move inputs to device
x_batch = x_batch.to(self._device)
y_batch = y_batch.to(self._device)

# Add random noise for randomized smoothing
i_batch = i_batch + torch.randn_like(i_batch, device=self._device) * std
x_batch += torch.randn_like(x_batch) * self.scale

# Zero the parameter gradients
self._optimizer.zero_grad()

# Perform prediction
try:
model_outputs = self._model(i_batch)
model_outputs = self._model(x_batch)
except ValueError as err:
if "Expected more than 1 value per channel when training" in str(err):
logger.exception(
Expand All @@ -211,7 +203,7 @@ def fit( # pylint: disable=W0221
raise err

# Form the loss function
loss = self._loss(model_outputs[-1], o_batch)
loss = self._loss(model_outputs[-1], y_batch)

# Do training
if self._use_amp: # pragma: no cover
Expand Down
55 changes: 24 additions & 31 deletions art/estimators/classification/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import copy
import logging
import os
import random
import time
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING

Expand Down Expand Up @@ -309,26 +308,27 @@ def predict( # pylint: disable=W0221
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
"""
import torch
from torch.utils.data import TensorDataset, DataLoader

# Set model mode
self._model.train(mode=training_mode)

# Apply preprocessing
x_preprocessed, _ = self._apply_preprocessing(x, y=None, fit=False)

results_list = []
# Create dataloader
x_tensor = torch.from_numpy(x_preprocessed)
dataset = TensorDataset(x_tensor)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

# Run prediction with batch processing
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
for m in range(num_batch):
# Batch indexes
begin, end = (
m * batch_size,
min((m + 1) * batch_size, x_preprocessed.shape[0]),
)
results_list = []
for (x_batch,) in dataloader:
# Move inputs to device
x_batch = x_batch.to(self._device)

# Run prediction
with torch.no_grad():
model_outputs = self._model(torch.from_numpy(x_preprocessed[begin:end]).to(self._device))
model_outputs = self._model(x_batch)
output = model_outputs[-1]
output = output.detach().cpu().numpy().astype(np.float32)
if len(output.shape) == 1:
Expand Down Expand Up @@ -373,7 +373,7 @@ def fit( # pylint: disable=W0221
nb_epochs: int = 10,
training_mode: bool = True,
drop_last: bool = False,
scheduler: Optional[Any] = None,
scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None,
**kwargs,
) -> None:
"""
Expand All @@ -393,6 +393,7 @@ def fit( # pylint: disable=W0221
and providing it takes no effect.
"""
import torch
from torch.utils.data import TensorDataset, DataLoader

# Set model mode
self._model.train(mode=training_mode)
Expand All @@ -408,32 +409,25 @@ def fit( # pylint: disable=W0221
# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)

num_batch = len(x_preprocessed) / float(batch_size)
if drop_last:
num_batch = int(np.floor(num_batch))
else:
num_batch = int(np.ceil(num_batch))
ind = np.arange(len(x_preprocessed))

x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
# Create dataloader
x_tensor = torch.from_numpy(x_preprocessed)
y_tensor = torch.from_numpy(y_preprocessed)
dataset = TensorDataset(x_tensor, y_tensor)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)

# Start training
for _ in range(nb_epochs):
# Shuffle the examples
random.shuffle(ind)

# Train for one epoch
for m in range(num_batch):
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
for x_batch, y_batch in dataloader:
# Move inputs to device
x_batch = x_batch.to(self._device)
y_batch = y_batch.to(self._device)

# Zero the parameter gradients
self._optimizer.zero_grad()

# Perform prediction
try:
model_outputs = self._model(i_batch)
model_outputs = self._model(x_batch)
except ValueError as err:
if "Expected more than 1 value per channel when training" in str(err):
logger.exception(
Expand All @@ -443,15 +437,14 @@ def fit( # pylint: disable=W0221
raise err

# Form the loss function
loss = self._loss(model_outputs[-1], o_batch)
loss = self._loss(model_outputs[-1], y_batch)

# Do training
if self._use_amp: # pragma: no cover
from apex import amp # pylint: disable=E0611

with amp.scale_loss(loss, self._optimizer) as scaled_loss:
scaled_loss.backward()

else:
loss.backward()

Expand Down
Loading