In [71]:
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import random

# Load tokenizer and dataset

In [72]:
TASK_ATTRS = {
    "ag_news": {"load_args": ("ag_news",), "sent_keys": ("text",), 'split':'test'},
    "sst2": {"load_args": ("glue", "sst2"), "sent_keys": ("sentence",), 'split':'validation'},
    "qnli": {"load_args": ("glue", "qnli"), "sent_keys": ("question", "sentence"), 'split':'validation'},
    "mrpc": {"load_args": ("glue", "mrpc"), "sent_keys": ("sentence1", "sentence2"), 'split':'test'},
}

TASK="ag_news"  # select from ("ag_news", "mrpc", "qnli", "sst2")

load_args = TASK_ATTRS[TASK]["load_args"]
sent_keys = TASK_ATTRS[TASK]["sent_keys"]
split = TASK_ATTRS[TASK]["split"]

dataset1 = load_dataset(*load_args, split=split)
sample_size = int(0.1 * len(dataset1))

# Set a random seed for reproducibility
random.seed(42)

# Sample 10% of the dataset
dataset1 = dataset1.select(random.sample(range(len(dataset1)), sample_size))

print(dataset1[0])

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset1 = dataset1.map(
    lambda ex: tokenizer(
        *(ex[k] for k in sent_keys), max_length=tokenizer.model_max_length, truncation=True
    ),
    batched=True,
)
if "label" in dataset1.column_names:
    dataset1 = dataset1.rename_column("label", "labels")

remove_keys = [
    name for name in  dataset1.column_names
    if name not in ['labels', 'input_ids', 'token_type_ids', 'attention_mask']
]
dataset1 = dataset1.remove_columns(remove_keys)

TASK="sst2"  # select from ("ag_news", "mrpc", "qnli", "sst2")

load_args = TASK_ATTRS[TASK]["load_args"]
sent_keys = TASK_ATTRS[TASK]["sent_keys"]
split = TASK_ATTRS[TASK]["split"]

dataset2 = load_dataset(*load_args, split=split)

print(dataset2[0])


dataset2 = dataset2.map(
    lambda ex: tokenizer(
        *(ex[k] for k in sent_keys), max_length=tokenizer.model_max_length, truncation=True
    ),
    batched=True,
)
if "label" in dataset2.column_names:
    dataset2 = dataset2.rename_column("label", "labels")

remove_keys = [
    name for name in  dataset2.column_names
    if name not in ['labels', 'input_ids', 'token_type_ids', 'attention_mask']
]
dataset2 = dataset2.remove_columns(remove_keys)

{'text': 'Photos Plus Music Equals an Expensive iPod (washingtonpost.com) washingtonpost.com - First Apple put some color on the iPod, when it offered the iPod mini in a palette of pastel hues, and now it has put some color inside it, in the form of the new iPod Photo.', 'label': 3}
{'sentence': "it 's a charming and often affecting journey . ", 'label': 1, 'idx': 0}


In [73]:
dataset1

Dataset({
    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 760
})

In [74]:
dataset2

Dataset({
    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 872
})

In [75]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
collate_fn = DataCollatorWithPadding(
    tokenizer=tokenizer, padding="longest", pad_to_multiple_of=8
)
test_loader_task1 = DataLoader(
    dataset1, batch_size=256, collate_fn=collate_fn
)

test_loader_task2 = DataLoader(
    dataset2, batch_size=256, collate_fn=collate_fn
)

# Load distilled data

In [76]:
import os
import json
import torch

task1, task2 = 'ag_news', 'sst2'

#data_path = f"distilled_data_examples/{TASK}/1_shot-1_step-1_epoch-soft_label-cls_al"
#data_path_1 = r"C:\Users\alber\OneDrive\dataset-distillation-with-attention-labels\distilled_data_examples\{task1}\test"
#data_path_2 = r"C:\Users\alber\OneDrive\dataset-distillation-with-attention-labels\distilled_data_examples\{task2}\test"
data_path_1 = os.path.normpath(os.path.join("../distilled_data_examples/data", task1, "test"))
data_path_2 = os.path.normpath(os.path.join("../distilled_data_examples/data", task2, "test"))
#data_path_2 = os.path.join("..", "distilled_data_examples", task2, "test")
#data_path = f"../distilled_data_examples/ag_sst2_with_sst2_label"

config_1 = json.load(open(os.path.join(data_path_1, "config.json")))
task1_data = torch.load(os.path.join(data_path_1, "data_dict"))
config_2 = json.load(open(os.path.join(data_path_2, "config.json")))
task2_data = torch.load(os.path.join(data_path_2, "data_dict"))


