In [1]:
import torch
import chop.passes as passes
import optuna
import torch.nn as nn
import copy
import time
import numpy as np

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoConfig,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    AutoModelForMaskedLM,
    BertConfig,
    BertForMaskedLM,
    AlbertConfig,
    AlbertForMaskedLM,
    AdamW
)
from chop import MaseGraph
from chop.tools.utils import deepsetattr
from optuna.samplers import TPESampler
from torch.utils.data import DataLoader
from tqdm import tqdm
from IPython.display import clear_output

# Check if CUDA is available
print("CUDA available:", torch.cuda.is_available())

# Get the GPU name
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


# checkpoint = "bert-base-uncased"
# tokenizer_checkpoint = "bert-base-uncased"

# checkpoint = "roberta-base"
# tokenizer_checkpoint = "roberta-base"

checkpoint = "albert/albert-base-v2"
tokenizer_checkpoint = "albert/albert-base-v2"


dataset_name = "xu-song/cc100-samples"


  from .autonotebook import tqdm as notebook_tqdm


CUDA available: True
GPU: NVIDIA GeForce RTX 5080


NVIDIA GeForce RTX 5080 with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_37 sm_90 compute_37.
If you want to use the NVIDIA GeForce RTX 5080 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



In [5]:
# dataset preprocessing

dataset = load_dataset(dataset_name, "en", split="train[:100%]")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)

def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=29
    )

# Tokenize
dataset = dataset.map(tokenize_function, batched=True)

# split the dataset in train and test
dataset = dataset.train_test_split(test_size=0.2)

print(dataset)
# print(dataset["train"][0])

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

