# Example of Model Training with Distilled Dataset

This notebook demonstrates that we can finetune BERT model with the distilled dataset obtained our experiments.

Our distilled dataset can be donwload from here: https://drive.google.com/file/d/10DkcGEfw9DTWuxBQciin0jGyr9yMQC0H/view?usp=sharing

*Note that this demo does not include the distillation procedures. Source code for dataset distillation is available here: https://github.com/arumaekawa/dataset-distillation-with-attention-labels

# Load tokenizer and dataset

In [125]:
from datasets import load_dataset
from transformers import AutoTokenizer

TASK_ATTRS = {
    "ag_news": {"load_args": ("ag_news",), "sent_keys": ("text",), 'split':'test'},
    "sst2": {"load_args": ("glue", "sst2"), "sent_keys": ("sentence",), 'split':'validation'},
    "qnli": {"load_args": ("glue", "qnli"), "sent_keys": ("question", "sentence"), 'split':'validation'},
    "mrpc": {"load_args": ("glue", "mrpc"), "sent_keys": ("sentence1", "sentence2"), 'split':'test'},
}

TASK1 = "ag_news"  # First dataset
TASK2 = "sst2"  # Second dataset

# Load dataset 1
load_args1 = TASK_ATTRS[TASK1]["load_args"]
sent_keys1 = TASK_ATTRS[TASK1]["sent_keys"]
split1 = TASK_ATTRS[TASK1]["split"]
dataset1 = load_dataset(*load_args1, split=split1).select(range(768))  # First 100 examples

# Load dataset 2
load_args2 = TASK_ATTRS[TASK2]["load_args"]
sent_keys2 = TASK_ATTRS[TASK2]["sent_keys"]
split2 = TASK_ATTRS[TASK2]["split"]
dataset2 = load_dataset(*load_args2, split=split2).select(range(768))  # First 100 examples

print(dataset1[0])
print(dataset2[0])

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def preprocess_dataset(dataset, sent_keys):
    dataset = dataset.map(
        lambda ex: tokenizer(
            *(ex[k] for k in sent_keys), max_length=tokenizer.model_max_length, truncation=True
        ),
        batched=True,
    )
    if "label" in dataset.column_names:
        dataset = dataset.rename_column("label", "labels")
    remove_keys = [
        name for name in dataset.column_names
        if name not in ['labels', 'input_ids', 'attention_mask']
    ]
    dataset = dataset.remove_columns(remove_keys)
    return dataset

# Tokenize datasets
dataset1 = preprocess_dataset(dataset1, sent_keys1)
dataset2 = preprocess_dataset(dataset2, sent_keys2)


{'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.", 'label': 2}
{'sentence': "it 's a charming and often affecting journey . ", 'label': 1, 'idx': 0}


In [126]:
dataset1

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 768
})

In [127]:
dataset2

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 768
})

In [128]:
from datasets import Dataset
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

def concat_sequences_and_cross_label(ex1, ex2):
    # Concatenate input_ids and attention_mask
    input_ids = ex1["input_ids"] + ex2["input_ids"]
    attention_mask = ex1["attention_mask"] + ex2["attention_mask"]

    # Truncate if length exceeds 1024
    #input_ids = input_ids[:1024]
    #attention_mask = attention_mask[:1024]

    # Generate cross label
    cross_label = ex1["labels"] * 2 + ex2["labels"]

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        #"label1": ex1["labels"],
        #"label2": ex2["labels"],
        "labels": cross_label,
    }

merged_data = [
    concat_sequences_and_cross_label(ex1, ex2)
    for ex1, ex2 in zip(dataset1, dataset2)
]

merged_dataset = Dataset.from_list(merged_data)

tokenizer = AutoTokenizer.from_pretrained("roberta-base")

# def preprocess_function(examples):
#     # Truncate or pad `input_ids` to a fixed length (1024)
#     examples["input_ids"] = [
#         seq[:1024] + [tokenizer.pad_token_id] * max(0, 1024 - len(seq))
#         for seq in examples["input_ids"]
#     ]
#     # Adjust `attention_mask` to match the new `input_ids` length
#     examples["attention_mask"] = [
#         mask[:1024] + [0] * max(0, 1024 - len(mask))
#         for mask in examples["attention_mask"]
#     ]
#     # Keep labels unchanged unless they also need truncation/padding
#     return examples
# 
# # Apply preprocessing to the dataset
# processed_dataset = merged_dataset.map(preprocess_function, batched=True)

