In [None]:
import torch
import random
import numpy as np
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from torch.utils.data import DataLoader

class TransformerBitFlipFramework:
    def __init__(self, model_name="distilbert-base-uncased", num_labels=2):
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
        self.model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
        self.original_state_dict = {k: v.clone() for k, v in self.model.state_dict().items()}  # Save original weights

    def preprocess(self, text):
        return self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    def predict(self, text):
        inputs = self.preprocess(text)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return torch.argmax(outputs.logits, dim=1).item()

    def train(self, train_dataloader, epochs=5, learning_rate=2e-5):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        self.model.train()

        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
        loss_fn = torch.nn.CrossEntropyLoss()

        for epoch in range(epochs):
            epoch_loss = 0.0
            for batch in train_dataloader:
                input_ids, attention_mask, labels = (t.to(device) for t in batch)
                optimizer.zero_grad()

                outputs = self.model(input_ids, attention_mask=attention_mask)
                loss = loss_fn(outputs.logits, labels)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

            print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}")

    def evaluate(self, dataloader):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        self.model.eval()

        total, correct = 0, 0
        with torch.no_grad():
            for batch in dataloader:
                input_ids, attention_mask, labels = (t.to(device) for t in batch)
                outputs = self.model(input_ids, attention_mask=attention_mask)
                _, preds = torch.max(outputs.logits, dim=1)
                total += labels.size(0)
                correct += (preds == labels).sum().item()

        accuracy = 100 * correct / total
        return accuracy

    def inject_random_bit_flip(self):
        import numpy as np
        param_name, param_tensor = random.choice(list(self.model.named_parameters()))
        param_data = param_tensor.detach().cpu().numpy()
        param_flat = param_data.flatten()
        byte_index = random.randint(0, param_flat.nbytes - 1) 
        bit_position = random.randint(0, 7)  
        param_as_bytes = np.frombuffer(param_flat.view(np.uint8), dtype=np.uint8)
        param_as_bytes[byte_index] ^= 1 << bit_position
        modified_param = torch.tensor(
            param_flat.reshape(param_data.shape), dtype=param_tensor.dtype
        )
        param_tensor.data = modified_param.to(param_tensor.device)

        return param_name, byte_index, bit_position


    def reset_model_weights(self):
        # Reset weights to original state
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                param.copy_(self.original_state_dict[name])

    def run_experiment(self, val_dataloader, num_flips=10):
        original_accuracy = self.evaluate(val_dataloader)
        print(f"Original Accuracy: {original_accuracy:.2f}%")
        flip_results = []

        for i in range(num_flips):
            param_name, byte_index, bit_position = self.inject_random_bit_flip()
            new_accuracy = self.evaluate(val_dataloader)
            flip_results.append((param_name, byte_index, bit_position, new_accuracy))
            print(f"Flip {i+1}: {param_name}, Byte {byte_index}, Bit {bit_position}, Accuracy: {new_accuracy:.2f}%")

            # Reset the model weights to avoid cumulative effects
            self.reset_model_weights()

        return flip_results


In [12]:
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset

# Define a custom dataset class for IMDb
class IMDbDataset(Dataset):
    def __init__(self, tokenizer, split="train", max_length=128):
        self.dataset = load_dataset("imdb", split=split)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]["text"]
        label = self.dataset[idx]["label"]
        inputs = self.tokenizer(
            text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt"
        )
        return inputs["input_ids"].squeeze(0), inputs["attention_mask"].squeeze(0), torch.tensor(label)

# Initialize framework and dataset
framework = TransformerBitFlipFramework()

# Load IMDb dataset for training and evaluation
train_dataset = IMDbDataset(framework.tokenizer, split="train[:5000]")  # Use a subset for speed
test_dataset = IMDbDataset(framework.tokenizer, split="test[:1000]")    # Use a subset for evaluation

# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16)

# Train the model
framework.train(train_dataloader, epochs=3)

# Evaluate the model
accuracy = framework.evaluate(test_dataloader)
print(f"Accuracy on IMDb test set: {accuracy:.2f}%")

# Run experiment with bit flips
results = framework.run_experiment(test_dataloader, num_flips=5)
print("Bit Flip Experiment Results:")
for result in results:
    print(result)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3, Loss: 5.4888
Epoch 2/3, Loss: 0.0387
Epoch 3/3, Loss: 0.0068
Accuracy on IMDb test set: 100.00%
Original Accuracy: 100.00%
Flip 1: distilbert.transformer.layer.0.output_layer_norm.weight, Byte 2328, Bit 5, Accuracy: 100.00%
Flip 2: distilbert.embeddings.LayerNorm.weight, Byte 2775, Bit 5, Accuracy: 98.60%
Flip 3: distilbert.transformer.layer.5.attention.q_lin.bias, Byte 684, Bit 4, Accuracy: 98.60%
Flip 4: distilbert.transformer.layer.1.output_layer_norm.weight, Byte 1657, Bit 5, Accuracy: 98.60%
Flip 5: distilbert.transformer.layer.0.sa_layer_norm.bias, Byte 945, Bit 0, Accuracy: 98.60%
Bit Flip Experiment Results:
('distilbert.transformer.layer.0.output_layer_norm.weight', 2328, 5, 100.0)
('distilbert.embeddings.LayerNorm.weight', 2775, 5, 98.6)
('distilbert.transformer.layer.5.attention.q_lin.bias', 684, 4, 98.6)
('distilbert.transformer.layer.1.output_layer_norm.weight', 1657, 5, 98.6)
('distilbert.transformer.layer.0.sa_layer_norm.bias', 945, 0, 98.6)
