In [1]:
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache"
os.environ["AUTH_TOKEN"] = "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"

In [2]:
import sled
from transformers import AutoModelForSeq2SeqLM, AutoConfig
model_name = "tau/bart-base-sled-govreport"
auth_token = "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"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=auth_token)

# from transformers.models.bart.modeling_bart import BartModel


There were unexpected keys in the checkpoint model loaded: ['_lm_head.weight'].


In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)

In [4]:
import torch
from torch import Tensor


class GradCache:
    def __init__(self):
        self.cache = []

    def __call__(self, grad: Tensor) -> None:
        if (grad != 0).any():
            grad = grad.cpu().detach().clone()
            self.cache.append(grad)


In [5]:
from datasets import load_dataset
dataset = load_dataset("ccdv/govreport-summarization", split="validation[:10]")

No config specified, defaulting to: gov_report_summarization_dataset/document
Reusing dataset gov_report_summarization_dataset (/home/olab/tomerronen1/xdg_cache/huggingface/datasets/ccdv___gov_report_summarization_dataset)/document/1.0.0/57ca3042de9c40c218cc94084cbc80a99a161036134bfc88112c57d251443590)


In [6]:
batch = tokenizer(dataset[:2]["report"], return_tensors='pt', padding=True, truncation=True, max_length=200)
input_ids = batch["input_ids"]
attention_mask=batch["attention_mask"]
batch_size = input_ids.shape[0]
num_input_tokens = input_ids.shape[1]
gen_batch = model.generate(**batch, num_beams=1, max_length=30)
decoder_attention_mask = (gen_batch != tokenizer.pad_token_id).long()
print(tokenizer.batch_decode(gen_batch))



['</s><s>Part of the Mariana Islands Archipelago, the CNMI is a chain of 14 islands in the western Pacific Ocean, just north</s>', '</s><s>The U.S. pipeline network includes both interstate and intrastate pipelines, the vast majority of which fall into the latter category:</s>']


In [7]:
smoothing_factor = 0.15
smoothing_num_samples = 2
embedding_layer = model._underlying_model.model.shared
inputs_embeds = embedding_layer(input_ids)
inputs_embeds = torch.repeat_interleave(inputs_embeds, repeats=smoothing_num_samples, dim=0)
std_range = inputs_embeds.max(dim=-1, keepdims=True).values - inputs_embeds.min(dim=-1, keepdims=True).values
noise = torch.normal(torch.zeros_like(inputs_embeds), torch.ones_like(inputs_embeds) * smoothing_factor * std_range)
inputs_embeds = inputs_embeds + noise

attention_mask = torch.repeat_interleave(attention_mask, repeats=smoothing_num_samples, dim=0)
decoder_input_ids = torch.repeat_interleave(gen_batch, repeats=smoothing_num_samples, dim=0)
decoder_attention_mask = torch.repeat_interleave(decoder_attention_mask, repeats=smoothing_num_samples, dim=0)
embed_dim = inputs_embeds.shape[-1]

In [8]:
inputs_embeds = inputs_embeds.unbind()
caches = []
for sequence_embeds in inputs_embeds:
    grad_cache = GradCache()
    sequence_embeds.register_hook(grad_cache)
    caches.append(grad_cache)
inputs_embeds = torch.stack(inputs_embeds)

In [9]:
model_output = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
                     decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask)

In [10]:
logits = model_output.logits[:,:-2,:]  # without additional token [-1], without the token that replaces the forced EOS [-2].
generated_tokens = gen_batch[:,1:-1]  # without the EOS that starts the generation, without the forced EOS at the end.
num_target_tokens = generated_tokens.shape[1]
generated_tokens = torch.repeat_interleave(generated_tokens, repeats=smoothing_num_samples, dim=0)

if (smoothing_factor == 0) and (smoothing_num_samples == 1):
    manual_greedy_output = logits.argmax(-1)
    assert (manual_greedy_output == generated_tokens).all()

In [11]:
logprobs = logits.log_softmax(dim=-1)  # more correct, but then every word in the dictionary participates - is that good?
# logprobs = logits
generated_logprobs = torch.gather(input=logprobs, dim=-1, index=generated_tokens.unsqueeze(-1)).squeeze(-1)

if (smoothing_factor == 0) and (smoothing_num_samples == 1):
    manual_greedy_generated_logprobs = logprobs.max(dim=-1).values
    assert (generated_logprobs == manual_greedy_generated_logprobs).all()

In [12]:
from tqdm import tqdm
# i_target_token = 6
# i_target_token = 7
# i_target_token = "entire_sequence"
i_target_token = "all_individual_target_tokens"
if i_target_token == "all_individual_target_tokens":
    for i_example in tqdm(range(generated_logprobs.shape[0])):
        for i_token in range(generated_logprobs.shape[1]):
            generated_logprobs[i_example][i_token].backward(retain_graph=True)
            model.zero_grad()