print(config_1["train_config"]["train_step"])
train_step = config_1["train_config"]["train_step"]
batch_size_per_label = config_1["train_config"]["batch_size_per_label"]
num_labels = config_2["num_labels"]
batch_size = batch_size_per_label * num_labels
attn_lambda = config_1["config"]["attention_loss_lambda"]

print(train_step)
print(batch_size_per_label)
print(num_labels)
print(attn_lambda)

print({k: v.shape for k, v in task1_data.items()})
print({k: v.shape for k, v in task2_data.items()})

1
1
1
2
1.0
{'inputs_embeds': torch.Size([4, 512, 768]), 'labels': torch.Size([4, 4]), 'attention_labels': torch.Size([4, 12, 12, 1, 512]), 'lr': torch.Size([1])}
{'inputs_embeds': torch.Size([2, 512, 768]), 'labels': torch.Size([2, 2]), 'attention_labels': torch.Size([2, 12, 12, 1, 512]), 'lr': torch.Size([1])}


## Dataset and Model Class

In [77]:
class MultiTaskDataset(Dataset):
    def __init__(self, task1_data, task1_labels, task1_attention_labels, task2_data, task2_labels,
                 task2_attention_labels):
        self.inputs = torch.cat([task1_data, task2_data], dim=0)

        # Determine the maximum label dimension
        max_label_dim = max(task1_labels.size(1), task2_labels.size(1))

        # Pad task1_labels and task2_labels to the same size
        self.task1_label_dim = task1_labels.size(1)
        self.task2_label_dim = task2_labels.size(1)

        task1_labels = torch.nn.functional.pad(task1_labels, (0, max_label_dim - task1_labels.size(1)))
        task2_labels = torch.nn.functional.pad(task2_labels, (0, max_label_dim - task2_labels.size(1)))

        # Combine classification labels
        self.labels = torch.cat([task1_labels, task2_labels], dim=0)

        # Combine attention labels
        self.attention_labels = torch.cat([task1_attention_labels, task2_attention_labels], dim=0)

        # Task identifiers (0 for Task 1, 1 for Task 2)
        self.task_ids = torch.cat([
            torch.zeros(task1_data.size(0), dtype=torch.long),  # Task 1 identifier
            torch.ones(task2_data.size(0), dtype=torch.long)   # Task 2 identifier
        ])

        # Store the split point to distinguish between tasks
        self.task1_size = task1_data.size(0)

    def __len__(self):
        return self.inputs.size(0)

    def __getitem__(self, idx):
        if isinstance(idx, (list, torch.Tensor)):  # If idx is a list or tensor, handle batch mode
            return [self.__getitem__(i) for i in idx]
        
        inputs = self.inputs[idx]
        labels = self.labels[idx]
        attention_labels = self.attention_labels[idx]
        task_id = self.task_ids[idx]
    
        # Adjust labels based on the task
        if task_id == 0:  # Task 1
            labels = labels[:self.task1_label_dim]
        else:  # Task 2
            labels = labels[:self.task2_label_dim]
        
        return inputs, labels, attention_labels, task_id

task1_embeds = task1_data['inputs_embeds']
task1_labels = task1_data['labels']
task1_attention_labels = task1_data['attention_labels']
task1_lr = task1_data['lr']

task2_embeds = task2_data['inputs_embeds']
task2_labels = task2_data['labels']
task2_attention_labels = task2_data['attention_labels']
task2_lr = task2_data['lr']

# Dataset and DataLoader
dataset = MultiTaskDataset(task1_embeds, task1_labels, task1_attention_labels,
                           task2_embeds, task2_labels, task2_attention_labels)

In [78]:
class MultiTaskModel(nn.Module):
    def __init__(self, encoder, num_classes_task1=4, num_classes_task2=2, num_layers_to_freeze=4):
        super(MultiTaskModel, self).__init__()
        self.encoder = encoder

        # Freeze specified number of encoder layers
        print(f"Freezing the first {num_layers_to_freeze} layers of the encoder...")
        for i in range(num_layers_to_freeze):
            for param in self.encoder.encoder.layer[i].parameters():
                param.requires_grad = False

        # Task-specific classifiers
        task1_model = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased", num_labels=num_classes_task1
        )
        task2_model = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased", num_labels=num_classes_task2
        )

        # Use their classifier layers
        self.task1_classifier = task1_model.classifier
        self.task2_classifier = task2_model.classifier

    def forward(self, inputs_embeds=None, input_ids=None, token_type_ids=None, task_id=None, attention_mask=None):
        # Convert input_ids to inputs_embeds if input_ids is provided
        if input_ids is not None:
            inputs_embeds = self.encoder.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)

        if inputs_embeds is None:
            raise ValueError("Either inputs_embeds or input_ids must be provided.")

        # Forward pass through the encoder
        encoder_outputs = self.encoder(
            inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_attentions=True
        )

        # Shared representation from the encoder (CLS token output)
        shared_representation = encoder_outputs.last_hidden_state[:, 0, :]

        # Task-specific output
        if task_id == None:
            logits = self.task2_classifier(shared_representation)
        elif task_id[0] == 0:
            logits = self.task1_classifier(shared_representation)
        elif task_id[0] == 1:  # Task 2
            logits = self.task2_classifier(shared_representation)

        # Use attentions from the encoder if needed
        attentions = encoder_outputs.attentions

        return logits, attentions

