In [None]:
import gc
import tensorflow as tf
from tqdm.notebook import tqdm
import numpy as np
# Release memory
tf.keras.backend.clear_session()
gc.collect()

In [None]:
import pandas as pd
from datasets import load_dataset

dataset = load_dataset('ag_news')

# Create train, validation, and test splits
train_dataset = dataset['train'].select(range(40000))
valid_dataset = dataset['train'].select(range(40000, 45000))
test_dataset = dataset['test'].select(range(5000))

# train_df = pd.read_csv('/kaggle/input/imdb-dataset-sentiment-analysis-in-csv-format/Train.csv')
# valid_df = pd.read_csv('/kaggle/input/imdb-dataset-sentiment-analysis-in-csv-format/Valid.csv')
# test_df = pd.read_csv('/kaggle/input/imdb-dataset-sentiment-analysis-in-csv-format/Test.csv')

# train_df = train_df.head(10)
# valid_df = valid_df.head(2)
# test_df = test_df.head(2)

In [None]:
train_dataset.shape, valid_dataset.shape, test_dataset.shape

In [None]:
from transformers import BertTokenizer, TFBertModel
import tensorflow as tf

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Function to tokenize texts
def tokenize_texts(texts, max_length=256):
    return tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors='tf'
    )

# X_train_texts = train_df['text'].tolist()
# X_valid_texts = valid_df['text'].tolist()
# X_test_texts = test_df['text'].tolist()
X_train_texts = [item['text'] for item in train_dataset]
X_valid_texts = [item['text'] for item in valid_dataset]
X_test_texts = [item['text'] for item in test_dataset]
train_encodings = tokenize_texts(X_train_texts)
valid_encodings = tokenize_texts(X_valid_texts)
test_encodings = tokenize_texts(X_test_texts)

# train_labels = tf.convert_to_tensor(train_df['label'].values, dtype=tf.int32)
# valid_labels = tf.convert_to_tensor(valid_df['label'].values, dtype=tf.int32)
# test_labels = tf.convert_to_tensor(test_df['label'].values, dtype=tf.int32)
train_labels = tf.convert_to_tensor([item['label'] for item in train_dataset], dtype=tf.int32)
valid_labels = tf.convert_to_tensor([item['label'] for item in valid_dataset], dtype=tf.int32)
test_labels = tf.convert_to_tensor([item['label'] for item in test_dataset], dtype=tf.int32)

In [None]:
class BertClassifier(tf.keras.Model):
    def __init__(self, bert_model):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.classifier = tf.keras.layers.Dense(4, activation='softmax')

    def call(self, inputs, **kwargs):
        outputs = self.bert(**inputs)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.classifier(cls_output)
        
bert_model = TFBertModel.from_pretrained('bert-base-uncased')
bert_classifier = BertClassifier(bert_model)

bert_classifier.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((
    dict(train_encodings),
    train_labels
))

test_dataset = tf.data.Dataset.from_tensor_slices((
    dict(test_encodings),
    test_labels
))

valid_dataset = tf.data.Dataset.from_tensor_slices((
    dict(valid_encodings),
    valid_labels
))

# Batch the datasets
batch_size = 8  # Adjust based on your hardware
train_dataset = train_dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
valid_dataset = valid_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

In [None]:
history = bert_classifier.fit(
    train_dataset,
    validation_data=valid_dataset,
    epochs=2,
    verbose=1
)

print("Training history:", history.history)

In [None]:
def compute_c_hp(texts, class_indices, bert_classifier, tokenizer, batch_size=8):
    bert_model = bert_classifier.bert
    embeddings_layer = bert_model.get_input_embeddings()

    c_hp_features = []
    num_texts = len(texts)

    # Process texts in batches
    for i in tqdm(range(0, num_texts, batch_size), desc="Computing C-HP features"):
        batch_texts = texts[i:i + batch_size]
        batch_indices = class_indices[i:i + batch_size]
        
        # Tokenize the batch
        inputs = tokenizer(
            batch_texts,
            return_tensors='tf',
            padding=True,
            truncation=True,
            max_length=256  # Reduced to match training
        )
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        # Convert input_ids to embeddings
        embeddings = embeddings_layer(input_ids)
        embeddings = tf.cast(embeddings, tf.float32)

        # Use GradientTape to compute gradients
        with tf.GradientTape() as tape:
            tape.watch(embeddings)
            # Forward pass with inputs_embeds
            outputs = bert_model({
                "inputs_embeds": embeddings,
                "attention_mask": attention_mask
            })
            cls_output = outputs.last_hidden_state[:, 0, :]  # [CLS] token
            preds = bert_classifier.classifier(cls_output)
            # Compute loss for the specified class_idx for the entire batch
            batch_indices_tensor = tf.constant(batch_indices, dtype=tf.int32)
            batch_indices_range = tf.range(tf.shape(preds)[0])
            selected_preds = tf.gather_nd(preds, tf.stack([batch_indices_range, batch_indices_tensor], axis=1))
            loss = tf.reduce_mean(selected_preds)

        # Compute gradients with respect to [CLS] output
        grads = tape.gradient(loss, cls_output)
        if grads is None:
            raise ValueError("Gradients is None. Check if the model is trainable or if cls_output is properly watched.")

        # Compute C-HP: Element-wise multiplication of gradients and [CLS] embedding
        c_hp = grads * cls_output
        c_hp_flat = tf.reshape(c_hp, [tf.shape(c_hp)[0], -1])
        c_hp_features.append(c_hp_flat.numpy())

    # Stack all features into a single array
    return np.vstack(c_hp_features)

