In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertConfig
from datasets import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from peft import LoraConfig, TaskType, get_peft_model
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Set seed for reproducibility
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

cpu


In [2]:
# Load the dataset
train_df = pd.read_csv('train.csv')  # Replace with the actual path to your dataset

# Use only 1% of the dataset
train_df = train_df.sample(frac=0.01, random_state=SEED)

In [3]:
# Binarize the 'target' column (toxicity labels)
train_df['target'] = (train_df['target'] >= 0.5).astype(int)

# Handle missing or invalid values in the 'comment_text' column
train_df['comment_text'] = train_df['comment_text'].fillna('')  # Replace NaN with empty strings
train_df['comment_text'] = train_df['comment_text'].astype(str)  # Ensure all values are strings

# Preprocess the text (example: lowercase and remove special characters)
train_df['comment_text'] = train_df['comment_text'].str.lower().replace(r'[^\w\s]', '', regex=True)

# Split the data into training and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_df['comment_text'], train_df['target'], test_size=0.2, random_state=SEED
)


In [4]:
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Convert to Hugging Face Dataset
train_dataset = Dataset.from_dict({'text': train_texts.tolist(), 'labels': train_labels.tolist()})
val_dataset = Dataset.from_dict({'text': val_texts.tolist(), 'labels': val_labels.tolist()})

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples['text'], max_length=128, truncation=True, padding='max_length')

tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_val_dataset = val_dataset.map(tokenize_function, batched=True)


Map:   0%|          | 0/14439 [00:00<?, ? examples/s]

Map:   0%|          | 0/3610 [00:00<?, ? examples/s]

In [5]:
# Set format for PyTorch
tokenized_train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
tokenized_val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Create DataLoader
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(tokenized_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
val_dataloader = DataLoader(tokenized_val_dataset, batch_size=32, collate_fn=data_collator)

In [6]:
# Define the Teacher Model
teacher_model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
teacher_model.to(device)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [7]:
# Define the Student Model
configuration = teacher_model.config.to_dict()
configuration['num_hidden_layers'] //= 2  # Reduce the number of layers by half
student_config = BertConfig.from_dict(configuration)

# Function to Distill Weights (Odd Layers)
def distill_odd_layers(teacher, student):
    teacher_layers = teacher.bert.encoder.layer  # Access the encoder layers of the teacher
    student_layers = student.bert.encoder.layer  # Access the encoder layers of the student

    for i in range(len(student_layers)):
        # Copy weights from odd layers of the teacher (1, 3, 5, 7, 9, 11)
        student_layers[i].load_state_dict(teacher_layers[2 * i].state_dict())

# Function to Distill Weights (Even Layers)
def distill_even_layers(teacher, student):
    teacher_layers = teacher.bert.encoder.layer  # Access the encoder layers of the teacher
    student_layers = student.bert.encoder.layer  # Access the encoder layers of the student

    for i in range(len(student_layers)):
        # Copy weights from even layers of the teacher (2, 4, 6, 8, 10, 12)
        student_layers[i].load_state_dict(teacher_layers[2 * i + 1].state_dict())




In [8]:
# Define Loss Functions
criterion_div = nn.KLDivLoss(reduction='batchmean')
criterion_cos = nn.CosineEmbeddingLoss()
num_epochs = 3
# Function to train and evaluate a model
def train_and_evaluate_model(student_model, model_name):
    print(f"Training {model_name} model...")

    # Define Optimizer and Scheduler
    optimizer = torch.optim.Adam(student_model.parameters(), lr=5e-5)
    num_epochs = 3
    num_training_steps = num_epochs * len(train_dataloader)
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=num_training_steps)

    # Lists to store metrics
    train_losses = []
    eval_accuracies = []

    # Training Loop
    for epoch in range(num_epochs):
        student_model.train()
        teacher_model.eval()
        total_loss = 0

        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            batch = {k: v.to(device) for k, v in batch.items()}

            # Forward pass (Student)
            student_outputs = student_model(**batch)

            # Forward pass (Teacher)
            with torch.no_grad():
                teacher_outputs = teacher_model(**batch)

            # Compute Losses
            loss_cls = student_outputs.loss  # Classification loss
            loss_div = criterion_div(
                torch.log_softmax(student_outputs.logits / 2.0, dim=-1),
                torch.softmax(teacher_outputs.logits / 2.0, dim=-1)
            )  # Distillation loss
            loss_cos = criterion_cos(
                student_outputs.logits, teacher_outputs.logits, torch.ones(batch['input_ids'].size(0)).to(device)
            )  # Cosine loss

            # Total Loss
            loss = (loss_cls + loss_div + loss_cos) / 3
            total_loss += loss.item()

            # Backward pass
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        # Store training loss
        train_losses.append(total_loss / len(train_dataloader))

        # Evaluation
        student_model.eval()
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for batch in val_dataloader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = student_model(**batch)
                predictions = torch.argmax(outputs.logits, dim=-1)
                total_correct += (predictions == batch['labels']).sum().item()
                total_samples += batch['labels'].size(0)

        eval_accuracy = total_correct / total_samples
        eval_accuracies.append(eval_accuracy)

        # Print logs
        print(f"Epoch {epoch + 1}: Train Loss = {train_losses[-1]:.4f}, Val Acc = {eval_accuracy:.4f}")

    # Print average validation accuracy
    print(f"{model_name} Avg Val Acc: {sum(eval_accuracies) / num_epochs:.4f}")

    # Save the trained model for inference
    output_dir = f"{model_name}_model"
    os.makedirs(output_dir, exist_ok=True)
    student_model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"{model_name} model saved to {output_dir} for inference.")

    return train_losses, eval_accuracies

