In [None]:
!pip uninstall accelerate transformers

In [None]:
!pip install accelerate transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
import os
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer, AdamW, get_scheduler, BertModel, BertTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.decomposition import PCA
from scipy.spatial.distance import cosine
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, precision_recall_fscore_support, roc_auc_score
import gc

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("GPU not available, using CPU")

# Data Loading & Preprocessing

In [None]:
#file_path = '/kaggle/input/hospital-comments/Dataset_v6.csv' #Kaggle
file_path ='/content/drive/MyDrive/Reviews_Dataset/Dataset_v6.csv' #Colab

# Read the CSV file
df = pd.read_csv(file_path)

# Display the DataFrame
df.head()

In [None]:
# Clean and filter the dataset
columns_to_check = [
    'ProcessedValence', 'ProcessedUnit', 'ProcessedType', 'Gender', 'Ethnicity', 'Age', 'Comment', 'Hospital',
    'ProcessedUnit', 'Age', 'Employment Status', 'Access to Transportation', 'Income/Poverty Level'
]
df_cleaned = df.dropna(subset=columns_to_check)
df_filtered = df_cleaned[df_cleaned['Comment'].str.strip().astype(bool)]
df_final = df_filtered[~(df_filtered[columns_to_check].eq("").any(axis=1))]

In [None]:
# Drop unnecessary columns
columns_to_drop = ['CommentLength', 'Type', 'CleanedComment', 'ExperienceDate', 'ExperienceDateString', 'Valence',
                   'Unit', 'Code', 'Entities', 'ProcessedComment', 'CovidRelated', 'CovidPeriod', 'Day']
df_final.drop(columns_to_drop, axis=1, inplace=True)

In [None]:
# Encode categorical columns
label_encoder = LabelEncoder()
columns_to_encode = [
    'ProcessedUnit', 'ProcessedType', 'Gender', 'Ethnicity', 'Age', 'Hospital', 'ProcessedUnit', 'Age',
    'Employment Status', 'Access to Transportation', 'Income/Poverty Level', 'ProcessedValence'
]
for column in tqdm(columns_to_encode, desc="Encoding Columns"):
    if column in df_final.columns:
        df_final[f'Encoded{column.replace(" ", "_") if " " in column else column}'] = label_encoder.fit_transform(df_final[column])


In [None]:
# Select relevant columns
columns_to_keep = ['Comment', 'EncodedProcessedUnit', 'EncodedProcessedType', 'EncodedGender', 'EncodedEthnicity',
                   'EncodedAge', 'EncodedHospital', 'EncodedEmployment_Status', 'EncodedAccess_to_Transportation',
                   'EncodedIncome/Poverty_Level', 'EncodedProcessedValence']
df_encoded = df_final[columns_to_keep]
df_encoded.head()

In [None]:
# Split data into training, validation, and test sets
train_data, temp_data = train_test_split(df_encoded, test_size=0.3, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)

# Custom Dataset Class

In [None]:
class ReviewDataset(torch.utils.data.Dataset):

    def __init__(self, reviews, labels, genders, ethnicities, incomes, tokenizer, chunk_size=128, max_length=8192):
        self.tokenizer = tokenizer
        self.chunk_size = chunk_size
        self.max_length = max_length
        self.data = self.tokenize_data(reviews, labels, genders, ethnicities, incomes)

    def tokenize_data(self, reviews, labels, genders, ethnicities, incomes):
        tokenized_data = []
        for review, label, gender, ethnicity, income in zip(reviews, labels, genders, ethnicities, incomes):
            tokens = self.tokenizer.tokenize(review)
            tokens = tokens[:self.max_length]
            token_chunks = [tokens[i:i+self.chunk_size] for i in range(0, len(tokens), self.chunk_size)]

            for chunk in token_chunks:
                inputs = self.tokenizer.encode_plus(
                    ' '.join(chunk),
                    max_length=self.chunk_size,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                input_ids = inputs['input_ids'].squeeze().clone().detach()
                attention_mask = inputs['attention_mask'].squeeze().clone().detach()
                token_type_ids = inputs['token_type_ids'].squeeze().clone().detach() if 'token_type_ids' in inputs else torch.zeros_like(input_ids)

                tokenized_data.append({
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "token_type_ids": token_type_ids,
                    "label": label,
                    "gender": gender,
                    "ethnicity": ethnicity,
                    "income": income,
                    "long": len(token_chunks) > 1,
                    "numberOfWords": len(chunk)
                })

        return tokenized_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids": item["input_ids"],
            "attention_mask": item["attention_mask"],
            "token_type_ids": item["token_type_ids"],
            "labels": torch.tensor(item["label"]),
            "long": torch.tensor(item["long"], dtype=torch.bool),
            "numberOfWords": torch.tensor(item["numberOfWords"]),
            "encoded_gender": torch.tensor(item["gender"]),
            "encoded_ethnicity": torch.tensor(item["ethnicity"]),
            "encoded_income": torch.tensor(item["income"])
        }

