Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sharenoise #1

Merged
merged 10 commits into from Jan 19, 2021
72 changes: 50 additions & 22 deletions dev/examples/test_spe.ipynb

Large diffs are not rendered by default.

152 changes: 119 additions & 33 deletions dev/spe/spe.py
@@ -1,5 +1,5 @@
import math
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -74,6 +74,8 @@ def __init__(
self.conv_q.weight.data = torch.rand(self.conv_q.weight.shape)
self.conv_k.weight.data = torch.rand(self.conv_k.weight.shape)

# reset qbar and kbar
self.reset()
cifkao marked this conversation as resolved.
Show resolved Hide resolved
return

# smooth init
Expand All @@ -88,6 +90,15 @@ def __init__(
self.conv_q.weight.data = init_weight.clone()
self.conv_k.weight.data = init_weight.clone()

def reset(self):
"""
Reset noise.
at training, this is typically done for each new batch.
at testing, this is typically never done
"""
self.qbar = None
self.kbar = None

def forward(
self, queries: torch.Tensor, keys: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -102,7 +113,6 @@ def forward(
"As of current implementation, queries and keys must have the same shape. "\
"got queries: {} and keys: {}".format(queries.shape, keys.shape)

batchsize = queries.shape[0]

if queries.shape[-1] < self.in_features:
raise ValueError('Expected keys/queries of dimension at least'
Expand All @@ -118,7 +128,35 @@ def forward(
keys = keys.permute(0, self.ndim + 1, self.ndim + 2,
*[d for d in range(1, self.ndim + 1)])

# d = queries.shape[1] #d=num_heads*keys_dim
# Qbar and Kbar should be
#(batchsize, num_realizations, num_heads, keys_dim, *shape)
# if it's not the case, draw them anew. If it's the case, assume we keep them.
desired_shape = (queries.shape[0], self.num_realizations, *queries.shape[1:])
if self.qbar is None or self.qbar.shape != desired_shape:
self._draw_noise(queries)
cifkao marked this conversation as resolved.
Show resolved Hide resolved

# sum over d after multiplying by queries and keys
qhat = (self.qbar * queries[:, None]).sum(axis=3)
khat = (self.kbar * keys[:, None]).sum(axis=3)

# qhat are (batchsize, num_realizations, num_heads, *shape), making them (batchsize, *shape, num_heads, num_realizations)
qhat = qhat.permute(0, *[x for x in range(3, self.ndim+3)], 2, 1)
khat = khat.permute(0, *[x for x in range(3, self.ndim+3)], 2, 1)

# concatenate with the non-positional part of keys and queries
qhat = torch.cat([qhat, queries_rest], dim=-1)
khat = torch.cat([khat, keys_rest], dim=-1)

return qhat, khat

def _draw_noise(self, queries):
cifkao marked this conversation as resolved.
Show resolved Hide resolved
"""
generate the random QBar and Kbar, depending on the parameters,
and store them in the module.
Args:
queries: (batchsize, num_heads, keys_dim, *shape)
"""
batchsize = queries.shape[0]
original_shape = queries.shape[3:]

# decide on the size of the signal to generate
Expand All @@ -133,8 +171,8 @@ def forward(
device=self.conv_q.weight.device) / math.sqrt(self.num_realizations * self.in_features)

# apply convolution, get (batchsize*num_realizations, num_heads*keys_dim, *shape)
pe_k = self.conv_k(z)
pe_q = self.conv_q(z)
self.kbar = self.conv_q(z)
self.qbar = self.conv_k(z)

# truncate to desired shape
for dim in range(len(shape)):
Expand All @@ -143,28 +181,15 @@ def forward(

indices = [slice(batchsize*self.num_realizations),
slice(self.num_heads*self.in_features)] + [slice(k, k+s, 1), ]
pe_k = pe_k[indices]
pe_q = pe_q[indices]
self.qbar = self.qbar[indices]
self.kbar = self.kbar[indices]

# making (batchsize, num_realizations, num_heads, keys_dim, *shape)
pe_k = pe_k.view(batchsize, self.num_realizations,
self.kbar = self.kbar.view(batchsize, self.num_realizations,
self.num_heads, self.in_features, *original_shape)
pe_q = pe_q.view(batchsize, self.num_realizations,
self.qbar = self.qbar.view(batchsize, self.num_realizations,
self.num_heads, self.in_features, *original_shape)

# sum over d after multiplying by queries and keys
qhat = (pe_q * queries[:, None]).sum(axis=3)
khat = (pe_k * keys[:, None]).sum(axis=3)

# qhat are (batchsize, num_realizations, num_heads, *shape), making them (batchsize, *shape, num_heads, num_realizations)
qhat = qhat.permute(0, *[x for x in range(3, self.ndim+3)], 2, 1)
khat = khat.permute(0, *[x for x in range(3, self.ndim+3)], 2, 1)

# concatenate with the non-positional part of keys and queries
qhat = torch.cat([qhat, queries_rest], dim=-1)
khat = torch.cat([khat, keys_rest], dim=-1)

return qhat, khat


class SineSPE(nn.Module):
Expand All @@ -180,6 +205,10 @@ class SineSPE(nn.Module):
num_realizations: The number of realizations of the stochastic
process (R).
num_sines: The number of sin and cos components (K).
key_shape: The expected shape of keys and queries. Needs to be
set either here, or by calling `reset()`.
share_in_batch: Whether to share the same set of
positional encodings for all examples in the batch.
"""

def __init__(
Expand All @@ -188,6 +217,8 @@ def __init__(
in_features: int = 64,
num_realizations: int = 256,
num_sines: int = 10,
key_shape: Optional[Tuple[int, ...]] = None,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks a bit redundant to me, what about a max_length ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

share_in_batch: bool = True,
):
super(SineSPE, self).__init__()

Expand All @@ -210,6 +241,32 @@ def __init__(
)
)

# bias initial frequencies to low values for long term range
self.freqs.data[...] -= 5.

# reset qbar and kbar
self.reset(key_shape, share_in_batch)

def reset(self,
key_shape: Tuple[int, ...],
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like it that we need to provide the key_shape each time, we can't do otherwise ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

share_in_batch: Optional[bool] = None):
"""
Reset positional encodings.

At training, this is typically done for each new batch.
At testing, this is typically never done.

Args:
key_shape: The expected shape of keys and queries.
share_in_batch: Whether to share the same set of
positional encodings for all examples in the batch.
"""
self.qbar = None
self.kbar = None
self._key_shape = key_shape
if share_in_batch is not None:
self._share_in_batch = share_in_batch

def forward(
self, queries: torch.Tensor, keys: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -224,9 +281,6 @@ def forward(
"As of current implementation, queries and keys must have the same shape. "\
"got queries: {} and keys: {}".format(queries.shape, keys.shape)

batchsize = queries.shape[0]
length = queries.shape[1]

if queries.shape[-1] < self.in_features:
raise ValueError('Expected keys/queries of dimension at least'
f'{self.in_features}, got {queries.shape[-1]}.')
Expand All @@ -235,6 +289,39 @@ def forward(
queries, queries_rest = _split_features(queries, self.in_features)
keys, keys_rest = _split_features(keys, self.in_features)

if self.qbar is None:
self._draw_noise()
desired_shape = (*queries.shape, self.num_realizations)
if self.qbar.shape[2:] != desired_shape[2:]:
raise RuntimeError(f'Positional encodings have shape {self.qbar.shape}, '
f'but expected {desired_shape} '
f'(queries have shape {queries.shape})')
length = queries.shape[1]
if self.qbar.shape[1] < length:
raise RuntimeError(f'Positional encodings have length {self.qbar.shape[1]}, '
f'but expected at least {length}')

# sum over the keys_dim after multiplying by queries and keys
qhat = (self.qbar[:, :length] * queries[..., None]).sum(axis=-2)
khat = (self.kbar[:, :length] * keys[..., None]).sum(axis=-2)
cifkao marked this conversation as resolved.
Show resolved Hide resolved

# concatenate with the non-positional part of keys and queries
qhat = torch.cat([qhat, queries_rest], dim=-1)
khat = torch.cat([khat, keys_rest], dim=-1)

return qhat, khat

def _draw_noise(self):
"""
Generate the random QBar and Kbar, depending on the parameters,
and store them in the module.
"""
if self._key_shape is None:
raise RuntimeError('`key_shape` is not set. Please call `reset()` first')

batchsize = 1 if self._share_in_batch else self._key_shape[0]
cifkao marked this conversation as resolved.
Show resolved Hide resolved
length = self._key_shape[1]

# build omega_q and omega_k,
# with shape (num_heads, keys_dim, length, 2*num_sines)
indices = torch.linspace(0, length-1, length, device=self.freqs.device)
Expand Down Expand Up @@ -272,16 +359,15 @@ def forward(
# z is still (batchsize, num_heads, keys_dim, 2*num_sines, num_realizations)
z = z * gains[None, ..., None]

# multiplying z with omega, summing over the sines (l),
# multiplying with queries, summing over key_dim (r)
qhat = torch.einsum('bhdlr,hdnl,bnhd->bnhr', z, omega_q, queries)
khat = torch.einsum('bhdlr,hdnl,bnhd->bnhr', z, omega_k, keys)
# computing the sum over the sines.
# gets (batchsize, num_heads, keys_dim, length, num_realizations)
self.qbar = torch.matmul(omega_q[None], z)
self.kbar = torch.matmul(omega_k[None], z)

# concatenate with the non-positional part of keys and queries
qhat = torch.cat([qhat, queries_rest], dim=-1)
khat = torch.cat([khat, keys_rest], dim=-1)
# permuting them to be (batchsize, length, num_heads, keys_dim, num_realizations)
self.qbar = self.qbar.permute(0, 3, 1, 2, 4)
self.kbar = self.kbar.permute(0, 3, 1, 2, 4)

return qhat, khat


def _split_features(x: torch.Tensor, num_positional: int) -> torch.Tensor:
Expand Down