In [1]:
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache"
os.environ["AUTH_TOKEN"] = "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"

In [2]:
import sled
from transformers import AutoModelForSeq2SeqLM, AutoConfig
device = "cuda"
model_name = "tau/bart-base-sled-govreport"
auth_token = "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"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=auth_token)
model.to(device)
# from transformers.models.bart.modeling_bart import BartModel
pass

There were unexpected keys in the checkpoint model loaded: ['_lm_head.weight'].


In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)

In [4]:
import torch
from torch import Tensor


class GradCache:
    def __init__(self):
        self.cache = []

    def __call__(self, grad: Tensor) -> None:
        if (grad != 0).any():
            grad = grad.cpu().detach().clone()
            self.cache.append(grad)


In [5]:
from datasets import load_dataset
dataset = load_dataset("ccdv/govreport-summarization", split="validation[:10]")

No config specified, defaulting to: gov_report_summarization_dataset/document
Reusing dataset gov_report_summarization_dataset (/home/olab/tomerronen1/xdg_cache/huggingface/datasets/ccdv___gov_report_summarization_dataset)/document/1.0.0/57ca3042de9c40c218cc94084cbc80a99a161036134bfc88112c57d251443590)


In [6]:
batch = tokenizer(dataset[5:8]["report"], return_tensors='pt', padding=True, truncation=True, max_length=200)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
batch_size = input_ids.shape[0]
num_input_tokens = input_ids.shape[1]
gen_batch = model.generate(input_ids=input_ids, attention_mask=attention_mask, num_beams=1, max_length=30)
decoder_attention_mask = (gen_batch != tokenizer.pad_token_id).long()
print(tokenizer.batch_decode(gen_batch))



['</s><s>Among health care programs, Medicaid is the largest as measured by enrollment (over 73 million in fiscal year 2017) and the second largest as</s>', "</s><s>Why GAO Did This Study\n\nGAO's simulations suggest that the state and local government sector will likely continue to face a difference</s>", '</s><s>Why GAO Did This Study\n\nThe Department of Defense (DOD) guidance states that the Air Force and other services are responsible</s>']


In [7]:
smoothing_factor = 0.05
smoothing_num_samples = 3
embedding_layer = model._underlying_model.model.shared
inputs_embeds = embedding_layer(input_ids)
inputs_embeds = torch.repeat_interleave(inputs_embeds, repeats=smoothing_num_samples, dim=0)
std_range = inputs_embeds.max(dim=-1, keepdims=True).values - inputs_embeds.min(dim=-1, keepdims=True).values
noise = torch.normal(torch.zeros_like(inputs_embeds), torch.ones_like(inputs_embeds) * smoothing_factor * std_range)
inputs_embeds = inputs_embeds + noise

attention_mask = torch.repeat_interleave(attention_mask, repeats=smoothing_num_samples, dim=0)
decoder_input_ids = torch.repeat_interleave(gen_batch, repeats=smoothing_num_samples, dim=0)
decoder_attention_mask = torch.repeat_interleave(decoder_attention_mask, repeats=smoothing_num_samples, dim=0)
embed_dim = inputs_embeds.shape[-1]

In [8]:
inputs_embeds = inputs_embeds.unbind()
caches = []
for sequence_embeds in inputs_embeds:
    grad_cache = GradCache()
    sequence_embeds.register_hook(grad_cache)
    caches.append(grad_cache)
inputs_embeds = torch.stack(inputs_embeds)

In [9]:
model_output = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
                     decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask)

In [10]:
logits = model_output.logits[:,1:-2,:]  # without additional generated token [-1], without the token that replaces the forced EOS [-2].
generated_tokens = gen_batch[:,2:-1]  # without the EOS+BOS that start the generation, without the forced EOS at the end.
num_target_tokens = generated_tokens.shape[1]
generated_tokens_no_repeats = generated_tokens.clone()
generated_tokens = torch.repeat_interleave(generated_tokens, repeats=smoothing_num_samples, dim=0)