# Fairness Loss

In [None]:
def compute_fairness_loss(logits, labels, genders, ethnicities, incomes):
    device = logits.device
    probs = torch.nn.Softmax(dim=1)(logits)
    predicted_labels = torch.argmax(probs, dim=1)

    def calculate_group_disparity(labels, predicted_labels, group_ids):
        unique_groups = torch.unique(group_ids)
        tpr_diff_sum = torch.tensor(0.0, device=device)
        fpr_diff_sum = torch.tensor(0.0, device=device)
        num_labels = 4

        for label in range(num_labels):
            for i in range(len(unique_groups)):
                for j in range(i + 1, len(unique_groups)):
                    group_i_mask = (group_ids == unique_groups[i]).to(device)
                    group_j_mask = (group_ids == unique_groups[j]).to(device)

                    true_positives_group_i = torch.sum((predicted_labels == label) & (labels == label) & group_i_mask)
                    true_positives_group_j = torch.sum((predicted_labels == label) & (labels == label) & group_j_mask)
                    false_positives_group_i = torch.sum((predicted_labels == label) & (labels != label) & group_i_mask)
                    false_positives_group_j = torch.sum((predicted_labels == label) & (labels != label) & group_j_mask)
                    positives_group_i = torch.sum((labels == label) & group_i_mask)
                    positives_group_j = torch.sum((labels == label) & group_j_mask)
                    negatives_group_i = torch.sum((labels != label) & group_i_mask)
                    negatives_group_j = torch.sum((labels != label) & group_j_mask)

                    tpr_group_i = true_positives_group_i.float() / positives_group_i.float() if positives_group_i != 0 else torch.tensor(0.0, device=device)
                    tpr_group_j = true_positives_group_j.float() / positives_group_j.float() if positives_group_j != 0 else torch.tensor(0.0, device=device)
                    fpr_group_i = false_positives_group_i.float() / negatives_group_i.float() if negatives_group_i != 0 else torch.tensor(0.0, device=device)
                    fpr_group_j = false_positives_group_j.float() / negatives_group_j.float() if negatives_group_j != 0 else torch.tensor(0.0, device=device)

                    tpr_diff = torch.abs(tpr_group_i - tpr_group_j)
                    fpr_diff = torch.abs(fpr_group_i - fpr_group_j)
                    tpr_diff_sum += tpr_diff
                    fpr_diff_sum += fpr_diff

        avg_tpr_diff = tpr_diff_sum / (num_labels * (len(unique_groups) * (len(unique_groups) - 1) / 2))
        avg_fpr_diff = fpr_diff_sum / (num_labels * (len(unique_groups) * (len(unique_groups) - 1) / 2))
        return avg_tpr_diff, avg_fpr_diff

    avg_tpr_diff_gender, avg_fpr_diff_gender = calculate_group_disparity(labels, predicted_labels, genders)
    avg_tpr_diff_ethnicity, avg_fpr_diff_ethnicity = calculate_group_disparity(labels, predicted_labels, ethnicities)
    avg_tpr_diff_income, avg_fpr_diff_income = calculate_group_disparity(labels, predicted_labels, incomes)

    avg_tpr_diff = (avg_tpr_diff_gender + avg_tpr_diff_ethnicity + avg_tpr_diff_income) / 3
    avg_fpr_diff = (avg_fpr_diff_gender + avg_fpr_diff_ethnicity + avg_fpr_diff_income) / 3

    fairness_loss = (avg_tpr_diff + avg_fpr_diff) / 2
    return fairness_loss

# Custom Trainer Class and Functions

In [None]:
def custom_collate(batch):
    collated = {}
    for key in batch[0].keys():
        collated[key] = torch.stack([item[key] for item in batch])
    return collated

In [None]:
class GradientReversalLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()

In [None]:
class AdversarialDebiasingModel(nn.Module):
    def __init__(self, bert_model, num_labels, num_protected_attributes):
        super(AdversarialDebiasingModel, self).__init__()
        self.bert = bert_model
        self.classifier = nn.Linear(bert_model.config.hidden_size, num_labels)
        self.adv_classifier = nn.Linear(bert_model.config.hidden_size, num_protected_attributes)
        self.grl = GradientReversalLayer.apply

    def forward(self, input_ids, attention_mask, token_type_ids=None, use_grl=True, **kwargs):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs[1]
        logits = self.classifier(pooled_output)
        if use_grl:
            adv_logits = self.adv_classifier(self.grl(pooled_output))
            return logits, adv_logits
        else:
            return logits

