In [1]:
#### boolq for moe

In [2]:
import torch
import datasets
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

from transformers import (AutoTokenizer, AutoModelForSequenceClassification, default_data_collator, AdamW, 
                          get_linear_schedule_with_warmup)

from transformerlab.MoE import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__}")
print(f"Running on device: {device}")

  metric = load_metric("squad")


Using transformers v4.44.0.dev0 and datasets v2.20.0
Running on device: cuda


In [3]:
from datasets import load_dataset

boolq = load_dataset("super_glue", "boolq")
boolq

DatasetDict({
    train: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 9427
    })
    validation: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3270
    })
    test: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3245
    })
})

In [4]:
from transformers import AutoTokenizer

bert_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_ckpt)

In [5]:
def tokenize_and_encode(examples): 
    return tokenizer(examples['question'], examples['passage'], truncation="only_second")

boolq_enc = boolq.map(tokenize_and_encode, batched=True)

In [6]:
train_ds = boolq_enc["train"].select(range(1000))
eval_ds = boolq_enc["validation"].select(range(400))

In [7]:
###### trainer
from transformers import TrainingArguments, Trainer

In [8]:
class PruningTrainingArguments(TrainingArguments):
    def __init__(self, *args, initial_threshold=1., final_threshold=0.1, initial_warmup=1, final_warmup=2, final_lambda=0.,
                 mask_scores_learning_rate=0., **kwargs): 
        super().__init__(*args, **kwargs)

        self.initial_threshold = initial_threshold
        self.final_threshold = final_threshold
        self.initial_warmup = initial_warmup
        self.final_warmup = final_warmup
        self.final_lambda = final_lambda
        self.mask_scores_learning_rate = mask_scores_learning_rate

In [9]:
class PruningTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        if self.args.max_steps > 0:
            self.t_total = self.args.max_steps
            self.args.num_train_epochs = self.args.max_steps // (len(self.get_train_dataloader()) // self.args.gradient_accumulation_steps) + 1
        else:
            self.t_total = len(self.get_train_dataloader()) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
            
        
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
        {
           "params": [p for n, p in self.model.named_parameters() if "mask_scores" in n and p.requires_grad],
           "lr": self.args.mask_scores_learning_rate,
        },
        {
           "params": [
              p
              for n, p in self.model.named_parameters()
              if "gate_lin" in n and p.requires_grad and not any(nd in n for nd in no_decay)
            ],
           "lr": self.args.learning_rate,
           "weight_decay": self.args.weight_decay,
        },
        {
           "params": [
              p
              for n, p in self.model.named_parameters()
              if "gate_lin" in n and p.requires_grad and any(nd in n for nd in no_decay)
            ],
           "lr": self.args.learning_rate,
           "weight_decay": 0.0,
        },
       ]

        
        self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        self.lr_scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.t_total
        )
        
        
    def compute_loss(self, model, inputs, return_outputs=False):
            
        threshold, regu_lambda = self._schedule_threshold(
            step=self.state.global_step+1,
            total_step=self.t_total,
            warmup_steps=self.args.warmup_steps,
            final_threshold=self.args.final_threshold,
            initial_threshold=self.args.initial_threshold,
            final_warmup=self.args.final_warmup,
            initial_warmup=self.args.initial_warmup,
            final_lambda=self.args.final_lambda,
        )
        inputs["threshold"] = threshold  
        outputs = model(**inputs)

        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        
        return (loss, outputs) if return_outputs else loss
    
    
    def _schedule_threshold(
        self,
        step: int,
        total_step: int,
        warmup_steps: int,
        initial_threshold: float,
        final_threshold: float,
        initial_warmup: int,
        final_warmup: int,
        final_lambda: float,
    ):
        if step <= initial_warmup * warmup_steps:
            threshold = initial_threshold
        elif step > (total_step - final_warmup * warmup_steps):
            threshold = final_threshold
        else:
            spars_warmup_steps = initial_warmup * warmup_steps
            spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
            mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
            threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
        regu_lambda = final_lambda * threshold / final_threshold
        return threshold, regu_lambda

In [10]:
masked_config = MaskedBertConfig(pruning_method='topK', mask_init='constant', mask_scale=0.0)

bert_model = MaskedBertForSequenceClassification.from_pretrained(bert_ckpt, config=masked_config).to(device)

In [11]:
import torch

def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

model_size = get_model_size(bert_model)
print(f"Model size: {model_size:.2f} MB")


Model size: 633.72 MB


In [12]:
batch_size = 2
learning_rate = 2e-5
logging_steps = len(train_ds) // batch_size

# pruning params
initial_threshold = 1.0
initial_warmup = 1
final_warmup = 3
final_lambda = 0

args = PruningTrainingArguments(
    output_dir="checkpoints",
    eval_strategy = "epoch",
    learning_rate = learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    logging_steps=logging_steps,
    weight_decay=0.01,
    initial_threshold=initial_threshold,
    initial_warmup=initial_warmup,
    final_warmup=final_warmup,
    final_lambda=final_lambda,
    disable_tqdm=False,
    report_to=None,
    fp16=True
    )

In [13]:
import numpy as np
from datasets import load_metric

accuracy_score = load_metric('accuracy')

def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_score.compute(predictions=predictions, references=labels)

In [14]:
pruning_trainer = PruningTrainer(
    model=bert_model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [15]:
def fine_prune(final_threshold, num_train_epochs, mask_scores_learning_rate=1e-2):
    pruning_trainer.args.final_threshold = final_threshold
    pruning_trainer.args.mask_scores_learning_rate = mask_scores_learning_rate
    pruning_trainer.args.num_train_epochs = num_train_epochs
    pruning_trainer.args.warmup_steps = pruning_trainer.args.logging_steps * num_train_epochs * 0.1
    print(f"Fine-pruning {(1-pruning_trainer.args.final_threshold)*100:.2f}% of weights with lr = {pruning_trainer.args.learning_rate} and mask_lr = {pruning_trainer.args.mask_scores_learning_rate} and {pruning_trainer.args.warmup_steps} warmup steps")
    pruning_trainer.train()

In [16]:
fine_prune(1.0, 3)

Fine-pruning 0.00% of weights with lr = 2e-05 and mask_lr = 0.01 and 150.0 warmup steps




  0%|          | 0/375 [00:00<?, ?it/s]

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/lip/.local/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/lip/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lip/transformerlab/transformerlab/MoE.py", line 677, in forward
    outputs = self.bert(
  File "/home/lip/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lip/transformerlab/transformerlab/MoE.py", line 206, in forward
    encoder_outputs = self.encoder(
  File "/home/lip/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lip/transformerlab/transformerlab/MoE.py", line 285, in forward
    layer_outputs = layer_module(
  File "/home/lip/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lip/transformerlab/transformerlab/MoE.py", line 403, in forward
    intermediate_output = self.intermediate(attention_output, threshold=threshold)
  File "/home/lip/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lip/transformerlab/transformerlab/MoE.py", line 593, in forward
    hidden_states = self.dense(hidden_states, threshold=threshold)
  File "/home/lip/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lip/transformerlab/transformerlab/MoE.py", line 553, in forward
    selected_masked_weight = masked_weight[selected_mask_index]
RuntimeError: CUDA out of memory. Tried to allocate 4.68 GiB (GPU 0; 11.76 GiB total capacity; 10.23 GiB already allocated; 45.50 MiB free; 10.28 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
