From ec10f00ca9f1143869d5fa0a479d9465397d3017 Mon Sep 17 00:00:00 2001 From: Ondrej Cifka Date: Tue, 19 Jan 2021 16:42:22 +0100 Subject: [PATCH] Always share in batch, request length instead of shape --- dev/spe/spe.py | 61 +++++++++++++++++++++----------------------------- 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/dev/spe/spe.py b/dev/spe/spe.py index 20359f3..bc52da3 100644 --- a/dev/spe/spe.py +++ b/dev/spe/spe.py @@ -205,10 +205,9 @@ class SineSPE(nn.Module): num_realizations: The number of realizations of the stochastic process (R). num_sines: The number of sin and cos components (K). - key_shape: The expected shape of keys and queries. Needs to be - set either here, or by calling `reset()`. - share_in_batch: Whether to share the same set of - positional encodings for all examples in the batch. + max_len: The maximum expected length of keys and queries. Can be + set either here or by calling `reset()`, otherwise it is + inferred from the actual keys and queries. """ def __init__( @@ -217,8 +216,7 @@ def __init__( in_features: int = 64, num_realizations: int = 256, num_sines: int = 10, - key_shape: Optional[Tuple[int, ...]] = None, - share_in_batch: bool = True, + max_len: Optional[int] = None, ): super(SineSPE, self).__init__() @@ -245,11 +243,9 @@ def __init__( self.freqs.data[...] -= 5. # reset qbar and kbar - self.reset(key_shape, share_in_batch) + self.reset(max_len) - def reset(self, - key_shape: Tuple[int, ...], - share_in_batch: Optional[bool] = None): + def reset(self, max_len: Optional[int] = None): """ Reset positional encodings. @@ -257,21 +253,18 @@ def reset(self, At testing, this is typically never done. Args: - key_shape: The expected shape of keys and queries. - share_in_batch: Whether to share the same set of - positional encodings for all examples in the batch. + max_len: The maximum expected length of keys and queries. """ self.qbar = None self.kbar = None - self._key_shape = key_shape - if share_in_batch is not None: - self._share_in_batch = share_in_batch + if max_len is not None or not hasattr(self, 'max_len'): + self.max_len = max_len def forward( self, queries: torch.Tensor, keys: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Perform sinusoidal SPE. + Apply SPE to queries and keys. Expects keys and queries of shape `(batch_size, ..., num_heads, key_dim)` and outputs keys and queries of shape `(batch_size, @@ -282,27 +275,29 @@ def forward( "got queries: {} and keys: {}".format(queries.shape, keys.shape) if queries.shape[-1] < self.in_features: - raise ValueError('Expected keys/queries of dimension at least' + raise ValueError('Expected keys/queries of dimension at least ' f'{self.in_features}, got {queries.shape[-1]}.') # split off the non-positional part queries, queries_rest = _split_features(queries, self.in_features) keys, keys_rest = _split_features(keys, self.in_features) + length = keys.shape[1] if self.qbar is None: + if self.max_len is None: + self.max_len = length self._draw_noise() - length = queries.shape[1] if self.qbar.shape[1] < length: raise RuntimeError(f'Keys/queries have length {length}, ' f'but expected at most {self.qbar.shape[1]}') - desired_shape = (queries.shape[0], self.qbar.shape[1], *queries.shape[2:], - self.num_realizations) + desired_shape = (1, self.qbar.shape[1], *queries.shape[2:], self.num_realizations) if self.qbar.shape[2:] != desired_shape[2:]: raise RuntimeError(f'Positional encodings have shape {self.qbar.shape}, ' f'but need {desired_shape} for queries of shape ' f'{queries.shape}') # sum over the keys_dim after multiplying by queries and keys + # qbar/kbar is (1, max_len, ...), truncating and broadcasting over the batch qhat = (self.qbar[:, :length] * queries[..., None]).sum(axis=-2) khat = (self.kbar[:, :length] * keys[..., None]).sum(axis=-2) @@ -317,15 +312,9 @@ def _draw_noise(self): Generate the random QBar and Kbar, depending on the parameters, and store them in the module. """ - if self._key_shape is None: - raise RuntimeError('`key_shape` is not set. Please call `reset()` first') - - batchsize = 1 if self._share_in_batch else self._key_shape[0] - length = self._key_shape[1] - # build omega_q and omega_k, # with shape (num_heads, keys_dim, length, 2*num_sines) - indices = torch.linspace(0, length-1, length, device=self.freqs.device) + indices = torch.linspace(0, self.max_len-1, self.max_len, device=self.freqs.device) # making sure the frequencies are in [0, 0.5] freqs = torch.sigmoid(self.freqs[:, :, None, :])/2. @@ -336,7 +325,7 @@ def _draw_noise(self): + self.offsets[:, :, None, :] ) omega_q = torch.stack([torch.cos(phases_q), torch.sin(phases_q)], dim=-1).view( - self.num_heads, self.in_features, length, 2*self.num_sines + 1, self.num_heads, self.in_features, self.max_len, 2*self.num_sines ) phases_k = ( @@ -344,7 +333,7 @@ def _draw_noise(self): * freqs * indices[None, None, :, None] ) omega_k = torch.stack([torch.cos(phases_k), torch.sin(phases_k)], dim=-1).view( - self.num_heads, self.in_features, length, 2*self.num_sines + 1, self.num_heads, self.in_features, self.max_len, 2*self.num_sines ) # gains is (num_heads, keys_dim, 2*num_sines). Making then nonnegative with softplus @@ -352,20 +341,20 @@ def _draw_noise(self): # draw noise of appropriate shape on the right device z = torch.randn( - batchsize, self.num_heads, self.in_features, 2 * self.num_sines, + 1, self.num_heads, self.in_features, 2 * self.num_sines, self.num_realizations, device=self.freqs.device) / math.sqrt(self.num_realizations * self.in_features) # scale each of the 2*num_sines by the appropriate gain - # z is still (batchsize, num_heads, keys_dim, 2*num_sines, num_realizations) + # z is still (1, num_heads, keys_dim, 2*num_sines, num_realizations) z = z * gains[None, ..., None] # computing the sum over the sines. - # gets (batchsize, num_heads, keys_dim, length, num_realizations) - self.qbar = torch.matmul(omega_q[None], z) - self.kbar = torch.matmul(omega_k[None], z) + # gets (1, num_heads, keys_dim, length, num_realizations) + self.qbar = torch.matmul(omega_q, z) + self.kbar = torch.matmul(omega_k, z) - # permuting them to be (batchsize, length, num_heads, keys_dim, num_realizations) + # permuting them to be (1, length, num_heads, keys_dim, num_realizations) self.qbar = self.qbar.permute(0, 3, 1, 2, 4) self.kbar = self.kbar.permute(0, 3, 1, 2, 4)