This uses the kaggle_data.csv file and a combined reddit_full.csv file. Ran on Google Colab T4 GPU for efficiency purposes.

Implemented by Amanda Pignataro.

Resources:

*   https://www.geeksforgeeks.org/nlp/distilbert-in-natural-language-processing/
*   https://medium.com/thenextlayer/building-a-text-classification-model-using-distilbert-703c1409696c
*   https://lewtun.github.io/transformerlab/experiments.distilbert.html
*   For debugging/Second part issues with TF and Keras: ChatGPT





**FIRST PART:** DistilBERT Implementation with 16-class Model

In [None]:
!pip install "transformers==4.44.2"



In [13]:
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"

In [None]:
##This will handle imports, GPU, label setup, load CSVs

import ast
import numpy as np
import pandas as pd
import tensorflow as tf

from sklearn.metrics import accuracy_score, f1_score, classification_report
from transformers import DistilBertTokenizerFast, TFDistilBertModel

##This will set up TF and GPU
print("TF version:", tf.__version__)

gpus = tf.config.list_physical_devices("GPU")
for g in gpus:
    try:
        tf.config.experimental.set_memory_growth(g, True)
    except Exception:
        pass
print("GPUs:", gpus)

##This will bring our files in
KAGGLE_CSV = "kaggle_data.csv"
REDDIT_CSV = "reddit_full.csv"

##These are our MBTI labels
MBTI16 = [
    "ISTJ","ISFJ","INFJ","INTJ",
    "ISTP","ISFP","INFP","INTP",
    "ENTJ","ENTP","ENFJ","ENFP",
    "ESTJ","ESFJ","ESTP","ESFP"
]
lab2id = {l:i for i,l in enumerate(MBTI16)}

##This will help clean text and load CSV files
def liststr_to_str(x):
    """
    This will determine if x looks like "['post1','post2']", turn into a single string.
    Otherwise just str(x).
    """
    if isinstance(x, str) and x.startswith('['):
        try:
            toks = ast.literal_eval(x)
            if isinstance(toks, list):
                return " ".join(map(str, toks))
        except Exception:
            pass
    return str(x)

def load_df(path, text_col_guess=("posts","body","text"),
            label_col_guess=("type","class","label")):
    df = pd.read_csv(path)

    text_col  = next((c for c in text_col_guess  if c in df.columns), None)
    label_col = next((c for c in label_col_guess if c in df.columns), None)
    assert text_col and label_col, (
        f"Could not find text/label in {path}. "
        f"Columns: {df.columns.tolist()}"
    )

    df = df[[text_col, label_col]].rename(columns={
        text_col: "text",
        label_col: "label"
    })
    df["text"]  = df["text"].map(liststr_to_str)
    df["label"] = df["label"].astype(str)

    df = df[df["label"].isin(MBTI16)].copy()
    df["y"] = df["label"].map(lab2id).astype("int32")
    return df

print("Loading Kaggle + Reddit CSVs...")
df_k = load_df(KAGGLE_CSV, text_col_guess=("posts","body"), label_col_guess=("type","class"))
df_r = load_df(REDDIT_CSV, text_col_guess=("body","posts"), label_col_guess=("class","type"))

##This will make smaller subsets for DistilBERT

##This will set how many examples per domain to use
N_K_BERT = min(len(df_k), 50000)
N_R_BERT = min(len(df_r), 100000)

##This will sample without replacement for robustness
df_k_bert = df_k.sample(N_K_BERT, random_state=42).reset_index(drop=True)
df_r_bert = df_r.sample(N_R_BERT, random_state=42).reset_index(drop=True)

print("For BERT we will use:")
print("  Kaggle:", len(df_k_bert), "rows")
print("  Reddit:", len(df_r_bert), "rows")

yk_bert = df_k_bert["y"].to_numpy().astype("int32")
yr_bert = df_r_bert["y"].to_numpy().astype("int32")


TF version: 2.19.0
GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Loading Kaggle + Reddit CSVs...
For BERT we will use:
  Kaggle: 50000 rows
  Reddit: 100000 rows