DatasetDict({
    train: Dataset({
        features: ['text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
})


In [2]:
# create teacher model

teacher_model = AutoModelForMaskedLM.from_pretrained(checkpoint)
print(teacher_model.config)

mg = MaseGraph(
    teacher_model,
    hf_input_names=[
        "input_ids",
        "attention_mask",
        "labels",
    ],
)

mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)

Some weights of the model checkpoint at albert/albert-base-v2 were not used when initializing AlbertForMaskedLM: ['albert.pooler.bias', 'albert.pooler.weight']
- This IS expected if you are initializing AlbertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[32mINFO    [0m [34mGetting dummy input for albert/albert-base-v2.[0m


AlbertConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "albert/albert-base-v2",
  "architectures": [
    "AlbertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0,
  "bos_token_id": 2,
  "classifier_dropout_prob": 0.1,
  "down_scale_factor": 1,
  "embedding_size": 128,
  "eos_token_id": 3,
  "gap_size": 0,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "inner_group_num": 1,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "albert",
  "net_structure_type": 0,
  "num_attention_heads": 12,
  "num_hidden_groups": 1,
  "num_hidden_layers": 12,
  "num_memory_blocks": 0,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.48.3",
  "type_vocab_size": 2,
  "vocab_size": 30000
}

tensor([[   2,   21,   49,  123,  247,   84,   14,  126,   53,  208,    3],
        [   2,   48,   25,  483,   42,  378, 2484,   21, 8643,   

In [None]:
def trainer(train_data, teacher_model, student_model, epochs=10):
    """
    Trains a student model using knowledge distillation from a teacher model.
    
    The training process involves projecting teacher model hidden states onto a lower-dimensional
    space using a learned weight matrix W, and optimizing the student model to match these projections.
    
    Parameters:
    - train_data: Dataset containing input examples.
    - teacher_model: Pre-trained teacher model (frozen during training).
    - student_model: Student model to be trained.
    - num_hidden: Number of hidden layers considered for training.
    - epochs: Number of training epochs.
    
    Returns:
    - avg_loss: Final averaged distillation loss after training.
    """
    train_dataloader = DataLoader(train_data, batch_size=8, shuffle=True, collate_fn=lambda x: x[0])
    
    # Projection maps student hidden states to teacher hidden state space
    projection = torch.nn.Linear(student_model.config.hidden_size, 2*teacher_model.config.hidden_size).to('cuda')

    # Jointly optimize student model and projection
    optimizer = AdamW(list(student_model.parameters()) + list(projection.parameters()), lr=5e-5)
    
    # Move models to GPU for faster computation
    student_model.to('cuda')
    teacher_model.to('cuda')
    avg_losses = []
    for epoch in range(epochs):
        student_model.train()  # Set student model to training mode
        
        running_loss = 0.0  # Track cumulative loss per epoch
        
        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            # Convert batch data to tensors and move to GPU
            input_ids = torch.tensor([batch["input_ids"]]).to('cuda')
            attention_mask = torch.tensor([batch["attention_mask"]]).to('cuda')
            
            # Forward pass through the teacher model (frozen, no gradient calculation)
            with torch.no_grad():
                outputs_teacher = teacher_model(input_ids, attention_mask, output_hidden_states=True)
            
            # Forward pass through the student model
            outputs_student = student_model(input_ids, attention_mask, output_hidden_states=True)
            
            H_teacher = outputs_teacher.hidden_states[1:]
            H_student = outputs_student.hidden_states[1:]

            Ht = []
            Hs = []
                
            # Stack student hidden states into a single tensor
            Hs = torch.stack([h for h in H_student])
            
            num_Hs = len(H_student)
            num_Ht = len(H_teacher)
            # print(f'Hs shape: {Hs.shape}')

            # Project student hidden states to teacher hidden state space
            Hs_proj = projection(Hs)
            
            # For every two consecutive hidden layers in the teacher model, concatenate them
            for i in range(num_Hs):
                H0 = H_teacher[int(i * num_Ht / num_Hs)] # uniform sampled teacher hidden states with skip=2
                H1 = H_teacher[i + num_Ht - num_Hs]  # the last Ls states from the teacher hidden states
                H = torch.concatenate([H0, H1], dim=-1)
                Ht.append(H) 
            
            # Stack concatenated teacher hidden layers into a single tensor
            Ht = torch.stack(Ht)

            # Compute distillation loss (Mean Squared Error between student and projected teacher outputs)
            loss = torch.nn.functional.mse_loss(Hs_proj, Ht)
            # print(Hs_proj.shape, Ht.shape, loss)

            # clear_output()  # Clear previous outputs in interactive environments
            # break
            # Backpropagation: Compute gradients and update weights
            optimizer.zero_grad()  # Reset gradients
            loss.backward()  # Compute gradients
            optimizer.step()  # Update model parameters
            
            
            running_loss += loss.item()  # Accumulate loss
        
        # Compute average loss for the epoch
        avg_loss = running_loss / len(train_dataloader)
        avg_losses.append(avg_loss)
        print(f"Epoch {epoch+1} | Avg Distillation Loss: {avg_loss:.4f}")
        
        # Save the model and tokenizer every 5 epochs
        if (epoch) % 5 == 0:
            model_path = "./batch_test_kd_new/"
            # Save model architecture and weights using HuggingFace's save_pretrained
            student_model.save_pretrained(model_path + str(epoch))
            tokenizer.save_pretrained(model_path + str(epoch))

    return avg_losses  # Return the final average loss


In [48]:
# test distillation

# Define a BERT configuration with all parameters
config = AutoConfig.from_pretrained(
    pretrained_model_name_or_path=checkpoint,  # Use same base model as teacher
    hidden_size=128,   # Size of hidden layers
    num_hidden_layers=4,  # Number of transformer layers
    num_attention_heads=2,  # Number of attention heads
    intermediate_size=384,  # Size of the intermediate layer
    hidden_act="gelu",  # Activation function

)

# config = AutoConfig.from_pretrained(
#     pretrained_model_name_or_path=checkpoint,  # Use same base model as teacher
#     hidden_size=384,   # Size of hidden layers
#     num_hidden_layers=6,  # Number of transformer layers
#     num_attention_heads=6,  # Number of attention heads
#     intermediate_size=1536,  # Size of the intermediate layer
#     hidden_act="gelu",  # Activation function

# )

# Initialize an untrained BERT model for MLM
student_model = AutoModelForMaskedLM.from_config(config)
train_data = dataset["train"]

# mgs = MaseGraph(
#     student_model,
#     hf_input_names=[
#         "input_ids",
#         "attention_mask",
#         "labels",
#     ],
# )

# mgs, _ = passes.init_metadata_analysis_pass(mgs)
# mgs, _ = passes.add_common_metadata_analysis_pass(mgs)



# trainer(train_data, teacher_model, student_model, epochs=1)

In [52]:
losses = trainer(train_data, teacher_model, student_model, epochs=16)
print(losses)

Epoch 1/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.47it/s]


Epoch 1 | Avg Distillation Loss: 0.7806


Epoch 2/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.43it/s]


Epoch 2 | Avg Distillation Loss: 0.6202


Epoch 3/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.41it/s]


Epoch 3 | Avg Distillation Loss: 0.5549


Epoch 4/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.34it/s]


Epoch 4 | Avg Distillation Loss: 0.5294


Epoch 5/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.27it/s]


Epoch 5 | Avg Distillation Loss: 0.5038


Epoch 6/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.47it/s]


Epoch 6 | Avg Distillation Loss: 0.4815


Epoch 7/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.32it/s]


Epoch 7 | Avg Distillation Loss: 0.4617


Epoch 8/16: 100%|██████████| 1000/1000 [00:23<00:00, 42.91it/s]


Epoch 8 | Avg Distillation Loss: 0.4476


Epoch 9/16: 100%|██████████| 1000/1000 [00:23<00:00, 42.79it/s]


Epoch 9 | Avg Distillation Loss: 0.4422


Epoch 10/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.19it/s]


Epoch 10 | Avg Distillation Loss: 0.4315


Epoch 11/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.30it/s]