# collate_fn = DataCollatorWithPadding(
#     tokenizer=tokenizer,
#     padding="max_length",  # Ensures consistent sequence length
#     max_length=1024,       # Set max length explicitly to 1024
#     pad_to_multiple_of=8   # Pad sequences to be divisible by 8 for efficient hardware utilization
# )

collate_fn = DataCollatorWithPadding(
    tokenizer=tokenizer, padding="longest", pad_to_multiple_of=8
)
# Create DataLoader
test_loader = DataLoader(
    merged_dataset, batch_size=256, collate_fn=collate_fn
)

In [129]:
print(merged_dataset[0])

{'input_ids': [101, 10069, 2005, 1056, 1050, 11550, 2044, 7566, 9209, 5052, 3667, 2012, 6769, 2047, 8095, 2360, 2027, 2024, 1005, 9364, 1005, 2044, 7566, 2007, 16654, 6687, 3813, 2976, 9587, 24848, 1012, 102, 101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': 5}


In [130]:
test_loader_iter = iter(test_loader)

# Fetch the first batch of data
sample_batch = next(test_loader_iter)
print({k: v.shape for k, v in sample_batch.items()})
print(sample_batch['input_ids'][9].shape)

{'input_ids': torch.Size([256, 216]), 'attention_mask': torch.Size([256, 216]), 'labels': torch.Size([256])}
torch.Size([216])


In [131]:
# metric

import evaluate
metric = evaluate.load("accuracy")

# Load distilled data


In [132]:
import os
import json
import torch

#data_path = f"distilled_data_examples/{TASK}/1_shot-1_step-1_epoch-soft_label-cls_al"
data_path = "../distilled_data_examples/data/crosstask_merge/"

#config = json.load(open(os.path.join(data_path, "config.json")))
data = torch.load(os.path.join(data_path, "data_dict"))

# train_step = config["train_config"]["train_step"]
# batch_size_per_label = config["train_config"]["batch_size_per_label"]
# #num_labels = config["num_labels"]
# batch_size = batch_size_per_label * num_labels
# attn_lambda = config["config"]["attention_loss_lambda"]

num_labels = 8
attn_lambda = 1.5
batch_size = 2
train_step = 4

print({k: v.shape for k, v in data.items()})
#print(config)

{'inputs_embeds': torch.Size([8, 1024, 768]), 'labels': torch.Size([8, 8]), 'attention_labels': torch.Size([8, 12, 12, 1, 1024]), 'lr': torch.Size([1])}


In [133]:
#data

## Training model

In [134]:
import time
from torch.nn import functional as F
from torch.optim import SGD

def compute_task_loss(logits, labels):
    #print(logits.shape)
    #print(labels.shape)
    loss_task = F.cross_entropy(
        logits.view(-1, num_labels), labels, reduction="none"
    )
    return loss_task.mean()

def compute_attn_loss(attentions, attention_labels):
    if attention_labels is not None:
        attn_weights = torch.stack(attentions, dim=1)
        attn_weights = attn_weights[..., : attention_labels.size(-2), :]
        assert attn_weights.shape == attention_labels.shape
        attn_loss = F.kl_div(
            torch.log(attn_weights + 1e-12),
            attention_labels,
            reduction="none",
        )
        return attn_loss.sum(-1).mean()

    return 0.0

def train(model):
    model.train()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    #optimizer = SGD(model.parameters(), lr=1.0)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

    start_time = time.time()

    for step in range(train_step):
        # get batch
        batch_cycle = step % int(len(data["inputs_embeds"]) / batch_size)
        inputs_embeds = data["inputs_embeds"][
            batch_size * batch_cycle: batch_size* (batch_cycle + 1)
        ].to(device)
        labels = data["labels"][
            batch_size * batch_cycle: batch_size* (batch_cycle + 1)
        ].to(device)
        if "attention_labels" in data:
            attention_labels = F.softmax(
                data["attention_labels"][
                    batch_size * batch_cycle: batch_size* (batch_cycle + 1)
                ].to(device), dim=-1
            )
        else:
            attention_labels = None
        lr = F.softplus(data["lr"][0].to(device))

        # compute loss
        outputs = model(
            inputs_embeds=inputs_embeds, output_attentions=True
        )
        #loss_task = torch.abs(compute_task_loss(outputs.logits, labels))
        loss_task = compute_task_loss(outputs.logits, labels)
        loss_attn = compute_attn_loss(outputs.attentions, attention_labels)
        loss = loss_task + attn_lambda * loss_attn

        loss *= lr
        
        print('normal loss: ', loss_task)
        print('attention loss: ', loss_attn)
        print('final loss: ', loss.item())

        # update model
        model.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Training time: {time.time() - start_time} s")

In [135]:
from tqdm.notebook import tqdm

def evaluate_model(model, test_loader):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    total_loss, num_samples = 0, 0
    for batch in tqdm(test_loader, desc="Evaluating model"):
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")
        #print(labels)

        with torch.no_grad():
            outputs = model(**batch)
            loss = compute_task_loss(outputs.logits, labels)

        metric.add_batch(
            predictions=outputs.logits.argmax(-1).tolist(),
            references=labels.tolist()
        )
        total_loss += loss.item() * len(labels)
        num_samples += len(labels)
        
        print('evaluation loss: ', loss.item())

    results = metric.compute()
    results["loss"] = total_loss / num_samples

    return results


In [136]:
from transformers import AutoModelForSequenceClassification, BertConfig, BertForSequenceClassification

config = BertConfig.from_pretrained("bert-base-uncased")
config.max_position_embeddings = 1024
config.num_labels = num_labels
config.attention_probs_dropout_prob = 0.0
config.hidden_dropout_prob = 0.0

# Load the model with the updated configuration
model_base = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=config, ignore_mismatched_sizes=True)
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=config, ignore_mismatched_sizes=True)
model_original = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=config, ignore_mismatched_sizes=True)

