In [9]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from transformers import AutoTokenizer, AutoModel
from datasets import Dataset
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
import matplotlib.pyplot as plt
import nbimporter
from torch.utils.data import DataLoader, random_split, TensorDataset
from small_network import SmallNetwork, InverseNetwork, kl_uniform_loss, combined_loss, init_weights
from concept_classification import ConceptClassifier
import time

In [10]:
device0 = torch.device("cuda:0")
device1 = torch.device("cuda:1")
device2 = torch.device("cuda:2")
device3 = torch.device("cuda:3")

device=device1

model_name="distilbert/distilbert-base-uncased"   # microsoft/mpnet-base, distilbert/distilbert-base-uncased
REVERSE=True # True: debiasing  False:Enhancing
resconstruction=True

In [11]:
# SwiGLU Activation Function
class SwiGLU(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SwiGLU, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(input_dim, hidden_dim)
        self.gelu = nn.GELU()

    def forward(self, x):
        return self.gelu(self.linear1(x)) * self.linear2(x)

def contrastive_loss(cos_sim, margin=0.1, reverse=False):     # False: debiasing, True: enhance the shortcuts
    if reverse:
        return torch.mean(torch.relu(1 - cos_sim - margin))
    loss = torch.mean(torch.relu(cos_sim - margin))
    return loss


In [12]:
class RobertaSentimentClassifier(nn.Module):
    def __init__(self, num_labels=2, model_name="roberta-base", small_network=None, debiasing_module=True):
        super(RobertaSentimentClassifier, self).__init__()
        self.model_name = model_name

        # Load RoBERTa model
        self.roberta = AutoModel.from_pretrained(model_name)
        hidden_size = self.roberta.config.hidden_size

        # Small Network
        self.small_network = small_network

        self.projection = nn.Linear(hidden_size, hidden_size)  # Adjust input size for concatenated embeddings
        self.swiglu = SwiGLU(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size , num_labels)
        self.debiasing_module=debiasing_module

    def forward(self, input_ids, attention_mask=None):
        # Forward pass through RoBERTa
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state[:, 0, :]  # Shape: [batch_size, hidden_size]

        # Forward pass through Small Network
        if self.small_network:
            reembedded_states = self.small_network(hidden_states)  # Shape: [batch_size, hidden_size]
        else:
            reembedded_states = torch.zeros_like(hidden_states)

        concept_proj=self.projection(hidden_states)
        content_proj=self.projection(reembedded_states)
        cos_sim = F.cosine_similarity(concept_proj, content_proj, dim=-1)
        
        # Concatenate original RoBERTa embeddings and reembedded outputs
        #combined_states = torch.cat([hidden_states, reembedded_states], dim=-1)  # Shape: [batch_size, hidden_size * 2]
        if self.debiasing_module:
            x = self.swiglu(concept_proj)
        else:
            x = self.swiglu(content_proj)
        logits = self.fc(x)
        return logits, cos_sim

In [13]:
roberta_model = AutoModel.from_pretrained(model_name).to(device)
roberta_model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.model_max_length

512

In [14]:
# Dataset for Sentiment Classification

# Tokenization function
def tokenization(batched_text):
    return tokenizer(batched_text['text'], padding="max_length", truncation=True, max_length=32) # 记得改这里

# Load dataset from CSV
def load_data_from_csv(file_path):
    df = pd.read_csv(file_path)
    texts = df["clean_text"].tolist()
    labels = df["label"].tolist()
    return texts, labels

# Load train and test datasets
if REVERSE:
    train_texts, train_labels = load_data_from_csv("group_a_train.csv")
    test_texts, test_labels = load_data_from_csv("group_b_test.csv")
else:
    train_texts, train_labels = load_data_from_csv("group_a_train.csv")
    test_texts, test_labels = load_data_from_csv("group_a_test.csv")
    
# Split the training data into training and validation sets, 80% train, 20% validation
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_texts, train_labels, test_size=0.2, random_state=42
)

# Convert to HuggingFace Dataset format
sentiment_train_data = Dataset.from_dict({"text": train_texts, "label": train_labels})
sentiment_val_data = Dataset.from_dict({"text": val_texts, "label": val_labels})
sentiment_test_data = Dataset.from_dict({"text": test_texts, "label": test_labels})

# Tokenize the datasets
sentiment_train_data = sentiment_train_data.map(tokenization, batched=True)
sentiment_val_data = sentiment_val_data.map(tokenization, batched=True)
sentiment_test_data = sentiment_test_data.map(tokenization, batched=True)

