# Config Setting And Data/Model Preprocess

In [1]:
class args:

    ## common args
    seed = 42
    model_name = "t5-small"
    model_cache_dir = "./model"

    is_limit_num_of_tran_and_eval_samples = (
        True  # if True, You should setting num_train_examples and num_evaluate_examples
    )
    ## training args
    is_train = False
    batch_size = 4
    num_train_examples = 30000
    epoch = 1

    max_input_length = 1024
    max_target_length = 128
    output_dir = "./result/t5-small-test-summarization"
    # output_dir = f"./result/t5-small-test-summarization-{num_train_examples}"

    # evaluate args
    is_eva = False
    num_evaluate_examples = 3000
    check_point = f"result/t5-small-test-summarization/checkpoint-51012"
    # check_point = f"./result/t5-small-test-summarization-{num_train_examples}/checkpoint-{(num_train_examples+batch_size-1)//batch_size*epoch}"

    ## influence args
    damping = 3e-3
    lissa_depth = 0.25
    lissa_repeat = 1

In [2]:
import random
import torch
import numpy as np
from transformers import (
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    AutoTokenizer,
)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f42c01d7110>

## Data Exploration

In [3]:
from datasets import load_dataset

raw_datasets = load_dataset("EdinburghNLP/xsum", cache_dir="./data")

Using the latest cached version of the dataset since EdinburghNLP/xsum couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at data/EdinburghNLP___xsum/default/1.2.0/40db7604fedb616a9d2b0673d11838fa5be8451c (last modified on Tue Jan  7 22:02:03 2025).


In [4]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [5]:
if args.is_limit_num_of_tran_and_eval_samples:
    raw_datasets["train"] = (
        raw_datasets["train"].shuffle(seed=args.seed).select(range(args.num_train_examples))
    )
    raw_datasets["validation"] = (
        raw_datasets["validation"]
        .shuffle(args.seed)
        .select(range(args.num_evaluate_examples))
    )
    raw_datasets["test"] = (
        raw_datasets["test"].shuffle(args.seed).select(range(args.num_evaluate_examples))
    )

In [6]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 30000
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 3000
    })
})

In [7]:
# load metric file
from datasets import load_metric

metric = load_metric(path="./metric/rouge.py")
metric

  metric = load_metric(path="./metric/rouge.py")


Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_agregator: Return aggregates if this is set to True
Retu

In [8]:
# import nltk
# nltk.download('punkt_tab')

## Model Download and Exploration 

In [9]:
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, cache_dir="./model")
model

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

## Tokenization

In [10]:
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model")

In [11]:
tokenizer(["Hello, this one sentence!", "This is another sentence."])

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

In [12]:
if model_name in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

In [13]:
max_input_length = args.max_input_length
max_target_length = args.max_target_length


def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["summary"], max_length=max_target_length, truncation=True
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [14]:
preprocess_function(raw_datasets["train"][:2])