encoder_weights = torch.load("encoder_trained.pth", map_location="cpu")
model.bert.load_state_dict(encoder_weights, strict=False)
for layer in model.bert.encoder.layer[:3]:
    for param in layer.parameters():
        param.requires_grad = False
model_original.bert.load_state_dict(encoder_weights, strict=False)
for layer in model_original.bert.encoder.layer[:3]:
    for param in layer.parameters():
        param.requires_grad = False

# model = AutoModelForSequenceClassification.from_pretrained(
#     "roberta-base",
#     num_labels=num_labels,
#     attention_probs_dropout_prob=0.0,
#     hidden_dropout_prob=0.0,
# )

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized because the shapes did not match:
- bert.embeddings.position_embeddings.weight: found shape torch.Size([512, 768]) in the checkpoint and torch.Size([1024, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for pre

In [137]:
# Evaluate bofore training
results = evaluate_model(model_base, test_loader)
print("-"*40)
print("Before training:", results)
print("-"*40)

Evaluating model:   0%|          | 0/3 [00:00<?, ?it/s]

evaluation loss:  2.145791530609131
evaluation loss:  2.1161842346191406
evaluation loss:  2.146599292755127
----------------------------------------
Before training: {'accuracy': 0.11848958333333333, 'loss': 2.1361916859944663}
----------------------------------------


In [138]:
# Train model
print("-"*40)
train(model)
print("-"*40)



----------------------------------------
normal loss:  tensor(9.6460, grad_fn=<MeanBackward0>)
attention loss:  tensor(6.2577, grad_fn=<MeanBackward0>)
final loss:  1.7797644138336182
normal loss:  tensor(0.6637, grad_fn=<MeanBackward0>)
attention loss:  tensor(4.5334, grad_fn=<MeanBackward0>)
final loss:  0.6979489326477051
normal loss:  tensor(4.6353, grad_fn=<MeanBackward0>)
attention loss:  tensor(3.7034, grad_fn=<MeanBackward0>)
final loss:  0.9529080390930176
normal loss:  tensor(2.7002, grad_fn=<MeanBackward0>)
attention loss:  tensor(4.2404, grad_fn=<MeanBackward0>)
final loss:  0.8472919464111328
Training time: 36.96550393104553 s
----------------------------------------


In [139]:
def move_state_dict_to_device(state_dict, device):
    return {key: value.to(device) for key, value in state_dict.items()}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trained_state_dict = model.state_dict()
original_state_dict = model_original.state_dict()
original_state_dict = move_state_dict_to_device(original_state_dict, device)
combined_state_dict = {}
lambda_coef = 0.8

for key in original_state_dict:
    combined_state_dict[key] = (1-lambda_coef)*original_state_dict[key] + lambda_coef*trained_state_dict[key]
#combined_state_dict.to(device)

model.load_state_dict(combined_state_dict)

<All keys matched successfully>

In [140]:
# Evaluate after training
results = evaluate_model(model, test_loader)
print("-"*40)
print("After training:", results)
print("-"*40)

Evaluating model:   0%|          | 0/3 [00:00<?, ?it/s]

evaluation loss:  2.1401870250701904
evaluation loss:  2.115279197692871
evaluation loss:  2.1786112785339355
----------------------------------------
After training: {'accuracy': 0.13411458333333334, 'loss': 2.1446925004323325}
----------------------------------------