# Set the format for PyTorch DataLoader compatibility
columns = ["input_ids", "attention_mask", "label"]
sentiment_train_data.set_format(type="torch", columns=columns)
sentiment_val_data.set_format(type="torch", columns=columns)
sentiment_test_data.set_format(type="torch", columns=columns)

# Create DataLoaders
train_loader_sentiment = DataLoader(sentiment_train_data, batch_size=16, shuffle=True)
val_loader_sentiment = DataLoader(sentiment_val_data, batch_size=16, shuffle=False)
test_loader_sentiment = DataLoader(sentiment_test_data, batch_size=16, shuffle=False)

print(sentiment_train_data)
print(sentiment_val_data)
print(sentiment_test_data)


Map:   0%|          | 0/1280 [00:00<?, ? examples/s]

Map:   0%|          | 0/320 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 1280
})
Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 320
})
Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 400
})


In [15]:
# Data for small network
data_path_small_network = "imbalanced_concepts.csv"  # Replace with your CSV path
data_small_network = pd.read_csv(data_path_small_network)

# Tokenize and process data
def tokenize_data(reviews):
    tokenized_data = tokenizer(reviews, padding=True, truncation=True, max_length=256, return_tensors="pt")
    input_ids = tokenized_data["input_ids"]
    attention_mask = tokenized_data["attention_mask"]
    return TensorDataset(input_ids, attention_mask)


# Prepare datasets
dataset_small_network = tokenize_data(data_small_network['clean_text'].tolist())
data_size_small_network = len(dataset_small_network)
train_size_small_network = int(0.8 * data_size_small_network)
val_size_small_network = data_size_small_network - train_size_small_network

train_data_small_network, val_data_small_network = random_split(dataset_small_network, [train_size_small_network, val_size_small_network])

train_loader_small_network = DataLoader(train_data_small_network, batch_size=16, shuffle=True)
val_loader_small_network = DataLoader(val_data_small_network, batch_size=16, shuffle=False)

In [31]:
lambda_reconstruction=0.9

start_time=time.time()
# Initialize models
small_network = SmallNetwork(input_dim=768, bottleneck_dim=384).to(device)
inverse_network = InverseNetwork(input_dim=768, bottleneck_dim=384).to(device)

# Apply initialization
#small_network.apply(init_weights)
#inverse_network.apply(init_weights)

concept_classifier = ConceptClassifier(num_labels=2, model_name=model_name).to(device)
concept_classifier.load_state_dict(torch.load("concept_classifier.pt"))
concept_classifier.eval()  # Freeze concept classifier during small network training

# Optimizers
small_network_optimizer = optim.AdamW(small_network.parameters(), lr=0.0001)
inverse_network_optimizer = optim.AdamW(inverse_network.parameters(), lr=0.0003)


# Early stopping and metrics
patience = 3
best_val_loss = float("inf")
wait = 0

train_losses_small = []
train_losses_sentiment = []
val_losses = []
concept_classifier.eval()

# Training Loop
num_epochs = 3
num_inverse_epochs=1
num_small_epochs=2

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    if resconstruction:
        # Step 0: Train inverse model to reconstruct the inputs
        small_network.eval()
        inverse_network.train()
        for _ in range(num_inverse_epochs):
            for batch in tqdm(train_loader_small_network, desc="Training Inverse Network"):
                input_ids, attention_mask = [x.to(device) for x in batch]
        
                # Forward pass through RoBERTa
                with torch.no_grad():
                    roberta_output = roberta_model(input_ids=input_ids, attention_mask=attention_mask)
                    roberta_embeddings = roberta_output.last_hidden_state
        
                # Forward pass through Small Network
                reembedded_output = small_network(roberta_embeddings)
        
                # Forward pass through Inverse Network
                reconstructed_output = inverse_network(reembedded_output)
        
                # Compute Reconstruction Loss
                reconstruction_loss = F.mse_loss(reconstructed_output, roberta_embeddings)
                inverse_network_optimizer.zero_grad()
                reconstruction_loss.backward()
                torch.nn.utils.clip_grad_norm_(inverse_network.parameters(), max_norm=1.0)
                inverse_network_optimizer.step()
            print(f"  Inverse Network Train Loss: {reconstruction_loss:.4f}")
        
    # Step 1: Train Small Network to Remove Concept Information
    small_network.train()
    inverse_network.eval()
    for _ in range(num_small_epochs):
        total_small_network_loss = 0
        for batch in tqdm(train_loader_small_network, desc="Training Small Network"):
            input_ids, attention_mask = [x.to(device) for x in batch]
    
            # Forward pass through RoBERTa
            with torch.no_grad():
                roberta_output = roberta_model(input_ids=input_ids, attention_mask=attention_mask)
                roberta_embeddings = roberta_output.last_hidden_state
    
            # Forward pass through Small Network
            reembedded_output = small_network(roberta_embeddings)
            if resconstruction:
                # Forward pass through Inverse Network
                reconstructed_output = inverse_network(reembedded_output)
    
            # Compute KL divergence loss
            logits = concept_classifier(input_ids=None, attention_mask=None, embeddings=reembedded_output, return_probs=False)
            kl_loss = kl_uniform_loss(logits, num_classes=2)
    
            # Combine losses for Small Network
            if resconstruction:
                total_loss = combined_loss(kl_loss, roberta_embeddings, reconstructed_output.detach(), lambda_reconstruction=lambda_reconstruction)  # Detach reconstructed_output
            else:
                total_loss= kl_loss
                
            # Backpropagation for Small Network
            small_network_optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(small_network.parameters(), max_norm=1.0)
            small_network_optimizer.step()
    
            total_small_network_loss += total_loss.item()
    
        avg_small_network_loss = total_small_network_loss / len(train_loader_small_network)
        print(f"  Small Network Train Loss: {avg_small_network_loss:.4f}")
        
        torch.save(small_network.state_dict(), f"Shuo_small_network_ep_{epoch}.pt")
        torch.save(inverse_network.state_dict(), f"Shuo_inverse_network_ep_{epoch}.pt")

  concept_classifier.load_state_dict(torch.load("concept_classifier.pt"))


