### Experiment Notebook

This notebook integrates the separate components to form the experiment. 

In [5]:
import torch
import torch.nn.utils.prune as prune
from torch_cka import CKA
import datasets
import transformers
import numpy as np
import matplotlib.pyplot as plt
import os

from torch.utils.data import DataLoader
from bertviz import model_view, head_view

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
### Encode various tasks
def preprocess_cola(examples, tokenizer, max_len):
    return tokenizer(examples['sentence'], padding='max_length', max_length=max_len)

def preprocess_sst2(examples, tokenizer):
    return tokenizer(examples['sentence'], padding='max_length')

def preprocess_qqp(examples, tokenizer):
    return tokenizer(examples["question1"], examples["question2"], padding="max_length", truncation=True)

def preprocess_rte(examples, tokenizer):
    return tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", truncation=True)

def encodeTask(task, checkpoint, batch_size):

    # Setup chosen task and metric
    dataset = datasets.load_dataset("glue", task)
    metric = datasets.load_metric("glue", task, trust_remote_code=True)

    # Cola pads to make length so all inputs are equal 
    if(task == 'cola'):
        # Figure out dataset characteristics
        max_len_train = len(max(dataset['train']['sentence'][:]))
        max_len_val = len(max(dataset['validation']['sentence'][:]))
        max_len_test = len(max(dataset['test']['sentence'][:]))
        max_len = max(max_len_train, max_len_val, max_len_test)

    # Tokenize Dataset
    tokenizer = transformers.BertTokenizer.from_pretrained(checkpoint, use_fast=True)

    # Select correct preprocessing function for each task
    if(task == 'cola'):
        encoded_dataset = dataset.map(preprocess_cola, batched=True, \
                                      fn_kwargs={'tokenizer': tokenizer, 'max_len': max_len})
        tokenized_dataset = encoded_dataset.with_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])
    elif(task == 'rte'):
        encoded_dataset = dataset.map(preprocess_rte, batched=True, \
                                      fn_kwargs={'tokenizer': tokenizer})
        tokenized_dataset = encoded_dataset.with_format('torch', columns=['input_ids', 'attention_mask'])
    elif(task == 'sst2'):
        encoded_dataset = dataset.map(preprocess_sst2, batched=True, \
                                      fn_kwargs={'tokenizer': tokenizer})
        tokenized_dataset = encoded_dataset.with_format('torch', columns=['input_ids', 'attention_mask'])
    elif(task == 'qqp'):
        encoded_dataset = dataset.map(preprocess_qqp, batched=True, \
                                      fn_kwargs={'tokenizer': tokenizer})
        tokenized_dataset = encoded_dataset.with_format('torch', columns=['input_ids', 'attention_mask'])
    else:
        print(f"The task: {task} isn't supported by our codebase currently. Please select \
                from cola, rte, sst2, and qqp")

    train_loader = DataLoader(tokenized_dataset['train'], batch_size=batch_size)
    val_loader = DataLoader(tokenized_dataset['validation'], batch_size=batch_size)
    test_loader = DataLoader(tokenized_dataset['test'], batch_size=1) # Inference uses batch size 1

    return train_loader, val_loader, test_loader, tokenizer

In [20]:
# Same for all tasks
checkpoint = "bert-base-uncased"
batch_size = 16

# Generate dataloaders
task = "cola"
cola_train_loader, cola_val_loader, cola_test_loader, cola_tokenzier = encodeTask(task, checkpoint, batch_size)

task = "rte"
rte_train_loader, rte_val_loader, rte_test_loader, rte_tokenzier = encodeTask(task, checkpoint, batch_size)

task = "sst2"
sst2_train_loader, sst2_val_loader, sst2_test_loader, sst2_tokenizer = encodeTask(task, checkpoint, batch_size)

task = "qqp"
qqp_train_loader, qqp_val_loader, qqp_test_loader, qqp_tokenizer = encodeTask(task, checkpoint, batch_size)



