#### Notebook for Exploring CKA applied to BERT

In [1]:
import torch
import torch.nn.utils.prune as prune
from torch_cka import CKA
import datasets
import transformers
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from bertviz import model_view, head_view

  from .autonotebook import tqdm as notebook_tqdm


#### Tokenize Dataset

Note that padding is being used based on the maxlen right now. This is a requirement to concat items into batches in PyTorch. Reduction was being used prior. 

TODO: Trim the max_len down to something more reasonable so there aren't a bunch of useless tokens.

In [None]:
# Setup chosen task and metric
task = "cola"
checkpoint = "bert-base-uncased"
dataset = datasets.load_dataset("glue", task)
metric = datasets.load_metric("glue", task)

# Figure out dataset characteristics
max_len_train = len(max(dataset['train']['sentence'][:]))
max_len_val = len(max(dataset['validation']['sentence'][:]))
max_len_test = len(max(dataset['test']['sentence'][:]))
max_len = max(max_len_train, max_len_val, max_len_test)

# Tokenize Dataset
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
# preprocess function
def preprocess_function(examples):
    return tokenizer(examples['sentence'], padding='max_length', max_length=max_len)
encoded_dataset = dataset.map(preprocess_function, batched=True)

In [None]:
# Collate function (not currently used)

#def collate_fn(batch):
#    input_ids = torch.stack([item['input_ids'] for item in batch])
#    attention_mask = torch.stack([item['attention_mask'] for item in batch])
#    #label = torch.stack([item['label'] for item in batch])
#    token_type_ids = torch.stack([item['token_type_ids'] for item in batch])
#    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'token_type_ids': token_type_ids}

In [4]:
# Convert to PyTorch dataloaders
batch_size = 32

tokenized_dataset = encoded_dataset.with_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])
print(tokenized_dataset)

train_loader = DataLoader(tokenized_dataset['train'], batch_size=batch_size)
val_loader = DataLoader(tokenized_dataset['validation'], batch_size=batch_size)
test_loader = DataLoader(tokenized_dataset['test'], batch_size=batch_size)

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1063
    })
})


If CUDA OOM is being thrown, try reducing the batch size

In [8]:
model_directory = 'bert_pruning_model'
num_labels=2
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # Check for device

# Import two models for comparison
#model1 = transformers.BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels, output_attentions=True)
model1 = transformers.BertForSequenceClassification.from_pretrained(model_directory, num_labels=num_labels, output_attentions=True)
model2 = transformers.BertForSequenceClassification.from_pretrained(model_directory, num_labels=num_labels, output_attentions=True)

model1 = model1.to(device)
model2 = model2.to(device)

# Choose layers to factor into CKA calculation
model1_names = [f"bert.encoder.layer.{i}.attention" for i in range(0, 12)]
model2_names = [s for s in model1_names] # Perform Copy

In [9]:
cka = CKA(model1, model2, model1_name="Model1", model2_name="Model2", model1_layers=model1_names, model2_layers=model2_names, device=device)
cka.compare(train_loader)
results = cka.export()

  warn("Dataloader for Model 2 is not given. Using the same dataloader for both models.")
| Comparing features |:  80%|███████▉  | 214/268 [00:36<00:09,  5.69it/s]

Show results of CKA calculation

In [None]:
print(results)
cka.plot_results("plot", "BERT Comparison")