# Train Odd Layer Model
student_model_odd = AutoModelForSequenceClassification.from_config(student_config)
student_model_odd.to(device)
distill_odd_layers(teacher_model, student_model_odd)
odd_train_losses, odd_eval_accuracies = train_and_evaluate_model(student_model_odd, "odd_layer")

# Train Even Layer Model
student_model_even = AutoModelForSequenceClassification.from_config(student_config)
student_model_even.to(device)
distill_even_layers(teacher_model, student_model_even)
even_train_losses, even_eval_accuracies = train_and_evaluate_model(student_model_even, "even_layer")

# Train LoRA Model
student_model_lora = AutoModelForSequenceClassification.from_config(student_config)
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,  # Task type for sequence classification
    inference_mode=False,  # Set to False for training
    r=8,  # Rank of the low-rank adaptation
    lora_alpha=32,  # Scaling factor
    lora_dropout=0.1,  # Dropout rate
    bias="none",  # Bias handling
)
student_model_lora = get_peft_model(student_model_lora, peft_config)
student_model_lora.to(device)
lora_train_losses, lora_eval_accuracies = train_and_evaluate_model(student_model_lora, "lora")

# Plotting Comparison
epochs_list = range(1, num_epochs + 1)

# Plot Training Loss
plt.figure(figsize=(12, 6))
plt.plot(epochs_list, odd_train_losses, label='Odd Layer Train Loss')
plt.plot(epochs_list, even_train_losses, label='Even Layer Train Loss')
plt.plot(epochs_list, lora_train_losses, label='LoRA Train Loss')

plt.title('Training Loss Comparison')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Plot Validation Accuracy
plt.figure(figsize=(12, 6))
plt.plot(epochs_list, odd_eval_accuracies, label='Odd Layer Val Acc')
plt.plot(epochs_list, even_eval_accuracies, label='Even Layer Val Acc')
plt.plot(epochs_list, lora_eval_accuracies, label='LoRA Val Acc')

plt.title('Validation Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Training odd_layer model...


Epoch 1/3:   0%|          | 0/452 [00:00<?, ?it/s]

KeyboardInterrupt: 