In [None]:
class CustomTrainer:
    def __init__(self, model, train_dataset, eval_dataset, test_dataset, optimizer, criterion, tokenizer, sensitive_features, save_dir="model_checkpoints",
                 fairness_weight=0.1, adv_weight=0.1, use_grl=True):
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer = optimizer
        self.criterion = criterion
        self.sensitive_features = sensitive_features
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.train_loader = self.get_train_dataloader()
        self.val_loader = self.get_eval_dataloader()
        self.test_dataset = test_dataset
        self.test_loader = self.get_test_dataloader()
        self.save_dir = save_dir
        self.history = {
            'train_loss': [], 'val_loss': [], 'accuracy': [],
            'f1': [], 'precision': [], 'recall': [], 'auc': [],
            'equalized_odds_genders': [], 'equalized_odds_ethnicity': [],
            'equalized_odds_age': [], 'weat_score': []
        }
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        print(f"Using device: {self.device}")
        self.fairness_weight = fairness_weight
        self.adv_weight = adv_weight
        self.use_grl = use_grl

        # Disable adversarial layers if adv_weight is 0
        if self.adv_weight == 0:
            for param in self.model.adv_classifier.parameters():
                param.requires_grad = False

        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

    def get_train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=64,
            collate_fn=custom_collate,
            shuffle=True,
        )

    def get_eval_dataloader(self, eval_dataset=None):
        eval_dataset = self.eval_dataset
        return DataLoader(
            eval_dataset,
            batch_size=32,
            collate_fn=custom_collate,
        )

    def get_test_dataloader(self, test_dataset=None):
        test_dataset = self.test_dataset
        return DataLoader(
            test_dataset,
            batch_size=32,
            collate_fn=custom_collate,
        )


    def calculate_losses(self, logits, adv_logits, batch):
        labels = batch['labels'].to(self.device)
        long_flags = batch["long"].to(self.device)
        num_words = batch["numberOfWords"].to(self.device)
        genders = batch["encoded_gender"].unsqueeze(1).float().to(self.device)
        ethnicities = batch["encoded_ethnicity"].unsqueeze(1).float().to(self.device)
        incomes = batch["encoded_income"].unsqueeze(1).float().to(self.device)

        # Main task loss
        loss_task = self.criterion(logits, labels)
        loss_adv = 0
        if self.adv_weight > 0:
            protected_labels = torch.cat((genders, ethnicities, incomes), dim=1)
            loss_adv = self.criterion(adv_logits, protected_labels)
        # Fairness loss
        fairness_loss = 0
        if self.fairness_weight > 0:
            fairness_loss = self.compute_fairness_loss(logits, labels, genders, ethnicities, incomes)
        # Adjust loss for long reviews
        adjusted_loss = self.adjust_loss_for_long_reviews(loss_task, long_flags, num_words)
        # Combine losses
        total_loss = adjusted_loss + self.fairness_weight * fairness_loss - self.adv_weight * loss_adv
        return total_loss

    def compute_fairness_loss(self, logits, labels, genders, ethnicities, incomes):
        probs = torch.nn.Softmax(dim=1)(logits)
        predicted_labels = torch.argmax(probs, dim=1)

        def calculate_group_disparity(labels, predicted_labels, group_ids):
            unique_groups = torch.unique(group_ids)
            tpr_diff_sum = torch.tensor(0.0, device=device)
            fpr_diff_sum = torch.tensor(0.0, device=device)
            num_labels = 4

            for label in range(num_labels):
                for i in range(len(unique_groups)):
                    for j in range(i + 1, len(unique_groups)):
                        group_i_mask = (group_ids == unique_groups[i]).to(device)
                        group_j_mask = (group_ids == unique_groups[j]).to(device)

                        true_positives_group_i = torch.sum((predicted_labels == label) & (labels == label) & group_i_mask)
                        true_positives_group_j = torch.sum((predicted_labels == label) & (labels == label) & group_j_mask)
                        false_positives_group_i = torch.sum((predicted_labels == label) & (labels != label) & group_i_mask)
                        false_positives_group_j = torch.sum((predicted_labels == label) & (labels != label) & group_j_mask)
                        positives_group_i = torch.sum((labels == label) & group_i_mask)
                        positives_group_j = torch.sum((labels == label) & group_j_mask)
                        negatives_group_i = torch.sum((labels != label) & group_i_mask)
                        negatives_group_j = torch.sum((labels != label) & group_j_mask)

                        tpr_group_i = true_positives_group_i.float() / positives_group_i.float() if positives_group_i != 0 else torch.tensor(0.0, device=device)
                        tpr_group_j = true_positives_group_j.float() / positives_group_j.float() if positives_group_j != 0 else torch.tensor(0.0, device=device)
                        fpr_group_i = false_positives_group_i.float() / negatives_group_i.float() if negatives_group_i != 0 else torch.tensor(0.0, device=device)
                        fpr_group_j = false_positives_group_j.float() / negatives_group_j.float() if negatives_group_j != 0 else torch.tensor(0.0, device=device)

                        tpr_diff = torch.abs(tpr_group_i - tpr_group_j)
                        fpr_diff = torch.abs(fpr_group_i - fpr_group_j)
                        tpr_diff_sum += tpr_diff
                        fpr_diff_sum += fpr_diff

            avg_tpr_diff = tpr_diff_sum / (num_labels * (len(unique_groups) * (len(unique_groups) - 1) / 2))
            avg_fpr_diff = fpr_diff_sum / (num_labels * (len(unique_groups) * (len(unique_groups) - 1) / 2))
            return avg_tpr_diff, avg_fpr_diff

        avg_tpr_diff_gender, avg_fpr_diff_gender = calculate_group_disparity(labels, predicted_labels, genders)
        avg_tpr_diff_ethnicity, avg_fpr_diff_ethnicity = calculate_group_disparity(labels, predicted_labels, ethnicities)
        avg_tpr_diff_income, avg_fpr_diff_income = calculate_group_disparity(labels, predicted_labels, incomes)

        avg_tpr_diff = (avg_tpr_diff_gender + avg_tpr_diff_ethnicity + avg_tpr_diff_income) / 3
        avg_fpr_diff = (avg_fpr_diff_gender + avg_fpr_diff_ethnicity + avg_fpr_diff_income) / 3

        fairness_loss = (avg_tpr_diff + avg_fpr_diff) / 2
        return fairness_loss

    def adjust_loss_for_long_reviews(self, loss, long_flags, num_words):
        adjusted_loss = loss / num_words.float()
        adjusted_loss = torch.where(long_flags, adjusted_loss, loss)
        return adjusted_loss.mean()

    def train(self, num_epochs):
        total_start_time = time.time()
        for epoch in range(num_epochs):
            epoch_start_time = time.time()
            self.model.train()
            train_loss = 0

            progress_bar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
            for batch_idx, batch in progress_bar:
                self.optimizer.zero_grad()
                inputs = {k: v.to(self.device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'token_type_ids']}
                if self.use_grl and self.adv_weight > 0:
                    logits, adv_logits = self.model(**inputs, use_grl=self.use_grl)
                else:
                    logits = self.model(**inputs, use_grl=self.use_grl)
                    adv_logits = None
                total_loss = self.calculate_losses(logits, adv_logits, batch)
                total_loss.backward()
                self.optimizer.step()
                train_loss += total_loss.item()

                progress_bar.set_postfix({
                    'loss': f'{total_loss.item():.4f}',
                    'progress': f'{(batch_idx+1)/len(self.train_loader)*100:.2f}%'
                })

                del inputs, logits, adv_logits, total_loss
                torch.cuda.empty_cache()
                gc.collect()

            epoch_end_time = time.time()
            epoch_runtime = epoch_end_time - epoch_start_time
            train_loss /= len(self.train_loader)



            val_loss, accuracy, f1, precision, recall, auc = self.evaluate()
            equalized_odds_genders, equalized_odds_ethnicity, equalized_odds_age = self.calculate_equalized_odds()
            weat_score = self.calculate_weat()

            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['accuracy'].append(accuracy)
            self.history['f1'].append(f1)
            self.history['precision'].append(precision)
            self.history['recall'].append(recall)
            self.history['auc'].append(auc)
            self.history['equalized_odds_genders'].append(equalized_odds_genders)
            self.history['equalized_odds_ethnicity'].append(equalized_odds_ethnicity)
            self.history['equalized_odds_age'].append(equalized_odds_age)
            self.history['weat_score'].append(weat_score)

            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Training Loss: {train_loss:.4f}")
            print(f"Validation Loss: {val_loss:.4f}")
            print(f"Accuracy: {accuracy:.4f}")
            print(f"F1 Score: {f1:.4f}")
            print(f"Precision: {precision:.4f}")
            print(f"Recall: {recall:.4f}")
            print(f"AUC: {auc:.4f}")
            print(f"Equalized Odds (Genders): {equalized_odds_genders:.4f}")
            print(f"Equalized Odds (Ethnicity): {equalized_odds_ethnicity:.4f}")
            print(f"Equalized Odds (Income): {equalized_odds_age:.4f}")
            print(f"WEAT Score: {weat_score:.4f}")
            print(f"Epoch duration: {epoch_runtime:.2f} seconds")

            # Save model checkpoint
            checkpoint_path = os.path.join(self.save_dir, f"epoch_{epoch+1}.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'accuracy': accuracy,
                'f1': f1,
                'precision': precision,
                'recall': recall,
                'auc': auc,
                'equalized_odds_genders': equalized_odds_genders,
                'equalized_odds_ethnicity': equalized_odds_ethnicity,
                'equalized_odds_age': equalized_odds_age,
                'weat_score': weat_score
            }, checkpoint_path)

            torch.cuda.empty_cache()
            gc.collect()

        total_end_time = time.time()
        total_duration = total_end_time - total_start_time
        print(f"\nTotal training time: {total_duration:.2f} seconds")

    def plot_metrics(self):
        epochs = range(1, len(self.history['train_loss']) + 1)

        plt.figure(figsize=(15, 10))

        plt.subplot(2, 2, 1)
        plt.plot(epochs, self.history['train_loss'], label='Train Loss')
        plt.plot(epochs, self.history['val_loss'], label='Validation Loss')
        plt.title('Training and Validation Loss')
        plt.legend()

        plt.subplot(2, 2, 2)
        plt.plot(epochs, self.history['accuracy'], label='Accuracy')
        plt.plot(epochs, self.history['f1'], label='F1 Score')
        plt.title('Accuracy and F1 Score')
        plt.legend()

        plt.subplot(2, 2, 3)
        plt.plot(epochs, self.history['precision'], label='Precision')
        plt.plot(epochs, self.history['recall'], label='Recall')
        plt.plot(epochs, self.history['auc'], label='AUC')
        plt.title('Precision, Recall, and AUC')
        plt.legend()

        plt.subplot(2, 2, 4)
        plt.plot(epochs, self.history['equalized_odds_genders'], label='EO Genders')
        plt.plot(epochs, self.history['equalized_odds_ethnicity'], label='EO Ethnicity')
        plt.plot(epochs, self.history['equalized_odds_age'], label='EO Income')
        plt.plot(epochs, self.history['weat_score'], label='WEAT Score')
        plt.title('Fairness Metrics')
        plt.legend()

        plt.tight_layout()
        plt.show()

    def evaluate_test_set(self):
        self.model.eval()
        test_loss = 0
        all_labels = []
        all_preds = []
        all_probs = []
        all_protected_labels = []
        current_review_preds = []
        current_review_probs = []
        current_review_label = None
        current_review_protected = None

        with torch.no_grad():
            for batch in tqdm(self.test_loader, desc="Evaluating Test Set"):
                inputs = {k: v.to(self.device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'token_type_ids']}
                labels = batch['labels'].to(self.device)
                long_flags = batch['long'].to(self.device)
                protected_labels = torch.stack([batch['encoded_gender'], batch['encoded_ethnicity'], batch['encoded_income']], dim=1).to(self.device)

                logits, _ = self.model(**inputs)
                loss = self.criterion(logits, labels)
                test_loss += loss.item()

                probs = torch.softmax(logits, dim=1)
                preds = torch.argmax(logits, dim=1)

                for i in range(len(labels)):
                    if not long_flags[i] or (long_flags[i] and not current_review_preds):
                        if current_review_label is not None:
                            all_labels.append(current_review_label)
                            all_preds.append(np.argmax(np.mean(current_review_preds, axis=0)))
                            all_probs.append(np.mean(current_review_probs, axis=0))
                            all_protected_labels.append(current_review_protected)

                        current_review_label = labels[i].cpu().item()
                        current_review_preds = [preds[i].cpu().numpy()]
                        current_review_probs = [probs[i].cpu().numpy()]
                        current_review_protected = protected_labels[i].cpu().numpy()
                    else:
                        current_review_preds.append(preds[i].cpu().numpy())
                        current_review_probs.append(probs[i].cpu().numpy())

        if current_review_label is not None:
            all_labels.append(current_review_label)
            all_preds.append(np.argmax(np.mean(current_review_preds, axis=0)))
            all_probs.append(np.mean(current_review_probs, axis=0))
            all_protected_labels.append(current_review_protected)

        test_loss /= len(self.test_loader)

        all_labels = np.array(all_labels)
        all_preds = np.array(all_preds)
        all_probs = np.array(all_probs)
        all_protected_labels = np.array(all_protected_labels)

        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average='macro')
        precision = precision_score(all_labels, all_preds, average='macro')
        recall = recall_score(all_labels, all_preds, average='macro')
        auc = roc_auc_score(all_labels, all_probs, average='macro', multi_class='ovr')
        equalized_odds_genders, equalized_odds_ethnicity, equalized_odds_income = self.calculate_equalized_odds2(all_labels, all_preds, all_protected_labels)
        weat_score = self.calculate_weat()

        print("\nTest Set Evaluation:")
        print(f"Test Loss: {test_loss:.4f}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"F1 Score: {f1:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"AUC: {auc:.4f}")
        print(f"Equalized Odds (Genders): {equalized_odds_genders:.4f}")
        print(f"Equalized Odds (Ethnicity): {equalized_odds_ethnicity:.4f}")
        print(f"Equalized Odds (Income): {equalized_odds_income:.4f}")
        print(f"WEAT Score: {weat_score:.4f}")

        return test_loss, accuracy, f1, precision, recall, auc, equalized_odds_genders, equalized_odds_ethnicity, equalized_odds_income, weat_score

    def evaluate(self):
        self.model.eval()
        val_loss = 0
        all_labels = []
        all_preds = []
        all_probs = []
        current_review_preds = []
        current_review_probs = []
        current_review_label = None

        with torch.no_grad():
            for batch in self.val_loader:
                inputs = {k: v.to(self.device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'token_type_ids']}
                labels = batch['labels'].to(self.device)
                long_flags = batch['long'].to(self.device)

                logits, _ = self.model(**inputs)
                loss = self.criterion(logits, labels)
                val_loss += loss.item()

                probs = torch.softmax(logits, dim=1)
                preds = torch.argmax(logits, dim=1)

                for i in range(len(labels)):
                    if not long_flags[i] or (long_flags[i] and not current_review_preds):
                        if current_review_label is not None:
                            all_labels.append(current_review_label)
                            all_preds.append(np.argmax(np.mean(current_review_preds, axis=0)))
                            all_probs.append(np.mean(current_review_probs, axis=0))

                        current_review_label = labels[i].cpu().item()
                        current_review_preds = [preds[i].cpu().numpy()]
                        current_review_probs = [probs[i].cpu().numpy()]
                    else:
                        current_review_preds.append(preds[i].cpu().numpy())
                        current_review_probs.append(probs[i].cpu().numpy())

        if current_review_label is not None:
            all_labels.append(current_review_label)
            all_preds.append(np.argmax(np.mean(current_review_preds, axis=0)))
            all_probs.append(np.mean(current_review_probs, axis=0))

        val_loss /= len(self.val_loader)

        all_labels = np.array(all_labels)
        all_preds = np.array(all_preds)
        all_probs = np.array(all_probs)

        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average='macro')
        precision = precision_score(all_labels, all_preds, average='macro')
        recall = recall_score(all_labels, all_preds, average='macro')
        auc = roc_auc_score(all_labels, all_probs, average='macro', multi_class='ovr')

        return val_loss, accuracy, f1, precision, recall, auc

    def calculate_equalized_odds(self):
        self.model.eval()
        all_labels = []
        all_preds = []
        all_protected_labels = []
        current_review_preds = []
        current_review_label = None
        current_review_protected = None

        with torch.no_grad():
            for batch in self.val_loader:
                inputs = {k: v.to(self.device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'token_type_ids']}
                labels = batch['labels'].to(self.device)
                long_flags = batch['long'].to(self.device)
                protected_labels = torch.stack([batch['encoded_gender'], batch['encoded_ethnicity'], batch['encoded_income']], dim=1).to(self.device)

                logits, _ = self.model(**inputs)
                preds = torch.argmax(logits, dim=1)

                for i in range(len(labels)):
                    if not long_flags[i] or (long_flags[i] and not current_review_preds):
                        if current_review_label is not None:
                            all_labels.append(current_review_label)
                            all_preds.append(np.argmax(np.mean(current_review_preds, axis=0)))
                            all_protected_labels.append(current_review_protected)

                        current_review_label = labels[i].cpu().item()
                        current_review_preds = [preds[i].cpu().numpy()]
                        current_review_protected = protected_labels[i].cpu().numpy()
                    else:
                        current_review_preds.append(preds[i].cpu().numpy())

            if current_review_label is not None:
                all_labels.append(current_review_label)
                all_preds.append(np.argmax(np.mean(current_review_preds, axis=0)))
                all_protected_labels.append(current_review_protected)

        all_labels = np.array(all_labels)
        all_preds = np.array(all_preds)
        all_protected_labels = np.array(all_protected_labels)

        results = []
        for i in range(all_protected_labels.shape[1]):
            tpr_diffs = []
            fpr_diffs = []
            unique_groups = np.unique(all_protected_labels[:, i])
            for label in range(4):  # Assuming 4 sentiment labels
                for group_a in unique_groups:
                    for group_b in unique_groups:
                        if group_a >= group_b:
                            continue
                        group_a_indices = all_protected_labels[:, i] == group_a
                        group_b_indices = all_protected_labels[:, i] == group_b

                        tpr_a = np.sum((all_preds[group_a_indices] == label) & (all_labels[group_a_indices] == label)) / np.sum(all_labels[group_a_indices] == label)
                        tpr_b = np.sum((all_preds[group_b_indices] == label) & (all_labels[group_b_indices] == label)) / np.sum(all_labels[group_b_indices] == label)
                        fpr_a = np.sum((all_preds[group_a_indices] == label) & (all_labels[group_a_indices] != label)) / np.sum(all_labels[group_a_indices] != label)
                        fpr_b = np.sum((all_preds[group_b_indices] == label) & (all_labels[group_b_indices] != label)) / np.sum(all_labels[group_b_indices] != label)

                        tpr_diffs.append(abs(tpr_a - tpr_b))
                        fpr_diffs.append(abs(fpr_a - fpr_b))

            results.append((np.mean(tpr_diffs) + np.mean(fpr_diffs)) / 2)

        return results[0], results[1], results[2]  # Assuming the order is gender, ethnicity, income


    def calculate_equalized_odds2(self, all_labels, all_preds, all_protected_labels):
        results = []
        for i in range(all_protected_labels.shape[1]):
            tpr_diffs = []
            fpr_diffs = []
            unique_groups = np.unique(all_protected_labels[:, i])
            for label in range(4):  # Assuming 4 sentiment labels
                for group_a in unique_groups:
                    for group_b in unique_groups:
                        if group_a >= group_b:
                            continue
                        group_a_indices = all_protected_labels[:, i] == group_a
                        group_b_indices = all_protected_labels[:, i] == group_b

                        tpr_a = np.sum((all_preds[group_a_indices] == label) & (all_labels[group_a_indices] == label)) / np.sum(all_labels[group_a_indices] == label)
                        tpr_b = np.sum((all_preds[group_b_indices] == label) & (all_labels[group_b_indices] == label)) / np.sum(all_labels[group_b_indices] == label)
                        fpr_a = np.sum((all_preds[group_a_indices] == label) & (all_labels[group_a_indices] != label)) / np.sum(all_labels[group_a_indices] != label)
                        fpr_b = np.sum((all_preds[group_b_indices] == label) & (all_labels[group_b_indices] != label)) / np.sum(all_labels[group_b_indices] != label)

                        tpr_diffs.append(abs(tpr_a - tpr_b))
                        fpr_diffs.append(abs(fpr_a - fpr_b))

            results.append((np.mean(tpr_diffs) + np.mean(fpr_diffs)) / 2)

        return results[0], results[1], results[2]


    def calculate_weat(self):
        # Define target and attribute words
        target_words_1 = ['man', 'male', 'boy', 'father', 'son', 'brother', 'uncle', 'husband', 'gentleman', 'sir']
        target_words_2 = ['woman', 'female', 'girl', 'mother', 'daughter', 'sister', 'aunt', 'wife', 'lady', 'madam']
        attribute_words_1 = ['doctor', 'nurse', 'therapist', 'surgeon', 'physician', 'specialist', 'practitioner', 'clinician', 'technician', 'aide']
        attribute_words_2 = ['teacher', 'engineer', 'scientist', 'librarian', 'manager', 'administrator', 'director', 'supervisor', 'assistant', 'worker']

        def mean_cosine_similarity(words, attributes):
            # Get the vocabulary of the tokenizer
            vocab = self.model.bert.embeddings.word_embeddings.num_embeddings
            # Get the embedding matrix
            embedding_matrix = self.model.bert.embeddings.word_embeddings.weight.detach().cpu().numpy()

            # Convert words and attributes to their respective indices in the vocabulary
            word_indices = [self.tokenizer.convert_tokens_to_ids(word) for word in words]
            attr_indices = [self.tokenizer.convert_tokens_to_ids(attr) for attr in attributes]

            # Calculate cosine similarities
            similarities = [
                1 - cosine(embedding_matrix[word_idx], embedding_matrix[attr_idx])
                for word_idx in word_indices
                for attr_idx in attr_indices
                if word_idx < vocab and attr_idx < vocab  # Ensure the word is in the vocabulary
            ]
            return np.mean(similarities)

        s1 = mean_cosine_similarity(target_words_1, attribute_words_1) - mean_cosine_similarity(target_words_1, attribute_words_2)
        s2 = mean_cosine_similarity(target_words_2, attribute_words_1) - mean_cosine_similarity(target_words_2, attribute_words_2)
        return s1 - s2