if (smoothing_factor == 0) and (smoothing_num_samples == 1):
    print("Checking")
    manual_greedy_output = logits.argmax(-1)
    assert (manual_greedy_output == generated_tokens).all()

In [11]:
logprobs = logits.log_softmax(dim=-1)  # more correct, but then every word in the dictionary participates - is that good?
# logprobs = logits
generated_logprobs = torch.gather(input=logprobs, dim=-1, index=generated_tokens.unsqueeze(-1)).squeeze(-1)

if (smoothing_factor == 0) and (smoothing_num_samples == 1):
    print("Checking")
    manual_greedy_generated_logprobs = logprobs.max(dim=-1).values
    assert (generated_logprobs == manual_greedy_generated_logprobs).all()

In [12]:
from tqdm import tqdm
# i_target_token = 6
# i_target_token = 7
# i_target_token = "entire_sequence"
i_target_token = "all_individual_target_tokens"
if i_target_token == "all_individual_target_tokens":
    for i_example in tqdm(range(generated_logprobs.shape[0])):
        for i_token in range(generated_logprobs.shape[1]):
            generated_logprobs[i_example][i_token].backward(retain_graph=True)
            model.zero_grad()
else:
    if i_target_token == "entire_sequence":
        logprobs_to_derivate = generated_logprobs.sum(dim=-1)
        for gen_seq in tokenizer.batch_decode(generated_tokens):
            print(gen_seq)
    else:
        logprobs_to_derivate = generated_logprobs[:,i_target_token]
        print(tokenizer.convert_ids_to_tokens(generated_tokens[:,i_target_token]))

    for i in range(batch_size):
        logprobs_to_derivate[i].backward(retain_graph=True)
        model.zero_grad()


100%|██████████| 9/9 [00:20<00:00,  2.27s/it]


In [13]:
grads = torch.stack([torch.stack(grad_cache.cache) for grad_cache in caches]) # bsz * smooth_samples, target tokens, input tokens, embed dim
grads = grads.view(batch_size, smoothing_num_samples, num_target_tokens, num_input_tokens, embed_dim)
# grads = grads.mean(dim=1)  # avg over noisy samples

In [14]:
# from functools import reduce
saliency = grads.abs().sum(dim=-1)  # abd grad method. bsz, target tokens, input tokens
saliency = saliency.mean(dim=1) # avg over noisy samples
# saliency = (grads * inputs_embeds.unsqueeze(1)).sum(dim=-1)  # grad dot input method. bsz, target tokens, input tokens
saliency = saliency - saliency.min(dim=-1, keepdims=True).values
saliency = saliency / saliency.max(dim=-1, keepdims=True).values
# saliency = grad_dot_input.softmax(dim=-1)
saliency_maxpool = saliency.max(dim=1).values

In [17]:
import pandas as pd
import matplotlib.pyplot as plt

i_example = 2

cols = {}
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[i_example])
input_tokens = pd.Series(input_tokens).str.replace('Ġ', '_').replace('<s>', 'BOS')
cols["Input Token MaxPool"] = input_tokens
cols["Saliency MaxPool"] = saliency_maxpool[i_example].detach().cpu().numpy()
for i_target_token in range(saliency.shape[1]):
    target_token = tokenizer.convert_ids_to_tokens([generated_tokens_no_repeats[i_example][i_target_token]])[0].replace('Ġ', '_').replace('<s>', 'BOS')
    cols[f"Input Tokens_{target_token}"] = input_tokens
    cols[f"Saliency_{target_token}"] = saliency[i_example][i_target_token].detach().cpu().numpy()

saliency_df = pd.DataFrame(cols)
cmap = plt.cm.get_cmap("coolwarm")
saliency_df = saliency_df.style.background_gradient(cmap=cmap)