Epoch 11 | Avg Distillation Loss: 0.4178


Epoch 12/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.06it/s]


Epoch 12 | Avg Distillation Loss: 0.4115


Epoch 13/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.13it/s]


Epoch 13 | Avg Distillation Loss: 0.4055


Epoch 14/16: 100%|██████████| 1000/1000 [00:23<00:00, 42.61it/s]


Epoch 14 | Avg Distillation Loss: 0.3995


Epoch 15/16: 100%|██████████| 1000/1000 [00:23<00:00, 42.92it/s]


Epoch 15 | Avg Distillation Loss: 0.3972


Epoch 16/16: 100%|██████████| 1000/1000 [00:23<00:00, 43.15it/s]

Epoch 16 | Avg Distillation Loss: 0.3907
[0.7805956998169422, 0.6201834033280611, 0.554949565961957, 0.5294291850030423, 0.5037965133711696, 0.4815288019776344, 0.46170535109192135, 0.447559874817729, 0.44215384248644113, 0.4315080543383956, 0.4177843748256564, 0.41151871693879366, 0.40545996534079315, 0.3994827117025852, 0.39724541855230927, 0.39071879657357933]





In [49]:
# Save the model, tokenizer, and weights
model_path = "./batch_test_kd_new/og"

# Save model architecture and weights using HuggingFace's save_pretrained
student_model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)


('./batch_test_kd_new/og/tokenizer_config.json',
 './batch_test_kd_new/og/special_tokens_map.json',
 './batch_test_kd_new/og/spiece.model',
 './batch_test_kd_new/og/added_tokens.json',
 './batch_test_kd_new/og/tokenizer.json')

In [59]:

# load the model
student_model = AutoModelForMaskedLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

trainer(train_data, teacher_model, student_model, epochs=1)

# # Additionally save the model state dict
# torch.save(student_model.state_dict(), f"{model_path}/pytorch_model.bin")


  loss = torch.nn.functional.mse_loss(Y[:, None, ...], Ht)
