In [None]:
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache"
os.environ["AUTH_TOKEN"] = "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"

In [None]:
import sled
from transformers import AutoModelForSeq2SeqLM, AutoConfig
device = "cpu"
model_name = "tau/bart-base-sled-govreport"
auth_token = "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"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=auth_token)
model.to(device)
# from transformers.models.bart.modeling_bart import BartModel
pass

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)

In [None]:
import torch
from torch import Tensor


class GradCache:
    def __init__(self):
        self.cache = []

    def __call__(self, grad: Tensor) -> None:
        if (grad != 0).any():
            grad = grad.cpu().detach().clone()
            self.cache.append(grad)


In [None]:
from datasets import load_dataset
dataset = load_dataset("ccdv/govreport-summarization", split="validation[:10]")

In [None]:
batch = tokenizer(dataset[5:8]["report"], return_tensors='pt', padding=True, truncation=True, max_length=200)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
batch_size = input_ids.shape[0]
num_input_tokens = input_ids.shape[1]
gen_batch = model.generate(input_ids=input_ids, attention_mask=attention_mask, num_beams=1, max_length=30)
decoder_attention_mask = (gen_batch != tokenizer.pad_token_id).long()
print(tokenizer.batch_decode(gen_batch))

In [None]:
# smoothing_factor = 0.05
# smoothing_num_samples = 3
smoothing_factor = 0
smoothing_num_samples = 1
embedding_layer = model._underlying_model.model.shared
inputs_embeds = embedding_layer(input_ids)
inputs_embeds = torch.repeat_interleave(inputs_embeds, repeats=smoothing_num_samples, dim=0)
std_range = inputs_embeds.max(dim=-1, keepdims=True).values - inputs_embeds.min(dim=-1, keepdims=True).values
noise = torch.normal(torch.zeros_like(inputs_embeds), torch.ones_like(inputs_embeds) * smoothing_factor * std_range)
inputs_embeds = inputs_embeds + noise

attention_mask = torch.repeat_interleave(attention_mask, repeats=smoothing_num_samples, dim=0)
decoder_input_ids = torch.repeat_interleave(gen_batch, repeats=smoothing_num_samples, dim=0)
decoder_attention_mask = torch.repeat_interleave(decoder_attention_mask, repeats=smoothing_num_samples, dim=0)
embed_dim = inputs_embeds.shape[-1]

In [None]:
inputs_embeds = inputs_embeds.unbind()
caches = []
for sequence_embeds in inputs_embeds:
    grad_cache = GradCache()
    sequence_embeds.register_hook(grad_cache)
    caches.append(grad_cache)
inputs_embeds = torch.stack(inputs_embeds)

In [None]:
model_output = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
                     decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask)

In [None]:
logits = model_output.logits[:,1:-2,:]  # without additional generated token [-1], without the token that replaces the forced EOS [-2].
generated_tokens = gen_batch[:,2:-1]  # without the EOS+BOS that start the generation, without the forced EOS at the end.
num_target_tokens = generated_tokens.shape[1]
generated_tokens_no_repeats = generated_tokens.clone()
generated_tokens = torch.repeat_interleave(generated_tokens, repeats=smoothing_num_samples, dim=0)

if (smoothing_factor == 0) and (smoothing_num_samples == 1):
    print("Checking")
    manual_greedy_output = logits.argmax(-1)
    assert (manual_greedy_output == generated_tokens).all()

In [None]:
logprobs = logits.log_softmax(dim=-1)  # more correct, but then every word in the dictionary participates - is that good?
# logprobs = logits
generated_logprobs = torch.gather(input=logprobs, dim=-1, index=generated_tokens.unsqueeze(-1)).squeeze(-1)

if (smoothing_factor == 0) and (smoothing_num_samples == 1):
    print("Checking")
    manual_greedy_generated_logprobs = logprobs.max(dim=-1).values
    assert (generated_logprobs == manual_greedy_generated_logprobs).all()

In [None]:
from tqdm import tqdm
# i_target_token = 6
# i_target_token = 7
# i_target_token = "entire_sequence"
i_target_token = "all_individual_target_tokens"
if i_target_token == "all_individual_target_tokens":
    for i_example in tqdm(range(generated_logprobs.shape[0])):
        for i_token in range(generated_logprobs.shape[1]):
            generated_logprobs[i_example][i_token].backward(retain_graph=True)
            model.zero_grad()