In [None]:
class CustomTrainer(Trainer):
    def __init__(self, *args, loss_fct=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = loss_fct if loss_fct is not None else nn.CrossEntropyLoss()

    def get_train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            collate_fn=custom_collate,
            shuffle=True,
            num_workers=self.args.dataloader_num_workers,
        )

    def get_eval_dataloader(self, eval_dataset=None):
        eval_dataset = eval_dataset or self.eval_dataset
        return DataLoader(
            eval_dataset,
            batch_size=self.args.eval_batch_size,
            collate_fn=custom_collate,
            num_workers=self.args.dataloader_num_workers,
        )

    def train(self, **kwargs):
        super().train(**kwargs)
        self.save_model(self.args.output_dir)  # Save the model after each epoch
        torch.cuda.empty_cache()  # Clear the GPU cache

        # Evaluate after each epoch if evaluation is enabled
        if self.args.evaluation_strategy == "epoch":
            self.evaluate()

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("label")
        long_flags = inputs.pop("long")
        num_words = inputs.pop("numberOfWords")
        genders = inputs.pop("encoded_gender")
        ethnicities = inputs.pop("encoded_ethnicity")
        incomes = inputs.pop("encoded_income")
        outputs = model(**inputs)
        logits = outputs.logits
        loss = self.loss_fct(logits, labels)
        fairness_loss = compute_fairness_loss(inputs, logits, labels, genders, ethnicities, incomes)
        fairness_weight = 0
        total_loss = (loss / num_words.sum()) + (fairness_weight * fairness_loss)
        return (total_loss, outputs) if return_outputs else total_loss



        self.log(metrics)
        print("dddd")
        print(metrics)
        return metrics

