In [2]:
import pandas as pd

df = pd.read_parquet("data/processed/train.parquet")
TEXT  = "comment_text"
LABELS = ["toxic","severe_toxic","obscene","threat","insult","identity_hate"]

In [4]:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
mskf = MultilabelStratifiedKFold(n_splits=10, shuffle=True, random_state=42)

train_idx, val_idx = next(mskf.split(df[TEXT], df[LABELS]))
X_train, X_val = df[TEXT].iloc[train_idx], df[TEXT].iloc[val_idx]
y_train, y_val = df[LABELS].iloc[train_idx], df[LABELS].iloc[val_idx]

  from scipy.sparse import issparse


In [5]:
from sklearn.feature_extraction.text import TfidfVectorizer

tfidf = TfidfVectorizer(
    strip_accents="unicode",
    lowercase=True,
    stop_words="english",
    max_features=200_000,      
    ngram_range=(1,2),        
    sublinear_tf=True
)
X_train_tfidf = tfidf.fit_transform(X_train)
X_val_tfidf   = tfidf.transform(X_val)


In [6]:
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier

clf = OneVsRestClassifier(
    LogisticRegression(
        max_iter=400,
        C=4,             
        class_weight="balanced",
        n_jobs=-1
    )
).fit(X_train_tfidf, y_train)


In [10]:
import joblib, os, json
os.makedirs("models", exist_ok=True)
joblib.dump({"tfidf": tfidf, "clf": clf},
            "models/tfidf_logreg.pkl")

['models/tfidf_logreg.pkl']

In [None]:
import pandas as pd
import numpy as np
import joblib
from sklearn.metrics import (
    confusion_matrix,
    precision_recall_fscore_support,
    roc_auc_score,
    accuracy_score,
    classification_report,
)

# --- Load model ---
model_bundle = joblib.load("models/tfidf_logreg.pkl")
tfidf = model_bundle["tfidf"]
clf = model_bundle["clf"]

# --- Load data ---
LABELS = ["toxic","severe_toxic","obscene","threat","insult","identity_hate"]
truth = pd.read_csv("test_samples_with_labels.csv")

# --- Predict ---
X = truth["comment_text"].astype(str).tolist()
y_true = truth[LABELS].values

# For samples with -1 labels (missing), we should ignore them in metrics below.
y_pred_bin = clf.predict(tfidf.transform(X))
y_pred_prob = clf.decision_function(tfidf.transform(X))

# --- Mask for missing labels ---
mask_matrix = (y_true != -1)

# --- Evaluation ---
def get_valid(label_idx):
    valid = mask_matrix[:, label_idx]
    return y_true[valid, label_idx], y_pred_bin[valid, label_idx], y_pred_prob[valid, label_idx]

print("="*40)
print("Per-label confusion matrices:")
for i, label in enumerate(LABELS):
    y_true_lbl, y_pred_lbl, _ = get_valid(i)
    cm = confusion_matrix(y_true_lbl, y_pred_lbl)
    print(f"\n{label}\n", pd.DataFrame(cm, index=["True 0", "True 1"], columns=["Pred 0", "Pred 1"]))

print("\nPer-label Precision, Recall, F1, Accuracy, ROC-AUC:")
stats = []
for i, label in enumerate(LABELS):
    y_true_lbl, y_pred_lbl, y_prob_lbl = get_valid(i)
    prec, rec, f1, _ = precision_recall_fscore_support(
        y_true_lbl, y_pred_lbl, average='binary', zero_division=0
    )
    acc = accuracy_score(y_true_lbl, y_pred_lbl)
    try:
        auc = roc_auc_score(y_true_lbl, y_prob_lbl)
    except Exception:
        auc = np.nan
    stats.append([label, acc, prec, rec, f1, auc])
df_stats = pd.DataFrame(stats, columns=["Label", "Accuracy", "Precision", "Recall", "F1", "ROC-AUC"]).set_index("Label")
print(df_stats.round(3))

print("\nMacro/Micro Precision, Recall, F1, ROC-AUC:")
flat_true, flat_pred, flat_prob = [], [], []
for i in range(len(LABELS)):
    y_true_lbl, y_pred_lbl, y_prob_lbl = get_valid(i)
    flat_true += list(y_true_lbl)
    flat_pred += list(y_pred_lbl)
    flat_prob += list(y_prob_lbl)

prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(flat_true, flat_pred, average="macro", zero_division=0)
prec_micro, rec_micro, f1_micro, _ = precision_recall_fscore_support(flat_true, flat_pred, average="micro", zero_division=0)
try:
    roc_auc_macro = roc_auc_score(flat_true, flat_prob, average="macro")
except Exception:
    roc_auc_macro = np.nan

print(f"Macro: Precision={prec_macro:.3f}, Recall={rec_macro:.3f}, F1={f1_macro:.3f}, ROC-AUC={roc_auc_macro:.3f}")
print(f"Micro: Precision={prec_micro:.3f}, Recall={rec_micro:.3f}, F1={f1_micro:.3f}")

report_dict = {}
for i, label in enumerate(LABELS):
    y_true_lbl, y_pred_lbl, _ = get_valid(i)
    report = classification_report(y_true_lbl, y_pred_lbl, output_dict=True, zero_division=0)
    report_dict[label] = report['1']
df_report = pd.DataFrame(report_dict).T[['precision', 'recall', 'f1-score', 'support']]
print("\nScikit-learn per-label classification report:")
print(df_report.round(3))

sample_acc = []
for i in range(y_true.shape[0]):
    mask = mask_matrix[i]
    if mask.sum() == 0:
        continue
    sample_acc.append((y_true[i][mask] == y_pred_bin[i][mask]).mean())
print(f"\nMean samplewise accuracy (fraction of correct labels per row): {np.mean(sample_acc):.3f}")

all_valid_mask = mask_matrix.all(axis=1)
if all_valid_mask.sum():
    subset_acc = np.mean([np.array_equal(y_true[i], y_pred_bin[i]) for i in range(len(y_true)) if all_valid_mask[i]])
    print(f"Subset accuracy (exact match on all labels, only full rows): {subset_acc:.3f}")

print("="*40)


Per-label confusion matrices:

toxic
         Pred 0  Pred 1
True 0     326      37
True 1       2      35

severe_toxic
         Pred 0  Pred 1
True 0     389       7
True 1       1       3

obscene
         Pred 0  Pred 1
True 0     361      16
True 1       4      19

threat
         Pred 0  Pred 1
True 0     395       3
True 1       0       2

insult
         Pred 0  Pred 1
True 0     361      17
True 1       6      16

identity_hate
         Pred 0  Pred 1
True 0     385      11
True 1       1       3

Per-label Precision, Recall, F1, Accuracy, ROC-AUC:
               Accuracy  Precision  Recall     F1  ROC-AUC
Label                                                     
toxic             0.902      0.486   0.946  0.642    0.975
severe_toxic      0.980      0.300   0.750  0.429    0.989
obscene           0.950      0.543   0.826  0.655    0.970
threat            0.992      0.400   1.000  0.571    1.000
insult            0.942      0.485   0.727  0.582    0.967
identity_hate     0.970