else:
    if i_target_token == "entire_sequence":
        logprobs_to_derivate = generated_logprobs.sum(dim=-1)
        for gen_seq in tokenizer.batch_decode(generated_tokens):
            print(gen_seq)
    else:
        logprobs_to_derivate = generated_logprobs[:,i_target_token]
        print(tokenizer.convert_ids_to_tokens(generated_tokens[:,i_target_token]))

    for i in range(batch_size):
        logprobs_to_derivate[i].backward(retain_graph=True)
        model.zero_grad()


In [None]:
grads = torch.stack([torch.stack(grad_cache.cache) for grad_cache in caches]) # bsz * smooth_samples, target tokens, input tokens, embed dim
grads = grads.view(batch_size, smoothing_num_samples, num_target_tokens, num_input_tokens, embed_dim)

In [None]:
grad_to_scalar_method = "l1_norm"
grad_to_scalar_factor = embed_dim
if grad_to_scalar_method == "l1_norm":
    saliency = grads.abs().sum(dim=-1)
elif grad_to_scalar_method == "grad_dot_input":
    saliency = (grads * inputs_embeds.unsqueeze(1)).sum(dim=-1)
elif grad_to_scalar_method == "l2_norm":
    saliency = (grads ** 2).sum(dim=-1)
else:
    raise ValueError()
saliency = saliency / grad_to_scalar_factor

# saliency shape before smoothing: [bsz, smoothing_num_samples, target tokens, input tokens]
saliency = saliency.mean(dim=1) # avg over noisy samples, now shape is [bsz, target tokens, input tokens]

# normalization_method, agg_method, post_agg_normalization = "softmax_over_entire_sequence", "sum", "none"
normalization_method, agg_method, post_agg_normalization = "min_max_per_target_token", "max", "none"
# normalization_method, agg_method, post_agg_normalization = "log_softmax_per_target_token", "max", "none"
# normalization_method, agg_method, post_agg_normalization = "min_max_per_target_token", "max", "devide_by_sum"

if normalization_method == "min_max_per_target_token":
    saliency = saliency - saliency.min(dim=-1, keepdims=True).values
    saliency = saliency / saliency.max(dim=-1, keepdims=True).values
elif normalization_method == "log_softmax_per_target_token":
    saliency = saliency.log_softmax(dim=-1)
elif normalization_method == "softmax_over_entire_sequence":
    temperature = 1.
    saliency = saliency.view(batch_size, -1)
    saliency = (saliency / temperature).softmax(dim=-1)
    saliency = saliency.view(batch_size, num_target_tokens, num_input_tokens)
else:
    raise ValueError()

if agg_method == "max":
    saliency_agg = saliency.max(dim=1).values
elif agg_method == "sum":
    saliency_agg = saliency.sum(dim=1)
else:
    raise ValueError()

if post_agg_normalization == "none":
    pass
elif post_agg_normalization == "softmax":
    saliency_agg = saliency_agg.softmax(dim=-1)
elif post_agg_normalization == "devide_by_sum":
    saliency_agg = saliency_agg / saliency_agg.sum(dim=1, keepdims=True)
else:
    raise ValueError()


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

i_example = 2

cols = {}
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[i_example])
input_tokens = pd.Series(input_tokens).str.replace('Ġ', '_').replace('<s>', 'BOS')
cols["Input Token MaxPool"] = input_tokens
cols["Saliency Agg"] = saliency_agg[i_example].detach().cpu().numpy()
for i_target_token in range(saliency.shape[1]):
    target_token = tokenizer.convert_ids_to_tokens([generated_tokens_no_repeats[i_example][i_target_token]])[0].replace('Ġ', '_').replace('<s>', 'BOS')
    cols[f"Input Tokens_{target_token}"] = input_tokens
    cols[f"Saliency_{target_token}"] = saliency[i_example][i_target_token].detach().cpu().numpy()

saliency_df = pd.DataFrame(cols)
cmap = plt.cm.get_cmap("coolwarm")
saliency_df = saliency_df.style.background_gradient(cmap=cmap)

print(tokenizer.decode(input_ids[i_example]))
print()
print(tokenizer.decode(generated_tokens_no_repeats[i_example]))
saliency_df