In [None]:
##This will set the DistilBERT tokenizer + encode texts on subsets

from transformers import DistilBertTokenizerFast

MODEL_NAME = "distilbert-base-uncased"
MAX_LEN    = 128

print("Loading DistilBERT tokenizer:", MODEL_NAME)
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)

def encode_texts(texts, max_len=MAX_LEN):
    """
    This will encode a Series/list of texts into input_ids and attention_mask arrays.
    """
    enc = tokenizer(
        list(texts.astype(str)),
        truncation=True,
        padding="max_length",
        max_length=max_len,
        return_tensors="np"
    )
    return enc["input_ids"], enc["attention_mask"]

print("Encoding Kaggle texts (subset)...")
Xk_ids, Xk_mask = encode_texts(df_k_bert["text"])

print("Encoding Reddit texts (subset)...")
Xr_ids, Xr_mask = encode_texts(df_r_bert["text"])

print("Kaggle encodings:", Xk_ids.shape, Xk_mask.shape)
print("Reddit encodings:", Xr_ids.shape, Xr_mask.shape)


Loading DistilBERT tokenizer: distilbert-base-uncased


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Encoding Kaggle texts (subset)...
Encoding Reddit texts (subset)...
Kaggle encodings: (50000, 128) (50000, 128)
Reddit encodings: (100000, 128) (100000, 128)


In [None]:
##This will Train/val splits + tf.data datasets

AUTOTUNE = tf.data.AUTOTUNE
BATCH    = 16

def make_bert_split(input_ids, attention_mask, y, frac=0.9, shuffle=True):
    n = len(y)
    k = int(n * frac)

    ids_tr, ids_va   = input_ids[:k],      input_ids[k:]
    mask_tr, mask_va = attention_mask[:k], attention_mask[k:]
    y_tr, y_va       = y[:k],              y[k:]

    if shuffle:
        idx = np.random.permutation(len(y_tr))
        ids_tr  = ids_tr[idx]
        mask_tr = mask_tr[idx]
        y_tr    = y_tr[idx]

    ds_tr = (
        tf.data.Dataset
          .from_tensor_slices(
              ({"input_ids": ids_tr, "attention_mask": mask_tr}, y_tr)
          )
          .shuffle(10000)
          .batch(BATCH)
          .prefetch(AUTOTUNE)
    )

    ds_va = (
        tf.data.Dataset
          .from_tensor_slices(
              ({"input_ids": ids_va, "attention_mask": mask_va}, y_va)
          )
          .batch(BATCH)
          .prefetch(AUTOTUNE)
    )

    return ds_tr, ds_va, (ids_va, mask_va, y_va)

print("Building train/val splits...")
ds_k_tr_bert, ds_k_va_bert, (Xk_ids_va, Xk_mask_va, yk_va) = make_bert_split(Xk_ids, Xk_mask, yk_bert)
ds_r_tr_bert, ds_r_va_bert, (Xr_ids_va, Xr_mask_va, yr_va) = make_bert_split(Xr_ids, Xr_mask, yr_bert)


Building train/val splits...


In [None]:
##This will establish the DistilBERT classifier model

def make_distilbert_classifier(num_labels=16, lr=3e-5):
    bert = TFDistilBertModel.from_pretrained(MODEL_NAME)

    input_ids      = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids")
    attention_mask = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask")

    outputs = bert(input_ids, attention_mask=attention_mask)
    cls_tok = outputs.last_hidden_state[:, 0, :]

    h = tf.keras.layers.Dropout(0.2)(cls_tok)
    logits = tf.keras.layers.Dense(num_labels, activation="softmax", name="type")(h)

    model = tf.keras.Model(
        inputs={"input_ids": input_ids, "attention_mask": attention_mask},
        outputs=logits
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )
    return model

cb_bert = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=1,
        restore_best_weights=True
    )
]

In [None]:
##This will train DistilBERT on Kaggle + Reddit