else:
    if i_target_token == "entire_sequence":
        logprobs_to_derivate = generated_logprobs.sum(dim=-1)
        for gen_seq in tokenizer.batch_decode(generated_tokens):
            print(gen_seq)
    else:
        logprobs_to_derivate = generated_logprobs[:,i_target_token]
        print(tokenizer.convert_ids_to_tokens(generated_tokens[:,i_target_token]))

    for i in range(batch_size):
        logprobs_to_derivate[i].backward(retain_graph=True)
        model.zero_grad()


100%|██████████| 4/4 [00:58<00:00, 14.67s/it]


In [28]:
grads = torch.stack([torch.stack(grad_cache.cache) for grad_cache in caches]) # bsz * smooth_samples, target tokens, input tokens, embed dim
grads = grads.view(batch_size, smoothing_num_samples, num_target_tokens, num_input_tokens, embed_dim)
grads = grads.mean(dim=1)  # avg over noisy samples

In [29]:
torch.repeat_interleave(torch.tensor([
    [1,2,3,4],
    [5,6,7,8],
]), dim=0, repeats=3).view(2,3,-1)[1]

tensor([[5, 6, 7, 8],
        [5, 6, 7, 8],
        [5, 6, 7, 8]])

In [30]:
# from functools import reduce
saliency = grads.abs().sum(dim=-1)  # abd grad method. bsz, target tokens, input tokens
# saliency = (grads * inputs_embeds.unsqueeze(1)).sum(dim=-1)  # grad dot input method. bsz, target tokens, input tokens
saliency = saliency - saliency.min(dim=-1, keepdims=True).values
saliency = saliency / saliency.max(dim=-1, keepdims=True).values
# saliency = grad_dot_input.softmax(dim=-1)

In [31]:
import pandas as pd
import matplotlib.pyplot as plt

cols = {}
i_example = 0
for i_target_token in range(saliency.shape[1]):
    input_tokens = tokenizer.convert_ids_to_tokens(input_ids[i_example])
    input_tokens = pd.Series(input_tokens).str.replace('Ġ', '_').replace('<s>', 'BOS')
    target_token = tokenizer.convert_ids_to_tokens([generated_tokens[i_example][i_target_token]])[0].replace('Ġ', '_').replace('<s>', 'BOS')
    cols[f"Input Tokens_{target_token}"] = input_tokens
    cols[f"Saliency_{target_token}"] = saliency[i_example][i_target_token].detach().cpu().numpy()

saliency_df = pd.DataFrame(cols)
cmap = plt.cm.get_cmap("coolwarm")
saliency_df = saliency_df.style.background_gradient(cmap=cmap)

print(tokenizer.decode(input_ids[i_example]))
print()
print(tokenizer.decode(generated_tokens[i_example]))
saliency_df

<s>Part of the Mariana Islands Archipelago, the CNMI is a chain of 14 islands in the western Pacific Ocean, just north of Guam and about 3,200 miles west of Hawaii. The CNMI has a total population of 53,890, according to the CNMI’s 2016 Household, Income, and Expenditures Survey. Almost 90 percent of the population (48,200) resided on the island of Saipan, with an additional 6 percent (3,056) on the island of Tinian and 5 percent (2,635) on the island of Rota. The Consolidated Natural Resources Act of 2008 amended the U.S.– CNMI covenant to apply federal immigration law to the CNMI after a transition period. To provide for an orderly transition from the CNMI immigration system to the U.S. federal immigration system under the immigration laws of the United States, DHS established the CW program in 2011. Under the program, foreign workers are</s>

<s>Part of the Mariana Islands Archipelago, the CNMI is a chain of 14 islands in the western Pacific Ocean, just north


