Skip to content

Commit

Permalink
Fixing slow pipeline tests (huggingface#14260)
Browse files Browse the repository at this point in the history
* Fiixng slow pipeline tests

* Remove the image-segmentaiton override.

* Fixing clamping only in training.

* Wav2vec2.

* Remove last mention of `no_grad`.

* Fixing copies.

* Rename.
  • Loading branch information
Narsil authored and Alberto Bégué committed Jan 27, 2022
1 parent a99113a commit f3189d9
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 66 deletions.
7 changes: 4 additions & 3 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,10 @@ def forward(
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)

if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.training:
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

outputs = (hidden_states,)

Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/unispeech/modeling_unispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,10 @@ def _conv_out_length(input_length, kernel_size, stride):
return input_lengths

def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]

attention_mask = torch.zeros(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,10 @@ def _conv_out_length(input_length, kernel_size, stride):
return input_lengths

def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]

attention_mask = torch.zeros(
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,10 @@ def _conv_out_length(input_length, kernel_size, stride):
return input_lengths

def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]

attention_mask = torch.zeros(
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/pipelines/image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:

return super().__call__(*args, **kwargs)

def get_inference_context(self):
return torch.no_grad

def preprocess(self, image):
image = load_image(image)
target_size = torch.IntTensor([[image.height, image.width]])
Expand Down
110 changes: 54 additions & 56 deletions src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,76 +93,74 @@ def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, *
)

def batch_inference(self, **inputs):
with torch.no_grad():
return self.model(**inputs)
return self.model(**inputs)

def sequential_inference(self, **inputs):
"""
Inference used for models that need to process sequences in a sequential fashion, like the SQA models which
handle conversational query related to a table.
"""
with torch.no_grad():
all_logits = []
all_aggregations = []
prev_answers = None
batch_size = inputs["input_ids"].shape[0]

input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
token_type_ids = inputs["token_type_ids"].to(self.device)
token_type_ids_example = None

for index in range(batch_size):
# If sequences have already been processed, the token type IDs will be created according to the previous
# answer.
if prev_answers is not None:
prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,)
model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,)

token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
for i in range(model_labels.shape[0]):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col_id = token_type_ids_example[:, 1].tolist()[i] - 1
row_id = token_type_ids_example[:, 2].tolist()[i] - 1

if row_id >= 0 and col_id >= 0 and segment_id == 1:
model_labels[i] = int(prev_answers[(col_id, row_id)])

token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device)

input_ids_example = input_ids[index]
attention_mask_example = attention_mask[index] # shape (seq_len,)
all_logits = []
all_aggregations = []
prev_answers = None
batch_size = inputs["input_ids"].shape[0]

input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
token_type_ids = inputs["token_type_ids"].to(self.device)
token_type_ids_example = None

for index in range(batch_size):
# If sequences have already been processed, the token type IDs will be created according to the previous
# answer.
if prev_answers is not None:
prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,)
model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,)

token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
outputs = self.model(
input_ids=input_ids_example.unsqueeze(0),
attention_mask=attention_mask_example.unsqueeze(0),
token_type_ids=token_type_ids_example.unsqueeze(0),
)
logits = outputs.logits
for i in range(model_labels.shape[0]):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col_id = token_type_ids_example[:, 1].tolist()[i] - 1
row_id = token_type_ids_example[:, 2].tolist()[i] - 1

if self.aggregate:
all_aggregations.append(outputs.logits_aggregation)
if row_id >= 0 and col_id >= 0 and segment_id == 1:
model_labels[i] = int(prev_answers[(col_id, row_id)])

all_logits.append(logits)
token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device)

dist_per_token = torch.distributions.Bernoulli(logits=logits)
probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(
dist_per_token.probs.device
)
input_ids_example = input_ids[index]
attention_mask_example = attention_mask[index] # shape (seq_len,)
token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
outputs = self.model(
input_ids=input_ids_example.unsqueeze(0),
attention_mask=attention_mask_example.unsqueeze(0),
token_type_ids=token_type_ids_example.unsqueeze(0),
)
logits = outputs.logits

coords_to_probs = collections.defaultdict(list)
for i, p in enumerate(probabilities.squeeze().tolist()):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col = token_type_ids_example[:, 1].tolist()[i] - 1
row = token_type_ids_example[:, 2].tolist()[i] - 1
if col >= 0 and row >= 0 and segment_id == 1:
coords_to_probs[(col, row)].append(p)
if self.aggregate:
all_aggregations.append(outputs.logits_aggregation)

all_logits.append(logits)

dist_per_token = torch.distributions.Bernoulli(logits=logits)
probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(
dist_per_token.probs.device
)

coords_to_probs = collections.defaultdict(list)
for i, p in enumerate(probabilities.squeeze().tolist()):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col = token_type_ids_example[:, 1].tolist()[i] - 1
row = token_type_ids_example[:, 2].tolist()[i] - 1
if col >= 0 and row >= 0 and segment_id == 1:
coords_to_probs[(col, row)].append(p)

prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs}
prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs}

logits_batch = torch.cat(tuple(all_logits), 0)
logits_batch = torch.cat(tuple(all_logits), 0)

return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0))
return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0))

def __call__(self, *args, **kwargs):
r"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pipelines_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_large_model_pt(self):
self.assertEqual(
nested_simplify(output, decimals=4),
[
{"score": 0.9809, "label": "go"},
{"score": 0.981, "label": "go"},
{"score": 0.0073, "label": "up"},
{"score": 0.0064, "label": "_unknown_"},
{"score": 0.0015, "label": "down"},
Expand Down

0 comments on commit f3189d9

Please sign in to comment.