## Training model

In [79]:
import time
from torch.nn import functional as F
from torch.optim import SGD

def compute_task_loss(logits, labels, num_labels):
    #print(logits.shape)
    #print(labels.shape)
    loss_task = F.cross_entropy(
        logits.view(-1, num_labels), labels, reduction="none"
    )
    return loss_task.mean()

def compute_attn_loss(attentions, attention_labels):
    #attention_labels = attention_labels.unsqueeze(0)
    if attention_labels is not None:
        attn_weights = torch.stack(attentions, dim=1)
        attn_weights = attn_weights[..., : attention_labels.size(-2), :]
        assert attn_weights.shape == attention_labels.shape
        
        #attn_weights = attn_weights / attn_weights.sum(dim=-1, keepdim=True)
        #attention_labels = attention_labels / attention_labels.sum(dim=-1, keepdim=True)
        
        # Clamp attn_weights to avoid log(0)
        #attn_weights = torch.clamp(attn_weights, min=1e-12)
        #attention_labels = torch.clamp(attention_labels, min=1e-12)
    
        # Calculate KL divergence
        attn_loss = F.kl_div(
            torch.log(attn_weights),
            attention_labels,
            reduction="none",
        )
    
        return attn_loss.sum(-1).mean()

    return 0.0

def train_multitask(model, dataset, optimizer, batch_size, train_step, attn_lambda):
    """
    Train function for multitask model.

    Parameters:
    - model: MultiTaskModel instance.
    - dataset: MultiTaskDataset instance containing inputs, labels, and attention labels.
    - optimizer: Optimizer for the model.
    - batch_size: Batch size for training.
    - train_step: Total training steps.
    - attn_lambda: Weight for the attention loss.
    """
    start_time = time.time()
    #optimizer = SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1.0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    for step in range(train_step):
        sample_len = len(dataset)
        batch_start = step * batch_size % sample_len
        batch_end = min(batch_start + batch_size, sample_len)
        batch_indices = list(range(batch_start, batch_end))  # Generate indices as a list
        
        # Retrieve batch
        batch = dataset[batch_indices]

        # Use zip to group components by type
        inputs, labels, attention_labels, task_ids = zip(*batch)
        

        # Convert to tensors if needed
        inputs_embeds = torch.stack(inputs)               # Stack inputs (if tensors)
        inputs_embeds = inputs_embeds.to(device)
        labels = torch.stack(labels)                 # Concatenate labels (if tensors)
        attention_labels = torch.stack(attention_labels)  # Concatenate attention labels (if tensors)
        task_ids = torch.tensor(task_ids)          # Convert task IDs to tensor
        task_ids = task_ids.to(device)
        labels = labels.to(device)
        attention_labels = attention_labels.to(device)
        lr = task1_lr if task_ids[0] == 0 else task2_lr
        lr = F.softplus(lr).item()
        attention_labels = F.softmax(attention_labels, dim=-1)

        # Forward pass for the entire batch
        logits, attentions = model(inputs_embeds=inputs_embeds, task_id=task_ids)

        # Task-specific loss
        loss_task = compute_task_loss(logits, labels, num_labels=logits.shape[1])

        # Attention loss
        loss_attn = compute_attn_loss(attentions, attention_labels)
        #loss_attn = 0

        # Combined loss
        loss = loss_task + attn_lambda * loss_attn
        loss *= lr
        
        # print('normal loss: ', loss_task)
        # print('attention loss: ', loss_attn)

        # Backward pass and optimization
        model.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Step {step + 1}/{train_step}, Task{task_ids[0]}, Loss: {loss.item():.4f}")

    elapsed = time.time() - start_time
    print(f"Training completed in {elapsed:.2f}s")

In [80]:
from tqdm.notebook import tqdm

def evaluate_model(model, test_loader, task_id):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    total_loss, num_samples = 0, 0
    for batch in tqdm(test_loader, desc="Evaluating model"):
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")
        #print(labels)

        with torch.no_grad():
            logits, attention = model(**batch, task_id=task_id)
            loss = compute_task_loss(logits, labels, num_labels=logits.shape[1])

        metric.add_batch(
            predictions=logits.argmax(-1).tolist(),
            references=labels.tolist()
        )
        total_loss += loss.item() * len(labels)
        num_samples += len(labels)
        
        print(f'Task {task_id} evaluation loss: {loss.item()}')

    results = metric.compute()
    results["loss"] = total_loss / num_samples

    return results