# Training Config

In [None]:
def get_optimizer(model, optimizer_name, learning_rate=0.1, weight_decay=0.05):
    if optimizer_name == "adamw":
        return AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_name == "sgd":
        return torch.optim.SGD(model.parameters(), lr=learning_rate)
    else:
        raise ValueError(f"Unsupported optimizer: {optimizer_name}")

In [None]:
def get_loss_function(loss_name):
    if loss_name == "cross_entropy":
        return nn.CrossEntropyLoss()
    elif loss_name == "mse":
        return nn.MSELoss()
    else:
        raise ValueError(f"Unsupported loss function: {loss_name}")

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1).numpy()

    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    accuracy = accuracy_score(labels, predictions)

    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [None]:
# Configuration Parameters
# base_model_name = 'roberta-base'
# optimizer_name = 'adamw'
# loss_function_name = 'cross_entropy'
# learning_rate = 5e-3
# fairness_weight = 0.3
base_model_name = 'bert-base-uncased'
bert_model = BertModel.from_pretrained(base_model_name)


In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

In [None]:
train_dataset = ReviewDataset(train_data['Comment'].values, train_data['EncodedProcessedValence'].values,
                              train_data['EncodedGender'].values, train_data['EncodedEthnicity'].values, train_data['EncodedIncome/Poverty_Level'].values, tokenizer)
