# NNI in T5
Ref : https://nni.readthedocs.io/zh/stable/tutorials/new_pruning_bert_glue.html

## Import lib

In [4]:
from transformers import (
    AutoTokenizer, 
    DataCollatorForSeq2Seq , 
    AutoModelForSeq2SeqLM, 
    Seq2SeqTrainingArguments, 
    Seq2SeqTrainer
)
from transformers import DataCollatorForSeq2Seq
import evaluate
from rich import print

In [5]:
from TALib import TALib
ta_lib = TALib()

## Start


In [6]:
from __future__ import annotations

from pathlib import Path

import numpy as np

import torch
from torch.utils.data import ConcatDataset

import nni

from datasets import load_dataset, load_metric
# from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
# from transformers.trainer import Trainer
# from transformers.training_args import TrainingArguments

# skip_exec = True

In [7]:
def build_model_and_tokenizer():
    model =  AutoModelForSeq2SeqLM.from_pretrained(TALib.CHECKPOINT)
    tokenizer = AutoTokenizer.from_pretrained(TALib.TK_ckpt)
    return model , tokenizer

In [8]:
def prepare_datasets(tokenizer):
    billsum = load_dataset("billsum", split="train")
    preprocess_function = TALib.preprocess_function_pass_tokenizer(tokenizer)
    tokenized_billsum = billsum.map(preprocess_function, batched=True)
    
    billsum_test = load_dataset("billsum", split="test")
    tokenized_billsum_test = billsum_test.map(preprocess_function, batched=True)
    
    return tokenized_billsum , tokenized_billsum_test

In [9]:
def prepare_traced_trainer(model, tokenizer ,tokenized_billsum):


    compute_metrics = TALib.compute_metrics_pass_tokenizer(tokenizer)
    
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=TALib.CHECKPOINT)
    
    training_args = Seq2SeqTrainingArguments(
        output_dir="TA_billsum_model",
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        weight_decay=0.01,  # Assuming you still want weight decay as it wasn't mentioned to remove
        save_total_limit=3,  # Assuming to maintain the save limit as before
        num_train_epochs=4,
        lr_scheduler_type="linear",
        seed=42,
        fp16=True,  # You mentioned "Native AMP" for mixed precision training which is generally enabled by setting fp16=True in Transformers
        logging_steps=10,  # Assuming to keep the logging frequency as before
        predict_with_generate=True,
    )
    
    
    trainer = nni.trace(Seq2SeqTrainer)(
        model=model,
        args=training_args,
        train_dataset=tokenized_billsum["train"],
        eval_dataset=tokenized_billsum["test"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    return trainer

In [None]:
# def build_fine_tuning_model():
#     model , _ = build_model_and_tokenizer()
#     return model


# if not skip_exec:
#     Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
#     build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')

## Distillers

In [1]:
from transformers.models.t5 import T5ForConditionalGeneration

In [3]:
from nni.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
from nni.compression.utils import TransformersEvaluator

In [None]:
def dynamic_distiller(student_model: T5ForConditionalGeneration, teacher_model: T5ForConditionalGeneration,
                      student_trainer: Seq2SeqTrainer):
    layer_num = len(student_model.encoder.block)
    config_list = [{
        # 'op_names': [f'bert.encoder.layer.{i}'],
        'op_names': [f'encoder.block.{i}'],
        # 'link': [f'encoder.block.{j}' for j in range(i, layer_num)],
        'link': "auto",
        'lambda': 0.9,
        'apply_method': 'mse',
    } for i in range(layer_num)]


    evaluator = TransformersEvaluator(student_trainer)

    def teacher_predict(batch, teacher_model):
        return teacher_model(**batch)

    return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)


def dynamic_distillation(student_model: T5ForConditionalGeneration,
                         teacher_model: T5ForConditionalGeneration,
                         tokenizer,
                         tokenizer_billsum,
                         max_steps: int | None,
                         max_epochs: int | None):
    student_trainer = prepare_traced_trainer(student_model, tokenizer , tokenizer_billsum)

    ori_teacher_device = teacher_model.device
    training = teacher_model.training
    teacher_model.to(student_trainer.args.device).eval()

    distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
    distiller.compress(max_steps, max_epochs)
    distiller.unwrap_model()

    teacher_model.to(ori_teacher_device).train(training)

In [None]:
def adapt_distiller(student_model: T5ForConditionalGeneration,
                    teacher_model: T5ForConditionalGeneration,
                    student_trainer: Seq2SeqTrainer):
    layer_num = len(student_model.encoder.block)
    config_list = [{
        'op_names': [f'encoder.block.{i}'],
        'lambda': 0.9,
        'apply_method': 'mse',
    } for i in range(layer_num)]


    evaluator = TransformersEvaluator(student_trainer)

    def teacher_predict(batch, teacher_model):
        return teacher_model(**batch)

    return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)


def adapt_distillation(student_model: T5ForConditionalGeneration, 
                       teacher_model: T5ForConditionalGeneration,
                       tokenizer,
                       tokenizer_billsum,
                       max_steps: int | None, max_epochs: int | None):
    
    student_trainer = prepare_traced_trainer(student_model, tokenizer , tokenizer_billsum)

    ori_teacher_device = teacher_model.device
    training = teacher_model.training
    teacher_model.to(student_trainer.args.device).eval()

    distiller = adapt_distiller(student_model, teacher_model, student_trainer)
    dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
    dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]
    distiller.track_forward(*dummy_input)

    distiller.compress(max_steps, max_epochs)
    distiller.unwrap_model()

    teacher_model.to(ori_teacher_device).train(training)

In [None]:
from nni.compression.pruning import MovementPruner
from nni.compression.speedup import ModelSpeedup
from nni.compression.utils.external.external_replacer import TransformersAttentionReplacer


def pruning_attn():
    Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
    # model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
    model , tokenizer = build_model_and_tokenizer()
    trainer = prepare_traced_trainer(model, tokenizer , None)
    evaluator = TransformersEvaluator(trainer)

    config_list = [{
        'op_types': ['Linear'],
        'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention\.*'],
        'sparse_threshold': 0.1,
        'granularity': [64, 64]
    }]

    pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)
    pruner.compress(None, 4)
    pruner.unwrap_model()

    masks = pruner.get_masks()
    Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
    torch.save(masks, './output/pruning/attn_masks.pth')
    torch.save(model, './output/pruning/attn_masked_model.pth')


# if not skip_exec:
#     pruning_attn()

In [None]:
def speedup_attn():
    model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')
    masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')
    dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
    replacer = TransformersAttentionReplacer(model)
    ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()

    # finetuning
    teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
    dynamic_distillation(model, teacher_model, None, 3)
    torch.save(model, './output/pruning/attn_pruned_model.pth')


if not skip_exec:
    speedup_attn()