print("\n=== Train DistilBERT on Kaggle ===")
m_k_bert = make_distilbert_classifier(num_labels=16, lr=3e-5)
hist_k_bert = m_k_bert.fit(
    ds_k_tr_bert,
    validation_data=ds_k_va_bert,
    epochs=2,
    callbacks=cb_bert,
    verbose=1
)

print("\n=== Train DistilBERT on Reddit ===")
m_r_bert = make_distilbert_classifier(num_labels=16, lr=3e-5)
hist_r_bert = m_r_bert.fit(
    ds_r_tr_bert,
    validation_data=ds_r_va_bert,
    epochs=2,
    callbacks=cb_bert,
    verbose=1
)


=== Train DistilBERT on Kaggle ===


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


Epoch 1/2
Epoch 2/2

=== Train DistilBERT on Reddit ===


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


Epoch 1/2
Epoch 2/2


In [None]:
##This will set up our evaluation helper function

def eval_bert_model(model, ids, mask, y, tag, batch_size=64):
    ds = (
        tf.data.Dataset
          .from_tensor_slices({"input_ids": ids, "attention_mask": mask})
          .batch(batch_size)
    )
    yhat = model.predict(ds, verbose=0)
    pred = np.argmax(yhat, axis=1)

    acc = accuracy_score(y, pred)
    f1  = f1_score(y, pred, average="macro")

    print(f"\n--- {tag} ---")
    print("Accuracy:", acc)
    print("Macro-F1:", f1)
    print(classification_report(y, pred, target_names=MBTI16, digits=3))

    return acc, f1


In [None]:
##This will show the four experiments and ΔF1 summary

##1) Kaggle → Kaggle (in-domain)
acc_k_in_bert, f1_k_in_bert = eval_bert_model(
    m_k_bert,
    Xk_ids_va, Xk_mask_va, yk_va,
    "DistilBERT Kaggle → Kaggle (val)"
)

##2) Reddit → Reddit (in-domain)
acc_r_in_bert, f1_r_in_bert = eval_bert_model(
    m_r_bert,
    Xr_ids_va, Xr_mask_va, yr_va,
    "DistilBERT Reddit → Reddit (val)"
)

##3) Kaggle → Reddit (cross)
acc_k2r_bert, f1_k2r_bert = eval_bert_model(
    m_k_bert,
    Xr_ids_va, Xr_mask_va, yr_va,
    "DistilBERT Kaggle → Reddit (cross)"
)

##4) Reddit → Kaggle (cross)
acc_r2k_bert, f1_r2k_bert = eval_bert_model(
    m_r_bert,
    Xk_ids_va, Xk_mask_va, yk_va,
    "DistilBERT Reddit → Kaggle (cross)"
)

##ΔF1 summary
print("\n=== DistilBERT cross-domain drop (ΔF1 = within − cross) ===")
print(f"Kaggle-trained ΔF1: {f1_k_in_bert - f1_k2r_bert:.4f}")
print(f"Reddit-trained ΔF1: {f1_r_in_bert - f1_r2k_bert:.4f}")



--- DistilBERT Kaggle → Kaggle (val) ---
Accuracy: 0.2234
Macro-F1: 0.05316244062945563
              precision    recall  f1-score   support

        ISTJ      0.000     0.000     0.000       136
        ISFJ      0.000     0.000     0.000        91
        INFJ      0.227     0.393     0.288       876
        INTJ      0.196     0.016     0.029       637
        ISTP      0.000     0.000     0.000       196
        ISFP      0.000     0.000     0.000       173
        INFP      0.216     0.629     0.322      1006
        INTP      0.258     0.172     0.207       749
        ENTJ      0.000     0.000     0.000       139
        ENTP      0.000     0.000     0.000       396
        ENFJ      0.000     0.000     0.000       119
        ENFP      0.143     0.003     0.005       374
        ESTJ      0.000     0.000     0.000        27
        ESFJ      0.000     0.000     0.000        13
        ESTP      0.000     0.000     0.000        46
        ESFP      0.000     0.000     0.000   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