Epoch 1/3


Training Inverse Network: 100%|████████████████████████████████████████████████| 289/289 [00:06<00:00, 47.27it/s]


  Inverse Network Train Loss: 0.3865


Training Small Network: 100%|██████████████████████████████████████████████████| 289/289 [00:06<00:00, 48.09it/s]


  Small Network Train Loss: 0.3721


Training Small Network: 100%|██████████████████████████████████████████████████| 289/289 [00:05<00:00, 56.35it/s]


  Small Network Train Loss: 0.3767
Epoch 2/3


Training Inverse Network: 100%|████████████████████████████████████████████████| 289/289 [00:06<00:00, 45.38it/s]


  Inverse Network Train Loss: 0.2409


Training Small Network: 100%|██████████████████████████████████████████████████| 289/289 [00:06<00:00, 46.48it/s]


  Small Network Train Loss: 0.2191


Training Small Network: 100%|██████████████████████████████████████████████████| 289/289 [00:05<00:00, 48.82it/s]


  Small Network Train Loss: 0.2195
Epoch 3/3


Training Inverse Network: 100%|████████████████████████████████████████████████| 289/289 [00:05<00:00, 55.70it/s]


  Inverse Network Train Loss: 0.1282


Training Small Network: 100%|██████████████████████████████████████████████████| 289/289 [00:05<00:00, 48.80it/s]


  Small Network Train Loss: 0.1292


Training Small Network: 100%|██████████████████████████████████████████████████| 289/289 [00:04<00:00, 58.29it/s]


  Small Network Train Loss: 0.1294


In [32]:
# 收敛性曲线
# small 400 steps IMDB
# pmnet      0.0692  0.0442  0.0159  0.0163  0.0069  0.0068
# distilbert 0.0665  0.0405  0.0136  0.0136  0.0061  0.0061
# roberta    0.0583  0.0298  0.0105  0.0105  0.0051  0.0051

# small 300 steps yelp
# pmnet      0.0535  0.0511  0.0302  0.0302  0.0141  0.0141
# distilbert 0.0444  0.0426  0.0244  0.0244  0.0142  0.0142
# roberta    0.0346  0.0291  0.0173  0.0173  0.0110  0.0110


In [33]:
# Step 2: Freeze Small Network and Train Sentiment Classifier
MARGIN=0.0
lamda=0.1
DEBIASING_MODULE=True

roberta_sentiment_classifier = RobertaSentimentClassifier(num_labels=2, model_name=model_name, small_network=small_network, debiasing_module=DEBIASING_MODULE).to(device)
# Freeze the parameters of the RoBERTa backbone
for param in roberta_sentiment_classifier.roberta.parameters():
    param.requires_grad = False

if roberta_sentiment_classifier.small_network is not None:
    for param in roberta_sentiment_classifier.small_network.parameters():
        param.requires_grad = False

# Ensure all additional layers are trainable
for param in roberta_sentiment_classifier.swiglu.parameters():
    param.requires_grad = True
for param in roberta_sentiment_classifier.projection.parameters():
    param.requires_grad = True
for param in roberta_sentiment_classifier.fc.parameters():
    param.requires_grad = True