{'input_ids': [[21603, 10, 86, 10256, 6, 6098, 7, 33, 1966, 21, 3135, 11, 12162, 53, 2061, 5, 299, 16, 2789, 6, 1363, 411, 7, 12940, 31, 7, 515, 56, 1243, 415, 5779, 56, 18682, 12, 43, 3, 9, 1075, 16, 1260, 1073, 5, 30358, 7, 33, 1461, 11264, 57, 2069, 789, 11, 819, 3081, 43, 72, 4333, 147, 7209, 7, 11, 12, 483, 8, 194, 8, 496, 930, 5, 94, 19, 3, 9, 1516, 606, 16, 8, 2925, 12355, 122, 1433, 13, 2061, 1002, 30, 893, 596, 13, 4395, 9, 31, 7, 12991, 1050, 5, 275, 2199, 8, 22982, 3141, 56, 129, 996, 1723, 12, 1588, 8, 540, 21, 1566, 2061, 12, 4285, 8, 496, 239, 6, 34, 54, 1492, 34, 30, 136, 20, 4571, 162, 26, 1291, 616, 5, 3271, 7, 43, 150, 1390, 12, 1130, 3237, 5, 486, 8, 798, 6, 3, 19585, 5678, 33, 1966, 21, 1898, 496, 716, 11, 79, 174, 6323, 23, 138, 6059, 12, 143, 1516, 1112, 5, 290, 33, 641, 72, 145, 3, 8630, 6980, 3, 9, 6615, 2720, 7, 16, 2789, 11, 165, 4924, 12, 66, 538, 2061, 19, 9909, 12, 8944, 8, 22982, 3141, 31, 7, 11352, 12, 125, 79, 580, 3, 9, 96, 18782, 485, 6, 3452, 825, 121

In [15]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

In [16]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

import nltk
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

# Summary Train

## Train

In [16]:
batch_size = args.batch_size
train_args = Seq2SeqTrainingArguments(
    output_dir=args.output_dir,
    logging_dir=args.output_dir,
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=args.epoch,
    predict_with_generate=True,
    fp16=True,
    disable_tqdm=False,
)
trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [21]:
# resume_from_checkpoint = "result/t5-small-test-summarization-50000/checkpoint-2000"
resume_from_checkpoint = None
if args.is_train:
    if resume_from_checkpoint:
        trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    else :
        trainer.train()

  0%|          | 0/22500 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  2%|▏         | 500/22500 [01:59<1:37:53,  3.75it/s]

{'loss': 3.1649, 'grad_norm': 4.244799613952637, 'learning_rate': 1.9558222222222223e-05, 'epoch': 0.07}


  4%|▍         | 1000/22500 [03:58<1:16:32,  4.68it/s]

{'loss': 2.929, 'grad_norm': 3.5437469482421875, 'learning_rate': 1.911377777777778e-05, 'epoch': 0.13}


  7%|▋         | 1500/22500 [05:57<1:20:15,  4.36it/s]

{'loss': 2.8983, 'grad_norm': 3.4916257858276367, 'learning_rate': 1.8669333333333334e-05, 'epoch': 0.2}


  9%|▉         | 2000/22500 [07:55<1:29:37,  3.81it/s]

{'loss': 2.8832, 'grad_norm': 3.9176409244537354, 'learning_rate': 1.822488888888889e-05, 'epoch': 0.27}


 11%|█         | 2500/22500 [09:54<1:11:59,  4.63it/s]

{'loss': 2.8741, 'grad_norm': 5.255237102508545, 'learning_rate': 1.7781333333333335e-05, 'epoch': 0.33}


 13%|█▎        | 3000/22500 [11:53<1:28:11,  3.68it/s]

{'loss': 2.83, 'grad_norm': 3.6619646549224854, 'learning_rate': 1.733688888888889e-05, 'epoch': 0.4}


 16%|█▌        | 3500/22500 [13:52<1:17:58,  4.06it/s]

{'loss': 2.8311, 'grad_norm': 3.1096231937408447, 'learning_rate': 1.6892444444444447e-05, 'epoch': 0.47}


 18%|█▊        | 4000/22500 [15:49<57:14,  5.39it/s]  

{'loss': 2.8156, 'grad_norm': 4.834927082061768, 'learning_rate': 1.6448000000000002e-05, 'epoch': 0.53}


 20%|██        | 4500/22500 [17:48<1:08:35,  4.37it/s]

{'loss': 2.8241, 'grad_norm': 4.271838188171387, 'learning_rate': 1.6004444444444444e-05, 'epoch': 0.6}


 22%|██▏       | 5000/22500 [19:49<1:02:31,  4.67it/s]

{'loss': 2.8166, 'grad_norm': 6.328220844268799, 'learning_rate': 1.556e-05, 'epoch': 0.67}


 24%|██▍       | 5500/22500 [21:48<1:05:05,  4.35it/s]

{'loss': 2.7787, 'grad_norm': 3.903757333755493, 'learning_rate': 1.5116444444444445e-05, 'epoch': 0.73}


 27%|██▋       | 6000/22500 [23:45<1:15:04,  3.66it/s]

{'loss': 2.791, 'grad_norm': 2.9944801330566406, 'learning_rate': 1.4672000000000001e-05, 'epoch': 0.8}


 29%|██▉       | 6500/22500 [25:48<1:13:35,  3.62it/s]

{'loss': 2.7827, 'grad_norm': 3.4611856937408447, 'learning_rate': 1.4227555555555557e-05, 'epoch': 0.87}


 31%|███       | 7000/22500 [27:47<1:12:02,  3.59it/s]

{'loss': 2.7835, 'grad_norm': 4.128263473510742, 'learning_rate': 1.378311111111111e-05, 'epoch': 0.93}


 33%|███▎      | 7500/22500 [29:48<53:14,  4.70it/s]  

{'loss': 2.8058, 'grad_norm': 3.7389602661132812, 'learning_rate': 1.3338666666666668e-05, 'epoch': 1.0}


                                                    
 33%|███▎      | 7500/22500 [33:24<53:14,  4.70it/s]

{'eval_loss': 2.5254247188568115, 'eval_rouge1': 27.4092, 'eval_rouge2': 7.2048, 'eval_rougeL': 21.463, 'eval_rougeLsum': 21.468, 'eval_gen_len': 18.7957, 'eval_runtime': 215.9428, 'eval_samples_per_second': 13.893, 'eval_steps_per_second': 3.473, 'epoch': 1.0}


 36%|███▌      | 8000/22500 [35:25<1:02:42,  3.85it/s]  

{'loss': 2.7493, 'grad_norm': 3.330371141433716, 'learning_rate': 1.2894222222222224e-05, 'epoch': 1.07}


 38%|███▊      | 8500/22500 [37:28<1:06:38,  3.50it/s]

{'loss': 2.7557, 'grad_norm': 3.736124038696289, 'learning_rate': 1.2449777777777778e-05, 'epoch': 1.13}


 40%|████      | 9000/22500 [39:28<52:28,  4.29it/s]  

{'loss': 2.7355, 'grad_norm': 3.215278387069702, 'learning_rate': 1.2005333333333333e-05, 'epoch': 1.2}


 42%|████▏     | 9500/22500 [41:27<54:11,  4.00it/s]  

{'loss': 2.7344, 'grad_norm': 5.4099907875061035, 'learning_rate': 1.1561777777777779e-05, 'epoch': 1.27}


 44%|████▍     | 10000/22500 [43:26<56:38,  3.68it/s] 

{'loss': 2.7492, 'grad_norm': 6.402969837188721, 'learning_rate': 1.1117333333333333e-05, 'epoch': 1.33}


 47%|████▋     | 10500/22500 [45:25<51:51,  3.86it/s]  

{'loss': 2.7221, 'grad_norm': 3.665576934814453, 'learning_rate': 1.067288888888889e-05, 'epoch': 1.4}


 49%|████▉     | 11000/22500 [47:25<42:55,  4.47it/s]  

{'loss': 2.7397, 'grad_norm': 2.9963107109069824, 'learning_rate': 1.0228444444444446e-05, 'epoch': 1.47}


 51%|█████     | 11500/22500 [49:29<40:01,  4.58it/s]  

{'loss': 2.7295, 'grad_norm': 4.257570266723633, 'learning_rate': 9.78488888888889e-06, 'epoch': 1.53}


 53%|█████▎    | 12000/22500 [51:29<38:57,  4.49it/s]  

{'loss': 2.7338, 'grad_norm': 3.348665237426758, 'learning_rate': 9.340444444444445e-06, 'epoch': 1.6}


 56%|█████▌    | 12500/22500 [53:29<36:58,  4.51it/s]  

{'loss': 2.7494, 'grad_norm': 4.473400115966797, 'learning_rate': 8.896000000000001e-06, 'epoch': 1.67}


 58%|█████▊    | 13000/22500 [55:28<27:04,  5.85it/s]  

{'loss': 2.7647, 'grad_norm': 6.455132007598877, 'learning_rate': 8.451555555555557e-06, 'epoch': 1.73}


 60%|██████    | 13500/22500 [57:29<27:36,  5.43it/s]  

{'loss': 2.7294, 'grad_norm': 3.945760488510132, 'learning_rate': 8.007111111111112e-06, 'epoch': 1.8}


 62%|██████▏   | 14000/22500 [59:30<38:02,  3.72it/s]  

{'loss': 2.721, 'grad_norm': 3.464590311050415, 'learning_rate': 7.563555555555556e-06, 'epoch': 1.87}


 64%|██████▍   | 14500/22500 [1:01:30<34:34,  3.86it/s]

{'loss': 2.7211, 'grad_norm': 3.765570640563965, 'learning_rate': 7.1191111111111124e-06, 'epoch': 1.93}


 67%|██████▋   | 15000/22500 [1:03:30<28:06,  4.45it/s]  

{'loss': 2.6831, 'grad_norm': 3.893583059310913, 'learning_rate': 6.674666666666667e-06, 'epoch': 2.0}


                                                       
 67%|██████▋   | 15001/22500 [1:07:02<132:34:21, 63.64s/it]

{'eval_loss': 2.4939675331115723, 'eval_rouge1': 28.1305, 'eval_rouge2': 7.5866, 'eval_rougeL': 22.0035, 'eval_rougeLsum': 22.0039, 'eval_gen_len': 18.8123, 'eval_runtime': 210.6905, 'eval_samples_per_second': 14.239, 'eval_steps_per_second': 3.56, 'epoch': 2.0}


 69%|██████▉   | 15500/22500 [1:09:01<27:53,  4.18it/s]    

{'loss': 2.6889, 'grad_norm': 3.740978717803955, 'learning_rate': 6.230222222222223e-06, 'epoch': 2.07}


 71%|███████   | 16000/22500 [1:11:00<24:18,  4.46it/s]

{'loss': 2.7198, 'grad_norm': 3.1007449626922607, 'learning_rate': 5.785777777777778e-06, 'epoch': 2.13}


 73%|███████▎  | 16500/22500 [1:13:00<21:12,  4.72it/s]

{'loss': 2.7154, 'grad_norm': 4.052583694458008, 'learning_rate': 5.342222222222223e-06, 'epoch': 2.2}


 76%|███████▌  | 17000/22500 [1:14:59<23:42,  3.87it/s]

{'loss': 2.6686, 'grad_norm': 3.2006192207336426, 'learning_rate': 4.897777777777778e-06, 'epoch': 2.27}


 78%|███████▊  | 17500/22500 [1:17:01<19:29,  4.28it/s]

{'loss': 2.6971, 'grad_norm': 3.740164279937744, 'learning_rate': 4.453333333333334e-06, 'epoch': 2.33}


 80%|████████  | 18000/22500 [1:19:01<18:51,  3.98it/s]

{'loss': 2.6882, 'grad_norm': 3.7211720943450928, 'learning_rate': 4.008888888888889e-06, 'epoch': 2.4}


 82%|████████▏ | 18500/22500 [1:21:01<14:14,  4.68it/s]

{'loss': 2.7057, 'grad_norm': 4.675158500671387, 'learning_rate': 3.564444444444445e-06, 'epoch': 2.47}


 84%|████████▍ | 19000/22500 [1:22:59<15:04,  3.87it/s]

{'loss': 2.6978, 'grad_norm': 4.1983561515808105, 'learning_rate': 3.12e-06, 'epoch': 2.53}


 87%|████████▋ | 19500/22500 [1:24:59<12:21,  4.05it/s]

{'loss': 2.7346, 'grad_norm': 3.343820333480835, 'learning_rate': 2.675555555555556e-06, 'epoch': 2.6}


 89%|████████▉ | 20000/22500 [1:27:00<10:57,  3.80it/s]

{'loss': 2.6643, 'grad_norm': 3.5886309146881104, 'learning_rate': 2.2320000000000004e-06, 'epoch': 2.67}


 91%|█████████ | 20500/22500 [1:28:56<09:31,  3.50it/s]

{'loss': 2.7, 'grad_norm': 3.9172167778015137, 'learning_rate': 1.7875555555555556e-06, 'epoch': 2.73}


 93%|█████████▎| 21000/22500 [1:30:56<05:21,  4.66it/s]

{'loss': 2.6812, 'grad_norm': 3.8187644481658936, 'learning_rate': 1.343111111111111e-06, 'epoch': 2.8}


 96%|█████████▌| 21500/22500 [1:32:55<03:17,  5.05it/s]

{'loss': 2.7033, 'grad_norm': 3.627049207687378, 'learning_rate': 8.986666666666667e-07, 'epoch': 2.87}


 98%|█████████▊| 22000/22500 [1:34:53<01:51,  4.50it/s]

{'loss': 2.7053, 'grad_norm': 3.444369316101074, 'learning_rate': 4.542222222222223e-07, 'epoch': 2.93}


100%|██████████| 22500/22500 [1:36:51<00:00,  3.98it/s]

{'loss': 2.7083, 'grad_norm': 3.530238389968872, 'learning_rate': 1.0666666666666668e-08, 'epoch': 3.0}


                                                       
100%|██████████| 22500/22500 [1:40:21<00:00,  3.74it/s]

{'eval_loss': 2.4836490154266357, 'eval_rouge1': 28.2465, 'eval_rouge2': 7.6712, 'eval_rougeL': 22.1983, 'eval_rougeLsum': 22.197, 'eval_gen_len': 18.8313, 'eval_runtime': 209.0429, 'eval_samples_per_second': 14.351, 'eval_steps_per_second': 3.588, 'epoch': 3.0}
{'train_runtime': 6021.3342, 'train_samples_per_second': 14.947, 'train_steps_per_second': 3.737, 'train_loss': 2.7623339409722223, 'epoch': 3.0}





## Evaluate

In [33]:
if args.is_eval:
    batch_size = args.batch_size
    train_args = Seq2SeqTrainingArguments(
        args.check_point,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        predict_with_generate=True,
        fp16=True,
        disable_tqdm=False,
    )
    trainer = Seq2SeqTrainer(
        model,
        train_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(args.check_point)
    trainer.evaluate()

  trainer = Seq2SeqTrainer(
100%|██████████| 750/750 [03:27<00:00,  3.61it/s]


{'eval_loss': 2.5135397911071777,
 'eval_model_preparation_time': 0.0018,
 'eval_rouge1': 28.123,
 'eval_rouge2': 7.7164,
 'eval_rougeL': 22.0584,
 'eval_rougeLsum': 22.0585,
 'eval_gen_len': 18.8183,
 'eval_runtime': 207.9819,
 'eval_samples_per_second': 14.424,
 'eval_steps_per_second': 3.606}

# Influence-function

### Get Test Example

In [17]:
model = AutoModelForSeq2SeqLM.from_pretrained(args.check_point)
# shuffle the test dataset and select 10 examples
tokenized_datasets["test"].shuffle(seed=args.seed)
influence_fn_examples = tokenized_datasets["test"].select(range(10))

input_documents = influence_fn_examples["document"]
target_documents = influence_fn_examples["summary"]
model_inputs = tokenizer(
    input_documents,
    max_length=max_input_length,
    padding=True,
    truncation=True,
    return_tensors="pt",
)
model_inputs = model_inputs.to(model.device)
##  Generate Summary
outputs = model.generate(**model_inputs, max_length=max_target_length)

In [18]:
for i, input_documents in enumerate(input_documents):
    print("Example:", i)
    print("Input Document:", input_documents)
    print("Generated Summary:", tokenizer.decode(outputs[i], skip_special_tokens=True))
    print("Target Summary:", target_documents[i])
    print("=" * 50)

Example: 0
Input Document: Sarah Johnson was one of 21 women heading to Liverpool when their minibus was hit by a lorry on the M62.
Her friend Bethany Jones, 18, was killed while Ms Johnson and several others were badly hurt.
Minibus driver James Johnson was jailed for more than six years for causing Bethany's death, in April 2013.
Ms Johnson, who broke her shoulder, back and pelvis, said the help she received from a charity while in hospital led her to want to support others.
Speaking publicly for the first time about the crash, Ms Johnson described how everyone was "excited and giddy" for the hen party.
"To me the impact was just a massive explosion," she said.  "I thought the bus had blown up.
"I remember the bus dropping on its side. The next thing, I woke up on the roadside so I'd actually come out of the window."
Ms Johnson was taken to Leeds General Infirmary where she, along with Bethany's sister Amy Firth, underwent major surgery and spent time in intensive care.
Whilst she wa

In [19]:
batch_size = 1
train_args = Seq2SeqTrainingArguments(
    args.check_point,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    fp16=True,
    disable_tqdm=False,
)

trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=influence_fn_examples,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  0%|          | 0/10 [00:00<?, ?it/s]Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
 20%|██        | 2/10 [00:00<00:00, 15.23it/s]Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
100%|██████████| 10/10 [00:01<00:00,  7.98it/s]


{'eval_loss': 2.421166181564331,
 'eval_model_preparation_time': 0.0022,
 'eval_rouge1': 27.1753,
 'eval_rouge2': 5.412,
 'eval_rougeL': 20.8743,
 'eval_rougeLsum': 20.8016,
 'eval_gen_len': 19.0,
 'eval_runtime': 1.5798,
 'eval_samples_per_second': 6.33,
 'eval_steps_per_second': 6.33}

### Build Influence Function

We refer to the code implementation of this paper on [Influence Funciton](ACL2020 Explaining black box predictions and unveiling data artifacts through influence functions):

1. 计算测试样本的梯度（L_TEST GRADIENT）：
这部分代码首先计算测试样本相对于模型参数的梯度。这是通过在测试样本上运行前向传播，然后计算损失函数相对于模型参数的梯度来实现的。
```python
######## L_TEST GRADIENT ########
model.zero_grad()
test_loss = model(input_ids, segment_ids, input_mask, label_ids)
test_grads = autograd.grad(test_loss, param_influence)
################
````


2. 计算逆Hessian向量积（IHVP）：
这部分代码使用 Lissa 算法（通过 get_inverse_hvp_lissa 函数）来近似计算逆Hessian矩阵与测试梯度的乘积。这是 influence score 计算的核心步骤，因为它允许我们估计如果训练数据中移除或稍微修改某个样本，模型参数将如何变化。
```python
######## IHVP ########
model.train()
logger.info("######## START COMPUTING IHVP ########")
inverse_hvp = get_inverse_hvp_lissa(test_grads, model, device, param_influence, train_dataloader_lissa, damping=damping, num_samples=args.lissa_repeat, recursion_depth=int(len(train_examples)*args.lissa_depth))
logger.info("######## FINISHED COMPUTING IHVP ########")
################
```
3. 计算训练样本的影响力（INFLUENCE）：
这部分代码遍历训练数据集，对于每个训练样本，计算其梯度，并使用之前计算的逆Hessian向量积来估计该样本对测试样本预测的影响力。这是通过计算训练样本梯度和逆Hessian向量积的点积来实现的。
```python
######## INFLUENCE ########
influences = np.zeros(len(train_dataloader.dataset))
for train_idx, (_input_ids, _input_mask, _segment_ids, _label_ids, _) in enumerate(tqdm(train_dataloader, desc="Train set index")):
    model.zero_grad()
    train_loss = model(_input_ids, _segment_ids, _input_mask, _label_ids)
    train_grads = autograd.grad(train_loss, param_influence)
    influences[train_idx] = torch.dot(inverse_hvp, gather_flat_grad(train_grads)).item()
################
```


In [20]:
param_optimizer = list(model.named_parameters())
for n, p in param_optimizer:
    print(n)

shared.weight
encoder.block.0.layer.0.SelfAttention.q.weight
encoder.block.0.layer.0.SelfAttention.k.weight
encoder.block.0.layer.0.SelfAttention.v.weight
encoder.block.0.layer.0.SelfAttention.o.weight
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight
encoder.block.0.layer.0.layer_norm.weight
encoder.block.0.layer.1.DenseReluDense.wi.weight
encoder.block.0.layer.1.DenseReluDense.wo.weight
encoder.block.0.layer.1.layer_norm.weight
encoder.block.1.layer.0.SelfAttention.q.weight
encoder.block.1.layer.0.SelfAttention.k.weight
encoder.block.1.layer.0.SelfAttention.v.weight
encoder.block.1.layer.0.SelfAttention.o.weight
encoder.block.1.layer.0.layer_norm.weight
encoder.block.1.layer.1.DenseReluDense.wi.weight
encoder.block.1.layer.1.DenseReluDense.wo.weight
encoder.block.1.layer.1.layer_norm.weight
encoder.block.2.layer.0.SelfAttention.q.weight
encoder.block.2.layer.0.SelfAttention.k.weight
encoder.block.2.layer.0.SelfAttention.v.weight
encoder.block.2.layer.0.SelfAttentio

In [21]:
frozen = ["shared.weight"]

param_influence = []
for n, p in param_optimizer:
    if not any(fr in n for fr in frozen):
        param_influence.append(p)
    elif "shared.weight" in n:
        pass  # need gradients through embedding layer for computing saliency map
    else:
        p.requires_grad = False

param_influence

[Parameter containing:
 tensor([[-0.0135, -0.0746,  0.0004,  ...,  0.0210, -0.0337, -0.1187],
         [-0.0306, -0.0240,  0.0503,  ..., -0.0444, -0.0494,  0.0637],
         [ 0.0539, -0.0559,  0.0016,  ..., -0.0335, -0.0202, -0.0460],
         ...,
         [-0.0888,  0.0146, -0.1137,  ..., -0.0082, -0.0970,  0.0397],
         [ 0.0002,  0.0165, -0.0426,  ..., -0.0972,  0.0247, -0.0032],
         [-0.1095, -0.0024, -0.0035,  ...,  0.0542, -0.0421, -0.0520]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([[ 0.0696,  0.0976, -0.0691,  ..., -0.6685, -0.0591,  0.4304],
         [ 0.0281,  0.0197, -0.2820,  ..., -0.6439,  0.0988,  0.0655],
         [ 0.5398, -0.0160,  0.0962,  ..., -0.1355, -0.0318,  0.3481],
         ...,
         [ 0.0089,  0.2795,  0.0293,  ..., -0.1555,  0.4314, -0.1776],
         [-0.3142, -0.0810, -0.1309,  ...,  0.1979,  0.3735, -0.4425],
         [ 0.1962,  0.4015, -0.1885,  ..., -0.5950,  0.1754,  0.4051]],
        device='cuda:0', r

In [22]:
param_shape_tensor = []
param_size = 0
for p in param_influence:
    tmp_p = p.clone().detach()
    param_shape_tensor.append(tmp_p)
    param_size += torch.numel(tmp_p)

print("  Parameter size = %d" % param_size)

  Parameter size = 44057088


In [23]:
print(influence_fn_examples[0].keys())

dict_keys(['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'])


In [24]:
import bert_util
import importlib
importlib.reload(bert_util) # reload when changes to bert_util.py

<module 'bert_util' from '/home/psz/workspace/AI/My-project/in-fn-for-generation/bert_util.py'>

In [25]:


import torch.autograd as autograd
from tqdm import tqdm

input_ids = [
    torch.tensor(example["input_ids"], dtype=torch.long).unsqueeze(0)
    for example in influence_fn_examples
]
label_ids = [
    torch.tensor(example["labels"], dtype=torch.long).unsqueeze(0)
    for example in influence_fn_examples
]
attn_mask = [
    torch.tensor(example["attention_mask"], dtype=torch.long).unsqueeze(0)
    for example in influence_fn_examples
]

# for each test sample, we should compute the influence socre on train dataset
for idx, (input_id, label_id, attention_mask) in enumerate(
    zip(input_ids, label_ids, attn_mask)
):
    print(
        f"====================test example: {idx}======================================"
    )
    input_id = input_id.to(model.device)
    label_id = label_id.to(model.device)
    attention_mask = attention_mask.to(model.device)

    
    # get test example grad
    model.zero_grad()
    output = model(input_id, attention_mask=attention_mask, labels=label_id)
    # test_loss = output.loss 
    # *scaled the loss to avoid the grad is too large
    scaled_loss = output.loss * 1e-5
    print(" loss:", scaled_loss)
    
    # test_grads = autograd.grad(test_loss, param_influence)
    test_grads = autograd.grad(scaled_loss, param_influence)
    # reload train dataset
    train_dataloader_lissa = tokenized_datasets["train"]
    print("len of traindataset", len(train_dataloader_lissa))

    device = model.device

    ######## IHVP ########
    model.train()
    print("######## START COMPUTING IHVP ########")
    inverse_hvp = bert_util.get_inverse_hvp_lissa(
        test_grads,
        model,
        device,
        param_influence,
        train_dataloader_lissa,
        damping=args.damping,
        num_samples=args.lissa_repeat,
        recursion_depth=int(len(train_dataloader_lissa) * args.lissa_depth),
    )
    print("######## FINISHED COMPUTING IHVP ########")
    break

 loss: tensor(2.2270e-05, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.000041
Recursion at depth 200: norm is 0.003136
Recursion at depth 400: norm is 0.004833
Recursion at depth 600: norm is 0.005764
Recursion at depth 800: norm is 0.006274
Recursion at depth 1000: norm is 0.006554
Recursion at depth 1200: norm is 0.006707
Recursion at depth 1400: norm is 0.006791
Recursion at depth 1600: norm is 0.006837
Recursion at depth 1800: norm is 0.006863
Recursion at depth 2000: norm is 0.006877
Recursion at depth 2200: norm is 0.006884
Recursion at depth 2400: norm is 0.006888
Recursion at depth 2600: norm is 0.006891
Recursion at depth 2800: norm is 0.006892
Recursion at depth 3000: norm is 0.006893
Recursion at depth 3200: norm is 0.006893
Recursion at depth 3400: norm is 0.006893
Recursion at depth 3600: norm is 0.006893
Recursion at depth 3800: norm is 0.006893
Recursion at depth 4000: norm is 0.