--- DistilBERT Reddit → Reddit (val) ---
Accuracy: 0.2959
Macro-F1: 0.07838966341325508
              precision    recall  f1-score   support

        ISTJ      0.000     0.000     0.000        98
        ISFJ      0.000     0.000     0.000        29
        INFJ      0.266     0.165     0.204      1218
        INTJ      0.267     0.285     0.276      2132
        ISTP      0.000     0.000     0.000       303
        ISFP      0.000     0.000     0.000        70
        INFP      0.245     0.099     0.141      1044
        INTP      0.325     0.685     0.441      2793
        ENTJ      0.000     0.000     0.000       271
        ENTP      0.208     0.097     0.132      1158
        ENFJ      0.000     0.000     0.000       120
        ENFP      0.182     0.037     0.061       596
        ESTJ      0.000     0.000     0.000        28
        ESFJ      0.000     0.000     0.000        25
        ESTP      0.000     0.000     0.000        79
        ESFP      0.000     0.000     0.000   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



--- DistilBERT Kaggle → Reddit (cross) ---
Accuracy: 0.1642
Macro-F1: 0.04295950338946858
              precision    recall  f1-score   support

        ISTJ      0.000     0.000     0.000        98
        ISFJ      0.000     0.000     0.000        29
        INFJ      0.178     0.385     0.243      1218
        INTJ      0.220     0.008     0.016      2132
        ISTP      0.000     0.000     0.000       303
        ISFP      0.000     0.000     0.000        70
        INFP      0.112     0.640     0.191      1044
        INTP      0.371     0.174     0.237      2793
        ENTJ      0.000     0.000     0.000       271
        ENTP      0.000     0.000     0.000      1158
        ENFJ      0.000     0.000     0.000       120
        ENFP      0.000     0.000     0.000       596
        ESTJ      0.000     0.000     0.000        28
        ESFJ      0.000     0.000     0.000        25
        ESTP      0.000     0.000     0.000        79
        ESFP      0.000     0.000     0.000 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



--- DistilBERT Reddit → Kaggle (cross) ---
Accuracy: 0.1744
Macro-F1: 0.05309070414596329
              precision    recall  f1-score   support

        ISTJ      0.000     0.000     0.000       136
        ISFJ      0.000     0.000     0.000        91
        INFJ      0.257     0.089     0.132       876
        INTJ      0.135     0.257     0.177       637
        ISTP      0.000     0.000     0.000       196
        ISFP      0.000     0.000     0.000       173
        INFP      0.321     0.087     0.138      1006
        INTP      0.176     0.677     0.279       749
        ENTJ      0.000     0.000     0.000       139
        ENTP      0.091     0.061     0.073       396
        ENFJ      0.000     0.000     0.000       119
        ENFP      0.180     0.029     0.051       374
        ESTJ      0.000     0.000     0.000        27
        ESFJ      0.000     0.000     0.000        13
        ESTP      0.000     0.000     0.000        46
        ESFP      0.000     0.000     0.000 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


**SECOND PART:** Re-train DistilBERT with 4 binary axes (IE, SN, TF, JP)


In [21]:
##This will handle imports, GPU, label setup, load CSVs

import ast
import numpy as np
import pandas as pd
import tensorflow as tf

from sklearn.metrics import accuracy_score, f1_score
from transformers import DistilBertTokenizerFast, TFDistilBertModel

##similar comments above
print("TF version:", tf.__version__)

gpus = tf.config.list_physical_devices("GPU")
for g in gpus:
    try:
        tf.config.experimental.set_memory_growth(g, True)
    except Exception:
        pass
print("GPUs:", gpus)

##This will load our datasets
KAGGLE_CSV = "kaggle_data.csv"
REDDIT_CSV = "reddit_full.csv"

##This will setup MBTI 16 labels (for filtering)
MBTI16 = [
    "ISTJ","ISFJ","INFJ","INTJ",
    "ISTP","ISFP","INFP","INTP",
    "ENTJ","ENTP","ENFJ","ENFP",
    "ESTJ","ESFJ","ESTP","ESFP"
]

