In [34]:
from __future__ import annotations

from pathlib import Path

import numpy as np

import torch
# from torch.utils.data import ConcatDataset

import nni

from datasets import load_dataset, load_metric
# from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
# from transformers.trainer import Trainer
# from transformers.training_args import TrainingArguments


from transformers import (
    AutoTokenizer, 
    DataCollatorForSeq2Seq , 
    AutoModelForSeq2SeqLM, 
    Seq2SeqTrainingArguments, 
    Seq2SeqTrainer,
    # t5 
    T5TokenizerFast,
    T5ForConditionalGeneration
)
from transformers import DataCollatorForSeq2Seq
import evaluate
from rich import print


In [19]:
from TALib import TALib
ta_lib = TALib()

In [20]:
def build_model(pretrained_model_name_or_path: str = TALib.CHECKPOINT, task_name: str = None):
    # is_regression = task_name == 'stsb'
    # num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
    # model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
    
    model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)
    
    return model

In [21]:
# tokenizer = AutoTokenizer.from_pretrained(TALib.TK_ckpt)

In [22]:
# type(tokenizer)

In [23]:
def prepare_datasets(task_name: str = None , tokenizer: T5TokenizerFast = None, cache_dir: str = None):
    billsum = load_dataset("billsum", split="train")
    preprocess_function = TALib.preprocess_function_pass_tokenizer(tokenizer)
    
    tokenized_billsum = billsum.map(preprocess_function, batched=True)
    
    billsum_test = load_dataset("billsum", split="test")
    tokenized_billsum_test = billsum_test.map(preprocess_function, batched=True)
    
    return tokenized_billsum , tokenized_billsum_test 

In [24]:
def prepare_traced_trainer(model:T5ForConditionalGeneration, task_name = None, load_best_model_at_end=False):
   
    metric = evaluate.load("rouge")

    tokenizer = T5TokenizerFast.from_pretrained(TALib.TK_ckpt)
    
    compute_metrics = TALib.compute_metrics_pass_tokenizer(tokenizer)
    
    train_dataset, validation_datasets = prepare_datasets(None, tokenizer, None)
    
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=TALib.CHECKPOINT)
    
    training_args = Seq2SeqTrainingArguments(
        output_dir="TA_billsum_model",
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        weight_decay=0.01,  # Assuming you still want weight decay as it wasn't mentioned to remove
        save_total_limit=3,  # Assuming to maintain the save limit as before
        num_train_epochs=4,
        lr_scheduler_type="linear",
        seed=42,
        fp16=True,  # You mentioned "Native AMP" for mixed precision training which is generally enabled by setting fp16=True in Transformers
        logging_steps=10,  # Assuming to keep the logging frequency as before
        predict_with_generate=True,
    )
    
    trainer = nni.trace(Seq2SeqTrainer)(
        model=model,
        args=training_args,
        train_dataset=train_dataset["train"],
        eval_dataset=train_dataset["test"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    return trainer

In [25]:
def build_finetuning_model(task_name: str , state_dict_path: str):
    model = build_model(TALib.TK_ckpt, None)
    if Path(state_dict_path).exists():
        model.load_state_dict(torch.load(state_dict_path))
    else:
        trainer = prepare_traced_trainer(model, None, True)
        trainer.train()
        torch.save(model.state_dict(), state_dict_path)
    return model


In [26]:
# Path('./output/t5_finetuned').mkdir(exist_ok=True, parents=True)

In [27]:
# model = AutoModelForSeq2SeqLM.from_pretrained(TALib.CHECKPOINT)

In [28]:
# torch.save(model.state_dict(), f'./output/t5_finetuned/ta_t5.bin')

In [31]:

skip_exec = False 
model_name = "ta_t5"
if not skip_exec:
    Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
    model = build_finetuning_model(None, f'./output/t5_finetuned/{model_name}.bin')

In [32]:
model

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [33]:
from nni.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
from nni.compression.utils import TransformersEvaluator

In [35]:
def dynamic_distiller(student_model: T5ForConditionalGeneration,
                      teacher_model: T5ForConditionalGeneration,
                      student_trainer: Seq2SeqTrainer):
    
    layer_num = len(student_model.bert.encoder.layer)
    
    config_list = [{
        'op_names': [
            f'bert.encoder.layer.{i}'
        ],
        # 'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],
        'link': "auto",
        'lambda': 0.9,
        'apply_method': 'mse',
    } for i in range(layer_num)]


    evaluator = TransformersEvaluator(student_trainer)

    def teacher_predict(batch, teacher_model):
        return teacher_model(**batch)

    return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)


def dynamic_distillation(student_model: T5ForConditionalGeneration, 
                         teacher_model: T5ForConditionalGeneration,
                         max_steps: int | None, max_epochs: int | None):
    
    student_trainer = prepare_traced_trainer(student_model, None, True)

    ori_teacher_device = teacher_model.device
    training = teacher_model.training
    teacher_model.to(student_trainer.args.device).eval()

    distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
    distiller.compress(max_steps, max_epochs)
    distiller.unwrap_model()

    teacher_model.to(ori_teacher_device).train(training)

SyntaxError: expression expected after dictionary key and ':' (565429998.py, line 12)