Skip to content

Commit

Permalink
Always share in batch, request length instead of shape
Browse files Browse the repository at this point in the history
  • Loading branch information
cifkao committed Jan 19, 2021
1 parent eadcd70 commit ec10f00
Showing 1 changed file with 25 additions and 36 deletions.
61 changes: 25 additions & 36 deletions dev/spe/spe.py
Expand Up @@ -205,10 +205,9 @@ class SineSPE(nn.Module):
num_realizations: The number of realizations of the stochastic
process (R).
num_sines: The number of sin and cos components (K).
key_shape: The expected shape of keys and queries. Needs to be
set either here, or by calling `reset()`.
share_in_batch: Whether to share the same set of
positional encodings for all examples in the batch.
max_len: The maximum expected length of keys and queries. Can be
set either here or by calling `reset()`, otherwise it is
inferred from the actual keys and queries.
"""

def __init__(
Expand All @@ -217,8 +216,7 @@ def __init__(
in_features: int = 64,
num_realizations: int = 256,
num_sines: int = 10,
key_shape: Optional[Tuple[int, ...]] = None,
share_in_batch: bool = True,
max_len: Optional[int] = None,
):
super(SineSPE, self).__init__()

Expand All @@ -245,33 +243,28 @@ def __init__(
self.freqs.data[...] -= 5.

# reset qbar and kbar
self.reset(key_shape, share_in_batch)
self.reset(max_len)

def reset(self,
key_shape: Tuple[int, ...],
share_in_batch: Optional[bool] = None):
def reset(self, max_len: Optional[int] = None):
"""
Reset positional encodings.
At training, this is typically done for each new batch.
At testing, this is typically never done.
Args:
key_shape: The expected shape of keys and queries.
share_in_batch: Whether to share the same set of
positional encodings for all examples in the batch.
max_len: The maximum expected length of keys and queries.
"""
self.qbar = None
self.kbar = None
self._key_shape = key_shape
if share_in_batch is not None:
self._share_in_batch = share_in_batch
if max_len is not None or not hasattr(self, 'max_len'):
self.max_len = max_len

def forward(
self, queries: torch.Tensor, keys: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform sinusoidal SPE.
Apply SPE to queries and keys.
Expects keys and queries of shape `(batch_size, ..., num_heads,
key_dim)` and outputs keys and queries of shape `(batch_size,
Expand All @@ -282,27 +275,29 @@ def forward(
"got queries: {} and keys: {}".format(queries.shape, keys.shape)

if queries.shape[-1] < self.in_features:
raise ValueError('Expected keys/queries of dimension at least'
raise ValueError('Expected keys/queries of dimension at least '
f'{self.in_features}, got {queries.shape[-1]}.')

# split off the non-positional part
queries, queries_rest = _split_features(queries, self.in_features)
keys, keys_rest = _split_features(keys, self.in_features)

length = keys.shape[1]
if self.qbar is None:
if self.max_len is None:
self.max_len = length
self._draw_noise()
length = queries.shape[1]
if self.qbar.shape[1] < length:
raise RuntimeError(f'Keys/queries have length {length}, '
f'but expected at most {self.qbar.shape[1]}')
desired_shape = (queries.shape[0], self.qbar.shape[1], *queries.shape[2:],
self.num_realizations)
desired_shape = (1, self.qbar.shape[1], *queries.shape[2:], self.num_realizations)
if self.qbar.shape[2:] != desired_shape[2:]:
raise RuntimeError(f'Positional encodings have shape {self.qbar.shape}, '
f'but need {desired_shape} for queries of shape '
f'{queries.shape}')

# sum over the keys_dim after multiplying by queries and keys
# qbar/kbar is (1, max_len, ...), truncating and broadcasting over the batch
qhat = (self.qbar[:, :length] * queries[..., None]).sum(axis=-2)
khat = (self.kbar[:, :length] * keys[..., None]).sum(axis=-2)

Expand All @@ -317,15 +312,9 @@ def _draw_noise(self):
Generate the random QBar and Kbar, depending on the parameters,
and store them in the module.
"""
if self._key_shape is None:
raise RuntimeError('`key_shape` is not set. Please call `reset()` first')

batchsize = 1 if self._share_in_batch else self._key_shape[0]
length = self._key_shape[1]

# build omega_q and omega_k,
# with shape (num_heads, keys_dim, length, 2*num_sines)
indices = torch.linspace(0, length-1, length, device=self.freqs.device)
indices = torch.linspace(0, self.max_len-1, self.max_len, device=self.freqs.device)

# making sure the frequencies are in [0, 0.5]
freqs = torch.sigmoid(self.freqs[:, :, None, :])/2.
Expand All @@ -336,36 +325,36 @@ def _draw_noise(self):
+ self.offsets[:, :, None, :]
)
omega_q = torch.stack([torch.cos(phases_q), torch.sin(phases_q)], dim=-1).view(
self.num_heads, self.in_features, length, 2*self.num_sines
1, self.num_heads, self.in_features, self.max_len, 2*self.num_sines
)

phases_k = (
2 * math.pi
* freqs * indices[None, None, :, None]
)
omega_k = torch.stack([torch.cos(phases_k), torch.sin(phases_k)], dim=-1).view(
self.num_heads, self.in_features, length, 2*self.num_sines
1, self.num_heads, self.in_features, self.max_len, 2*self.num_sines
)

# gains is (num_heads, keys_dim, 2*num_sines). Making then nonnegative with softplus
gains = nn.functional.softplus(self.gains).repeat(1, 1, 2)

# draw noise of appropriate shape on the right device
z = torch.randn(
batchsize, self.num_heads, self.in_features, 2 * self.num_sines,
1, self.num_heads, self.in_features, 2 * self.num_sines,
self.num_realizations,
device=self.freqs.device) / math.sqrt(self.num_realizations * self.in_features)

# scale each of the 2*num_sines by the appropriate gain
# z is still (batchsize, num_heads, keys_dim, 2*num_sines, num_realizations)
# z is still (1, num_heads, keys_dim, 2*num_sines, num_realizations)
z = z * gains[None, ..., None]

# computing the sum over the sines.
# gets (batchsize, num_heads, keys_dim, length, num_realizations)
self.qbar = torch.matmul(omega_q[None], z)
self.kbar = torch.matmul(omega_k[None], z)
# gets (1, num_heads, keys_dim, length, num_realizations)
self.qbar = torch.matmul(omega_q, z)
self.kbar = torch.matmul(omega_k, z)

# permuting them to be (batchsize, length, num_heads, keys_dim, num_realizations)
# permuting them to be (1, length, num_heads, keys_dim, num_realizations)
self.qbar = self.qbar.permute(0, 3, 1, 2, 4)
self.kbar = self.kbar.permute(0, 3, 1, 2, 4)

Expand Down

0 comments on commit ec10f00

Please sign in to comment.