Skip to content

Commit 0fe933e

Browse files
Bordaethanwharrisrohitgr7
authored
fixing TPU tests (Lightning-AI#2632)
* init * rename * tpu_core_idx * idx 8 * idxs * @pl_multi_process_test * assert * assert * deamon * no close * imort * msg * use_single_gpu * dataset * idx * fix idx * dataset * format * add pickable * typo * apex * typo * wip * wip * wip * wip * wip * wip * wip * wip * docs * typo * tests * tests * tests * tests * tests * tests * tests * tests * tests * tests * tests * tests * tests * tests * tests * tests * tests * docs * docs * Apply suggestions from code review Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> * docs * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
1 parent 84c507c commit 0fe933e

23 files changed

+339
-192
lines changed

.circleci/config.yml

+5-3
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,11 @@ references:
6262
# happened to the job in Kubernetes. If we try MAX_CHECKS times and
6363
# still the job hasn't finished, give up and return the starting
6464
# non-zero status code.
65-
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else echo "Job not finished yet"; fi; sleep 30; done && \
65+
printf "Waiting for job to finish: " && \
66+
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "."; fi; sleep $CHECK_SPEEP; done && \
6667
echo "Done waiting. Job status code: $status_code" && \
6768
# Allow time for logs to flush.
68-
sleep 30 && \
69+
sleep 10 && \
6970
echo "JOB_NAME: $job_name" && \
7071
gcloud logging read "resource.type=k8s_container resource.labels.project_id=$GOOGLE_PROJECT_ID resource.labels.location=$GOOGLE_COMPUTE_ZONE resource.labels.cluster_name=$GKE_CLUSTER resource.labels.namespace_name=default resource.labels.pod_name:$job_name" --limit 10000000 --order asc --format 'value(textPayload)' --project=$GOOGLE_PROJECT_ID > /tmp/full_output.txt && \
7172
if grep -q '<?xml version="1.0" ?>' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '/<?xml version="1.0" ?>/'; else mv /tmp/full_output.txt xx00; fi && \
@@ -101,7 +102,8 @@ jobs:
101102
docker:
102103
- image: circleci/python:3.7
103104
environment:
104-
- MAX_CHECKS: 60
105+
- MAX_CHECKS: 240
106+
- CHECK_SPEEP: 5
105107
steps:
106108
- checkout
107109
- go/install

.github/workflows/ci-testing.yml

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ jobs:
3030
# TODO: temporary fix till hanging jobs on macOS for py38 is resolved
3131
- python-version: 3.8
3232
os: macOS-10.15
33+
# TODO: temporary fix till pyYaml can be installed, see: https://github.com/actions/setup-python/issues/114
34+
- python-version: 3.7
35+
os: ubuntu-18.04
36+
requires: 'minimal'
3337

3438
# Timeout: https://stackoverflow.com/a/59076067/4521646
3539
timeout-minutes: 25

.github/workflows/tpu-testing.yml

+7-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ env:
1414
GKE_CLUSTER: lightning-cluster
1515
GKE_ZONE: us-central1-a
1616
IMAGE: gcr.io/${{ secrets.GKE_PROJECT }}/tpu-testing-image
17+
MAX_CHECKS: 240
18+
CHECK_SPEEP: 5
1719

1820
jobs:
1921
setup-build-publish-deploy:
@@ -82,17 +84,17 @@ jobs:
8284
job_name=${job_name% created} && \
8385
echo "Waiting on kubernetes job: $job_name in cluster: $GKE_CLUSTER" && \
8486
i=0 && \
85-
# 30 checks spaced 30s apart = 900s total.
86-
max_checks=30 && \
87+
# 60 checks spaced 30s apart = 900s total.
8788
status_code=2 && \
8889
# Check on the job periodically. Set the status code depending on what
89-
# happened to the job in Kubernetes. If we try max_checks times and
90+
# happened to the job in Kubernetes. If we try MAX_CHECKS times and
9091
# still the job hasn't finished, give up and return the starting
9192
# non-zero status code.
92-
while [ $i -lt $max_checks ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else echo "Job not finished yet"; fi; sleep 30; done && \
93+
printf "Waiting for job to finish: " && \
94+
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "." ; fi; sleep $CHECK_SPEEP; done && \
9395
echo "Done waiting. Job status code: $status_code" && \
9496
# Allow time for logs to flush.
95-
sleep 60 && \
97+
sleep 10 && \
9698
echo "JOB_NAME: $job_name" && \
9799
echo "GKE_CLUSTER: $GKE_CLUSTER" && \
98100
echo "GKE_ZONE: $GKE_ZONE" && \

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838

3939
- Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))
4040

