In [8]:
# pip install transformers datasets torch
!pip install -q transformers accelerate bitsandbytes datasets

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [33]:
import torch
import numpy as np
from torch.nn.functional import kl_div, log_softmax
from torch.optim import AdamW
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import pandas as pd

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Load CoLA dataset
from datasets import load_dataset
cola_data = load_dataset("glue", "cola")

# Load HANS dataset using pandas
hans_train = pd.read_csv('/Users/yueyang/Downloads/heuristics_train_set.txt', sep='\t')
hans_eval = pd.read_csv('/Users/yueyang/Downloads/heuristics_evaluation_set.txt', sep='\t')

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")

# Map gold_label to numeric labels 
label_mapping = {"entailment": 1, "non-entailment": 0}

# Preprocess HANS data for PyTorch
class HANSPreprocessor(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=128):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label_mapping = label_mapping

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        # Tokenize the sentences
        encoding = self.tokenizer(
            text=row["sentence1"],
            text_pair=row["sentence2"],
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()
        # Map gold_label to numeric label
        label = self.label_mapping[row["gold_label"]]
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label,
        }

# Create PyTorch Datasets for HANS
hans_train_dataset = HANSPreprocessor(hans_train, tokenizer)
hans_eval_dataset = HANSPreprocessor(hans_eval, tokenizer)

# DataLoader for HANS
hans_train_loader = DataLoader(hans_train_dataset, batch_size=32, shuffle=True)
hans_eval_loader = DataLoader(hans_eval_dataset, batch_size=32)

# Preprocess CoLA dataset for PyTorch
# def preprocess_cola(batch):
#     encoding = tokenizer(
#         batch["sentence"],
#         truncation=True,
#         padding="max_length",
#         max_length=128,
#         return_tensors="pt"
#     )
#     return {
#         "input_ids": encoding["input_ids"].squeeze(),
#         "attention_mask": encoding["attention_mask"].squeeze(),
#         "labels": torch.tensor(batch["label"], dtype=torch.long)
#     }
def preprocess_cola(batch):
    encoding = tokenizer(
        batch["sentence"],
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt"
    )
    return {
        "input_ids": encoding["input_ids"].squeeze(0),  # Convert to tensor
        "attention_mask": encoding["attention_mask"].squeeze(0),  # Convert to tensor
        "labels": torch.tensor(batch["label"], dtype=torch.long)  # Convert label to tensor
    }


cola_train_dataset = cola_data["train"].map(preprocess_cola)
cola_validation_dataset = cola_data["validation"].map(preprocess_cola)

cola_train_loader = DataLoader(cola_train_dataset, batch_size=32, shuffle=True)
cola_validation_loader = DataLoader(cola_validation_dataset, batch_size=32)

# Define collate function for dataloaders
def collate_fn(batch):
    input_ids = torch.stack([torch.tensor(item["input_ids"]) for item in batch])
    attention_mask = torch.stack([torch.tensor(item["attention_mask"]) for item in batch])
    labels = torch.tensor([item["labels"] for item in batch])  # Use "labels" key here
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
# def collate_fn(batch):
#     input_ids = torch.stack([item["input_ids"] for item in batch])
#     attention_mask = torch.stack([item["attention_mask"] for item in batch])
#     labels = torch.stack([item["labels"] for item in batch])  # Use torch.stack here
#     return {
#         "input_ids": input_ids,
#         "attention_mask": attention_mask,
#         "labels": labels
#     }