In [21]:
# Define model
class BertCustomHead(torch.nn.Module):
    def __init__(self, config, num_classes, tokenizer, task_type='sequence_classification'):
        super(BertCustomHead, self).__init__()
        self.bert = transformers.BertModel(config)
        self.task_type = task_type

        self.heads = torch.nn.ModuleDict({
            'sequence_classification': torch.nn.Linear(config.hidden_size, num_classes),
            'token_classification': torch.nn.Linear(config.hidden_size, num_classes),
            'multiple_choice': torch.nn.Linear(config.hidden_size, 1),
            'summarization': torch.nn.Linear(config.hidden_size, config.vocab_size)
        })

        self.loss_fns = {
            'sequence_classification': torch.nn.CrossEntropyLoss(),
            'token_classification': torch.nn.CrossEntropyLoss(),
            'multiple_choice': torch.nn.BCEWithLogitsLoss(),
            'summarization': torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
        }

        if task_type not in self.heads:
            raise ValueError("Invalid task type. Supported types: 'sequence_classification', 'token_classification', 'multiple_choice', 'summarization'")

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, decoder_input_ids=None):

        if self.task_type == 'summarization':
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            task_output = outputs.last_hidden_state
            attentions = outputs.attentions
            logits = self.heads[self.task_type](task_output)
            
        elif self.task_type == 'sequence_classification':
            outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
            attentions = outputs.attentions
            task_output = outputs.last_hidden_state
            logits = self.heads[self.task_type](task_output)
        else:
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            task_output = outputs.pooler_output
            attentions = outputs.attentions
            return self.heads[self.task_type](task_output)

        return task_output, attentions, logits

In [22]:
checkpoint = "bert-base-uncased"
num_classes = 2  # SST-2 has binary labels: positive and negative

config = transformers.BertConfig.from_pretrained(checkpoint, output_attentions=True)
model = BertCustomHead(config, num_classes, cola_tokenzier, task_type='sequence_classification')



In [None]:
### Unused / Extra Code
"""
def encodeRTE():

    task = "rte"
    dataset = datasets.load_dataset("glue", task)
    metric = datasets.load_metric("glue", task, trust_remote_code=True)

    tokenizer = transformers.BertTokenizer.from_pretrained(checkpoint)
    def preprocess_function(examples):
        return tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", truncation=True)

    encoded_dataset = dataset.map(preprocess_function, batched=True)
    encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    RTE_train_dataset = encoded_dataset['train']
    RTE_validation_dataset = encoded_dataset['validation']
    RTE_train_dataloader = DataLoader(RTE_train_dataset, shuffle=True, batch_size=batch_size)
    RTE_validation_dataloader = DataLoader(RTE_validation_dataset, batch_size=batch_size)
"""


"""
def encodeSST():

    # Setup chosen task and metric
    task = "sst2"
    checkpoint = "bert-base-uncased" # TODO: Update this
    dataset = datasets.load_dataset("glue", task)
    metric = datasets.load_metric("glue", task, trust_remote_code=True)

    # Tokenize Dataset
    tokenizer = transformers.BertTokenizer.from_pretrained(checkpoint, use_fast=True)

    def preprocess_function(examples):
        return tokenizer(examples["sentence"], padding="max_length", truncation=True)

    encoded_dataset = dataset.map(preprocess_function, batched=True)
    encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    # SST dataloaders 
    batch_size = 16
    train_dataset = encoded_dataset['train']
    validation_dataset = encoded_dataset['validation']
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size)
"""

"""
def encodeCola():

    # Setup chosen task and metric
    task = "cola"
    checkpoint = "bert-base-uncased" #TODO: Update this to local fine tuned model
    dataset = datasets.load_dataset("glue", task)
    metric = datasets.load_metric("glue", task)

    # Figure out dataset characteristics
    max_len_train = len(max(dataset['train']['sentence'][:]))
    max_len_val = len(max(dataset['validation']['sentence'][:]))
    max_len_test = len(max(dataset['test']['sentence'][:]))
    max_len = max(max_len_train, max_len_val, max_len_test)

    # Tokenize Dataset
    tokenizer = transformers.BertTokenizer.from_pretrained(checkpoint, use_fast=True)
    # preprocess function
    def preprocess_function(examples):
        return tokenizer(examples['sentence'], padding='max_length', max_length=max_len)
    encoded_dataset = dataset.map(preprocess_function, batched=True)

    return task, metric, encoded_dataset, tokenizer
"""

"""
def encodeQQP():

    task = "qqp"  
    checkpoint = "bert-base-uncased"
    batch_size = 16

    dataset = datasets.load_dataset("glue", task)
    metric = datasets.load_metric("glue", task, )

    def preprocess_function(examples):
        return tokenizer(examples["question1"], examples["question2"], padding="max_length", truncation=True)

    tokenizer = transformers.BertTokenizer.from_pretrained(checkpoint)
    encoded_dataset = dataset.map(preprocess_function, batched=True)
    encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    encoded_dataset['train'] = encoded_dataset['train'].select(range(67349))
    train_dataset = encoded_dataset['train']
    validation_dataset = encoded_dataset['validation']
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size)
"""