41+
- Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632))
42+
4143
## [0.8.5] - 2020-07-09
4244

4345
### Added

docs/source/new-project.rst

-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import torch
77
from torch.nn import functional as F
88
from torch.utils.data import DataLoader
9-
from torchvision.datasets import MNIST
10-
from torchvision import transforms
119

1210

1311
Quick Start

pytorch_lightning/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@
5454
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
5555
else:
5656
from pytorch_lightning.core import LightningDataModule, LightningModule, data_loader
57+
from pytorch_lightning.core.step_result import TrainResult, EvalResult
5758
from pytorch_lightning.callbacks import Callback
5859
from pytorch_lightning.trainer import Trainer
5960
from pytorch_lightning.utilities.seed import seed_everything
6061
from pytorch_lightning import metrics
61-
from pytorch_lightning.core.step_result import TrainResult, EvalResult
6262

6363
__all__ = [
6464
'Trainer',

pytorch_lightning/accelerator_backends/ddp_spawn_backend.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,27 @@ class DDPSpawnBackend(object):
3030

3131
def __init__(self, trainer):
3232
self.trainer = trainer
33-
self.q = None
33+
self.mp_queue = None
3434

3535
def setup(self):
3636
self.trainer.set_random_port()
3737

3838
# pass in a state q
3939
smp = mp.get_context('spawn')
40-
self.q = smp.SimpleQueue()
40+
self.mp_queue = smp.SimpleQueue()
4141

4242
def train(self, model, nprocs):
43-
mp.spawn(self.ddp_train, nprocs=nprocs, args=(self.q, model,))
43+
mp.spawn(self.ddp_train, nprocs=nprocs, args=(self.mp_queue, model,))
4444

4545
def teardown(self, model):
4646
# restore main state with best weights
47-
best_path = self.q.get()
48-
results = self.q.get()
49-
last_path = self.q.get()
47+
best_path = self.mp_queue.get()
48+
results = self.mp_queue.get()
49+
last_path = self.mp_queue.get()
5050

5151
# transfer back the best path to the trainer
5252
self.trainer.checkpoint_callback.best_model_path = best_path
53+
# todo, pass also bets score
5354

5455
# load last weights
5556
if last_path is not None and not self.trainer.testing:
@@ -59,13 +60,13 @@ def teardown(self, model):
5960
self.trainer.model = model
6061
return results
6162

62-
def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
63+
def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
6364
"""
6465
Entry point for ddp
6566
6667
Args:
6768
process_idx:
68-
q:
69+
mp_queue: multiprocessing queue
6970
model:
7071
is_master:
7172
proc_offset:
@@ -166,7 +167,7 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
166167
model = self.trainer.get_model()
167168

168169
# persist info in ddp_spawn
169-
self.trainer.transfer_ddp_spawn_state_on_fit_end(model, q, results)
170+
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
170171

171172
# clean up memory
172173
torch.cuda.empty_cache()

pytorch_lightning/accelerator_backends/gpu_backend.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616

17+
from pytorch_lightning.core import LightningModule
1718
try:
1819
from apex import amp
1920
except ImportError:
@@ -45,15 +46,15 @@ def setup(self, model):
4546

4647
# TODO: remove with dropping NVIDIA AMP support
4748
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
48-
if self.trainer.use_amp and not native_amp_available:
49+
if APEX_AVAILABLE and self.trainer.use_amp and not native_amp_available:
4950
model = self._setup_nvidia_apex(model)
5051
return model
5152

5253
def train(self, model):
5354
results = self.trainer.run_pretrain_routine(model)
5455
return results
5556

56-
def _setup_nvidia_apex(self, model):
57+
def _setup_nvidia_apex(self, model: LightningModule):
5758
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
5859
self.trainer.optimizers = optimizers
5960
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

pytorch_lightning/accelerator_backends/tpu_backend.py

+68-37
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@
1313
# limitations under the License.
1414

1515
import os
16+
17+
import torch
18+
import torch.multiprocessing as mp
19+
20+
from pytorch_lightning.core import LightningModule
1621
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
1722
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1823
from pytorch_lightning import _logger as log
1924

20-
2125
try:
2226
import torch_xla
2327
import torch_xla.core.xla_model as xm
@@ -33,31 +37,52 @@ class TPUBackend(object):
3337
def __init__(self, trainer):
3438
self.trainer = trainer
3539
self.start_method = None
40+
self.mp_queue = None
3641

3742
def setup(self):
3843
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
3944

4045
if not XLA_AVAILABLE:
41-
raise MisconfigurationException('No TPU devices found.')
46+
raise MisconfigurationException('PyTorch XLA not installed.')
47+
48+
# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
49+
self.start_method = 'fork'
50+
51+
# pass in a state q
52+
smp = mp.get_context(self.start_method)
53+
self.mp_queue = smp.SimpleQueue()
54+
55+
def teardown(self, model):
56+
# restore main state with best weights
57+
best_path = self.mp_queue.get()
58+
results = self.mp_queue.get()
59+
last_path = self.mp_queue.get()
4260

43-
# COLAB_GPU is an env var available by default in Colab environments.
44-
self.start_method = 'fork' if self.trainer.on_colab_kaggle else 'spawn'
61+
# transfer back the best path to the trainer
62+
self.trainer.checkpoint_callback.best_model_path = best_path
63+
# todo, pass also bets score
4564

46-
def teardown(self):
65+
# load last weights
66+
if last_path and not self.trainer.testing:
67+
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
68+
model.load_state_dict(ckpt)
69+
70+
self.trainer.model = model
4771

4872
# when training completes, load the weights back in main process
4973
self.__load_weights_on_main_process()
74+
return results
5075

51-
def train(self, model):
76+
def train(self, model: LightningModule):
5277
self.trainer.model = model
5378

5479
# train
5580
if self.trainer.tpu_id is not None:
56-
self.tpu_train_in_process(self.trainer.tpu_id, model)
81+
self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue)
5782
else:
5883
xmp.spawn(
5984
self.tpu_train_in_process,
60-
args=(model,),
85+
args=(model, self.trainer, self.mp_queue),
6186
nprocs=self.trainer.tpu_cores,
6287
start_method=self.start_method
6388
)
@@ -71,63 +96,69 @@ def __load_weights_on_main_process(self):
7196

7297
self.trainer.model = model
7398

74-
def tpu_train_in_process(self, tpu_core_idx, model):
99+
def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None):
75100
"""
76101
Here we are inside each individual process
77102
"""
78-
if not self.trainer.testing:
79-
self.trainer.setup('fit')
103+
if not trainer:
104+
trainer = self.trainer
105+
if not trainer.testing:
106+
trainer.setup('fit')
80107
model.setup('fit')
81108

82109
# setup TPU training
83-
self.__setup_tpu_training(model)
110+
self.__setup_tpu_training(model, trainer)
84111

85112
# Run the pretrain routine
86-
self.trainer.run_pretrain_routine(model)
113+
results = trainer.run_pretrain_routine(model)
87114

88115
# save weights at the end of training
89-
self.__save_end_of_training_weights(model)
116+
self.__save_end_of_training_weights(model, trainer)
90117

91-
def __save_end_of_training_weights(self, model):
118+
# persist info in spawn
119+
trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
92120

121+
def __save_end_of_training_weights(self, model: LightningModule, trainer):
93122
# when training ends on these platforms dump weights to get out of the main process
94-
if self.trainer.on_colab_kaggle:
123+
if trainer.on_colab_kaggle:
95124
rank_zero_warn('cleaning up... please do not interrupt')
96-
self.trainer.save_spawn_weights(model)
125+
trainer.save_spawn_weights(model)
97126

98-
def __setup_tpu_training(self, model):
127+
def __setup_tpu_training(self, model: LightningModule, trainer):
99128
# use the default device from the process
100-
tpu_device = xm.xla_device()
129+
# tpu_device = xm.xla_device()
101130

102131
# if given an ordinal device, use this as the device
103-
if self.trainer.tpu_id is not None:
104-
tpu_device = xm.xla_device(self.trainer.tpu_id)
105-
132+
if trainer.tpu_id is not None:
133+
tpu_device = xm.xla_device(trainer.tpu_id)
134+
else:
135+
tpu_device = xm.xla_device()
106136
# track the device and move model to it
107-
self.trainer._device = tpu_device
108-
model.to(self.trainer._device)
137+
trainer._device = tpu_device
138+
model.to(trainer._device)
109139

110140
# get the appropriate tpu ranks
111-
self.trainer.tpu_local_core_rank = xm.get_local_ordinal()
112-
self.trainer.tpu_global_core_rank = xm.get_ordinal()
141+
trainer.tpu_local_core_rank = xm.get_local_ordinal()
142+
trainer.tpu_global_core_rank = xm.get_ordinal()
113143

114144
# avoid duplicating progress bar
115-
if self.trainer.tpu_global_core_rank != 0 and self.trainer.progress_bar_callback is not None:
116-
self.trainer.progress_bar_callback.disable()
145+
if trainer.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
146+
trainer.progress_bar_callback.disable()
117147

118-
self.trainer.global_rank = self.trainer.tpu_local_core_rank
119-
rank_zero_only.rank = self.trainer.global_rank
148+
trainer.global_rank = trainer.tpu_local_core_rank
149+
rank_zero_only.rank = trainer.global_rank
120150

121151
# CHOOSE OPTIMIZER
122152
# allow for lr schedulers as well
123-
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
124-
self.trainer.optimizers = optimizers
125-
self.trainer.lr_schedulers = lr_schedulers
126-
self.trainer.optimizer_frequencies = optimizer_frequencies
153+
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model)
154+
trainer.optimizers = optimizers
155+
trainer.lr_schedulers = lr_schedulers
156+
trainer.optimizer_frequencies = optimizer_frequencies
127157

128158
# init 16 bit for TPU
129-
if self.trainer.precision == 16:
159+
if trainer.precision == 16:
130160
os.environ['XLA_USE_BF16'] = str(1)
131161

132-
log.info(f'INIT TPU local core: {self.trainer.tpu_local_core_rank},'
133-
f' global rank: {self.trainer.tpu_global_core_rank}')
162+
log.info(f'INIT TPU local core: {trainer.tpu_local_core_rank},'
163+
f' global rank: {trainer.tpu_global_core_rank}'
164+
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')

pytorch_lightning/core/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -305,5 +305,9 @@ def training_step(self, batch, batch_idx):
305305
from pytorch_lightning.core.decorators import data_loader
306306
from pytorch_lightning.core.lightning import LightningModule
307307

308-
__all__ = ['LightningDataModule', 'LightningModule', 'data_loader']
308+
__all__ = [
309+
'LightningDataModule',
310+
'LightningModule',
311+
'data_loader',
312+
]
309313
# __call__ = __all__

pytorch_lightning/core/decorators.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from functools import wraps
22
from typing import Callable
33

4-
import torch
5-
64
from pytorch_lightning.core.lightning import LightningModule
75
from pytorch_lightning.utilities import rank_zero_warn
86

0 commit comments

Comments
 (0)