AXES = ["IE", "NS", "TF", "JP"]

##This will help to clean text and derive 4 axes
def liststr_to_str(x):
    """
    This will determine if x looks like "['post1','post2']", turn into a single string.
    Otherwise just str(x).
    """
    if isinstance(x, str) and x.startswith('['):
        try:
            toks = ast.literal_eval(x)
            if isinstance(toks, list):
                return " ".join(map(str, toks))
        except Exception:
            pass
    return str(x)

def mbti_to_axes_vec(label):
    """
    This will convert a 4-letter MBTI type (e.g., 'INFJ') into a 4-dim binary vector:

      [IE, NS, TF, JP]

    Convention (0 = first letter, 1 = second letter):

      - IE: 0 = I, 1 = E
      - NS: 0 = N, 1 = S
      - TF: 0 = T, 1 = F
      - JP: 0 = J, 1 = P
    """
    label = str(label).upper()
    assert label in MBTI16, f"Unknown MBTI type: {label}"

    v_IE = 0 if label[0] == "I" else 1
    v_NS = 0 if label[1] == "N" else 1
    v_TF = 0 if label[2] == "T" else 1
    v_JP = 0 if label[3] == "J" else 1
    return np.array([v_IE, v_NS, v_TF, v_JP], dtype="int32")

def load_df(path, text_col_guess=("posts","body","text"),
            label_col_guess=("type","class","label")):
    """
    This will load CSV, find text/label columns, clean the text, keep only valid MBTI16,
    and create 4 axis columns (ax_IE, ax_NS, ax_TF, ax_JP).
    """
    df = pd.read_csv(path)

    text_col  = next((c for c in text_col_guess  if c in df.columns), None)
    label_col = next((c for c in label_col_guess if c in df.columns), None)
    assert text_col and label_col, (
        f"Could not find text/label in {path}. "
        f"Columns: {df.columns.tolist()}"
    )

    df = df[[text_col, label_col]].rename(columns={
        text_col: "text",
        label_col: "label"
    })
    df["text"]  = df["text"].map(liststr_to_str)
    df["label"] = df["label"].astype(str).str.upper()

    ##Thi swill keep only the 16 standard MBTI types
    df = df[df["label"].isin(MBTI16)].copy()

    ##This will compute the 4 axes
    axes_array = np.vstack(df["label"].map(mbti_to_axes_vec).to_numpy())
    df["ax_IE"] = axes_array[:, 0]
    df["ax_NS"] = axes_array[:, 1]
    df["ax_TF"] = axes_array[:, 2]
    df["ax_JP"] = axes_array[:, 3]

    return df

print("Loading Kaggle + Reddit CSVs...")
df_k = load_df(KAGGLE_CSV, text_col_guess=("posts","body"), label_col_guess=("type","class"))
df_r = load_df(REDDIT_CSV, text_col_guess=("body","posts"), label_col_guess=("class","type"))

print("Kaggle shape:", df_k.shape)
print("Reddit shape:", df_r.shape)

##This will make smaller subsets for DistilBERT

##This will set how many examples per domain to use
N_K_BERT = min(len(df_k), 50000)
N_R_BERT = min(len(df_r), 100000)

df_k_bert = df_k.sample(N_K_BERT, random_state=42).reset_index(drop=True)
df_r_bert = df_r.sample(N_R_BERT, random_state=42).reset_index(drop=True)

print("For BERT (axes) we will use:")
print("  Kaggle:", len(df_k_bert), "rows")
print("  Reddit:", len(df_r_bert), "rows")


yk_axes_bert = df_k_bert[["ax_IE","ax_NS","ax_TF","ax_JP"]].to_numpy().astype("float32")
yr_axes_bert = df_r_bert[["ax_IE","ax_NS","ax_TF","ax_JP"]].to_numpy().astype("float32")

TF version: 2.19.0
GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Loading Kaggle + Reddit CSVs...
Kaggle shape: (410915, 6)
Reddit shape: (1651100, 6)
For BERT (axes) we will use:
  Kaggle: 50000 rows
  Reddit: 100000 rows


