# Config Setting And Data/Model Preprocess

In [18]:
class args:

    ## common args
    seed = 42
    model_name = "t5-small"
    model_cache_dir = "./model"

    is_limit_num_of_tran_and_eval_samples = (
        True  # if True, You should setting num_train_examples and num_evaluate_examples
    )
    ## training args
    is_train = True
    batch_size = 4
    num_train_examples = 30000
    epoch = 1

    max_input_length = 1024
    max_target_length = 128
    # output_dir = "./result/t5-small-test-summarization"
    output_dir = f"./result/t5-small-test-summarization-{num_train_examples}"

    # evaluate args
    is_eval = True
    num_evaluate_examples = 3000
    # check_point = f"result/t5-small-test-summarization/checkpoint-51012"
    check_point = f"./result/t5-small-test-summarization-{num_train_examples}/checkpoint-{(num_train_examples+batch_size-1)//batch_size*epoch}"

    ## influence args
    loss_scale = 1e-2
    influence_on_decision = True
    damping = 3e-3
    lissa_depth = 0.15
    lissa_repeat = 1

In [19]:
import random
import torch
import numpy as np
from transformers import (
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    AutoTokenizer,
)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

<torch._C.Generator at 0x7fb9d0093c50>

## Data Exploration

In [20]:
from datasets import load_dataset

raw_datasets = load_dataset("EdinburghNLP/xsum", cache_dir="./data")

Using the latest cached version of the dataset since EdinburghNLP/xsum couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at data/EdinburghNLP___xsum/default/1.2.0/40db7604fedb616a9d2b0673d11838fa5be8451c (last modified on Tue Jan 14 15:22:51 2025).


In [21]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [22]:
if args.is_limit_num_of_tran_and_eval_samples:
    raw_datasets["train"] = (
        raw_datasets["train"].shuffle(seed=args.seed).select(range(args.num_train_examples))
    )
    raw_datasets["validation"] = (
        raw_datasets["validation"]
        .shuffle(args.seed)
        .select(range(args.num_evaluate_examples))
    )
    raw_datasets["test"] = (
        raw_datasets["test"].shuffle(args.seed).select(range(args.num_evaluate_examples))
    )

In [23]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 30000
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 3000
    })
})

In [24]:
# load metric file
from datasets import load_metric

metric = load_metric(path="./metric/rouge.py")
metric

Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_agregator: Return aggregates if this is set to True
Retu

In [25]:
# import nltk
# nltk.download('punkt_tab')

## Model Download and Exploration 

In [26]:
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, cache_dir="./model")

## Tokenization

In [27]:
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model")

In [28]:
if model_name in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

In [29]:
max_input_length = args.max_input_length
max_target_length = args.max_target_length

def preprocess_function(examples):
    
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["summary"], max_length=max_target_length, truncation=True
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [30]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

Map: 100%|██████████| 30000/30000 [00:08<00:00, 3537.47 examples/s]
Map: 100%|██████████| 3000/3000 [00:00<00:00, 3658.49 examples/s]
Map: 100%|██████████| 3000/3000 [00:00<00:00, 3777.15 examples/s]


In [31]:
import pickle
with open("data/xsm_train_30000.pickle", "wb") as f:
    pickle.dump(tokenized_datasets["train"], f)
with open("data/xsm_validation_3000.pickle", "wb") as f:
    pickle.dump(tokenized_datasets["validation"], f)
with open("data/xsm_test_3000.pickle", "wb") as f:
    pickle.dump(tokenized_datasets["test"], f)

In [16]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Summary Train

## Train

