In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Model, load_model

from transformers import TFBertModel, BertTokenizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd

import matplotlib.pyplot as plt
from lime.lime_text import LimeTextExplainer
import shap
from collections import defaultdict
import itertools

In [None]:
# load and preprocess data: cali housing

cali_housing_path = '../data/California_Houses.csv'
RANDOM_SEED = 492
cali_df = pd.read_csv(cali_housing_path)
y_series = cali_df['Median_House_Value']
y = pd.DataFrame(y_series, columns=['Median_House_Value'])
features = [col for col in cali_df.columns if col != 'Median_House_Value']
X = cali_df[features]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_SEED)
X_train = X_train.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
X_test = X_test.reset_index(drop=True)
y_test = y_test.reset_index(drop=True)


In [None]:
# preprocessing

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

X_train_scaled = pd.DataFrame(X_train_scaled, columns=X_train.columns)
X_test_scaled = pd.DataFrame(X_test_scaled, columns=X_test.columns)

# make the data into strings
X_train_texts = X_train_scaled.astype(str).apply(' '.join, axis=1).tolist()
X_test_texts = X_test_scaled.astype(str).apply(' '.join, axis=1).tolist()

In [None]:
# Load the pre-trained BERT model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = TFBertModel.from_pretrained('bert-base-uncased')

In [None]:
class BayesianDense(tf.keras.layers.Layer):
    def __init__(self, units, prior_stddev=1.0, **kwargs):
        super(BayesianDense, self).__init__(**kwargs)
        self.units = units
        self.prior_stddev = prior_stddev
        self.dense_flipout = tfp.layers.DenseFlipout(
            units,
            activation='relu',
            kernel_divergence_fn=self.kl_divergence_fn
        )
        
    def kl_divergence_fn(self, q, p, _):
        return tfp.distributions.kl_divergence(q, p) / tf.cast(tf.reduce_prod(q.batch_shape), tf.float32)

    def call(self, inputs):
        return self.dense_flipout(inputs)
    

In [None]:
class BERTRegression(Model):
    def __init__(self, max_length, dense_size, dropout_rate=0.1, num_samples=10, num_features=13):
        super(BERTRegression, self).__init__()
        self.max_length = max_length
        self.bert_model = TFBertModel.from_pretrained('bert-base-uncased')
        self.bert_model.trainable = True
        self.dropout = Dropout(dropout_rate)
        self.dense_bayes_layer = BayesianDense(dense_size)
        self.output_layer = Dense(1, activation='linear')
        self.num_mc_samples = num_samples
        self.explainer = LimeTextExplainer()
        self.num_features = num_features
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def call(self, inputs, training=True):
        input_ids, attention_mask, token_type_ids = inputs
        bert_output = self.bert_model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[1]
        dropout_output = self.dropout(bert_output, training=training)
        hidden_output = self.dense_bayes_layer(dropout_output)
        output = self.output_layer(hidden_output)
        return output
    
    def predict_with_uncertainty(self, inputs):
        predictions = tf.stack([self(inputs) for _ in range(self.num_samples)])
        prediction_mean = tf.reduce_mean(predictions, axis=0)
        prediction_stddev = tf.math.reduce_std(predictions, axis=0)
        return prediction_mean, prediction_stddev

    def explain_lime(self, text_instance):
        explanations = []

        def predict_function(texts):
            inputs = self.tokenizer.batch_encode_plus(
                texts,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='tf'
            )
            outputs = self.predict_with_uncertainty([
                inputs['input_ids'],
                inputs['attention_mask'],
                inputs['token_type_ids']
            ])
            return outputs[0].numpy()

        exp = self.explainer.explain_instance(
            text_instance,
            predict_function,
            num_features=self.num_features
        )
        explanations.append(exp)
        return explanations