In [None]:
# ===================== COMPLETE DISTILBERT-BASED ATS EMAIL CLASSIFIER =====================
# Copyâ€“Paste Ready | Small Model | Resume Training | No Optimizer Bloat
# ========================================================================================

# ---------- STEP 0: IMPORT LIBRARIES ----------
import os
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

from transformers import (
    DistilBertTokenizerFast,
    TFDistilBertForSequenceClassification
)

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight

from tensorflow.keras.callbacks import (
    EarlyStopping,
    ReduceLROnPlateau,
    ModelCheckpoint
)

# ---------- STEP 1: LOAD & CLEAN DATA ----------
df = pd.read_csv("E:\\Talentprism\\data\\final_training_data_with_missing_classes.csv")
df.columns = ["label", "text"]

df = df.dropna(subset=["label", "text"])
df = df.drop_duplicates(subset=["text"])
df = df.sample(frac=1, random_state=42)

print("Original Class Distribution:")
print(df["label"].value_counts())

# ---------- STEP 2: BALANCE DATASET ----------
df_balanced = df.groupby("label").apply(
    lambda x: x.sample(min(len(x), 450), random_state=42)
).reset_index(drop=True)

print("\nBalanced Class Distribution:")
print(df_balanced["label"].value_counts())

# ---------- STEP 3: LABEL ENCODING ----------
label_encoder = LabelEncoder()
df_balanced["label_encoded"] = label_encoder.fit_transform(df_balanced["label"])

num_classes = df_balanced["label_encoded"].nunique()
print("\nClasses:", label_encoder.classes_)

# ---------- STEP 4: TRAIN / VAL / TEST SPLIT ----------
X_train, X_temp, y_train, y_temp = train_test_split(
    df_balanced["text"],
    df_balanced["label_encoded"],
    test_size=0.3,
    stratify=df_balanced["label_encoded"],
    random_state=42
)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp,
    y_temp,
    test_size=0.33,
    stratify=y_temp,
    random_state=42
)

print(f"\nTrain: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")

# ---------- STEP 5: LOAD DISTILBERT TOKENIZER ----------
tokenizer = DistilBertTokenizerFast.from_pretrained(
    "distilbert-base-uncased"
)

# ---------- STEP 6: TOKENIZATION ----------
def encode_texts(texts, tokenizer, max_len=256):
    return tokenizer(
        texts.tolist(),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="tf"
    )

train_enc = encode_texts(X_train, tokenizer)
val_enc   = encode_texts(X_val, tokenizer)
test_enc  = encode_texts(X_test, tokenizer)

# ---------- STEP 7: CLASS WEIGHTS ----------
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(y_train),
    y=y_train
)
class_weight_dict = dict(enumerate(class_weights))
print("\nClass Weights:", class_weight_dict)

# ---------- STEP 8: LOAD DISTILBERT MODEL ----------
model = TFDistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=num_classes,
    from_pt=True
)

# ---------- STEP 9: COMPILE MODEL ----------
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5)

model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

# ---------- STEP 10: CALLBACKS ----------
early_stop = EarlyStopping(
    monitor="val_loss",
    patience=2,
    restore_best_weights=True
)

reduce_lr = ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.3,
    patience=1,
    min_lr=1e-7,
    verbose=1
)

checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_cb = ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, "ckpt"),
    save_weights_only=True,     # IMPORTANT
    save_best_only=True,
    monitor="val_loss",
    mode="min",
    verbose=1
)

# ---------- STEP 11: RESUME FROM CHECKPOINT IF EXISTS ----------
latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
if latest_ckpt:
    print(f"Resuming from checkpoint: {latest_ckpt}")
    model.load_weights(latest_ckpt)

# ---------- STEP 12: TRAIN MODEL ----------
history = model.fit(
    dict(train_enc),
    y_train,
    validation_data=(dict(val_enc), y_val),
    epochs=5,
    batch_size=16,
    class_weight=class_weight_dict,
    callbacks=[checkpoint_cb, early_stop, reduce_lr],
    verbose=1
)

# ---------- STEP 13: EVALUATE ----------
test_preds = model.predict(dict(test_enc))
y_pred = np.argmax(test_preds.logits, axis=1)

print("\n--- Classification Report ---")
print(classification_report(
    y_test,
    y_pred,
    target_names=label_encoder.classes_,
    digits=4
))

# ---------- STEP 14: CONFUSION MATRIX ----------
cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=label_encoder.classes_,
    yticklabels=label_encoder.classes_
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()

# ---------- STEP 15: TRAINING CURVES ----------
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history["loss"], label="Train Loss")
plt.plot(history.history["val_loss"], label="Val Loss")
plt.legend()
plt.title("Loss")

plt.subplot(1, 2, 2)
plt.plot(history.history["accuracy"], label="Train Accuracy")
plt.plot(history.history["val_accuracy"], label="Val Accuracy")
plt.legend()
plt.title("Accuracy")

plt.show()

# ---------- STEP 16: CONFIDENCE-AWARE PREDICTION ----------
def predict_email(text, threshold=0.7):
    enc = tokenizer(
        text,
        return_tensors="tf",
        truncation=True,
        padding=True,
        max_length=256
    )

    logits = model(enc).logits
    probs = tf.nn.softmax(logits, axis=1)
    idx = tf.argmax(probs, axis=1).numpy()[0]
    confidence = float(probs[0][idx])

    result = {
        "predicted_class": label_encoder.classes_[idx],
        "confidence": confidence
    }

    if confidence < threshold:
        result["warning"] = "Low confidence prediction"

    return result

print(predict_email("Dear HR, I am applying for the Data Scientist position."))

# ---------- STEP 17: SAVE FINAL ARTIFACTS ----------
model.save_weights("ats_distilbert_weights")      # ~250 MB
tokenizer.save_pretrained("ats_tokenizer")        # <5 MB

# ===================== END OF COMPLETE CODE =====================