In [22]:
##This will set the DistilBERT tokenizer + encode texts on subsets

MODEL_NAME = "distilbert-base-uncased"
MAX_LEN  = 128

print("Loading DistilBERT tokenizer:", MODEL_NAME)
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)

def encode_texts(texts, max_len=MAX_LEN):
    """
    This will encode a Series/list of texts into input_ids and attention_mask arrays.
    """
    enc = tokenizer(
        list(texts.astype(str)),
        truncation=True,
        padding="max_length",
        max_length=max_len,
        return_tensors="np"
    )
    return enc["input_ids"], enc["attention_mask"]

print("Encoding Kaggle texts (subset)...")
Xk_ids, Xk_mask = encode_texts(df_k_bert["text"])

print("Encoding Reddit texts (subset)...")
Xr_ids, Xr_mask = encode_texts(df_r_bert["text"])

print("Kaggle encodings:", Xk_ids.shape, Xk_mask.shape)
print("Reddit encodings:", Xr_ids.shape, Xr_mask.shape)


Loading DistilBERT tokenizer: distilbert-base-uncased
Encoding Kaggle texts (subset)...




Encoding Reddit texts (subset)...
Kaggle encodings: (50000, 128) (50000, 128)
Reddit encodings: (100000, 128) (100000, 128)


In [23]:
##This will train/val splits + tf.data datasets (axes)

AUTOTUNE = tf.data.AUTOTUNE
BATCH    = 16

def make_bert_split(input_ids, attention_mask, y_axes, frac=0.9, shuffle=True):
    """
    This will split into train/val and build tf.data datasets.

    y_axes is shape (N, 4) with binary labels for [IE, SN, TF, JP].
    """
    n = len(y_axes)
    k = int(n * frac)

    ids_tr, ids_va   = input_ids[:k],         input_ids[k:]
    mask_tr, mask_va = attention_mask[:k],    attention_mask[k:]
    y_tr, y_va       = y_axes[:k],            y_axes[k:]

    if shuffle:
        idx = np.random.permutation(len(y_tr))
        ids_tr  = ids_tr[idx]
        mask_tr = mask_tr[idx]
        y_tr    = y_tr[idx]

    ds_tr = (
        tf.data.Dataset
          .from_tensor_slices(
              ({"input_ids": ids_tr, "attention_mask": mask_tr}, y_tr)
          )
          .shuffle(10000)
          .batch(BATCH)
          .prefetch(AUTOTUNE)
    )

    ds_va = (
        tf.data.Dataset
          .from_tensor_slices(
              ({"input_ids": ids_va, "attention_mask": mask_va}, y_va)
          )
          .batch(BATCH)
          .prefetch(AUTOTUNE)
    )

    return ds_tr, ds_va, (ids_va, mask_va, y_va)

print("Building train/val splits for axes...")
ds_k_tr_axes, ds_k_va_axes, (Xk_ids_va, Xk_mask_va, yk_axes_va) = make_bert_split(
    Xk_ids, Xk_mask, yk_axes_bert
)
ds_r_tr_axes, ds_r_va_axes, (Xr_ids_va, Xr_mask_va, yr_axes_va) = make_bert_split(
    Xr_ids, Xr_mask, yr_axes_bert
)

Building train/val splits for axes...


In [32]:
##This will create the DistilBERT classifier model (4-axis, multi-label)