In [81]:
!jupyter nbextension enable --py widgetsnbextension

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [82]:
import ipywidgets as widgets
widgets.IntSlider()

IntSlider(value=0)

In [83]:
def move_state_dict_to_device(state_dict, device):
    return {key: value.to(device) for key, value in state_dict.items()}

In [84]:
from transformers import AutoModelForSequenceClassification, BertModel
import evaluate

metric = evaluate.load("accuracy")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load model
encoder = BertModel.from_pretrained("bert-base-uncased")
model_base = MultiTaskModel(encoder, num_classes_task1=4, num_classes_task2=2, num_layers_to_freeze=0)
encoder.load_state_dict(torch.load("encoder_trained.pth"), strict=False)
print("Length of encoder layers:", len(encoder.encoder.layer))

# Define the number of layers to freeze
num_layers_to_freeze = 3
model = MultiTaskModel(encoder, num_classes_task1=4, num_classes_task2=2, num_layers_to_freeze=num_layers_to_freeze)
model_original = MultiTaskModel(encoder, num_classes_task1=4, num_classes_task2=2, num_layers_to_freeze=num_layers_to_freeze)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

# Evaluate bofore training
results = evaluate_model(model_base, test_loader_task1, task_id=[0])
print("-"*40)
print("AG Before training:", results)
print("-"*40)
results = evaluate_model(model_base, test_loader_task2, task_id=[1])
print("-"*40)
print("SST Before training:", results)
print("-"*40)

# Train model
print("-"*40)
train_multitask(model, dataset, optimizer, batch_size=2, train_step=6, attn_lambda=1.5)
print("-"*40)

trained_state_dict = model.state_dict()
original_state_dict = model_original.state_dict()
original_state_dict = move_state_dict_to_device(original_state_dict, device)
combined_state_dict = {}
lambda_coef = 0.8

for key in original_state_dict:
    combined_state_dict[key] = (1-lambda_coef)*original_state_dict[key] + lambda_coef*trained_state_dict[key]
#combined_state_dict.to(device)

model.load_state_dict(combined_state_dict)


# Evaluate after training
results = evaluate_model(model, test_loader_task1, task_id=[0])
print("-"*40)
print("AG After training:", results)
print("-"*40)
results = evaluate_model(model, test_loader_task2, task_id=[1])
print("-"*40)
print("SST After training:", results)
print("-"*40)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Freezing the first 0 layers of the encoder...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Length of encoder layers: 12
Freezing the first 3 layers of the encoder...
Freezing the first 3 layers of the encoder...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Evaluating model:   0%|          | 0/3 [00:00<?, ?it/s]

Task [0] evaluation loss: 1.4355992078781128
Task [0] evaluation loss: 1.4456617832183838
Task [0] evaluation loss: 1.4114227294921875
----------------------------------------
AG Before training: {'accuracy': 0.225, 'loss': 1.4310995403089022}
----------------------------------------


Evaluating model:   0%|          | 0/4 [00:00<?, ?it/s]

Task [1] evaluation loss: 0.732970654964447
Task [1] evaluation loss: 0.7354337573051453
Task [1] evaluation loss: 0.7158809900283813
Task [1] evaluation loss: 0.7278229594230652
----------------------------------------
SST Before training: {'accuracy': 0.4369266055045872, 'loss': 0.7280626728994037}
----------------------------------------
----------------------------------------
Step 1/6, Task0, Loss: 1.0087
Step 2/6, Task0, Loss: 0.6036
Step 3/6, Task1, Loss: 0.1274
Step 4/6, Task0, Loss: 0.0162
Step 5/6, Task0, Loss: 0.2838
Step 6/6, Task1, Loss: 0.0797
Training completed in 13.61s
----------------------------------------


Evaluating model:   0%|          | 0/3 [00:00<?, ?it/s]

Task [0] evaluation loss: 1.5628390312194824
Task [0] evaluation loss: 1.473114252090454
Task [0] evaluation loss: 1.4652178287506104
----------------------------------------
AG After training: {'accuracy': 0.3263157894736842, 'loss': 1.5007606079703883}
----------------------------------------


Evaluating model:   0%|          | 0/4 [00:00<?, ?it/s]

Task [1] evaluation loss: 0.7112665176391602
Task [1] evaluation loss: 0.6918203234672546
Task [1] evaluation loss: 0.722996711730957
Task [1] evaluation loss: 0.7092440128326416
----------------------------------------
SST After training: {'accuracy': 0.5034403669724771, 'loss': 0.7087600537396352}
----------------------------------------