In [17]:
batch_size = args.batch_size
train_args = Seq2SeqTrainingArguments(
    output_dir=args.output_dir,
    logging_dir=args.output_dir,
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=args.epoch,
    predict_with_generate=True,
    fp16=True,
    disable_tqdm=False,
)
trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = Seq2SeqTrainer(


In [20]:
# resume_from_checkpoint = "result/t5-small-test-summarization-50000/checkpoint-2000"
resume_from_checkpoint = None
if args.is_train:
    if resume_from_checkpoint:
        trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    else :
        trainer.train()

  0%|          | 0/7500 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  7%|▋         | 500/7500 [02:00<31:22,  3.72it/s]

{'loss': 3.1672, 'grad_norm': 4.169085502624512, 'learning_rate': 1.867466666666667e-05, 'epoch': 0.07}


 13%|█▎        | 1000/7500 [04:00<23:22,  4.63it/s]

{'loss': 2.9334, 'grad_norm': 3.778669595718384, 'learning_rate': 1.7344000000000002e-05, 'epoch': 0.13}


 20%|██        | 1500/7500 [06:00<23:04,  4.33it/s]

{'loss': 2.902, 'grad_norm': 3.7677221298217773, 'learning_rate': 1.601066666666667e-05, 'epoch': 0.2}


 27%|██▋       | 2000/7500 [08:00<24:17,  3.77it/s]

{'loss': 2.8889, 'grad_norm': 4.0293097496032715, 'learning_rate': 1.4677333333333334e-05, 'epoch': 0.27}


 33%|███▎      | 2500/7500 [10:00<18:00,  4.63it/s]

{'loss': 2.8816, 'grad_norm': 3.1108665466308594, 'learning_rate': 1.3344000000000001e-05, 'epoch': 0.33}


 40%|████      | 3000/7500 [11:59<20:19,  3.69it/s]

{'loss': 2.8388, 'grad_norm': 3.582239866256714, 'learning_rate': 1.2010666666666668e-05, 'epoch': 0.4}


 47%|████▋     | 3500/7500 [13:58<16:29,  4.04it/s]

{'loss': 2.842, 'grad_norm': 2.6782193183898926, 'learning_rate': 1.0677333333333335e-05, 'epoch': 0.47}


 53%|█████▎    | 4000/7500 [15:55<10:51,  5.37it/s]

{'loss': 2.8281, 'grad_norm': 5.320406436920166, 'learning_rate': 9.344e-06, 'epoch': 0.53}


 60%|██████    | 4500/7500 [17:54<11:23,  4.39it/s]

{'loss': 2.8394, 'grad_norm': 4.065776824951172, 'learning_rate': 8.010666666666668e-06, 'epoch': 0.6}


 67%|██████▋   | 5000/7500 [19:55<08:51,  4.70it/s]

{'loss': 2.8347, 'grad_norm': 6.069305896759033, 'learning_rate': 6.680000000000001e-06, 'epoch': 0.67}


 73%|███████▎  | 5500/7500 [21:54<07:35,  4.39it/s]

{'loss': 2.7995, 'grad_norm': 3.9715235233306885, 'learning_rate': 5.346666666666667e-06, 'epoch': 0.73}


 80%|████████  | 6000/7500 [23:50<06:46,  3.69it/s]

{'loss': 2.8151, 'grad_norm': 3.0418689250946045, 'learning_rate': 4.013333333333334e-06, 'epoch': 0.8}


 87%|████████▋ | 6500/7500 [25:52<04:35,  3.63it/s]

{'loss': 2.8104, 'grad_norm': 3.5562727451324463, 'learning_rate': 2.68e-06, 'epoch': 0.87}


 93%|█████████▎| 7000/7500 [27:51<02:18,  3.61it/s]

{'loss': 2.8149, 'grad_norm': 3.6993963718414307, 'learning_rate': 1.3466666666666668e-06, 'epoch': 0.93}


100%|██████████| 7500/7500 [29:50<00:00,  4.75it/s]

{'loss': 2.8404, 'grad_norm': 3.7193663120269775, 'learning_rate': 1.6e-08, 'epoch': 1.0}


                                                   
100%|██████████| 7500/7500 [33:21<00:00,  3.75it/s]

{'eval_loss': 2.560678005218506, 'eval_rouge1': 26.5342, 'eval_rouge2': 6.7979, 'eval_rougeL': 20.8162, 'eval_rougeLsum': 20.8365, 'eval_gen_len': 18.818, 'eval_runtime': 210.3405, 'eval_samples_per_second': 14.263, 'eval_steps_per_second': 3.566, 'epoch': 1.0}
{'train_runtime': 2001.3864, 'train_samples_per_second': 14.99, 'train_steps_per_second': 3.747, 'train_loss': 2.869096305338542, 'epoch': 1.0}





## Evaluate

In [24]:
if args.is_eval:
    batch_size = args.batch_size
    train_args = Seq2SeqTrainingArguments(
        args.check_point,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        predict_with_generate=True,
        fp16=True,
        disable_tqdm=False,
    )
    trainer = Seq2SeqTrainer(
        model,
        train_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(args.check_point)
    res = trainer.evaluate()
res

  trainer = Seq2SeqTrainer(
100%|██████████| 750/750 [03:26<00:00,  3.63it/s]


{'eval_loss': 2.5910897254943848,
 'eval_model_preparation_time': 0.0019,
 'eval_rouge1': 26.4958,
 'eval_rouge2': 6.6502,
 'eval_rougeL': 20.6039,
 'eval_rougeLsum': 20.5968,
 'eval_gen_len': 18.8133,
 'eval_runtime': 206.8396,
 'eval_samples_per_second': 14.504,
 'eval_steps_per_second': 3.626}

# Influence-function

### Get Test Example

In [17]:
# model = AutoModelForSeq2SeqLM.from_pretrained(args.check_point,torch_dtype=torch.bfloat16)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = AutoModelForSeq2SeqLM.from_pretrained(args.check_point,torch_dtype=torch.bfloat16).to(device)

# shuffle the test dataset and select 10 examples
tokenized_datasets["test"].shuffle(seed=args.seed)
influence_fn_examples = tokenized_datasets["test"].select(range(10))
import pickle 
save_path = "./data/xsum-sample_10.pick"
with open(save_path, "wb") as f:
    pickle.dump(influence_fn_examples, f)


### Build Influence Function

In [28]:
param_optimizer = list(model.named_parameters())
for n, p in param_optimizer:
    print(n)

shared.weight
encoder.block.0.layer.0.SelfAttention.q.weight
encoder.block.0.layer.0.SelfAttention.k.weight
encoder.block.0.layer.0.SelfAttention.v.weight
encoder.block.0.layer.0.SelfAttention.o.weight
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight
encoder.block.0.layer.0.layer_norm.weight
encoder.block.0.layer.1.DenseReluDense.wi.weight
encoder.block.0.layer.1.DenseReluDense.wo.weight
encoder.block.0.layer.1.layer_norm.weight
encoder.block.1.layer.0.SelfAttention.q.weight
encoder.block.1.layer.0.SelfAttention.k.weight
encoder.block.1.layer.0.SelfAttention.v.weight
encoder.block.1.layer.0.SelfAttention.o.weight
encoder.block.1.layer.0.layer_norm.weight
encoder.block.1.layer.1.DenseReluDense.wi.weight
encoder.block.1.layer.1.DenseReluDense.wo.weight
encoder.block.1.layer.1.layer_norm.weight
encoder.block.2.layer.0.SelfAttention.q.weight
encoder.block.2.layer.0.SelfAttention.k.weight
encoder.block.2.layer.0.SelfAttention.v.weight
encoder.block.2.layer.0.SelfAttentio

In [29]:
frozen = ["shared.weight"]

param_influence = []
for n, p in param_optimizer:
    if not any(fr in n for fr in frozen):
        param_influence.append(p)
    elif "shared.weight" in n:
        pass  # need gradients through embedding layer for computing saliency map
    else:
        p.requires_grad = False

param_influence

[Parameter containing:
 tensor([[-0.0153, -0.0728,  0.0043,  ...,  0.0277, -0.0354, -0.1187],
         [-0.0383, -0.0253,  0.0483,  ..., -0.0444, -0.0483,  0.0688],
         [ 0.0535, -0.0581,  0.0047,  ..., -0.0354, -0.0187, -0.0420],
         ...,
         [-0.0928,  0.0115, -0.1108,  ..., -0.0062, -0.1011,  0.0393],
         [ 0.0025,  0.0171, -0.0422,  ..., -0.0957,  0.0299,  0.0016],
         [-0.1084, -0.0009, -0.0052,  ...,  0.0525, -0.0435, -0.0522]],
        device='cuda:0', dtype=torch.bfloat16, requires_grad=True),
 Parameter containing:
 tensor([[ 0.0713,  0.0977, -0.0703,  ..., -0.6719, -0.0559,  0.4316],
         [ 0.0242,  0.0205, -0.2871,  ..., -0.6484,  0.0981,  0.0674],
         [ 0.5430, -0.0146,  0.1001,  ..., -0.1348, -0.0327,  0.3516],
         ...,
         [ 0.0054,  0.2832,  0.0288,  ..., -0.1514,  0.4336, -0.1807],
         [-0.3184, -0.0830, -0.1357,  ...,  0.2031,  0.3770, -0.4453],
         [ 0.1982,  0.4023, -0.1875,  ..., -0.5977,  0.1816,  0.4082]],
    

In [30]:
param_shape_tensor = []
param_size = 0
for p in param_influence:
    tmp_p = p.clone().detach()
    param_shape_tensor.append(tmp_p)
    param_size += torch.numel(tmp_p)

print("  Parameter size = %d" % param_size)

  Parameter size = 44057088


In [31]:
print(influence_fn_examples[0].keys())

dict_keys(['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'])


In [33]:

def to_tensor(x):
    return torch.tensor(x, dtype=torch.long).unsqueeze(0)

input_ids = [
    torch.tensor(example["input_ids"], dtype=torch.long).unsqueeze(0)
    for example in influence_fn_examples
]
label_ids = [
    torch.tensor(example["labels"], dtype=torch.long).unsqueeze(0)
    for example in influence_fn_examples
]
attn_mask = [
    torch.tensor(example["attention_mask"], dtype=torch.long).unsqueeze(0)
    for example in influence_fn_examples
]

import torch.autograd as autograd
from tqdm import tqdm
import pickle
import os
import sys
import bert_util
import importlib
importlib.reload(bert_util) # reload when changes to bert_util.py

with open("output.log","w") as f: 
    # sys.stdout = f
# for each test sample, we should compute the influence socre on train dataset
    for idx, (input_id, label_id, attention_mask) in enumerate(
        zip(input_ids, label_ids, attn_mask)
    ):
        print(
            f"====================test example: {idx}======================================"
        )
        input_id = input_id.to(model.device)
        label_id = label_id.to(model.device)
        attention_mask = attention_mask.to(model.device)

        # get test example grad
        model.zero_grad()
        output = model(input_id, attention_mask=attention_mask, labels=label_id)
        # test_loss = output.loss
        # *scaled the loss to avoid the grad is too large
        scaled_loss = output.loss * args.loss_scale
        print(" loss:", scaled_loss)

        # test_grads = autograd.grad(test_loss, param_influence)
        test_grads = autograd.grad(scaled_loss, param_influence)
        
        # reload train dataset
        train_dataloader_lissa = tokenized_datasets["train"]
        print("len of traindataset", len(train_dataloader_lissa))

        device = model.device

        ######## IHVP ########
        model.train()
        print("######## START COMPUTING IHVP ########")
        inverse_hvp = bert_util.get_inverse_hvp_lissa(
            test_grads,
            model,
            device,
            param_influence,
            train_dataloader_lissa,
            loss_scale=args.loss_scale,
            damping=args.damping,
            num_samples=args.lissa_repeat,
            recursion_depth=int(len(train_dataloader_lissa) * args.lissa_depth),
        )
        print("######## FINISHED COMPUTING IHVP ########")
        print("inverse_hvp:", inverse_hvp)

        influences = np.zeros(len(train_dataloader_lissa))
        train_tok_sal_lists = []
        for train_idx,  sample in enumerate(
            tqdm(train_dataloader_lissa, desc="Train set index")
        ):
            (doc, summ, _, _input_ids, _input_mask, _label_ids) = sample.values()
            

            _input_ids = to_tensor(_input_ids).to(device)
            _input_mask = to_tensor(_input_mask).to(device)
            _label_ids = to_tensor(_label_ids).to(device)

            ######## L_TRAIN GRADIENT ########
            model.zero_grad()
            output = model(_input_ids,attention_mask = _input_mask, labels= _label_ids)
            train_loss = output.loss * args.loss_scale
            train_grads = autograd.grad(train_loss, param_influence)
            influences[train_idx] = torch.dot(
                inverse_hvp, bert_util.gather_flat_grad(train_grads)
            ).item()
            # print(influences[train_idx])
        if args.influence_on_decision:
            pickle.dump(influences, open(os.path.join(args.output_dir, "influences_test_" + str(idx) + ".pkl"), "wb"))
        else:
            pickle.dump(influences, open(os.path.join(args.output_dir, "influences_on_x_test_" + str(idx) + ".pkl"), "wb"))
    
sys.stdout = sys.__stdout__

 loss: tensor(0.0230, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.125863
Recursion at depth 200: norm is 8.654841
Recursion at depth 400: norm is 9.190510
Recursion at depth 600: norm is 9.313373
Recursion at depth 800: norm is 9.314757
Recursion at depth 1000: norm is 9.303090
Recursion at depth 1200: norm is 9.333562
Recursion at depth 1400: norm is 9.347337
Recursion at depth 1600: norm is 9.399684
Recursion at depth 1800: norm is 9.389694
Recursion at depth 2000: norm is 9.459481
Recursion at depth 2200: norm is 9.496763
Recursion at depth 2400: norm is 9.490389
Recursion at depth 2600: norm is 9.489272
Recursion at depth 2800: norm is 9.554091
Recursion at depth 3000: norm is 9.551081
Recursion at depth 3200: norm is 9.567074
Recursion at depth 3400: norm is 9.569996
Recursion at depth 3600: norm is 9.564068
Recursion at depth 3800: norm is 9.588473
Recursion at depth 4000: norm is 9.5859

Train set index: 100%|██████████| 30000/30000 [19:35<00:00, 25.52it/s]


 loss: tensor(0.0279, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.155072
Recursion at depth 200: norm is 10.832218
Recursion at depth 400: norm is 12.113464
Recursion at depth 600: norm is 12.663466
Recursion at depth 800: norm is 12.468938
Recursion at depth 1000: norm is 12.908704
Recursion at depth 1200: norm is 12.941697
Recursion at depth 1400: norm is 13.546585
Recursion at depth 1600: norm is 13.635581
Recursion at depth 1800: norm is 13.969149
Recursion at depth 2000: norm is 15.808049
Recursion at depth 2200: norm is 15.455676
Recursion at depth 2400: norm is 15.791433
Recursion at depth 2600: norm is 15.824063
Recursion at depth 2800: norm is 15.973456
Recursion at depth 3000: norm is 16.217779
Recursion at depth 3200: norm is 16.106342
Recursion at depth 3400: norm is 16.023407
Recursion at depth 3600: norm is 16.005775
Recursion at depth 3800: norm is 15.428082
Recursion at depth 4

Train set index: 100%|██████████| 30000/30000 [20:19<00:00, 24.60it/s]


 loss: tensor(0.0165, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.092706
Recursion at depth 200: norm is 6.444446
Recursion at depth 400: norm is 6.829885
Recursion at depth 600: norm is 6.885294
Recursion at depth 800: norm is 6.919654
Recursion at depth 1000: norm is 6.970932
Recursion at depth 1200: norm is 6.977200
Recursion at depth 1400: norm is 7.011880
Recursion at depth 1600: norm is 7.200767
Recursion at depth 1800: norm is 7.262277
Recursion at depth 2000: norm is 7.262067
Recursion at depth 2200: norm is 7.283370
Recursion at depth 2400: norm is 7.295090
Recursion at depth 2600: norm is 7.306193
Recursion at depth 2800: norm is 7.299913
Recursion at depth 3000: norm is 7.300943
Recursion at depth 3200: norm is 7.302222
Recursion at depth 3400: norm is 7.497781
Recursion at depth 3600: norm is 7.532568
Recursion at depth 3800: norm is 7.530846
Recursion at depth 4000: norm is 7.5388

Train set index: 100%|██████████| 30000/30000 [19:14<00:00, 25.98it/s]


 loss: tensor(0.0445, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.204399
Recursion at depth 200: norm is 14.280125
Recursion at depth 400: norm is 14.905560
Recursion at depth 600: norm is 15.009358
Recursion at depth 800: norm is 15.176588
Recursion at depth 1000: norm is 15.215185
Recursion at depth 1200: norm is 15.313102
Recursion at depth 1400: norm is 15.695178
Recursion at depth 1600: norm is 15.765046
Recursion at depth 1800: norm is 15.867601
Recursion at depth 2000: norm is 16.162296
Recursion at depth 2200: norm is 16.197598
Recursion at depth 2400: norm is 16.234915
Recursion at depth 2600: norm is 16.610086
Recursion at depth 2800: norm is 16.515530
Recursion at depth 3000: norm is 16.513325
Recursion at depth 3200: norm is 16.503174
Recursion at depth 3400: norm is 16.516851
Recursion at depth 3600: norm is 16.543299
Recursion at depth 3800: norm is 16.673193
Recursion at depth 4

Train set index: 100%|██████████| 30000/30000 [19:38<00:00, 25.46it/s]


 loss: tensor(0.0234, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.170336
Recursion at depth 200: norm is 11.983575
Recursion at depth 400: norm is 13.733478
Recursion at depth 600: norm is 13.830777
Recursion at depth 800: norm is 13.903862
Recursion at depth 1000: norm is 14.524327
Recursion at depth 1200: norm is 14.893278
Recursion at depth 1400: norm is 14.848613
Recursion at depth 1600: norm is 14.609480
Recursion at depth 1800: norm is 14.784294
Recursion at depth 2000: norm is 14.824445
Recursion at depth 2200: norm is 15.790593
Recursion at depth 2400: norm is 15.774895
Recursion at depth 2600: norm is 15.756003
Recursion at depth 2800: norm is 15.744400
Recursion at depth 3000: norm is 15.577909
Recursion at depth 3200: norm is 15.302768
Recursion at depth 3400: norm is 15.673893
Recursion at depth 3600: norm is 15.589507
Recursion at depth 3800: norm is 15.733931
Recursion at depth 4

Train set index: 100%|██████████| 30000/30000 [19:51<00:00, 25.17it/s]


 loss: tensor(0.0246, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.121810
Recursion at depth 200: norm is 8.360209
Recursion at depth 400: norm is 8.875549
Recursion at depth 600: norm is 8.956800
Recursion at depth 800: norm is 9.219600
Recursion at depth 1000: norm is 9.276332
Recursion at depth 1200: norm is 9.298748
Recursion at depth 1400: norm is 9.585082
Recursion at depth 1600: norm is 9.706594
Recursion at depth 1800: norm is 9.719789
Recursion at depth 2000: norm is 9.821321
Recursion at depth 2200: norm is 9.823512
Recursion at depth 2400: norm is 10.042622
Recursion at depth 2600: norm is 10.199114
Recursion at depth 2800: norm is 10.252500
Recursion at depth 3000: norm is 10.149735
Recursion at depth 3200: norm is 10.355267
Recursion at depth 3400: norm is 10.569717
Recursion at depth 3600: norm is 10.553460
Recursion at depth 3800: norm is 10.888927
Recursion at depth 4000: norm i

Train set index: 100%|██████████| 30000/30000 [19:33<00:00, 25.57it/s]


 loss: tensor(0.0278, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.107856
Recursion at depth 200: norm is 7.392168
Recursion at depth 400: norm is 7.791256
Recursion at depth 600: norm is 7.870235
Recursion at depth 800: norm is 7.898916
Recursion at depth 1000: norm is 7.916296
Recursion at depth 1200: norm is 7.913392
Recursion at depth 1400: norm is 7.947817
Recursion at depth 1600: norm is 8.050328
Recursion at depth 1800: norm is 8.083349
Recursion at depth 2000: norm is 8.092225
Recursion at depth 2200: norm is 8.076015
Recursion at depth 2400: norm is 8.053933
Recursion at depth 2600: norm is 8.228823
Recursion at depth 2800: norm is 8.516932
Recursion at depth 3000: norm is 8.596601
Recursion at depth 3200: norm is 16.101881
Recursion at depth 3400: norm is 14.428217
Recursion at depth 3600: norm is 14.353481
Recursion at depth 3800: norm is 14.163577
Recursion at depth 4000: norm is 13

Train set index: 100%|██████████| 30000/30000 [21:18<00:00, 23.47it/s]


 loss: tensor(0.0325, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.197950
Recursion at depth 200: norm is 13.846393
Recursion at depth 400: norm is 14.411400
Recursion at depth 600: norm is 28.548025
Recursion at depth 800: norm is 28.233511
Recursion at depth 1000: norm is 28.361376
Recursion at depth 1200: norm is 27.838142
Recursion at depth 1400: norm is 27.102139
Recursion at depth 1600: norm is 25.911852
Recursion at depth 1800: norm is 25.813135
Recursion at depth 2000: norm is 25.709654
Recursion at depth 2200: norm is 25.945692
Recursion at depth 2400: norm is 26.054483
Recursion at depth 2600: norm is 26.188921
Recursion at depth 2800: norm is 25.942804
Recursion at depth 3000: norm is 25.851004
Recursion at depth 3200: norm is 26.019230
Recursion at depth 3400: norm is 32.092506
Recursion at depth 3600: norm is 31.799833
Recursion at depth 3800: norm is 31.699993
Recursion at depth 4

Train set index: 100%|██████████| 30000/30000 [20:17<00:00, 24.63it/s]


 loss: tensor(0.0370, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.115739
Recursion at depth 200: norm is 7.929761
Recursion at depth 400: norm is 8.412388
Recursion at depth 600: norm is 8.447202
Recursion at depth 800: norm is 8.478775
Recursion at depth 1000: norm is 8.507248
Recursion at depth 1200: norm is 8.617958
Recursion at depth 1400: norm is 8.604177
Recursion at depth 1600: norm is 8.625663
Recursion at depth 1800: norm is 8.663033
Recursion at depth 2000: norm is 8.698076
Recursion at depth 2200: norm is 8.711054
Recursion at depth 2400: norm is 8.651614
Recursion at depth 2600: norm is 8.759000
Recursion at depth 2800: norm is 8.761760
Recursion at depth 3000: norm is 8.762509
Recursion at depth 3200: norm is 8.842787
Recursion at depth 3400: norm is 10.055709
Recursion at depth 3600: norm is 10.067150
Recursion at depth 3800: norm is 10.001190
Recursion at depth 4000: norm is 10.

Train set index: 100%|██████████| 30000/30000 [19:00<00:00, 26.31it/s]


 loss: tensor(0.0246, device='cuda:0', grad_fn=<MulBackward0>)
len of traindataset 30000
######## START COMPUTING IHVP ########
Recursion at depth 0: norm is 0.113240
Recursion at depth 200: norm is 7.889523
Recursion at depth 400: norm is 8.210333
Recursion at depth 600: norm is 9.093567
Recursion at depth 800: norm is 9.162790
Recursion at depth 1000: norm is 9.094481
Recursion at depth 1200: norm is 9.126731
Recursion at depth 1400: norm is 9.175330
Recursion at depth 1600: norm is 9.221914
Recursion at depth 1800: norm is 9.320100
Recursion at depth 2000: norm is 9.317429
Recursion at depth 2200: norm is 9.369915
Recursion at depth 2400: norm is 9.487803
Recursion at depth 2600: norm is 9.511250
Recursion at depth 2800: norm is 9.536650
Recursion at depth 3000: norm is 9.503697
Recursion at depth 3200: norm is 9.507857
Recursion at depth 3400: norm is 9.542685
Recursion at depth 3600: norm is 9.570725
Recursion at depth 3800: norm is 9.603239
Recursion at depth 4000: norm is 9.7288

Train set index: 100%|██████████| 30000/30000 [18:57<00:00, 26.37it/s]
