In [None]:
pip install git+https://github.com/lxuechen/private-transformers.git

In [None]:
pip install transformers torch datasets evaluate accelerate -U


In [None]:
from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
)

from datasets import load_dataset
from torch.optim import AdamW
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm

In [None]:

# Load pre-trained TinyBERT model and tokenizer
model_name = "prajjwal1/bert-tiny"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)

fullChoice = int(input("Enter 1 for Full Fine-Tuning and 0 for Fine-Tuning with a single layer on top of it: "))

if not fullChoice:
  for param in model.parameters():
     param.requires_grad = False

  model.classifier = nn.Linear(model.bert.pooler.dense.out_features, 2)

print(f"Number of Trainable Parameters= {sum(p.numel() for p in model.parameters() if p.requires_grad==True)}")

TinyBert is loaded along with its tokenizer. Depending on user input, all layers are frozen and a layer is added on top of a LLM or full fine tuning is performed

In [None]:
def preprocess_data(data):
    return tokenizer(data["sentence"], truncation=True, padding="max_length",max_length=128)

In [None]:
def load_cleaned_data(task):
    if task == 1:
        print("SST2 Dataset")

        task = "sst2"

        # Load SST-2 dataset
        dataset = load_dataset("glue", task)
        tokenized_data = dataset.map(preprocess_data, batched=True)
        tokenized_data = tokenized_data.remove_columns(["idx","sentence"])
        tokenized_data = tokenized_data.rename_column("label", "labels")
        tokenized_data.set_format("torch")
        return tokenized_data, task



    elif task == 2:
        print("QNLI Dataset")
    elif task == 3:
        print("MNLI Dataset")
    elif task == 4:
        print("QQP Dataset")
    else:
        print("Invalid Dataset")
        task = None


In [None]:
taskChoice = int(input("Enter \n1 for SST2, 2 for QNLI, 3 for MNLI and 4 for QQP: "))

tokenized_data , task = load_cleaned_data(taskChoice)

train_dataloader = DataLoader(tokenized_data['train'], shuffle=True, batch_size=1024)
val_dataloader = DataLoader(tokenized_data['validation'], shuffle=True, batch_size=1024)

Dataset is tokenized, cleaned and separated into training and validation sets

In [None]:

# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
epochs = 5

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * epochs),)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

In [None]:
import evaluate

def evaluate_model(model,dataloader,task):
    metric = evaluate.load("glue", task)
    model.eval()
    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])

    return metric.compute()

###Without Differential Privacy

In [None]:
def trainModel(model,optimizer,train_dataloader,val_dataloader,loss_fn,lr_scheduler,tqdm,task,epochs=5 ,dp=False):

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for step,batch in enumerate(tqdm(train_dataloader)):

            # Forward pass
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)

            if dp:
              loss = F.cross_entropy(outputs.logits, batch["labels"]).mean(dim=0).unsqueeze(0)

              # Backward pass and update with DP
              optimizer.step(loss=loss)

            else:
              loss = loss_fn(outputs.logits, batch["labels"])
              total_loss += loss.detach().float()

              # Backward pass and update without DP
              loss.backward()
              optimizer.step()

            lr_scheduler.step()
            optimizer.zero_grad()

        if not dp:
            train_epoch_loss = total_loss / len(train_dataloader)
            train_ppl = torch.exp(train_epoch_loss)
            print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} ")

        # Evaluate on validation set
        with torch.no_grad():
            val_accuracy = evaluate_model(model, val_dataloader,task)
            print(f"Epoch {epoch+1}, Validation Accuracy {'with' if dp else 'without'} DP: {val_accuracy}")

    print("Training complete!")

In [None]:
trainModel(model,optimizer,train_dataloader,val_dataloader,loss_fn,lr_scheduler,tqdm,task)

###With Differential Privacy


In [None]:
import transformers, torch
from private_transformers import PrivacyEngine
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

privacy_engine = PrivacyEngine(
    model,
    batch_size=1024,
    sample_size=tokenized_data['train'].num_rows,
    epochs=5,
    max_grad_norm=0.2,
    target_epsilon=3,
    clipping_mode="ghost"
)

optimizer = torch.optim.AdamW(params=model.parameters(), lr=0.01)
privacy_engine.attach(optimizer)

for epoch in range(5):
        model.train()
        # total_loss = 0
        for step,batch in enumerate(tqdm(train_dataloader)):

            optimizer.zero_grad()
            # Forward pass
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)

            loss = F.cross_entropy(outputs.logits, batch["labels"]).mean(dim=0).unsqueeze(0)
            # Backward pass and update with DP
            optimizer.step(loss=loss)

            lr_scheduler.step()


        # Evaluate on validation set
        with torch.no_grad():
            val_accuracy = evaluate_model(model, val_dataloader,task)
            print(f"Epoch {epoch+1}, Validation Accuracy DP: {val_accuracy}")

print("Training complete!")