print(tokenizer.decode(input_ids[i_example]))
print()
print(tokenizer.decode(generated_tokens_no_repeats[i_example]))
saliency_df

<s>DOD guidance states that the Air Force and other services are responsible for providing trained and ready forces to fulfill the current and future operational requirements of the combatant commands. The Air Force is specifically responsible for gaining and maintaining air superiority. The Air Force Strategic Master Plan states that the Air Force must focus clearly on the capabilities that will allow freedom of maneuver and decisive action in highly contested spaces, including high-end air capabilities. Fifth generation fighter capabilities and ready and trained Airmen who are properly equipped for their missions are central components of the Air Force’s ability to provide air superiority in contested environments. The F-22 is the Air Force’s fifth generation, air superiority fighter that incorporates a stealthy and highly maneuverable airframe, advanced integrated avionics, and engines capable of sustained supersonic flight. The F-22 is optimized for air-to-air combat, able to carry

Unnamed: 0,Input Token MaxPool,Saliency MaxPool,Input Tokens_Why,Saliency_Why,Input Tokens__GA,Saliency__GA,Input Tokens_O,Saliency_O,Input Tokens__Did,Saliency__Did,Input Tokens__This,Saliency__This,Input Tokens__Study,Saliency__Study,Input Tokens_Ċ,Saliency_Ċ,Input Tokens_The,Saliency_The,Input Tokens__Department,Saliency__Department,Input Tokens__of,Saliency__of,Input Tokens__Defense,Saliency__Defense,Input Tokens__(,Saliency__(,Input Tokens_D,Saliency_D,Input Tokens_OD,Saliency_OD,Input Tokens_),Saliency_),Input Tokens__guidance,Saliency__guidance,Input Tokens__states,Saliency__states,Input Tokens__that,Saliency__that,Input Tokens__the,Saliency__the,Input Tokens__Air,Saliency__Air,Input Tokens__Force,Saliency__Force,Input Tokens__and,Saliency__and,Input Tokens__other,Saliency__other,Input Tokens__services,Saliency__services,Input Tokens__are,Saliency__are,Input Tokens__responsible,Saliency__responsible
0,BOS,0.16486,BOS,0.098741,BOS,0.152556,BOS,0.104258,BOS,0.123056,BOS,0.132433,BOS,0.16486,BOS,0.116258,BOS,0.151474,BOS,0.062367,BOS,0.05709,BOS,0.039576,BOS,0.069373,BOS,0.0576,BOS,0.043292,BOS,0.056719,BOS,0.056476,BOS,0.047976,BOS,0.073847,BOS,0.108807,BOS,0.053167,BOS,0.048691,BOS,0.058481,BOS,0.045893,BOS,0.043958,BOS,0.089036,BOS,0.056352
1,D,0.769523,D,0.592635,D,0.618353,D,0.547684,D,0.472447,D,0.437512,D,0.447562,D,0.419002,D,0.642006,D,0.65948,D,0.668204,D,0.583263,D,0.524276,D,0.769523,D,0.708626,D,0.530258,D,0.47105,D,0.174106,D,0.256265,D,0.311701,D,0.136012,D,0.095955,D,0.112305,D,0.106167,D,0.084127,D,0.178638,D,0.101602
2,OD,1.0,OD,0.682167,OD,0.959405,OD,1.0,OD,0.842323,OD,0.831259,OD,0.722748,OD,0.671043,OD,1.0,OD,1.0,OD,1.0,OD,1.0,OD,1.0,OD,1.0,OD,1.0,OD,0.924066,OD,0.740124,OD,0.221081,OD,0.3684,OD,0.50032,OD,0.259667,OD,0.157294,OD,0.199112,OD,0.14631,OD,0.148256,OD,0.254731,OD,0.16154
3,_guidance,1.0,_guidance,0.572172,_guidance,0.777143,_guidance,0.496871,_guidance,0.664234,_guidance,0.532292,_guidance,0.623308,_guidance,0.603161,_guidance,0.8179,_guidance,0.389893,_guidance,0.406904,_guidance,0.251585,_guidance,0.681084,_guidance,0.317305,_guidance,0.260006,_guidance,1.0,_guidance,1.0,_guidance,0.60612,_guidance,0.743476,_guidance,0.752886,_guidance,0.263698,_guidance,0.177807,_guidance,0.345761,_guidance,0.197393,_guidance,0.151463,_guidance,0.382753,_guidance,0.253557
4,_states,1.0,_states,0.373943,_states,0.779725,_states,0.381445,_states,0.546816,_states,0.41547,_states,0.452867,_states,0.39102,_states,0.551053,_states,0.204438,_states,0.192344,_states,0.082171,_states,0.324412,_states,0.127218,_states,0.09579,_states,0.381683,_states,0.294265,_states,1.0,_states,0.689397,_states,0.491639,_states,0.146829,_states,0.111529,_states,0.213572,_states,0.148789,_states,0.084028,_states,0.269003,_states,0.140308
5,_that,1.0,_that,0.191397,_that,0.296265,_that,0.217866,_that,0.209049,_that,0.208256,_that,0.281267,_that,0.229374,_that,0.275252,_that,0.146275,_that,0.134412,_that,0.074445,_that,0.166445,_that,0.096562,_that,0.082518,_that,0.159891,_that,0.178732,_that,0.262538,_that,1.0,_that,0.547976,_that,0.155372,_that,0.093605,_that,0.200397,_that,0.15407,_that,0.075474,_that,0.242832,_that,0.128959
6,_the,1.0,_the,0.094477,_the,0.12078,_the,0.103155,_the,0.089075,_the,0.110343,_the,0.116724,_the,0.089307,_the,0.165305,_the,0.133386,_the,0.122108,_the,0.09862,_the,0.13949,_the,0.066048,_the,0.091351,_the,0.118683,_the,0.125295,_the,0.118871,_the,0.49702,_the,1.0,_the,0.263792,_the,0.151314,_the,0.227471,_the,0.226977,_the,0.10128,_the,0.240178,_the,0.117766
7,_Air,1.0,_Air,0.10364,_Air,0.123386,_Air,0.120513,_Air,0.094982,_Air,0.136509,_Air,0.123922,_Air,0.086936,_Air,0.142282,_Air,0.257806,_Air,0.150231,_Air,0.07658,_Air,0.182132,_Air,0.076391,_Air,0.070021,_Air,0.102605,_Air,0.112668,_Air,0.063216,_Air,0.207046,_Air,0.548464,_Air,1.0,_Air,0.461734,_Air,0.294223,_Air,0.254532,_Air,0.143194,_Air,0.259905,_Air,0.133327
8,_Force,1.0,_Force,0.082017,_Force,0.109616,_Force,0.108996,_Force,0.08141,_Force,0.109839,_Force,0.095591,_Force,0.076725,_Force,0.115857,_Force,0.138896,_Force,0.127181,_Force,0.066587,_Force,0.193193,_Force,0.062752,_Force,0.063772,_Force,0.098987,_Force,0.114964,_Force,0.058291,_Force,0.168518,_Force,0.410065,_Force,0.376635,_Force,1.0,_Force,0.320979,_Force,0.300305,_Force,0.145108,_Force,0.186424,_Force,0.114953
9,_and,1.0,_and,0.069623,_and,0.075284,_and,0.083393,_and,0.088033,_and,0.079819,_and,0.093439,_and,0.064433,_and,0.09599,_and,0.064555,_and,0.076089,_and,0.030834,_and,0.133693,_and,0.038557,_and,0.037911,_and,0.093219,_and,0.081831,_and,0.071405,_and,0.1983,_and,0.382002,_and,0.15915,_and,0.3093,_and,1.0,_and,0.537176,_and,0.256401,_and,0.30005,_and,0.18524
