Skip to content
Permalink
Browse files

QaNetEncoder Multi-GPU (#2692)

* Replace List with ModuleList

* Remove List import

* Rebase

* Remove List import

* Add test

* Pylint
  • Loading branch information...
kl2806 committed Apr 22, 2019
1 parent 4c98095 commit 4422d53d68374de5c0129f679fa1d43f18b8f42d
@@ -1,9 +1,8 @@
from typing import List

from overrides import overrides
import torch
from torch.nn import Dropout
from torch.nn import LayerNorm
from torch.nn import ModuleList
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.residual_with_layer_dropout import ResidualWithLayerDropout
from allennlp.modules.seq2seq_encoders.multi_head_self_attention import MultiHeadSelfAttention
@@ -74,7 +73,7 @@ def __init__(self,
else:
self._input_projection_layer = lambda x: x

self._encoder_blocks: List[QaNetEncoderBlock] = []
self._encoder_blocks = ModuleList([])
for block_index in range(num_blocks):
encoder_block = QaNetEncoderBlock(hidden_dim,
hidden_dim,
@@ -1,4 +1,6 @@
# pylint: disable=no-self-use,invalid-name
import torch
import pytest
from flaky import flaky
import numpy
from numpy.testing import assert_almost_equal
@@ -8,6 +10,8 @@
from allennlp.common.testing import ModelTestCase
from allennlp.data.dataset import Batch
from allennlp.models import Model
from allennlp.data.iterators import BasicIterator
from allennlp.training import Trainer


class QaNetTest(ModelTestCase):
@@ -45,6 +49,23 @@ def test_forward_pass_runs_correctly(self):
def test_model_can_train_save_and_load(self):
self.ensure_model_can_train_save_and_load(self.param_file, tolerance=1e-4)

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need multiple GPUs.")
def test_multigpu_qanet(self):
params = Params.from_file(self.param_file)
vocab = Vocabulary.from_instances(self.instances)
model = Model.from_params(vocab=vocab, params=params['model']).cuda()
optimizer = torch.optim.SGD(self.model.parameters(), 0.01, momentum=0.9)
multigpu_iterator = BasicIterator(batch_size=4)
multigpu_iterator.index_with(model.vocab)
trainer = Trainer(model,
optimizer,
multigpu_iterator,
self.instances,
num_epochs=2,
cuda_device=[0, 1])
trainer.train()

def test_batch_predictions_are_consistent(self):
# The same issue as the bidaf test case.
# The CNN encoder has problems with this kind of test - it's not properly masked yet, so

0 comments on commit 4422d53

Please sign in to comment.
You can’t perform that action at this time.