In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score
import shap

rng = np.random.default_rng(7)

# -------------------------
# 1) Generator danych: "prawdziwy" sygnał + token-amulet (spurious)
# -------------------------
pos_words = ["great", "amazing", "love", "excellent", "wonderful", "pleasant", "fantastic"]
neg_words = ["bad", "terrible", "hate", "awful", "horrible", "nasty", "poor"]
neutral = ["movie", "product", "service", "today", "really", "quite", "very", "actually", "just"]

SPUR = "zxq"   # token-amulet (nielogiczny, ale model go pokocha)

def make_sentence(label, spur_prob):
    # "prawdziwy" sygnał: trochę słów zgodnych z etykietą
    core = rng.choice(pos_words, size=rng.integers(2, 5), replace=True).tolist() if label == 1 \
           else rng.choice(neg_words, size=rng.integers(2, 5), replace=True).tolist()
    fluff = rng.choice(neutral, size=rng.integers(3, 7), replace=True).tolist()
    words = core + fluff
    rng.shuffle(words)

    # token-amulet (spurious correlate): pojawia się częściej dla jednej klasy
    if rng.random() < spur_prob:
        # wstawiamy go jak „pieczęć”
        insert_at = rng.integers(0, len(words)+1)
        words.insert(insert_at, SPUR)

    return " ".join(words)

def build_dataset(n, spur_for_pos=0.9, spur_for_neg=0.1):
    y = rng.integers(0, 2, size=n)  # 0 neg, 1 pos
    X = []
    for label in y:
        spur_prob = spur_for_pos if label == 1 else spur_for_neg
        X.append(make_sentence(int(label), spur_prob))
    return np.array(X), y

# Train env: token SPUR mocno skorelowany z pozytywem
X, y = build_dataset(4000, spur_for_pos=0.95, spur_for_neg=0.05)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.25, random_state=42)

# -------------------------
# 2) Model: prosto, brutalnie skutecznie (LogReg + TF-IDF)
# -------------------------
clf = make_pipeline(
    TfidfVectorizer(ngram_range=(1,2), min_df=2),
    LogisticRegression(max_iter=2000)
)

clf.fit(X_train, y_train)
pred_val = clf.predict(X_val)
print("VAL accuracy:", accuracy_score(y_val, pred_val))
print(classification_report(y_val, pred_val, digits=3))

# -------------------------
# 3) XAI: SHAP dla modelu liniowego w przestrzeni TF-IDF
# -------------------------
vectorizer = clf.named_steps["tfidfvectorizer"]
linear_model = clf.named_steps["logisticregression"]

X_val_vec = vectorizer.transform(X_val)

explainer = shap.LinearExplainer(linear_model, vectorizer.transform(X_train), feature_perturbation="interventional")
shap_vals = explainer(X_val_vec)

feature_names = np.array(vectorizer.get_feature_names_out())

def top_features_for_instance(i, k=12):
    vals = shap_vals.values[i]
    idx = np.argsort(np.abs(vals))[::-1][:k]
    return list(zip(feature_names[idx], vals[idx]))

# pokaż kilka przykładów, gdzie SPUR występuje i jest "wyjaśnieniem"
hits = [i for i, txt in enumerate(X_val) if SPUR in txt][:3]
for i in hits:
    print("\n--- Example text:", X_val[i])
    print("Pred:", int(clf.predict([X_val[i]])[0]), "Prob(pos):", float(clf.predict_proba([X_val[i]])[0,1]))
    print("Top SHAP features:")
    for f, v in top_features_for_instance(i, k=10):
        print(f"  {f:20s}  {v:+.4f}")

# -------------------------
# 4) Ekstremum: "świat się zmienia" (odwracamy korelację tokenu)
# -------------------------
# Test env: SPUR teraz częściej u negatywów (flip!)
X_test, y_test = build_dataset(1500, spur_for_pos=0.05, spur_for_neg=0.95)
pred_test = clf.predict(X_test)
print("\nTEST (flipped spurious) accuracy:", accuracy_score(y_test, pred_test))

# znajdź przypadki, gdzie model jest PEWNY i BŁĘDNY
probs = clf.predict_proba(X_test)[:,1]
wrong = np.where(pred_test != y_test)[0]
conf_wrong = wrong[np.argsort(np.abs(probs[wrong] - 0.5))[::-1]][:5]  # najbardziej pewne pomyłki

# SHAP dla testu
X_test_vec = vectorizer.transform(X_test)
shap_test = explainer(X_test_vec)

for i in conf_wrong:
    print("\n### CONFIDENT WRONG")
    print("Text:", X_test[i])
    print("True:", int(y_test[i]), "Pred:", int(pred_test[i]), "Prob(pos):", float(probs[i]))
    vals = shap_test.values[i]
    idx = np.argsort(np.abs(vals))[::-1][:10]
    print("Top SHAP features driving decision:")
    for f, v in zip(feature_names[idx], vals[idx]):
        print(f"  {f:20s}  {v:+.4f}")


VAL accuracy: 1.0
              precision    recall  f1-score   support

           0      1.000     1.000     1.000       487
           1      1.000     1.000     1.000       513

    accuracy                          1.000      1000
   macro avg      1.000     1.000     1.000      1000
weighted avg      1.000     1.000     1.000      1000






--- Example text: love today zxq amazing excellent actually product actually really product amazing
Pred: 1 Prob(pos): 0.9706152991335267
Top SHAP features:
  amazing               +1.0556
  excellent             +0.4487
  love                  +0.4469
  today zxq             +0.2489
  fantastic             -0.2385
  zxq amazing           +0.2125
  hate                  +0.1941
  zxq                   +0.1739
  pleasant              -0.1739
  wonderful             -0.1624

--- Example text: just amazing movie really amazing wonderful great today zxq service
Pred: 1 Prob(pos): 0.9819973940129617
Top SHAP features:
  amazing               +1.1023
  great                 +0.4852
  wonderful             +0.4712
  today zxq             +0.2587
  fantastic             -0.2385
  zxq service           +0.2065
  hate                  +0.1941
  zxq                   +0.1929
  pleasant              -0.1739
  love                  -0.1705

--- Example text: product zxq just very today wonderful e