# Optimizer for trainable parameters
trainable_params = [
    {'params': roberta_sentiment_classifier.swiglu.parameters()},
    {'params': roberta_sentiment_classifier.projection.parameters()},
    {'params': roberta_sentiment_classifier.fc.parameters()}
]

sentiment_optimizer = optim.AdamW(trainable_params, lr=0.0003)

# Loss function remains the same
criterion = nn.CrossEntropyLoss()


if roberta_sentiment_classifier.small_network is not None:
    for param in roberta_sentiment_classifier.small_network.parameters():
        param.requires_grad = False


num_classifier_epochs=50
patience = 5
best_val_loss = float("inf")
wait = 0

train_losses_sentiment = []
val_losses = []
small_network.eval()
for epoch in range(num_classifier_epochs):
    print(f"Epoch {epoch + 1}/{num_classifier_epochs}")

    roberta_sentiment_classifier.train()
    total_sentiment_loss = 0

    for batch in tqdm(train_loader_sentiment, desc="Training Sentiment Classifier"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        sentiment_optimizer.zero_grad()
        logits, cos_sim = roberta_sentiment_classifier(input_ids=input_ids, attention_mask=attention_mask)
        loss_ce = F.cross_entropy(logits, labels)
        loss_contrastive = contrastive_loss(cos_sim, margin=MARGIN, reverse=REVERSE)
        loss=loss_ce+lamda*loss_contrastive
        loss.backward()
        sentiment_optimizer.step()

        total_sentiment_loss += loss.item()

    avg_sentiment_loss = total_sentiment_loss / len(train_loader_sentiment)
    train_losses_sentiment.append(avg_sentiment_loss)
    print(f"  Sentiment Classifier Train Loss: {avg_sentiment_loss:.4f}")

    #############################################################################
    
    # Validation Phase
    roberta_sentiment_classifier.eval()
    total_val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader_sentiment, desc="Validating Sentiment Classifier"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits, cos_sim = roberta_sentiment_classifier(input_ids=input_ids, attention_mask=attention_mask)
            total_val_loss += F.cross_entropy(logits, labels).item()

            preds = torch.argmax(logits, dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_val_loss = total_val_loss / len(val_loader_sentiment)
    val_losses.append(avg_val_loss)
    val_accuracy = correct / total
    print(f"  Validation Loss: {avg_val_loss:.4f}")
    print(f"  Validation Accuracy: {val_accuracy:.4f}")

    # Early Stopping
    if avg_val_loss < best_val_loss:
        print(f"  Validation loss improved from {best_val_loss:.4f} to {avg_val_loss:.4f}. Saving model...")
        best_val_loss = avg_val_loss
        best_model_state = roberta_sentiment_classifier.state_dict()
        wait = 0
    else:
        wait += 1
        print(f"  No improvement in validation loss for {wait}/{patience} epochs.")
        if wait >= patience:
            print("Early stopping triggered.")
            break

# Save the best model
if best_model_state:
    roberta_sentiment_classifier.load_state_dict(best_model_state)
    torch.save(roberta_sentiment_classifier.state_dict(), "Shuo_sentiment_classifier.pt")
    print("Best Sentiment Classifier model saved!")
end_time=time.time()
print((end_time-start_time)/1280)

Epoch 1/50


Training Sentiment Classifier: 100%|█████████████████████████████████████████████| 80/80 [00:01<00:00, 45.81it/s]


  Sentiment Classifier Train Loss: 0.4038


Validating Sentiment Classifier: 100%|███████████████████████████████████████████| 20/20 [00:00<00:00, 58.18it/s]


  Validation Loss: 0.1894
  Validation Accuracy: 0.9344
  Validation loss improved from inf to 0.1894. Saving model...
Epoch 2/50


Training Sentiment Classifier: 100%|█████████████████████████████████████████████| 80/80 [00:01<00:00, 45.78it/s]


  Sentiment Classifier Train Loss: 0.2252


Validating Sentiment Classifier: 100%|███████████████████████████████████████████| 20/20 [00:00<00:00, 56.13it/s]


  Validation Loss: 0.2801
  Validation Accuracy: 0.8844
  No improvement in validation loss for 1/5 epochs.
Epoch 3/50


Training Sentiment Classifier: 100%|█████████████████████████████████████████████| 80/80 [00:01<00:00, 46.52it/s]


  Sentiment Classifier Train Loss: 0.2015


Validating Sentiment Classifier: 100%|███████████████████████████████████████████| 20/20 [00:00<00:00, 38.81it/s]


  Validation Loss: 0.1973
  Validation Accuracy: 0.9344
  No improvement in validation loss for 2/5 epochs.
Epoch 4/50


Training Sentiment Classifier: 100%|█████████████████████████████████████████████| 80/80 [00:01<00:00, 55.84it/s]


  Sentiment Classifier Train Loss: 0.2253


Validating Sentiment Classifier: 100%|███████████████████████████████████████████| 20/20 [00:00<00:00, 51.12it/s]


  Validation Loss: 0.2172
  Validation Accuracy: 0.9187
  No improvement in validation loss for 3/5 epochs.
Epoch 5/50


Training Sentiment Classifier: 100%|█████████████████████████████████████████████| 80/80 [00:01<00:00, 47.06it/s]


  Sentiment Classifier Train Loss: 0.1927


Validating Sentiment Classifier: 100%|███████████████████████████████████████████| 20/20 [00:00<00:00, 68.76it/s]


  Validation Loss: 0.1896
  Validation Accuracy: 0.9281
  No improvement in validation loss for 4/5 epochs.
Epoch 6/50


Training Sentiment Classifier: 100%|█████████████████████████████████████████████| 80/80 [00:01<00:00, 77.28it/s]


  Sentiment Classifier Train Loss: 0.1591


Validating Sentiment Classifier: 100%|███████████████████████████████████████████| 20/20 [00:00<00:00, 84.80it/s]


  Validation Loss: 0.2348
  Validation Accuracy: 0.9156
  No improvement in validation loss for 5/5 epochs.
Early stopping triggered.
Best Sentiment Classifier model saved!
0.05298670399934054


In [34]:
b=time.time()
# Evaluate the sentiment classifier
roberta_sentiment_classifier.eval()

# Initialize variables for storing predictions and labels
all_preds = []
all_labels = []

with torch.no_grad():  # Disable gradient computation for evaluation
    for sample in tqdm(test_loader_sentiment, desc="Evaluating Test Set"):
        # Extract input_ids, attention_mask, and labels from the batch
        input_ids = sample['input_ids'].to(device)
        attention_mask = sample['attention_mask'].to(device)
        labels = sample['label'].to(device)

        # Get logits from the model
        logits, cos_sim = roberta_sentiment_classifier(input_ids=input_ids, attention_mask=attention_mask)

        # Convert logits to predicted labels
        preds = torch.argmax(logits, dim=1)

        # Append predictions and labels for metric computation
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Compute Metrics
precision = precision_score(all_labels, all_preds, average="binary")  # Use 'binary' for 2-class tasks
recall = recall_score(all_labels, all_preds, average="binary")
accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average="binary")


# Print Metrics
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

e=time.time()
print((e-b)/400)

Evaluating Test Set: 100%|███████████████████████████████████████████████████████| 25/25 [00:00<00:00, 56.76it/s]

Accuracy: 0.9200
Precision: 0.9421
Recall: 0.8950
F1 Score: 0.9179
0.0011224699020385741





In [None]:
92.00 91.96
90.75 90.82
91.50 91.71
91.50 91.75


In [46]:
ROBERTA
    IMDB
        b test
            baseline
                Accuracy: 0.7883
                Precision: 0.9220
                Recall: 0.6300
                F1 Score: 0.7485
            
            ours
                Accuracy: 0.8350
                Precision: 0.7965
                Recall: 0.9000
                F1 Score: 0.8451
                no reconstruction
                    Accuracy: 0.8133
                    Precision: 0.9017
                    Recall: 0.7033
                    F1 Score: 0.7903

        a test
            baseline
                Accuracy: 0.8850
                Precision: 0.8367
                Recall: 0.9567
                F1 Score: 0.8927
            
            ours
                Accuracy: 0.8967
                Precision: 0.8889
                Recall: 0.9067
                F1 Score: 0.8977

    yelp
        b test 
            baseline fl:Accuracy: 0.8900 F1 Score: 0.8952
                Accuracy: 0.8925
                Precision: 0.8651
                Recall: 0.9300
                F1 Score: 0.8964
  
            ours
                Accuracy: 0.9150
                Precision: 0.9323
                Recall: 0.8950
                F1 Score: 0.9133
                no reconstruction
                    Accuracy: 0.7975
                    Precision: 0.7133
                    Recall: 0.9950
                    F1 Score: 0.8309

        a test
            baseline
                Accuracy: 0.9375
                Precision: 0.9730
                Recall: 0.9000
                F1 Score: 0.9351
            
            ours
                Accuracy: 0.9475
                Precision: 0.9686
                Recall: 0.9250
                F1 Score: 0.9463



DistilBERT
    IMDB
        b test
            baseline
                Accuracy: 0.8167
                Precision: 0.7987
                Recall: 0.8467
                F1 Score: 0.8220
            ours
                Accuracy: 0.8400
                Precision: 0.8248
                Recall: 0.8633
                F1 Score: 0.8436
                no reconstruction
                    Accuracy: 0.8083
                    Precision: 0.7697
                    Recall: 0.8800
                    F1 Score: 0.8212
        
        
        a test
            baseline
                Accuracy: 0.8400
                Precision: 0.7982
                Recall: 0.9100
                F1 Score: 0.8505
            ours
                Accuracy: 0.8550
                Precision: 0.8562
                Recall: 0.8533
                F1 Score: 0.8548

    yelp
        b test
            baseline
                Accuracy: 0.8975
                Precision: 0.8472
                Recall: 0.9700
                F1 Score: 0.9044
            
            ours
                Accuracy: 0.9200
                Precision: 0.9078
                Recall: 0.9350
                F1 Score: 0.9212
                no reconstruction
                    Accuracy: 0.9150
                    Precision: 0.9611
                    Recall: 0.8650
                    F1 Score: 0.9105

        a test
            baseline
                Accuracy: 0.9475
                Precision: 0.9453
                Recall: 0.9500
                F1 Score: 0.9476

            
            ours
                Accuracy: 0.9525
                Precision: 0.9641
                Recall: 0.9400
                F1 Score: 0.9519

MPnet
    IMDB
        b test
            baseline
                Accuracy: 0.7933
                Precision: 0.7699
                Recall: 0.8367
                F1 Score: 0.8019
            ours
                Accuracy: 0.8150
                Precision: 0.8247
                Recall: 0.8000
                F1 Score: 0.8122
                no reconstruction
                    Accuracy: 0.7983
                    Precision: 0.8327
                    Recall: 0.7467
                    F1 Score: 0.7873

        a test
            baseline
                Accuracy: 0.8733
                Precision: 0.8972
                Recall: 0.8433
                F1 Score: 0.8694

            ours
                Accuracy: 0.8883
                Precision: 0.8923
                Recall: 0.8833
                F1 Score: 0.8878

    yelp
        b test
            baseline
                Accuracy: 0.8900
                Precision: 0.9432
                Recall: 0.8300
                F1 Score: 0.8830
    
            ours
                Accuracy: 0.9075
                Precision: 0.9137
                Recall: 0.9000
                F1 Score: 0.9068
                no reconstruction
                    Accuracy: 0.9025
                    Precision: 0.9497
                    Recall: 0.8500
                    F1 Score: 0.8971

        a test
            baseline
                Accuracy: 0.9275
                Precision: 0.8869
                Recall: 0.9800
                F1 Score: 0.9311

            ours
                Accuracy: 0.9500
                Precision: 0.9500
                Recall: 0.9500
                F1 Score: 0.9500

IndentationError: unexpected indent (4247027720.py, line 2)

In [33]:
123/1280

0.09609375

In [65]:
## baseline
focal_loss=False

baseline_sentiment_classifier = RobertaSentimentClassifier(num_labels=2, model_name=model_name, small_network=None).to(device)
# Freeze the parameters of the RoBERTa backbone
for param in baseline_sentiment_classifier.roberta.parameters():
    param.requires_grad = False


# Ensure all additional layers are trainable
for param in baseline_sentiment_classifier.swiglu.parameters():
    param.requires_grad = True
for param in baseline_sentiment_classifier.fc.parameters():
    param.requires_grad = True
for param in baseline_sentiment_classifier.projection.parameters():
    param.requires_grad = True

# Optimizer for trainable parameters
trainable_params = [
    {'params': baseline_sentiment_classifier.projection.parameters()},
    {'params': baseline_sentiment_classifier.swiglu.parameters()},
    {'params': baseline_sentiment_classifier.fc.parameters()}
]

baseline_optimizer = optim.AdamW(trainable_params, lr=0.0003)

if focal_loss:
    criterion = FocalLoss(alpha=0.25, gamma=2, reduction='mean')  # 你可以根据需求调整超参数
else:
    criterion = torch.nn.CrossEntropyLoss()

num_classifier_epochs=50
# Early stopping and metrics
patience = 5
best_val_loss = float("inf")
wait = 0

train_losses_sentiment = []
val_losses = []

b=time.time()
for epoch in range(num_classifier_epochs):
    print(f"Epoch {epoch + 1}/{num_classifier_epochs}")

    baseline_sentiment_classifier.train()
    total_sentiment_loss = 0

    for batch in tqdm(train_loader_sentiment, desc="Training Sentiment Classifier"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        baseline_optimizer.zero_grad()
        logits, cos_sim = baseline_sentiment_classifier(input_ids=input_ids, attention_mask=attention_mask)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        baseline_optimizer.step()

        total_sentiment_loss += loss.item()

    avg_sentiment_loss = total_sentiment_loss / len(train_loader_sentiment)
    train_losses_sentiment.append(avg_sentiment_loss)
    print(f"  Sentiment Classifier Train Loss: {avg_sentiment_loss:.4f}")

    #############################################################################
    
    # Validation Phase
    baseline_sentiment_classifier.eval()
    total_val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader_sentiment, desc="Validating Sentiment Classifier"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits, cos_sim = baseline_sentiment_classifier(input_ids=input_ids, attention_mask=attention_mask)
            total_val_loss += F.cross_entropy(logits, labels).item()

            preds = torch.argmax(logits, dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_val_loss = total_val_loss / len(val_loader_sentiment)
    val_losses.append(avg_val_loss)
    val_accuracy = correct / total
    print(f"  Validation Loss: {avg_val_loss:.4f}")
    print(f"  Validation Accuracy: {val_accuracy:.4f}")

    # Early Stopping
    if avg_val_loss < best_val_loss:
        print(f"  Validation loss improved from {best_val_loss:.4f} to {avg_val_loss:.4f}. Saving model...")
        best_val_loss = avg_val_loss
        best_model_state = baseline_sentiment_classifier.state_dict()
        wait = 0
    else:
        wait += 1
        print(f"  No improvement in validation loss for {wait}/{patience} epochs.")
        if wait >= patience:
            print("Early stopping triggered.")
            break

# Save the best model
if best_model_state:
    baseline_sentiment_classifier.load_state_dict(best_model_state)
    torch.save(baseline_sentiment_classifier.state_dict(), "Shuo_sentiment_classifier_baseline.pt")
    print("Best Sentiment Classifier model saved!")

e=time.time()
print((e-b)/1280)

Epoch 1/50


Training Sentiment Classifier: 100%|███████████████████████████████████████████████████| 80/80 [00:00<00:00, 136.66it/s]


  Sentiment Classifier Train Loss: 0.3942


Validating Sentiment Classifier: 100%|█████████████████████████████████████████████████| 20/20 [00:00<00:00, 187.98it/s]


  Validation Loss: 0.1780
  Validation Accuracy: 0.9344
  Validation loss improved from inf to 0.1780. Saving model...
Epoch 2/50


Training Sentiment Classifier: 100%|███████████████████████████████████████████████████| 80/80 [00:00<00:00, 168.22it/s]


  Sentiment Classifier Train Loss: 0.2399


Validating Sentiment Classifier: 100%|█████████████████████████████████████████████████| 20/20 [00:00<00:00, 174.04it/s]


  Validation Loss: 0.2102
  Validation Accuracy: 0.9313
  No improvement in validation loss for 1/5 epochs.
Epoch 3/50


Training Sentiment Classifier: 100%|███████████████████████████████████████████████████| 80/80 [00:00<00:00, 107.21it/s]


  Sentiment Classifier Train Loss: 0.2031


Validating Sentiment Classifier: 100%|██████████████████████████████████████████████████| 20/20 [00:00<00:00, 68.67it/s]


  Validation Loss: 0.1705
  Validation Accuracy: 0.9375
  Validation loss improved from 0.1780 to 0.1705. Saving model...
Epoch 4/50


Training Sentiment Classifier: 100%|████████████████████████████████████████████████████| 80/80 [00:01<00:00, 66.88it/s]


  Sentiment Classifier Train Loss: 0.1793


Validating Sentiment Classifier: 100%|██████████████████████████████████████████████████| 20/20 [00:00<00:00, 67.67it/s]


  Validation Loss: 0.1983
  Validation Accuracy: 0.9250
  No improvement in validation loss for 1/5 epochs.
Epoch 5/50


Training Sentiment Classifier: 100%|████████████████████████████████████████████████████| 80/80 [00:01<00:00, 65.06it/s]


  Sentiment Classifier Train Loss: 0.1741


Validating Sentiment Classifier: 100%|██████████████████████████████████████████████████| 20/20 [00:00<00:00, 67.46it/s]


  Validation Loss: 0.1791
  Validation Accuracy: 0.9281
  No improvement in validation loss for 2/5 epochs.
Epoch 6/50


Training Sentiment Classifier: 100%|████████████████████████████████████████████████████| 80/80 [00:01<00:00, 65.97it/s]


  Sentiment Classifier Train Loss: 0.1614


Validating Sentiment Classifier: 100%|██████████████████████████████████████████████████| 20/20 [00:00<00:00, 67.63it/s]


  Validation Loss: 0.1689
  Validation Accuracy: 0.9344
  Validation loss improved from 0.1705 to 0.1689. Saving model...
Epoch 7/50


Training Sentiment Classifier: 100%|████████████████████████████████████████████████████| 80/80 [00:01<00:00, 66.62it/s]


  Sentiment Classifier Train Loss: 0.1723


Validating Sentiment Classifier: 100%|██████████████████████████████████████████████████| 20/20 [00:00<00:00, 70.07it/s]


  Validation Loss: 0.1835
  Validation Accuracy: 0.9281
  No improvement in validation loss for 1/5 epochs.
Epoch 8/50


Training Sentiment Classifier: 100%|████████████████████████████████████████████████████| 80/80 [00:00<00:00, 87.71it/s]


  Sentiment Classifier Train Loss: 0.1644


Validating Sentiment Classifier: 100%|██████████████████████████████████████████████████| 20/20 [00:00<00:00, 93.52it/s]


  Validation Loss: 0.1776
  Validation Accuracy: 0.9250
  No improvement in validation loss for 2/5 epochs.
Epoch 9/50


Training Sentiment Classifier: 100%|████████████████████████████████████████████████████| 80/80 [00:00<00:00, 92.75it/s]


  Sentiment Classifier Train Loss: 0.1553


Validating Sentiment Classifier: 100%|██████████████████████████████████████████████████| 20/20 [00:00<00:00, 90.28it/s]


  Validation Loss: 0.2096
  Validation Accuracy: 0.9250
  No improvement in validation loss for 3/5 epochs.
Epoch 10/50


Training Sentiment Classifier: 100%|███████████████████████████████████████████████████| 80/80 [00:00<00:00, 177.33it/s]


  Sentiment Classifier Train Loss: 0.1573


Validating Sentiment Classifier: 100%|█████████████████████████████████████████████████| 20/20 [00:00<00:00, 193.52it/s]


  Validation Loss: 0.1892
  Validation Accuracy: 0.9250
  No improvement in validation loss for 4/5 epochs.
Epoch 11/50


Training Sentiment Classifier: 100%|████████████████████████████████████████████████████| 80/80 [00:00<00:00, 84.28it/s]


  Sentiment Classifier Train Loss: 0.1476


Validating Sentiment Classifier: 100%|██████████████████████████████████████████████████| 20/20 [00:00<00:00, 79.73it/s]


  Validation Loss: 0.1768
  Validation Accuracy: 0.9281
  No improvement in validation loss for 5/5 epochs.
Early stopping triggered.
Best Sentiment Classifier model saved!
0.011106026731431485


In [60]:
b=time.time()
# Evaluate the baseline sentiment classifier
baseline_sentiment_classifier.eval()

# Initialize variables for storing predictions and labels
all_preds = []
all_labels = []

with torch.no_grad():  # Disable gradient computation for evaluation
    for sample in tqdm(test_loader_sentiment, desc="Evaluating Test Set"):
        # Extract input_ids, attention_mask, and labels from the batch
        input_ids = sample['input_ids'].to(device)
        attention_mask = sample['attention_mask'].to(device)
        labels = sample['label'].to(device)

        # Get logits from the model
        logits, cos_sim = baseline_sentiment_classifier(input_ids=input_ids, attention_mask=attention_mask)

        # Convert logits to predicted labels
        preds = torch.argmax(logits, dim=1)

        # Append predictions and labels for metric computation
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Compute Metrics
precision = precision_score(all_labels, all_preds, average="binary")  # Use 'binary' for 2-class tasks
recall = recall_score(all_labels, all_preds, average="binary")
accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average="binary")


# Print Metrics
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
e=time.time()

Evaluating Test Set: 100%|██████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 61.03it/s]

Accuracy: 0.9050
Precision: 0.8682
Recall: 0.9550
F1 Score: 0.9095





In [39]:
(e-b)/400

0.0019940930604934693

In [29]:
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Softmax over logits to get probabilities
        inputs = F.softmax(inputs, dim=-1)
        
        # One-hot encoding of targets
        targets = F.one_hot(targets, num_classes=inputs.size(1))
        
        # Cross entropy loss component
        cross_entropy_loss = -targets * torch.log(inputs)
        
        # Focal loss component
        loss = self.alpha * (1 - inputs) ** self.gamma * cross_entropy_loss
        
        if self.reduction == 'mean':
            return loss.sum() / targets.size(0)
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