val_dataset = ReviewDataset(val_data['Comment'], val_data['EncodedProcessedValence'].values,
                            train_data['EncodedGender'].values, train_data['EncodedEthnicity'].values, train_data['EncodedIncome/Poverty_Level'].values, tokenizer)
test_dataset = ReviewDataset(test_data['Comment'], test_data['EncodedProcessedValence'].values,
                             train_data['EncodedGender'].values, train_data['EncodedEthnicity'].values, train_data['EncodedIncome/Poverty_Level'].values, tokenizer)

In [None]:
model = AdversarialDebiasingModel(bert_model, num_labels=4, num_protected_attributes=3)
optimizer = optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

In [None]:
trainer = CustomTrainer(model, train_dataset, val_dataset, test_dataset, optimizer, criterion, tokenizer, sensitive_features=['gender', 'ethnicity', 'income'], save_dir='/content/drive/MyDrive/Reviews_Dataset/model_checkpoints', fairness_weight=0.3, adv_weight=0, use_grl=False)

In [None]:
trainer.train(num_epochs=10)

In [None]:
trainer.plot_metrics()

In [None]:
trainer.evaluate_test_set()

In [None]:
trainer.train_from_checkpoint("model_checkpoints/epoch_5.pth", num_epochs=10)

In [None]:
# Load the pre-trained model without specifying num_labels
model = AutoModelForSequenceClassification.from_pretrained(base_model_name, ignore_mismatched_sizes=True)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