def compute_raw_features(texts, class_indices, bert_classifier, tokenizer, batch_size=8):
    bert_model = bert_classifier.bert
    embeddings_layer = bert_model.get_input_embeddings()

    c_hp_features = []
    num_texts = len(texts)

    # Process texts in batches
    for i in tqdm(range(0, num_texts, batch_size), desc="Computing C-HP features"):
        batch_texts = texts[i:i + batch_size]
        batch_indices = class_indices[i:i + batch_size]
        
        # Tokenize the batch
        inputs = tokenizer(
            batch_texts,
            return_tensors='tf',
            padding=True,
            truncation=True,
            max_length=256  # Reduced to match training
        )
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        # Convert input_ids to embeddings
        embeddings = embeddings_layer(input_ids)
        embeddings = tf.cast(embeddings, tf.float32)

        # Use GradientTape to compute gradients
        with tf.GradientTape() as tape:
            tape.watch(embeddings)
            # Forward pass with inputs_embeds
            outputs = bert_model({
                "inputs_embeds": embeddings,
                "attention_mask": attention_mask
            })
            cls_output = outputs.last_hidden_state[:, 0, :]  # [CLS] token
            
        c_hp = cls_output
        c_hp_flat = tf.reshape(c_hp, [tf.shape(c_hp)[0], -1])
        c_hp_features.append(c_hp_flat.numpy())

    # Stack all features into a single array
    return np.vstack(c_hp_features)

def predict_with_bert(bert_classifier, texts, tokenizer, batch_size=8):
    predictions = []
    num_texts = len(texts)

    # Process texts in batches
    for i in tqdm(range(0, num_texts, batch_size), desc="Predicting with BERT"):
        batch_texts = texts[i:i + batch_size]

        # Tokenize the batch
        inputs = tokenizer(
            batch_texts,
            return_tensors='tf',
            padding=True,
            truncation=True,
            max_length=256
        )

        # Predict for current batch
        batch_predictions = bert_classifier(dict(inputs))
        batch_labels = np.argmax(batch_predictions, axis=1)
        predictions.append(batch_labels)

    # Concatenate all predictions into a single array
    return np.concatenate(predictions)

In [None]:
from sklearn.metrics import accuracy_score

y_test_pred = predict_with_bert(bert_classifier, X_test_texts, tokenizer)

print('acc', accuracy_score(test_labels, y_test_pred))

In [None]:
y_train_pred = predict_with_bert(bert_classifier, X_train_texts, tokenizer)
c_hp_features_train = compute_c_hp(
    texts=X_train_texts,
    class_indices=y_train_pred,
    bert_classifier=bert_classifier,
    tokenizer=tokenizer,
    batch_size=8
)

c_hp_features_test = compute_c_hp(
    texts=X_test_texts,
    class_indices=y_test_pred,
    bert_classifier=bert_classifier,
    tokenizer=tokenizer,
    batch_size=8
)

In [None]:
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(c_hp_features_train, y_train_pred)
y_knn_pred = knn.predict(c_hp_features_test)

print('acc', accuracy_score(test_labels, y_knn_pred))
print('agreement:', accuracy_score(y_test_pred, y_knn_pred))

In [None]:
id = 0
sample_text = X_test_texts[id]
sample_classIdx = y_test_pred[id]
true_label = test_labels[id].numpy()

# sample_classIdx, true_label

sample_chp = compute_c_hp(
    texts=[sample_text],
    class_indices=[sample_classIdx],
    bert_classifier=bert_classifier,
    tokenizer=tokenizer,
    batch_size=8
)

In [None]:
distances, indices = knn.kneighbors(sample_chp)

print(f"\nTest sample: {sample_text}")
print(f"Prediction: {knn.predict(sample_chp)[0]}")
print("Similar cases (indices):", indices[0])
print("Distances:", distances[0])
print("\nSimilar texts:")
for idx in indices[0]:
    print(f"- {X_train_texts[idx]} (Label: {train_labels[idx]})")

In [None]:
raw_features_train = compute_raw_features(
    texts=X_train_texts,
    class_indices=y_train_pred,
    bert_classifier=bert_classifier,
    tokenizer=tokenizer,
    batch_size=8
)

raw_features_test = compute_raw_features(
    texts=X_test_texts,
    class_indices=y_test_pred,
    bert_classifier=bert_classifier,
    tokenizer=tokenizer,
    batch_size=8
)

In [None]:
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(raw_features_train, y_train_pred)
y_knn_pred = knn.predict(raw_features_test)

print('acc', accuracy_score(test_labels, y_knn_pred))
print('agreement:', accuracy_score(y_test_pred, y_knn_pred))