Skip to content

Commit

Permalink
optimize generation caching (#12)
Browse files Browse the repository at this point in the history
Over 10x speedup, adds MLP caching and optimizes attention caching.
Uses changes from https://t.co/BTwo6NKq9H.
  • Loading branch information
neverix authored Nov 3, 2021
1 parent 0b3d648 commit 47de7a2
Showing 1 changed file with 48 additions and 9 deletions.
57 changes: 48 additions & 9 deletions rudalle/dalle/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def forward(self, hidden_states, ltor_mask, has_cache, use_cache):
layernorm_output = self.input_layernorm(hidden_states)

# Self attention.
attention_output, has_cache = self.attention(
attention_output, att_has_cache = self.attention(
layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache)

if self.cogview_sandwich_layernorm:
Expand All @@ -159,15 +159,16 @@ def forward(self, hidden_states, ltor_mask, has_cache, use_cache):
layernorm_output = self.post_attention_layernorm(layernorm_input)

# MLP.
mlp_output = self.mlp(layernorm_output)
mlp_output, mlp_has_cache = self.mlp(
layernorm_output, has_cache=has_cache, use_cache=use_cache)

if self.cogview_sandwich_layernorm:
mlp_output = self.before_second_addition_layernorm(mlp_output)

# Second residual connection.
output = layernorm_input + mlp_output

return output, has_cache
return output, att_has_cache and mlp_has_cache


class DalleSelfAttention(torch.nn.Module):
Expand Down Expand Up @@ -212,6 +213,11 @@ def __init__(self, hidden_size, num_attention_heads,
self.dense = torch.nn.Linear(hidden_size, hidden_size)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)

# Cache
self.past_key = None
self.past_value = None
self.past_output = None

def _transpose_for_scores(self, tensor):
""" Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """
new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head)
Expand All @@ -227,6 +233,7 @@ def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask):
)
else:
attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head)
ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:]
attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask)
if self.cogview_pb_relax:
# normalize attention scores. Should not affect resulting softmax value
Expand Down Expand Up @@ -258,10 +265,10 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)

# Can be simplified, but I didn't for readability's sake
if use_cache and has_cache:
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
query_layer = torch.cat((self.past_query, query_layer), dim=-2)
key_layer = torch.cat((self.past_key, key_layer), dim=-2)
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
attention_scores = self._calculate_attention_scores(
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
)
Expand All @@ -271,13 +278,17 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
)

if use_cache:
self.past_query = query_layer
self.past_key = key_layer
self.past_value = value_layer
has_cache = True
else:
self.past_key = None
self.past_value = None
self.past_output = None
has_cache = False

if use_cache and has_cache:
attention_scores = attention_scores[..., -1:, :]

# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)

Expand All @@ -298,6 +309,16 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):

# Output. [b, s, h]
output = self.dense(context_layer)

if use_cache:
# Can be simplified, but I didn't for readability's sake
if has_cache:
output = torch.cat((self.past_output, output), dim=-2)
self.past_output = output
else:
self.past_output = output
has_cache = True

output = self.output_dropout(output)
return output, has_cache

Expand All @@ -321,12 +342,30 @@ def __init__(self, hidden_size, output_dropout_prob):
# Project back to h.
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(output_dropout_prob)
# MLP cache
self.past_x = None

def forward(self, hidden_states, has_cache=False, use_cache=False):
if has_cache and use_cache:
hidden_states = hidden_states[:, -1:]

def forward(self, hidden_states):
# [b, s, 4hp]
x = self.dense_h_to_4h(hidden_states)
x = gelu(x)
# [b, s, h]
x = self.dense_4h_to_h(x)
if use_cache:
# Can be simplified, but I didn't for readability's sake
if has_cache:
x = torch.cat((self.past_x, x), dim=-2)
self.past_x = x
else:
self.past_x = x

has_cache = True
else:
self.past_x = None
has_cache = False
output = self.dropout(x)
return output

return output, has_cache

0 comments on commit 47de7a2

Please sign in to comment.