Skip to content

Commit

Permalink
compute importance measure in mini-batches
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreasMadsen committed Feb 20, 2021
1 parent 2bb494d commit 083ef34
Showing 1 changed file with 39 additions and 29 deletions.
68 changes: 39 additions & 29 deletions comp550/dataset/roar.py
Expand Up @@ -87,64 +87,74 @@ def collate(self, observations):
def uncollate(self, observations):
return self._base_dataset.uncollate(observations)

def _importance_measure_random(self, observation):
return torch.tensor(self._rng.rand(*observation['sentence'].shape))
def _importance_measure_random(self, batch):
return torch.tensor(self._rng.rand(*batch['sentence'].shape))

def _importance_measure_attention(self, observation):
def _importance_measure_attention(self, batch):
with torch.no_grad():
_, alpha = self._model(self.collate([observation]))
return torch.squeeze(alpha, dim=0)
_, alpha = self._model(batch)
return torch.squeeze(alpha)

def _importance_measure_gradient(self, observation):
def _importance_measure_gradient(self, batch):
# Make a shallow copy, because batch['sentence'] will be overwritten
batch = batch.copy()

# TODO: Ensure that the padding is zero. In theory we don't need padding.
# Setup batch to be a one-hot encoded float32 with require_grad. This is neccesary
# as torch does not allow computing grad w.r.t. to an int-tensor.
batch = self.collate([observation])
batch['sentence'] = torch.nn.functional.one_hot(batch['sentence'], len(self.vocabulary))
batch['sentence'] = batch['sentence'].type(torch.float32)
batch['sentence'].requires_grad = True

# Compute model
y, _ = self._model(batch)

# Compute gradient
yc = y[0, observation['label']]
yc_wrt_x, = torch.autograd.grad(yc, (batch['sentence'], ))
# Select correct label, as we would like gradient of y[correct_label] w.r.t. x
yc = y[torch.arange(len(batch['label'])), batch['label']]
# autograd.grad must take a scalar, however we would like $d y_{i,c}/d x_i$
# to be computed as a batch, meaning for each $i$. To work around this,
# use that for $g(x) = \sum_i f(x_i)$, we have $d g(x)/d x_{x_i} = d f(x_i)/d x_{x_i}$.
# The gradient of the sum, is therefore equivalent to the batch_gradient.
yc_wrt_x, = torch.autograd.grad(torch.sum(yc, axis=0), (batch['sentence'], ))

# Normalize the vector-gradient per token into one scalar
return torch.norm(torch.squeeze(yc_wrt_x, 0), 2, dim=1)
return torch.norm(yc_wrt_x, 2, dim=2)

def _importance_measure_integrated_gradient(self, observation):
# Implement as x .* (1/k) .* sum([f'((i/k) .* x) for i in range(1, k+1))
pass

def _mask_observation(self, observation):
importance = self._importance_measure_fn(observation)
def _mask_batch(self, batch):
batch_importance = self._importance_measure_fn(batch)

masked_batch = []
with torch.no_grad():
# Prevent masked tokens from being "removed"
importance[torch.logical_not(observation['mask'])] = -np.inf
for importance, observation in zip(batch_importance, self.uncollate(batch)):
# Trim importance to the observation length
importance = importance[0:len(observation['sentence'])]

# Ensure that already "removed" tokens continues to be "removed"
importance[observation['sentence'] == self.tokenizer.mask_token_id] = np.inf
# Prevent masked tokens from being "removed"
importance[torch.logical_not(observation['mask'])] = -np.inf

# Tokens to remove.
# Ensure that k does not exceed the number of un-masked tokens, if it does
# masked tokens will be "removed" too.
k = torch.minimum(torch.tensor(self._k), torch.sum(observation['mask']))
_, remove_indices = torch.topk(importance, k=k, sorted=False)
# Ensure that already "removed" tokens continues to be "removed"
importance[observation['sentence'] == self.tokenizer.mask_token_id] = np.inf

# "Remove" top-k important tokens
observation['sentence'][remove_indices] = self.tokenizer.mask_token_id
# Tokens to remove.
# Ensure that k does not exceed the number of un-masked tokens, if it does
# masked tokens will be "removed" too.
k = torch.minimum(torch.tensor(self._k), torch.sum(observation['mask']))
_, remove_indices = torch.topk(importance, k=k, sorted=False)

return observation
# "Remove" top-k important tokens
observation['sentence'][remove_indices] = self.tokenizer.mask_token_id
masked_batch.append(observation)

return masked_batch

def _mask_dataset(self, dataloader, name):
outputs = []
for batched_observation in tqdm(dataloader(batch_size=1, num_workers=0, shuffle=False),
desc=f'Building {name} dataset', leave=False):
outputs.append(self._mask_observation(self.uncollate(batched_observation)[0]))
for batch in tqdm(dataloader(batch_size=self.batch_size, num_workers=0, shuffle=False),
desc=f'Building {name} dataset', leave=False):
outputs += self._mask_batch(batch)
return outputs

def prepare_data(self):
Expand Down

0 comments on commit 083ef34

Please sign in to comment.