In [1]:
!pip install datasets

Defaulting to user installation because normal site-packages is not writeable


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertConfig, BertModel, Trainer, TrainingArguments
from datasets import load_dataset
from sklearn.metrics import accuracy_score

# Dejavu Sparse MLP and Attention Block
class SparseMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, sparsity_ratio):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.sparsity_ratio = sparsity_ratio
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        hidden = F.gelu(self.fc1(x))
        threshold = torch.topk(hidden.abs(), k=int(self.sparsity_ratio * hidden.size(-1)), dim=-1).values[..., -1, None]
        mask = (hidden.abs() >= threshold).float()
        sparse_hidden = hidden * mask
        return self.fc2(sparse_hidden)

class SparseAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, sparsity_ratio):
        super().__init__()
        self.num_heads = num_heads
        self.sparsity_ratio = sparsity_ratio
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        threshold = torch.topk(attn_output.abs(), k=int(self.sparsity_ratio * attn_output.size(-1)), dim=-1).values[..., -1, None]
        mask = (attn_output.abs() >= threshold).float()
        sparse_attn_output = attn_output * mask
        return sparse_attn_output

class DejavuBertBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, sparsity_ratio):
        super().__init__()
        self.attn = SparseAttention(embed_dim, num_heads, sparsity_ratio)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.mlp = SparseMLP(embed_dim, hidden_dim, embed_dim, sparsity_ratio)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_output = self.attn(x)
        x = x + attn_output
        x = self.ln1(x)
        mlp_output = self.mlp(x)
        x = x + mlp_output
        x = self.ln2(x)
        return x

class DejavuBertModel(nn.Module):
    def __init__(self, bert_config, sparsity_ratio):
        super().__init__()
        self.bert = BertModel(bert_config)
        self.sparsity_ratio = sparsity_ratio
        self.dejavu_blocks = nn.ModuleList([
            DejavuBertBlock(
                embed_dim=bert_config.hidden_size,
                num_heads=bert_config.num_attention_heads,
                hidden_dim=bert_config.intermediate_size,
                sparsity_ratio=sparsity_ratio
            ) for _ in range(bert_config.num_hidden_layers)
        ])
        self.classifier = nn.Linear(bert_config.hidden_size, bert_config.num_labels)

    def forward(self, input_ids, attention_mask, labels=None, calculate_pre_dejavu=False):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        x = outputs.last_hidden_state

        if calculate_pre_dejavu:
            logits = self.classifier(x[:, 0, :])
            return {"logits": logits}

        for block in self.dejavu_blocks:
            x = block(x)
        logits = self.classifier(x[:, 0, :])

        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}

def load_data():
    piqa = load_dataset("piqa")
    hellaswag = load_dataset("hellaswag")
    boolq = load_dataset("boolq")
    return piqa, hellaswag, boolq

def tokenize_function_piqa(examples):
    tokenized = tokenizer(examples["goal"], padding="max_length", truncation=True, max_length=max_seq_len)
    tokenized["labels"] = examples["label"]
    return tokenized

def tokenize_function_swag(examples):
    # Combine context with each possible ending
    inputs = [f"{examples['context']} {ending}" for ending in examples["endings"]]
    tokenized = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_seq_len)
    tokenized["labels"] = examples["label"]
    return tokenized

def tokenize_function_boolq(examples):
    inputs = [examples["question"] + " " + examples["passage"]]
    tokenized = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_seq_len)
    tokenized["labels"] = examples["label"]
    return tokenized

def train_and_evaluate(dataset, model, tokenizer, max_seq_len, batch_size, num_epochs, label):
    if label == "piqa":
        tokenized_data = dataset.map(tokenize_function_piqa, batched=True)
    elif label == "hello_swag":
        tokenized_data = dataset.map(tokenize_function_swag, batched=True)
    elif label == "boolq":
        tokenized_data = dataset.map(tokenize_function_boolq, batched=True)

    train_dataset = tokenized_data["train"]
    eval_dataset = tokenized_data["validation"]

    from transformers import DataCollatorWithPadding
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    training_args = TrainingArguments(
        output_dir="./results",
        evaluation_strategy="steps",
        learning_rate=5e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_epochs,
        weight_decay=0.01,
    )

    def compute_metrics(pred):
        labels = pred.label_ids
        preds = pred.predictions.argmax(-1)
        return {"accuracy": accuracy_score(labels, preds)}

    pre_dejavu_outputs = []
    for batch in eval_dataset:
        with torch.no_grad():
            input_ids = torch.tensor(batch["input_ids"]).unsqueeze(0)
            attention_mask = torch.tensor(batch["attention_mask"]).unsqueeze(0)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, calculate_pre_dejavu=True)
            logits = outputs["logits"]
            preds = logits.argmax(dim=-1).numpy()
            pre_dejavu_outputs.extend(preds)

    pre_dejavu_acc = accuracy_score(eval_dataset["labels"], pre_dejavu_outputs)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    post_dejavu_metrics = trainer.evaluate()
    post_dejavu_acc = post_dejavu_metrics.get("eval_accuracy", 0)

    accuracy_drop = pre_dejavu_acc - post_dejavu_acc
    print("accuracy drop")
    print(accuracy_drop)
    # Write results to a file
    with open("results_dejavu.txt", "a") as f:
        f.write(f"Task: {label}\n")
        f.write(f"Pre-Sparsity Accuracy: {pre_dejavu_acc:.4f}\n")
        f.write(f"Post-Sparsity Accuracy: {post_dejavu_acc:.4f}\n")
        f.write(f"Accuracy Drop: {accuracy_drop:.4f}\n")
        f.write("=" * 50 + "\n")

    return post_dejavu_metrics, pre_dejavu_acc, accuracy_drop

def main():
    global tokenizer, max_seq_len
    max_seq_len = 128
    sparsity_ratio = 0.2
    batch_size = 16
    num_epochs = 1

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    bert_config = BertConfig.from_pretrained("bert-base-uncased", num_labels=2)
    piqa, hellaswag, boolq = load_data()

    model = DejavuBertModel(bert_config, sparsity_ratio)

    # Open the results file and write a header
    with open("results_dejavu.txt", "w") as f:
        f.write("Dejavu BERT Results\n")
        f.write("=" * 50 + "\n")

    # Evaluate on PIQA
    # label = "piqa"
    # print("Evaluating on PIQA...")
    # train_and_evaluate(piqa, model, tokenizer, max_seq_len, batch_size, num_epochs, label)

    # Evaluate on HellaSwag
    label = "hello_swag"
    print("Evaluating on HellaSwag...")
    train_and_evaluate(hellaswag, model, tokenizer, max_seq_len, batch_size, num_epochs, label)

    # Evaluate on BoolQ
    label = "boolq"
    print("Evaluating on BoolQ...")
    train_and_evaluate(boolq, model, tokenizer, max_seq_len, batch_size, num_epochs, label)

if __name__ == "__main__":
    main()




Evaluating on PIQA...


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss,Validation Loss,Accuracy
500,0.7403,0.693665,0.504897
1000,0.6957,0.693276,0.495103


accuracy drop
-0.003808487486398282
Evaluating on HellaSwag...


Map:   0%|          | 0/39905 [00:00<?, ? examples/s]

KeyError: 'context'