def make_distilbert_axes_classifier(num_axes=4, lr=3e-5):
    """
    This will be a DistilBERT multi-label classifier for 4 axes.
    Outputs shape (batch, num_axes) with sigmoid activation.

    Bypassing the transformers input_processing wrapper here
    by calling the underlying DistilBERT encoder (bert_model.distilbert)
    inside a Keras Lambda layer.
    """
    bert_model = TFDistilBertModel.from_pretrained(MODEL_NAME)

    ##This will get hidden size from config (DistilBERT uses `dim`)
    hidden_size = getattr(bert_model.config, "dim", None)
    if hidden_size is None:
        hidden_size = getattr(bert_model.config, "hidden_size")

    ##This will use keras functional inputs
    input_ids      = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids")
    attention_mask = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask")

    ##This will wrap the DistilBERT encoder in a Lambda to make sure it works with Keras
    def distilbert_encoder(inputs):
        ids, mask = inputs

        ##This will call the encoder directly
        outputs = bert_model.distilbert(
            input_ids=ids,
            attention_mask=mask,
        )


        if isinstance(outputs, tuple):
            hidden_states = outputs[0]
        else:
            hidden_states = outputs.last_hidden_state
        return hidden_states


    ##This will use an explicit output_shape for Lambda b/c of Keras
    sequence_output = tf.keras.layers.Lambda(
        distilbert_encoder,
        name="distilbert_encoder",
        output_shape=(MAX_LEN, hidden_size),
    )([input_ids, attention_mask])


    cls_tok = sequence_output[:, 0, :]

    h = tf.keras.layers.Dropout(0.2)(cls_tok)
    logits = tf.keras.layers.Dense(num_axes, activation="sigmoid", name="axes")(h)

    model = tf.keras.Model(
        inputs={"input_ids": input_ids, "attention_mask": attention_mask},
        outputs=logits
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss="binary_crossentropy",   # multi-label loss
        metrics=[tf.keras.metrics.BinaryAccuracy(name="binary_accuracy")]
    )
    return model


In [33]:
##This will train DistilBERT axis models on Kaggle + Reddit

print("\n=== Train DistilBERT (4 axes) on Kaggle ===")
m_k_axes = make_distilbert_axes_classifier(num_axes=4, lr=3e-5)
hist_k_axes = m_k_axes.fit(
    ds_k_tr_axes,
    validation_data=ds_k_va_axes,
    epochs=2,
    callbacks=cb_axes,
    verbose=1
)

print("\n=== Train DistilBERT (4 axes) on Reddit ===")
m_r_axes = make_distilbert_axes_classifier(num_axes=4, lr=3e-5)
hist_r_axes = m_r_axes.fit(
    ds_r_tr_axes,
    validation_data=ds_r_va_axes,
    epochs=2,
    callbacks=cb_axes,
    verbose=1
)



=== Train DistilBERT (4 axes) on Kaggle ===


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


