# BERT MODEL WITH S-LEARNER

In [None]:
import os
import torch
import pandas as pd
import numpy as np
import logging
import warnings
import dowhy
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from dowhy import CausalModel
from econml.dml import DML
from econml.sklearn_extensions.linear_model import StatsModelsLinearRegression
from econml.metalearners import SLearner
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, matthews_corrcoef
from sklearn.model_selection import StratifiedKFold

# PREPROCESSING DATASET

In [None]:
# Load the dataset from CSV file
df = pd.read_csv("Womens Clothing E-Commerce Reviews.csv")

# Drop rows with missing values
df = df.dropna(subset=['Review Text'])
df = df.dropna(subset=['Age'])
df = df.dropna(subset=['Rating'])
df = df.dropna(subset=['Recommended IND'])

# Convert 'Rating' column to binary: 1 for ratings >= 4, 0 otherwise
# This is the treatment variable 
df['Rating'] = df['Rating'].apply(lambda x: 1 if x >= 4 else 0)

In [None]:
# Select random rows where 'Recommended IND' is 1
recommended_1 = df[df['Recommended IND'] == 1].sample(n=1000, random_state=42)

# Select random rows where 'Recommended IND' is 0, with replacement
recommended_0 = df[df['Recommended IND'] == 0].sample(n=1000, replace=True, random_state=42)

# Combine the two DataFrames
data_file = pd.concat([recommended_1, recommended_0])

In [None]:
def load_data(data_file):
    # Load the DataFrame from the data_file
    df = data_file
    
    # Extract the 'Review Text' and 'Recommended IND' columns
    texts = df['Review Text'].tolist()
    labels = df['Recommended IND'].tolist()
    treatment = df['Rating'].tolist()
    confounding = df['Age'].tolist()
    # Return the texts and labels
    return texts, labels, treatment, confounding

In [None]:
# Loading data from the data file
# and assigning it to variables 'texts' and 'labels'
texts, labels, treatment, confounding = load_data(data_file)


# CLASS DEFINITIONS FOR BERT

In [None]:
#Defines the TextClassification dataset for training 
class TextClassificationDatasetT(Dataset):
    def __init__(self, texts, labels, estimated_effects, tokenizer, max_length, padding=True, truncation=True):
        self.texts = texts
        self.labels = labels
        self.estimated_effects = estimated_effects
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        labels = self.labels[idx]
        effect =  self.estimated_effects[idx] 
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(labels),
            'effects': torch.tensor(effect, dtype=torch.float)
        }


In [None]:
#Defines the TextClassification dataset vor validation
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
            self.texts = texts
            self.labels = labels
            self.tokenizer = tokenizer
            self.max_length = max_length
    def __len__(self):
            return len(self.texts)
    def __getitem__(self, idx):
            text = self.texts[idx]
            labels = self.labels[idx]
            encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
            return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(labels)}

In [None]:
# Defines the BERT Classifier 
class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            x = self.dropout(pooled_output)
            logits = self.fc(x)
            return logits

# DEFINE BERT VARIABLES

In [None]:
# Initialize 
bert_model_name = 'bert-base-uncased'
num_classes = 2
max_length = 256
batch_size = 16
num_epochs = 1
learning_rate = 2e-5

In [None]:
# Define the number of folds for cross-validation
num_folds = 5

# Initialize cross-validation splitter
skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

In [None]:
#Define tokenizer
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

In [None]:
#Define device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier(bert_model_name, num_classes).to(device)

In [None]:
# Lists to store evaluation results
accuracy_scores = []
classification_reports = []
mcc_scores = []

# TRAINING AND EVALUATION FUNCTION

In [None]:
def train(model, data_loader, optimizer, scheduler, device):
    model.train()
    for batch in data_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        effects = batch['effects'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        weighted_losses = loss * effects
        loss = weighted_losses.mean()
        loss.backward()
        optimizer.step()
        scheduler.step()

In [None]:
# Evaluate the model 
def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    actual_labels = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs, dim=1)
            predictions.extend(preds.cpu().tolist())
            actual_labels.extend(labels.cpu().tolist())
    return accuracy_score(actual_labels, predictions), classification_report(actual_labels, predictions)

In [None]:
def predict(text, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs, dim=1)
    return preds.item()

# RESULTS

In [None]:
# Iterate over the folds
for fold, (train_index, val_index) in enumerate(skf.split(texts, labels)):
    print(f"Fold {fold + 1}/{num_folds}")

    # Split data into train and validation sets for this fold
    train_texts_fold = [texts[i] for i in train_index]
    train_labels_fold = [labels[i] for i in train_index]
    train_treatment_fold = [treatment[i] for i in train_index]
    train_confounding_fold = [confounding[i] for i in train_index]

    val_texts_fold = [texts[i] for i in val_index]
    val_labels_fold = [labels[i] for i in val_index]

    train_labels_fold = np.array(train_labels_fold)
    train_treatment_fold = np.array(train_treatment_fold)
    train_confounding_fold = np.array(train_confounding_fold).reshape(-1, 1)
    
    # Convert to PyTorch tensors
    train_labels_fold = torch.tensor(train_labels_fold, dtype=torch.long)
    train_treatment_fold = torch.tensor(train_treatment_fold)
    train_confounding_fold = torch.tensor(train_confounding_fold)

    #Initialize the SLearner with RandomForestRegressor as base model
    est = SLearner(overall_model=RandomForestRegressor())
    
    #Fit the SLearner model using the outcome and treatment
    est.fit(train_labels_fold, train_treatment_fold)
    #Apply the s-learner to the training data 
    individual_effects = est.effect(train_confounding_fold)
    train_slearner_fold = individual_effects    
    

    # Prepare datasets and dataloaders for this fold
    train_dataset_fold = TextClassificationDatasetT(train_texts_fold, train_labels_fold, train_slearner_fold, tokenizer, max_length)
    val_dataset_fold = TextClassificationDataset(val_texts_fold, val_labels_fold, tokenizer, max_length)
    train_dataloader_fold = DataLoader(train_dataset_fold, batch_size=batch_size, shuffle=True)
    val_dataloader_fold = DataLoader(val_dataset_fold, batch_size=batch_size)
    
    # Initialize and train model for this fold
    model = BERTClassifier(bert_model_name, num_classes).to(device)
    optimizer = AdamW(model.parameters(), lr=learning_rate, no_deprecation_warning=True)
    total_steps = len(train_dataloader_fold) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
    for epoch in range(num_epochs):
        train(model, train_dataloader_fold, optimizer, scheduler, device)

    #Iterate over the validation set and generate predictions
    predicted_labels = []
    for text in val_texts_fold:
        pred = predict(text, model, tokenizer, device)  # Use your predict function here
        predicted_labels.append(pred)
        
    # Calculate MCC
    mcc = matthews_corrcoef(val_labels_fold, predicted_labels)
    mcc_scores.append(mcc)

    # Evaluate model for this fold
    accuracy, report = evaluate(model, val_dataloader_fold, device)
    accuracy_scores.append(accuracy)
    classification_reports.append(report)

    print(f"Validation Accuracy: {accuracy:.4f}")
    print(report)
    print(f"MCC: {mcc:.4f}")

