Skip to content

Commit 4dbd761

Browse files
refactor 3/n (Lightning-AI#2709)
* reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator
1 parent b34217e commit 4dbd761

File tree

10 files changed

+155
-65
lines changed

10 files changed

+155
-65
lines changed

Diff for: .pyrightconfig.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"pytorch_lightning/__init__.py",
88
"pytorch_lightning/callbacks",
99
"pytorch_lightning/core",
10-
"pytorch_lightning/accelerators",
10+
"pytorch_lightning/accelerator_backends",
1111
"pytorch_lightning/loggers",
1212
"pytorch_lightning/logging",
1313
"pytorch_lightning/metrics",

Diff for: docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@
138138
exclude_patterns = [
139139
'api/pytorch_lightning.rst',
140140
'api/pl_examples.*',
141-
'api/pytorch_lightning.accelerators.*',
141+
'api/pytorch_lightning.accelerator_backends.*',
142142
'api/modules.rst',
143143
'PULL_REQUEST_TEMPLATE.md',
144144

Diff for: pytorch_lightning/accelerator_backends/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pytorch_lightning.accelerator_backends.gpu_backend import GPUBackend
2+
from pytorch_lightning.accelerator_backends.tpu_backend import TPUBackend
3+
from pytorch_lightning.accelerator_backends.dp_backend import DataParallelBackend

Diff for: pytorch_lightning/accelerator_backends/dp_backend.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import torch
2+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3+
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
4+
from torch import optim
5+
6+
try:
7+
from apex import amp
8+
except ImportError:
9+
APEX_AVAILABLE = False
10+
else:
11+
APEX_AVAILABLE = True
12+
13+
14+
class DataParallelBackend(object):
15+
16+
def __init__(self, trainer):
17+
self.trainer = trainer
18+
self.model_autocast_original_forward = None
19+
20+
def setup(self, model):
21+
# call setup after the ddp process has connected
22+
if not self.trainer.testing:
23+
self.trainer.setup('fit')
24+
model.setup('fit')
25+
26+
# put model on correct device
27+
model.cuda(self.trainer.root_gpu)
28+
29+
# CHOOSE OPTIMIZER
30+
# allow for lr schedulers as well
31+
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
32+
self.trainer.optimizers = optimizers
33+
self.trainer.lr_schedulers = lr_schedulers
34+
self.trainer.optimizer_frequencies = optimizer_frequencies
35+
36+
# hack forward to do autocast for the user
37+
self.model_autocast_original_forward = model.forward
38+
39+
# init half precision
40+
if self.trainer.use_amp:
41+
model = self.__init_half_precision(model)
42+
43+
# init torch data parallel
44+
model = self.__init_torch_data_parallel(model)
45+
46+
self.trainer.model = model
47+
48+
def __init_torch_data_parallel(self, model):
49+
# create list of device ids
50+
device_ids = self.trainer.data_parallel_device_ids
51+
if isinstance(device_ids, int):
52+
device_ids = list(range(device_ids))
53+
54+
# set dp device
55+
torch.cuda.set_device(self.trainer.root_gpu)
56+
model = LightningDataParallel(model, device_ids=device_ids)
57+
return model
58+
59+
def __init_half_precision(self, model):
60+
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
61+
62+
if native_amp_available:
63+
self.__init_native_amp(model)
64+
else:
65+
model = self.__init_nvidia_apex(model)
66+
return model
67+
68+
def __init_native_amp(self, model):
69+
model.forward = torch.cuda.amp.autocast()(model.forward)
70+
71+
def __init_nvidia_apex(self, model):
72+
# check for this bug (amp + dp + !01 doesn't work)
73+
# https://github.com/NVIDIA/apex/issues/227
74+
if self.trainer.amp_level == 'O2':
75+
raise MisconfigurationException(
76+
f'Amp level {self.trainer.amp_level} with DataParallel is not supported.'
77+
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
78+
f' We recommend you switch to ddp if you want to use amp')
79+
else:
80+
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
81+
self.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers)
82+
83+
return model
84+
85+
def train(self):
86+
model = self.trainer.model
87+
results = self.trainer.run_pretrain_routine(model)
88+
return results
89+
90+
def teardown(self):
91+
92+
# replace the original fwd function
93+
self.trainer.model.forward = self.model_autocast_original_forward
94+
95+
def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
96+
"""
97+
Reinitialize optimizer.step properties added by schedulers
98+
"""
99+
for scheduler in schedulers:
100+
scheduler = scheduler['scheduler']
101+
102+
for optimizer in optimizers:
103+
# check that we dont mix users optimizers and schedulers
104+
if scheduler.optimizer == optimizer:
105+
# Find the mro belonging to the base lr scheduler class
106+
for i, mro in enumerate(scheduler.__class__.__mro__):
107+
is_regular_scheduler = optim.lr_scheduler._LRScheduler
108+
is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau
109+
if is_regular_scheduler or is_lr_reduce_on_plateau:
110+
idx = i
111+
state = scheduler.state_dict()
112+
else:
113+
state = None
114+
115+
scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
116+
if state is not None:
117+
scheduler.load_state_dict(state)

Diff for: pytorch_lightning/accelerators/gpu_accelerator.py renamed to pytorch_lightning/accelerator_backends/gpu_backend.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,21 @@
1414

1515
import torch
1616

17+
try:
18+
from apex import amp
19+
except ImportError:
20+
APEX_AVAILABLE = False
21+
else:
22+
APEX_AVAILABLE = True
1723

18-
class GPUAccelerator(object):
24+
25+
class GPUBackend(object):
1926

2027
def __init__(self, trainer):
2128
self.trainer = trainer
2229

2330
def setup(self, model):
31+
2432
# call setup
2533
if not self.trainer.testing:
2634
self.trainer.setup('fit')
@@ -38,9 +46,15 @@ def setup(self, model):
3846
# TODO: remove with dropping NVIDIA AMP support
3947
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
4048
if self.trainer.use_amp and not native_amp_available:
41-
self._setup_nvidia_apex(model)
49+
model = self._setup_nvidia_apex(model)
50+
return model
51+
52+
def train(self, model):
53+
results = self.trainer.run_pretrain_routine(model)
54+
return results
4255

4356
def _setup_nvidia_apex(self, model):
4457
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
4558
self.trainer.optimizers = optimizers
4659
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
60+
return model

Diff for: pytorch_lightning/accelerators/tpu_accelerator.py renamed to pytorch_lightning/accelerator_backends/tpu_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
XLA_AVAILABLE = True
2929

3030

31-
class TPUAccelerator(object):
31+
class TPUBackend(object):
3232

3333
def __init__(self, trainer):
3434
self.trainer = trainer

Diff for: pytorch_lightning/accelerators/__init__.py

-2
This file was deleted.

Diff for: pytorch_lightning/trainer/distrib_parts.py

-46
Original file line numberDiff line numberDiff line change
@@ -179,52 +179,6 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
179179
return model.transfer_batch_to_device(batch, device)
180180
return move_data_to_device(batch, device)
181181

182-
def dp_train(self, model):
183-
# call setup after the ddp process has connected
184-
if not self.testing:
185-
self.setup('fit')
186-
model.setup('fit')
187-
188-
model.cuda(self.root_gpu)
189-
190-
# CHOOSE OPTIMIZER
191-
# allow for lr schedulers as well
192-
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
193-
194-
# hack forward to do autocast for the user
195-
model_autocast_original_forward = model.forward
196-
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
197-
# wrap the user's forward in autocast and give it back at the end
198-
model.forward = torch.cuda.amp.autocast()(model.forward)
199-
200-
# TODO: remove with dropping NVIDIA AMP support
201-
# check for this bug (amp + dp + !01 doesn't work)
202-
# https://github.com/NVIDIA/apex/issues/227
203-
if self.use_dp and self.use_amp and not NATIVE_AMP_AVALAIBLE and not self.use_tpu:
204-
if self.amp_level == 'O2':
205-
raise MisconfigurationException(
206-
f'Amp level {self.amp_level} with DataParallel is not supported.'
207-
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
208-
f' We recommend you switch to ddp if you want to use amp')
209-
else:
210-
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
211-
self.reinit_scheduler_properties(optimizers, self.lr_schedulers)
212-
213-
# create list of device ids
214-
device_ids = self.data_parallel_device_ids
215-
if isinstance(device_ids, int):
216-
device_ids = list(range(device_ids))
217-
218-
# set dp device
219-
torch.cuda.set_device(self.root_gpu)
220-
221-
model = LightningDataParallel(model, device_ids=device_ids)
222-
223-
result = self.run_pretrain_routine(model)
224-
model.forward = model_autocast_original_forward
225-
226-
return result
227-
228182
def horovod_train(self, model):
229183
# call setup after the ddp process has connected
230184
if not self.testing:

Diff for: pytorch_lightning/trainer/trainer.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pytorch_lightning.utilities.debugging import InternalDebugger
5252
from pytorch_lightning.utilities.exceptions import MisconfigurationException
5353
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
54-
from pytorch_lightning.accelerators import GPUAccelerator, TPUAccelerator
54+
from pytorch_lightning.accelerator_backends import GPUBackend, TPUBackend, DataParallelBackend
5555

5656
# warnings to ignore in trainer
5757
warnings.filterwarnings(
@@ -661,7 +661,7 @@ def __init__(
661661
# tracks internal state for debugging
662662
self.dev_debugger = InternalDebugger(self)
663663
self.config_validator = ConfigValidator(self)
664-
self.accelerator = None
664+
self.accelerator_backend = None
665665

666666
# Callback system
667667
self.on_init_end()
@@ -1064,24 +1064,25 @@ def fit(
10641064
self.set_random_port()
10651065
results = self.spawn_ddp_children(model)
10661066

1067-
# 1 gpu or dp option triggers training using DP module
1068-
# easier to avoid NCCL issues
10691067
elif self.use_dp:
1070-
results = self.dp_train(model)
1068+
self.accelerator_backend = DataParallelBackend(self)
1069+
self.accelerator_backend.setup(model)
1070+
results = self.accelerator_backend.train()
1071+
self.accelerator_backend.teardown()
10711072

10721073
elif self.use_horovod:
10731074
results = self.horovod_train(model)
10741075

10751076
elif self.single_gpu:
1076-
self.accelerator = GPUAccelerator(self)
1077-
self.accelerator.setup(model)
1078-
results = self.run_pretrain_routine(model)
1077+
self.accelerator_backend = GPUBackend(self)
1078+
model = self.accelerator_backend.setup(model)
1079+
results = self.accelerator_backend.train(model)
10791080

10801081
elif self.use_tpu:
1081-
self.accelerator = TPUAccelerator(self)
1082-
self.accelerator.setup()
1083-
self.accelerator.train(model)
1084-
self.accelerator.teardown()
1082+
self.accelerator_backend = TPUBackend(self)
1083+
self.accelerator_backend.setup()
1084+
self.accelerator_backend.train(model)
1085+
self.accelerator_backend.teardown()
10851086

10861087
# ON CPU
10871088
else:

Diff for: tests/models/test_test_loop.py

+3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def test_single_gpu_test(tmpdir):
3838
def test_dp_test(tmpdir):
3939
tutils.set_random_master_port()
4040

41+
import os
42+
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
43+
4144
model = EvalModelTemplate()
4245
trainer = pl.Trainer(
4346
default_root_dir=os.getcwd(),

0 commit comments

Comments
 (0)