Epoch 1/2
[1m2813/2813[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m190s[0m 64ms/step - binary_accuracy: 0.6284 - loss: 0.6287 - val_binary_accuracy: 0.6953 - val_loss: 0.5762
Epoch 2/2
[1m2813/2813[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m173s[0m 61ms/step - binary_accuracy: 0.6923 - loss: 0.5792 - val_binary_accuracy: 0.6985 - val_loss: 0.5739

=== Train DistilBERT (4 axes) on Reddit ===


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


Epoch 1/2
[1m5625/5625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m359s[0m 62ms/step - binary_accuracy: 0.7109 - loss: 0.5543 - val_binary_accuracy: 0.7487 - val_loss: 0.5171
Epoch 2/2
[1m5625/5625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m346s[0m 61ms/step - binary_accuracy: 0.7455 - loss: 0.5212 - val_binary_accuracy: 0.7496 - val_loss: 0.5142


In [41]:
##This will print evaluation for 4 axes (IE, SN, TF, JP)

from sklearn.metrics import accuracy_score, f1_score

AXIS_DISPLAY_NAMES = ["IE", "SN", "TF", "JP"]

def eval_bert_axes_model(model, ids, mask, y_axes, run_name,
                         batch_size=64, thresh=0.5):
    """
    This will print per-axis metrics in the style:

        --- Kaggle → Kaggle (val): Per-axis metrics ---
        IE: acc=0.600  macroF1=0.531
        SN: acc=0.638  macroF1=0.505
        TF: acc=0.599  macroF1=0.597
        JP: acc=0.541  macroF1=0.536

    Returns:
        macro_f1_over_axes, axis_metrics_dict
    """

    ds = (
        tf.data.Dataset
          .from_tensor_slices({"input_ids": ids, "attention_mask": mask})
          .batch(batch_size)
    )


    yhat = model.predict(ds, verbose=0)  # (N,4)
    y_pred_bin = (yhat >= thresh).astype("int32")

    axis_metrics = {}
    macro_f1_per_axis = []

    print(f"\n--- {run_name}: Per-axis metrics ---")

    for j, axis_name in enumerate(AXIS_DISPLAY_NAMES):
        y_true_axis = y_axes[:, j]
        y_pred_axis = y_pred_bin[:, j]

        acc_axis      = accuracy_score(y_true_axis, y_pred_axis)
        f1_macro_axis = f1_score(y_true_axis, y_pred_axis, average="macro")

        axis_metrics[axis_name] = {
            "accuracy": acc_axis,
            "macroF1": f1_macro_axis,
        }
        macro_f1_per_axis.append(f1_macro_axis)

        print(f"{axis_name}: acc={acc_axis:0.3f}  macroF1={f1_macro_axis:0.3f}")

    macro_f1_axes = float(np.mean(macro_f1_per_axis))
    return macro_f1_axes, axis_metrics

In [42]:
##This will executre the four DistilBERT axis experiments and ΔF1 summary

##1) Kaggle → Kaggle (in-domain)
f1_k_in_axes, metrics_k_in_axes = eval_bert_axes_model(
    m_k_axes,
    Xk_ids_va, Xk_mask_va, yk_axes_va,
    "Kaggle → Kaggle (val)"
)

##2) Reddit → Reddit (in-domain)
f1_r_in_axes, metrics_r_in_axes = eval_bert_axes_model(
    m_r_axes,
    Xr_ids_va, Xr_mask_va, yr_axes_va,
    "Reddit → Reddit (val)"
)

##3) Kaggle → Reddit (cross)
f1_k2r_axes, metrics_k2r_axes = eval_bert_axes_model(
    m_k_axes,
    Xr_ids_va, Xr_mask_va, yr_axes_va,
    "Kaggle → Reddit (cross)"
)

##4) Reddit → Kaggle (cross)
f1_r2k_axes, metrics_r2k_axes = eval_bert_axes_model(
    m_r_axes,
    Xk_ids_va, Xk_mask_va, yk_axes_va,
    "Reddit → Kaggle (cross)"
)

##ΔF1 summary
print("\n=== DistilBERT (axes) cross-domain drop (ΔF1 = in-domain − cross) ===")
print(f"Kaggle-trained ΔF1 (axes): {f1_k_in_axes - f1_k2r_axes:.4f}")
print(f"Reddit-trained ΔF1 (axes): {f1_r_in_axes - f1_r2k_axes:.4f}")



--- Kaggle → Kaggle (val): Per-axis metrics ---
IE: acc=0.773  macroF1=0.436
SN: acc=0.859  macroF1=0.462
TF: acc=0.572  macroF1=0.526
JP: acc=0.590  macroF1=0.373

--- Reddit → Reddit (val): Per-axis metrics ---
IE: acc=0.769  macroF1=0.435
SN: acc=0.933  macroF1=0.483
TF: acc=0.688  macroF1=0.417
JP: acc=0.608  macroF1=0.389

--- Kaggle → Reddit (cross): Per-axis metrics ---
IE: acc=0.769  macroF1=0.435
SN: acc=0.933  macroF1=0.483
TF: acc=0.436  macroF1=0.431
JP: acc=0.607  macroF1=0.385

--- Reddit → Kaggle (cross): Per-axis metrics ---
IE: acc=0.773  macroF1=0.436
SN: acc=0.859  macroF1=0.462
TF: acc=0.467  macroF1=0.323
JP: acc=0.593  macroF1=0.375

=== DistilBERT (axes) cross-domain drop (ΔF1 = in-domain − cross) ===
Kaggle-trained ΔF1 (axes): 0.0160
Reddit-trained ΔF1 (axes): 0.0320