# Update DataLoaders with collate_fn
cola_train_loader = DataLoader(cola_train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
cola_validation_loader = DataLoader(cola_validation_dataset, batch_size=32, collate_fn=collate_fn)
hans_train_loader = DataLoader(hans_train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
hans_eval_loader = DataLoader(hans_eval_dataset, batch_size=32, collate_fn=collate_fn)

# # Print a sample from HANS DataLoader
# print(next(iter(hans_train_loader)))



In [25]:
cola_train_dataset = cola_data["train"].map(preprocess_cola)
print(cola_train_dataset[0])  # Verify that input_ids, attention_mask, and labels are tensors


Map: 100%|██████████| 8551/8551 [00:01<00:00, 8041.95 examples/s]

{'sentence': "Our friends won't buy this analysis, let alone the next one we propose.", 'label': 1, 'idx': 0, 'input_ids': [2, 2522, 964, 351, 75, 907, 42, 1966, 6, 905, 1937, 5, 220, 65, 52, 15393, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': 1}





In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load teacher (pre-trained model)
teacher_model = AutoModelForSequenceClassification.from_pretrained("facebook/opt-125m")
teacher_model.to(device)
teacher_model.eval()  # Set teacher to evaluation mode

# Load student model (same architecture, will be fine-tuned)
student_model = AutoModelForSequenceClassification.from_pretrained("facebook/opt-125m")
student_model.to(device)

# Define the task loss (Cross-Entropy) and distillation loss (KL Divergence)
task_loss_fn = CrossEntropyLoss()

def distillation_loss_fn(student_logits, teacher_logits):
    return kl_div(log_softmax(student_logits, dim=-1), teacher_logits.softmax(dim=-1), reduction="batchmean")

# Optimizer for student model
optimizer = AdamW(student_model.parameters(), lr=1e-5)


Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-125m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-125m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Training loop
num_epochs = 1
results_df = pd.DataFrame(columns=["epoch", "in_domain_accuracy", "out_domain_accuracy"])

for epoch in range(num_epochs):
    # Training Phase with CoLA train loader
    student_model.train()
    for batch in tqdm(cola_train_loader, desc=f"Training Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass (teacher model)
        with torch.no_grad():
            teacher_logits = teacher_model(input_ids=input_ids, attention_mask=attention_mask).logits

        # Forward pass (student model)
        student_logits = student_model(input_ids=input_ids, attention_mask=attention_mask).logits

        # Compute losses
        distillation_loss = distillation_loss_fn(student_logits, teacher_logits)
        classification_loss = task_loss_fn(student_logits, labels)
        loss = 0.5 * distillation_loss + 0.5 * classification_loss

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluation Phase
    student_model.eval()
    in_domain_correct, in_domain_total = 0, 0
    out_domain_correct, out_domain_total = 0, 0

    # In-domain (CoLA validation loader)
    with torch.no_grad():
        for batch in tqdm(cola_validation_loader, desc="Evaluating In-Domain"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            logits = student_model(input_ids=input_ids, attention_mask=attention_mask).logits
            predictions = torch.argmax(logits, dim=-1)

            in_domain_correct += (predictions == labels).sum().item()
            in_domain_total += labels.size(0)

    # Out-of-domain (HANS eval loader)
    with torch.no_grad():
        for batch in tqdm(hans_eval_loader, desc="Evaluating Out-Domain"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            logits = student_model(input_ids=input_ids, attention_mask=attention_mask).logits
            predictions = torch.argmax(logits, dim=-1)

            out_domain_correct += (predictions == labels).sum().item()
            out_domain_total += labels.size(0)

    # Compute accuracies
    in_domain_accuracy = in_domain_correct / in_domain_total
    out_domain_accuracy = out_domain_correct / out_domain_total

    print(f"Epoch {epoch+1}: In-Domain Accuracy = {in_domain_accuracy:.4f}, Out-of-Domain Accuracy = {out_domain_accuracy:.4f}")
    
    # Add results to DataFrame
    new_row = pd.DataFrame({
        "epoch": [epoch + 1],
        "in_domain_accuracy": [in_domain_accuracy],
        "out_domain_accuracy": [out_domain_accuracy]
    })
    results_df = pd.concat([results_df, new_row], ignore_index=True)

    # # Save results to a CSV file after each epoch
    # results_df.to_csv("ContextDistillation_Cola.csv", index=False)

# Final save of the DataFrame after training
results_df.to_csv("ContextDistillation_Cola.csv", index=False)
print("Training results saved to final_training_results.csv")


Training Epoch 1:  28%|██▊       | 76/268 [05:21<13:58,  4.37s/it]