Unnamed: 0,Input Tokens_BOS,Saliency_BOS,Input Tokens_Part,Saliency_Part,Input Tokens__of,Saliency__of,Input Tokens__the,Saliency__the,Input Tokens__Mar,Saliency__Mar,Input Tokens_iana,Saliency_iana,Input Tokens__Islands,Saliency__Islands,Input Tokens__Arch,Saliency__Arch,Input Tokens_ipel,Saliency_ipel,Input Tokens_ago,Saliency_ago,"Input Tokens_,","Saliency_,",Input Tokens__CN,Saliency__CN,Input Tokens_MI,Saliency_MI,Input Tokens__is,Saliency__is,Input Tokens__a,Saliency__a,Input Tokens__chain,Saliency__chain,Input Tokens__14,Saliency__14,Input Tokens__islands,Saliency__islands,Input Tokens__in,Saliency__in,Input Tokens__western,Saliency__western,Input Tokens__Pacific,Saliency__Pacific,Input Tokens__Ocean,Saliency__Ocean,Input Tokens__just,Saliency__just,Input Tokens__north,Saliency__north
0,BOS,0.055623,BOS,0.059126,BOS,0.060221,BOS,0.045772,BOS,0.08166,BOS,0.028551,BOS,0.016642,BOS,0.056559,BOS,0.019967,BOS,0.039828,BOS,0.118984,BOS,0.05062,BOS,0.122932,BOS,0.050794,BOS,0.113758,BOS,0.049379,BOS,0.035409,BOS,0.053621,BOS,0.104222,BOS,0.053328,BOS,0.020077,BOS,0.044085,BOS,0.071659,BOS,0.040978
1,Part,0.026105,Part,1.0,Part,0.093235,Part,0.09072,Part,1.0,Part,0.253947,Part,0.052546,Part,0.504376,Part,0.09695,Part,0.213389,Part,0.173612,Part,0.258237,Part,0.229512,Part,0.251115,Part,0.526558,Part,0.338969,Part,0.083182,Part,0.205438,Part,0.181776,Part,0.084664,Part,0.031056,Part,0.112768,Part,0.221136,Part,0.086726
2,_of,0.047135,_of,0.091107,_of,0.021457,_of,0.025913,_of,0.115611,_of,0.057366,_of,0.020859,_of,0.192801,_of,0.025735,_of,0.028027,_of,0.036071,_of,0.078806,_of,0.087427,_of,0.032872,_of,0.105001,_of,0.045494,_of,0.010526,_of,0.032575,_of,0.036387,_of,0.032294,_of,0.012404,_of,0.016765,_of,0.117177,_of,0.014918
3,_the,0.030875,_the,0.025073,_the,0.041504,_the,0.050934,_the,0.050066,_the,0.012318,_the,0.041382,_the,0.321956,_the,0.292343,_the,0.746729,_the,0.11562,_the,0.046478,_the,0.098197,_the,0.025572,_the,0.066953,_the,0.058924,_the,0.031219,_the,0.044685,_the,0.341187,_the,0.120956,_the,0.055941,_the,0.075912,_the,0.18874,_the,0.035955
4,_Mar,0.05362,_Mar,0.099316,_Mar,0.065155,_Mar,0.148737,_Mar,0.551219,_Mar,0.439107,_Mar,0.452197,_Mar,0.249629,_Mar,0.059764,_Mar,0.226125,_Mar,0.241216,_Mar,0.361684,_Mar,0.372786,_Mar,0.171917,_Mar,0.375128,_Mar,0.242696,_Mar,0.058592,_Mar,0.078633,_Mar,0.178686,_Mar,0.38866,_Mar,0.167409,_Mar,0.16535,_Mar,0.427631,_Mar,0.039416
5,iana,0.031645,iana,0.078688,iana,0.107273,iana,0.128624,iana,0.178861,iana,1.0,iana,1.0,iana,0.492304,iana,0.176484,iana,0.357944,iana,0.136387,iana,0.221174,iana,0.470318,iana,0.255461,iana,0.299796,iana,0.139274,iana,0.048726,iana,0.119465,iana,0.192232,iana,0.198886,iana,0.065393,iana,0.114694,iana,0.3023,iana,0.040141
6,_Islands,0.036527,_Islands,0.064058,_Islands,0.106836,_Islands,0.141618,_Islands,0.136552,_Islands,0.127849,_Islands,0.786485,_Islands,0.561701,_Islands,0.065291,_Islands,0.345098,_Islands,0.228501,_Islands,0.251986,_Islands,0.441142,_Islands,0.241609,_Islands,0.495553,_Islands,0.134648,_Islands,0.10506,_Islands,0.109915,_Islands,0.237192,_Islands,0.20036,_Islands,0.355096,_Islands,0.105754,_Islands,0.270137,_Islands,0.043051
7,_Arch,0.0843,_Arch,0.122685,_Arch,0.0874,_Arch,0.104262,_Arch,0.197685,_Arch,0.14034,_Arch,0.311118,_Arch,1.0,_Arch,1.0,_Arch,1.0,_Arch,0.275048,_Arch,0.423831,_Arch,0.469572,_Arch,0.276944,_Arch,0.510483,_Arch,0.169355,_Arch,0.076875,_Arch,0.125716,_Arch,0.195704,_Arch,0.11315,_Arch,0.045841,_Arch,0.09486,_Arch,0.366412,_Arch,0.057554
8,ipel,0.020964,ipel,0.095217,ipel,0.062107,ipel,0.041681,ipel,0.047943,ipel,0.07547,ipel,0.018076,ipel,0.092064,ipel,0.033265,ipel,0.13043,ipel,0.147691,ipel,0.139313,ipel,0.193999,ipel,0.086091,ipel,0.152858,ipel,0.071842,ipel,0.063736,ipel,0.077145,ipel,0.158528,ipel,0.036014,ipel,0.023886,ipel,0.046987,ipel,0.229596,ipel,0.027892
9,ago,0.197436,ago,0.043985,ago,0.077027,ago,0.043061,ago,0.088132,ago,0.058586,ago,0.050521,ago,0.236207,ago,0.60135,ago,0.742649,ago,0.171286,ago,0.277867,ago,0.29608,ago,0.169125,ago,0.282014,ago,0.093621,ago,0.043146,ago,0.082781,ago,0.119767,ago,0.043606,ago,0.032098,ago,0.072664,ago,0.232674,ago,0.047986
