Skip to content
Permalink
Browse files

pytorch 1.0 (#2165)

* fixes

* allenlp

* fix bug

* reduce=False -> reduction='none'

* fix _sparse_mask

* remove print statement

* fix more tests
  • Loading branch information...
joelgrus committed Dec 13, 2018
1 parent d0d6310 commit 8e861e75ed9afee06a046c24539fad5535e5c578
@@ -121,8 +121,8 @@ def __init__(self,
"arc representation dim", "arc feedforward output dim")

self._unlabelled_f1 = F1Measure(positive_label=1)
self._arc_loss = torch.nn.BCEWithLogitsLoss(reduce=False)
self._tag_loss = torch.nn.CrossEntropyLoss(reduce=False)
self._arc_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
self._tag_loss = torch.nn.CrossEntropyLoss(reduction='none')
initializer(self)

@overrides
@@ -468,8 +468,7 @@ def _create_grammar_state(self,

if global_actions:
global_action_tensors, global_action_ids = zip(*global_actions)
global_action_tensor = entity_types.new_tensor(torch.cat(global_action_tensors, dim=0),
dtype=torch.long)
global_action_tensor = torch.cat(global_action_tensors, dim=0).to(entity_types.device).long()
global_input_embeddings = self._action_embedder(global_action_tensor)
global_output_embeddings = self._output_action_embedder(global_action_tensor)
translated_valid_actions[key]['global'] = (global_input_embeddings,
@@ -481,8 +480,9 @@ def _create_grammar_state(self,
entity_ids = [entity_map[entity] for entity in entities]
entity_linking_scores = linking_scores[entity_ids]
entity_type_tensor = entity_types[entity_ids]
entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor)
entity_type_embeddings = entity_types.new_tensor(entity_type_embeddings, dtype=torch.float)
entity_type_embeddings = (self._entity_type_decoder_embedding(entity_type_tensor)
.to(entity_types.device)
.float())
translated_valid_actions[key]['linked'] = (entity_linking_scores,
entity_type_embeddings,
list(linked_action_ids))
@@ -185,7 +185,7 @@ def make_model(num_layers: int = 6,
# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform(p)
torch.nn.init.xavier_uniform_(p)
return model


@@ -147,7 +147,7 @@ def sort_batch_by_length(tensor: torch.Tensor, sequence_lengths: torch.Tensor):
sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True)
sorted_tensor = tensor.index_select(0, permutation_index)

index_range = sequence_lengths.new_tensor(torch.arange(0, len(sequence_lengths)))
index_range = torch.arange(0, len(sequence_lengths), device=sequence_lengths.device)
# This is the equivalent of zipping with index, sorting by the original
# sequence lengths and returning the now sorted indices.
_, reverse_mapping = permutation_index.sort(0, descending=False)
@@ -208,7 +208,7 @@ def get_dropout_mask(dropout_probability: float, tensor_for_masking: torch.Tenso
This scaling ensures expected values and variances of the output of applying this mask
and the original tensor are the same.
"""
binary_mask = tensor_for_masking.new_tensor(torch.rand(tensor_for_masking.size()) > dropout_probability)
binary_mask = (torch.rand(tensor_for_masking.size()) > dropout_probability).to(tensor_for_masking.device)
# Scale mask by 1/keep_prob to preserve output statistics.
dropout_mask = binary_mask.float().div(1.0 - dropout_probability)
return dropout_mask
@@ -70,4 +70,8 @@ def test_mst_decodes_arc_labels_with_respect_to_unconstrained_scores(self):
length = torch.LongTensor([3])
heads, tags = self.model._run_mst_decoding(energy, length) # pylint: disable=protected-access
assert heads.tolist()[0] == [0, 0, 1]
assert tags.tolist()[0] == [0, 1, 0]

# This test produces different results under PyTorch 0.4.1 and 1.0.
# Almost certainly this is because it's underspecified.
# TODO(markn): modify this test to have a single correct result
assert tags.tolist()[0] in ([0, 1, 0], [0, 1, 1])
@@ -392,9 +392,9 @@ def test_viterbi_decode(self):
_, argmax_indices = torch.max(sequence_logits, 1)
assert indices == argmax_indices.data.squeeze().tolist()

# Test that pairwise potentials effect the sequence correctly and that
# Test that pairwise potentials affect the sequence correctly and that
# viterbi_decode can handle -inf values.
sequence_logits = torch.FloatTensor([[0, 0, 0, 3, 4],
sequence_logits = torch.FloatTensor([[0, 0, 0, 3, 5],
[0, 0, 0, 3, 4],
[0, 0, 0, 3, 4],
[0, 0, 0, 3, 4],
@@ -421,6 +421,7 @@ def test_viterbi_decode(self):
transition_matrix = torch.zeros([5, 5])
transition_matrix[4, 4] = -10
transition_matrix[4, 3] = -10
transition_matrix[3, 4] = -10
indices, _ = util.viterbi_decode(sequence_logits, transition_matrix)
assert indices == [3, 3, 3, 3, 3, 3]

@@ -449,6 +450,7 @@ def test_viterbi_decode(self):
transition_matrix = torch.zeros([5, 5])
transition_matrix[4, 4] = -10
transition_matrix[4, 3] = -2
transition_matrix[3, 4] = -2
# The 1st, 4th and 5th sequence elements are observed - they should be
# equal to 2, 0 and 4. The last tag should be equal to 3, because although
# the penalty for transitioning to the 4th tag is -2, the unary potential
@@ -515,7 +517,7 @@ def test_sequence_cross_entropy_with_logits_averages_batch_correctly(self):

vector_loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, average=None)
# Batch has one completely padded row, so divide by 4.
assert loss.data.numpy() == vector_loss.data.sum() / 4
assert loss.data.numpy() == vector_loss.sum().item() / 4

def test_sequence_cross_entropy_with_logits_averages_token_correctly(self):
# test token average is the same as multiplying the per-batch loss
@@ -534,7 +536,7 @@ def test_sequence_cross_entropy_with_logits_averages_token_correctly(self):
vector_loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, batch_average=False)
total_token_loss = (vector_loss * weights.float().sum(dim=-1)).sum()
average_token_loss = (total_token_loss / weights.float().sum()).detach()
assert_almost_equal(loss.detach()[0], average_token_loss[0])
assert_almost_equal(loss.detach().item(), average_token_loss.item())

def test_replace_masked_values_replaces_masked_values_with_finite_value(self):
tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
@@ -137,6 +137,18 @@ def from_params(cls, model_parameters: List, params: Params): # type: ignore
"bert_adam": BertAdam,
}

def _safe_sparse_mask(tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
In PyTorch 1.0, Tensor._sparse_mask was changed to Tensor.sparse_mask.
This wrapper allows AllenNLP to (temporarily) work with both 1.0 and 0.4.1.
"""
# pylint: disable=protected-access
try:
return tensor.sparse_mask(mask)
except AttributeError:
# TODO(joelgrus): remove this and/or warn at some point
return tensor._sparse_mask(mask)


@Optimizer.register('dense_sparse_adam')
class DenseSparseAdam(torch.optim.Optimizer):
@@ -223,14 +235,14 @@ def make_sparse(values):
# Decay the first and second moment running average coefficient
# old <- b * old + (1 - b) * new
# <==> old += (1 - b) * (new - old)
old_exp_avg_values = exp_avg._sparse_mask(grad)._values()
old_exp_avg_values = _safe_sparse_mask(exp_avg, grad)._values()
exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
exp_avg.add_(make_sparse(exp_avg_update_values))
old_exp_avg_sq_values = exp_avg_sq._sparse_mask(grad)._values()
old_exp_avg_sq_values = _safe_sparse_mask(exp_avg_sq, grad)._values()
exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))

# Dense addition again is intended, avoiding another _sparse_mask
# Dense addition again is intended, avoiding another sparse_mask
numer = exp_avg_update_values.add_(old_exp_avg_values)
exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
denom = exp_avg_sq_update_values.sqrt_().add_(group['eps'])
@@ -7,7 +7,9 @@

# This installs Pytorch for CUDA 8 only. If you are using a newer version,
# please visit http://pytorch.org/ and install the relevant version.
torch>=0.4.1,<0.5.0
# For now AllenNLP works with both PyTorch 1.0 and 0.4.1. Expect that in
# the future only >=1.0 will be supported.
torch>=0.4.1

# Parameter parsing (but not on Windows).
jsonnet==0.10.0 ; sys.platform != 'win32'
@@ -101,7 +101,7 @@
packages=find_packages(exclude=["*.tests", "*.tests.*",
"tests.*", "tests"]),
install_requires=[
'torch>=0.4.1,<0.5.0',
'torch>=0.4.1',
"jsonnet==0.10.0 ; sys.platform != 'win32'",
'overrides',
'nltk',

0 comments on commit 8e861e7

Please sign in to comment.
You can’t perform that action at this time.