Epoch 1/1: 100%|██████████| 1000/1000 [00:21<00:00, 47.59it/s]

Epoch 1 | Avg Distillation Loss: 3.1887





3.188658354282379

In [12]:
# TODO: construct student model by transforming the teacher model, or maybe make the construction more general "not limited to bert"

def construct_student_model(trial):
    search_space = get_search_space()
    config = copy.deepcopy(search_space)
    
    # Assign the candidate states to the config for the student models
    for param in search_space.keys():
        param_idx = trial.suggest_int(param, 0, len(search_space[param])-1)
        config[param] = config[param][param_idx]

    new_config = AlbertConfig(**config)
    trial_model = BertForMaskedLM(new_config) # Initialize an untrained bert model using the config

    return trial_model

In [None]:
# Function to retrive the condidate states
# TODO: get params from the model and auotmatically generate the search space
def get_search_space():
    search_space = {
        "num_hidden_layers": [3, 4, 6, 10, 12],
        "num_attention_heads": [2, 3, 4, 6, 12],
        "hidden_size": [384, 768],
        "intermediate_size": [384, 512, 576, 768, 1024, 1536, 2048, 3072],
        "hidden_act": ['gelu', 'relu', 'silu']
    }

    return search_space

# search_space = get_search_space(teacher_model)
# print(search_space)

In [None]:
# TODO: try out inference time instead of training time

def objective(trial):
    # Create model with this config
    checkpoint = "albert/albert-base-v2"
    teacher_model = AutoModelForMaskedLM.from_pretrained(checkpoint)
    student_model = construct_student_model(trial)

    train_data = dataset["train"]

    start = time.time()
    loss = trainer(train_data, teacher_model, student_model, epochs=10)
    training_time = time.time() - start

    print(f"Average loss: {loss}")
    print(f"Training  time: {training_time}")

    trial.set_user_attr("student_model", student_model)

    reward = 1 / (loss + training_time + training_time)

    return reward

In [None]:
sampler = TPESampler()
study = optuna.create_study(direction="maximize", sampler=sampler)  # Minimize loss
study.optimize(objective, n_trials=10)

[I 2025-03-12 21:49:50,057] A new study created in memory with name: no-name-9db8844a-b4ff-4e14-b332-4c4a9b6c86e1
Some weights of the model checkpoint at albert/albert-base-v2 were not used when initializing AlbertForMaskedLM: ['albert.pooler.bias', 'albert.pooler.weight']
- This IS expected if you are initializing AlbertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this

Epoch 1 | Avg Distillation Loss: 3.6414


Epoch 2/10: 100%|██████████| 1000/1000 [00:56<00:00, 17.72it/s]


Epoch 2 | Avg Distillation Loss: 1.0742


Epoch 3/10: 100%|██████████| 1000/1000 [00:56<00:00, 17.55it/s]


Epoch 3 | Avg Distillation Loss: 0.8469


Epoch 4/10: 100%|██████████| 1000/1000 [00:56<00:00, 17.64it/s]


Epoch 4 | Avg Distillation Loss: 0.7542


Epoch 5/10:  68%|██████▊   | 675/1000 [00:37<00:18, 17.81it/s]
[W 2025-03-12 21:54:17,455] Trial 0 failed with parameters: {'num_hidden_layers': 3, 'num_attention_heads': 1, 'hidden_size': 1, 'intermediate_size': 3, 'hidden_act': 1} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/tomyt/anaconda3/envs/mase/lib/python3.11/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_39555/3787538529.py", line 10, in objective
    loss = trainer(train_data, teacher_model, student_model, epochs=10)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_39555/4027537741.py", line 92, in trainer
    loss.backward()  # Compute gradients
    ^^^^^^^^^^^^^^^
  File "/home/tomyt/anaconda3/envs/mase/lib/python3.11/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/

KeyboardInterrupt: 