# # Modify the classifier to fit the number of labels
# model.classifier = nn.Linear(model.config.hidden_size, 4)
# model.num_labels = 4

# # # Initialize the classifier weights
# model.classifier.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
# model.classifier.bias.data.zero_()

In [None]:
train_dataset = ReviewDataset(train_data['Comment'].values, train_data['EncodedProcessedValence'].values,
                              train_data['EncodedGender'].values, train_data['EncodedEthnicity'].values, train_data['EncodedIncome/Poverty_Level'].values, tokenizer)
val_dataset = ReviewDataset(val_data['Comment'], val_data['EncodedProcessedValence'].values,
                            train_data['EncodedGender'].values, train_data['EncodedEthnicity'].values, train_data['EncodedIncome/Poverty_Level'].values, tokenizer)
test_dataset = ReviewDataset(test_data['Comment'], test_data['EncodedProcessedValence'].values,
                             train_data['EncodedGender'].values, train_data['EncodedEthnicity'].values, train_data['EncodedIncome/Poverty_Level'].values, tokenizer)

In [None]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir='/kaggle/working/resultsR',
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=20,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='/kaggle/working/logsR',
    logging_steps=10,
    save_total_limit=1,
    load_best_model_at_end=True,
    learning_rate=learning_rate,
)

In [None]:
# Create the custom trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    optimizers=(get_optimizer(model, optimizer_name, learning_rate), None)
)

In [None]:
API_KEY = ""

In [None]:
# Train the model
trainer.train()

In [None]:
# Evaluate the model on the test set
eval_result = trainer.evaluate(eval_dataset=test_dataset)
print(f"Evaluation Results: {eval_result}")

In [None]:
trainer.train()

In [None]:
eval_result = trainer.evaluate(eval_dataset=test_dataset)
print(f"Evaluation Results: {eval_result}")