Skip to content

Commit

Permalink
Raise exceptions instead of using asserts in modeling_openai huggingf…
Browse files Browse the repository at this point in the history
…ace#12789 (huggingface#14386)

* Raise exceptions instead of using asserts for control flow in modeling_openai huggingface#12789

* reformatted file
  • Loading branch information
nbertagnolli authored and Alberto Bégué committed Jan 27, 2022
1 parent 888b664 commit 56103ba
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions src/transformers/models/openai/modeling_openai.py
Expand Up @@ -83,13 +83,16 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
# del init_params[1]
init_params = [arr.squeeze() for arr in init_params]

try:
assert model.tokens_embed.weight.shape == init_params[1].shape
assert model.positions_embed.weight.shape == init_params[0].shape
except AssertionError as e:
e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
e.args += (model.positions_embed.weight.shape, init_params[0].shape)
raise
# Check that the token and position embeddings weight dimensions map those of the init parameters.
if model.tokens_embed.weight.shape != init_params[1].shape:
raise ValueError(
f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape: {init_params[1].shape}"
)

if model.positions_embed.weight.shape != init_params[0].shape:
raise ValueError(
f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape: {init_params[0].shape}"
)

model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
Expand All @@ -100,7 +103,8 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):

for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
name = name[6:] # skip "model/"
assert name[-2:] == ":0"
if name[-2:] != ":0":
raise ValueError(f"Layer {name} does not end with :0")
name = name[:-2]
name = name.split("/")
pointer = model
Expand All @@ -120,20 +124,11 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise

# Ensure that the pointer and array have compatible shapes.
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")

logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array)
return model
Expand All @@ -147,7 +142,8 @@ def __init__(self, nx, n_positions, config, scale=False):
super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implementation]
assert n_state % config.n_head == 0
if n_state % config.n_head != 0:
raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
self.register_buffer(
"bias", torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions)
)
Expand Down Expand Up @@ -804,9 +800,10 @@ def forward(
else:
batch_size, sequence_length = inputs_embeds.shape[:2]

assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
# Ensure the batch size is > 1 if there is no padding.
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")

if self.config.pad_token_id is None:
sequence_lengths = -1
else:
Expand Down

0 comments on commit 56103ba

Please sign in to comment.