In [None]:
# %%
import torch
torch.cuda.is_available()


True

In [None]:
# %%
# (Colab only) upload your XLSX
#from google.colab import files
#uploaded = files.upload()


In [None]:
# %%
!pip -q install transformers accelerate evaluate openpyxl scikit-learn pandas numpy torch


In [None]:
# %%
# If you haven't installed these in this environment, uncomment
!pip -q install "transformers>=4.38"

import re
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE


'cuda'

In [None]:
# %%
PATH_XLSX = "evrensel_isci_sendika_2024_dec2025_clean_fin_uncorrupted_real.xlsx"  # <-- change this
df = pd.read_excel(PATH_XLSX)

required_cols = ["EVENT_RELEVANT", "EVENT_ID", "title", "content","date"]
missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"Missing columns in XLSX: {missing}")

df.shape, df.columns.tolist()


((9186, 15),
 ['title',
  'date',
  'link',
  'content',
  'EVENT_RELEVANT',
  'EVENT_ID',
  'Unnamed: 6',
  'error',
  'Unnamed: 8',
  'Unnamed: 9',
  'Unnamed: 10',
  'Unnamed: 11',
  'Unnamed: 12',
  650,
  160])

In [None]:
# %%
def normalize_text(x):
    """Light cleanup: keep Turkish characters, remove weird spaces, collapse whitespace."""
    if pd.isna(x):
        return ""
    x = str(x).replace("\u00A0", " ")  # non-breaking space
    x = re.sub(r"\s+", " ", x).strip()
    return x

df["title"] = df["title"].apply(normalize_text)
df["content"] = df["content"].apply(normalize_text)

# final text fed into the model
df["text"] = (df["title"].astype(str) + "\n\n" + df["content"].astype(str)).str.strip()
df[["title","content","text"]].head(2)


Unnamed: 0,title,content,text
0,Bartın'da Hema'ya ait maden ocağında vagonları...,Bartın'ın Amasra ilçesindeki Hema Enerji şirke...,Bartın'da Hema'ya ait maden ocağında vagonları...
1,Bu soygun düzeni değişmeli,Pendik Marmara Eğitim ve Araştırma Hastanesind...,Bu soygun düzeni değişmeli\n\nPendik Marmara E...


In [None]:
# %%
# ----------------------------
# Cell 4 (fixed) — Define labeled rows (accept 0/1 as floats OR strings)
# ----------------------------
import numpy as np
import pandas as pd

def normalize_label(x):
    if pd.isna(x):
        return np.nan

    # float/integer case (your current situation: 0.0 / 1.0)
    if isinstance(x, (int, np.integer, float, np.floating)):
        if x == 0 or x == 0.0:
            return 0
        if x == 1 or x == 1.0:
            return 1
        return np.nan

    # string case
    s = str(x).strip().lower()
    if s in {"0", "0.0", "no", "n", "false"}:
        return 0
    if s in {"1", "1.0", "yes", "y", "true"}:
        return 1

    return np.nan

df["LABEL_CLEAN"] = df["EVENT_RELEVANT"].apply(normalize_label)
labeled_mask = df["LABEL_CLEAN"].notna()

print("Total rows:", len(df))
print("Labeled rows:", int(labeled_mask.sum()))
df.loc[labeled_mask, ["EVENT_RELEVANT","LABEL_CLEAN"]].head(10)


Total rows: 9186
Labeled rows: 738


Unnamed: 0,EVENT_RELEVANT,LABEL_CLEAN
0,0.0,0.0
1,0.0,0.0
2,0.0,0.0
3,0.0,0.0
4,0.0,0.0
5,0.0,0.0
6,0.0,0.0
7,0.0,0.0
8,0.0,0.0
9,0.0,0.0


In [None]:
# %%
# ----------------------------
# Cell 5 — Train/validation split (using LABEL_CLEAN from Cell 4)
# ----------------------------

# Keep only labeled rows (LABEL_CLEAN is 0/1, NaN otherwise)
df_labeled = df.loc[labeled_mask].copy()

# This is what the Trainer will learn on
df_labeled["label"] = df_labeled["LABEL_CLEAN"].astype(int)

# Stratified split so class balance is preserved in train/val
train_df, val_df = train_test_split(
    df_labeled,
    test_size=0.2,
    random_state=42,
    stratify=df_labeled["label"]
)

print("Train size:", len(train_df), "Val size:", len(val_df))
train_df["label"].value_counts(normalize=True), val_df["label"].value_counts(normalize=True)


Train size: 590 Val size: 148


(label
 0    0.783051
 1    0.216949
 Name: proportion, dtype: float64,
 label
 0    0.783784
 1    0.216216
 Name: proportion, dtype: float64)

In [None]:
# %%
# Inspect what EVENT_RELEVANT really looks like
s = df["EVENT_RELEVANT"]

print("dtype:", s.dtype)
print("non-null count:", s.notna().sum())

# show a sample of unique raw values (as-is)
u = s.dropna().unique()
print("unique values sample (up to 50):", u[:50])

# show stringified + stripped sample too
u_str = pd.Series(u).astype(str).str.strip()
print("stringified sample (up to 50):", u_str.head(50).tolist())


dtype: float64
non-null count: 738
unique values sample (up to 50): [0. 1.]
stringified sample (up to 50): ['0.0', '1.0']


In [None]:
# %%
MODEL_NAME = "dbmdz/bert-base-turkish-cased"  # BERTurk
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

MAX_LEN = 384  # 512 is allowed but slower

class TextClsDataset(Dataset):
    def __init__(self, texts, labels=None, tokenizer=None, max_len=384):
        self.texts = list(texts)
        self.labels = None if labels is None else list(labels)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, i):
        enc = self.tokenizer(
            self.texts[i],
            truncation=True,
            max_length=self.max_len,
            padding=False
        )
        item = {k: torch.tensor(v) for k, v in enc.items()}
        if self.labels is not None:
            item["labels"] = torch.tensor(int(self.labels[i]))
        return item

train_ds = TextClsDataset(train_df["text"], train_df["label"], tokenizer, MAX_LEN)
val_ds   = TextClsDataset(val_df["text"],   val_df["label"],   tokenizer, MAX_LEN)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# %%
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2).to(DEVICE)

pos = int(train_df["label"].sum())
neg = int(len(train_df) - pos)

# More weight to positive class if positives are rare
class_weights = torch.tensor([1.0, (neg / max(pos, 1))], dtype=torch.float, device=DEVICE)
class_weights


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dbmdz/bert-base-turkish-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tensor([1.0000, 3.6094], device='cuda:0')

In [None]:
# %%
# ----------------------------
# Cell 8 (updated) — Trainer setup + fine-tune (version compatible)
# ----------------------------
import inspect

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()[:, 1]
    return {
        "roc_auc": float(roc_auc_score(labels, probs)) if len(np.unique(labels)) > 1 else float("nan")
    }

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Weighted loss
def custom_loss_fn(model, inputs, return_outputs=False):
    labels = inputs.pop("labels")
    outputs = model(**inputs)
    logits = outputs.logits
    loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
    loss = loss_fct(logits, labels)
    return (loss, outputs) if return_outputs else loss

# Some transformers versions support compute_loss in Trainer; fallback safely
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        return custom_loss_fn(model, inputs, return_outputs=return_outputs)

args = TrainingArguments(
    output_dir="berturk_event_detect",
    learning_rate=2e-5,
    per_device_train_batch_size=8 if DEVICE=="cuda" else 4,
    per_device_eval_batch_size=16 if DEVICE=="cuda" else 8,
    num_train_epochs=3,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="roc_auc",
    greater_is_better=True,
    fp16=(DEVICE=="cuda"),
    report_to=[]
)

trainer = WeightedTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()


  trainer = WeightedTrainer(


Epoch,Training Loss,Validation Loss,Roc Auc
1,0.713,0.674348,0.856681
2,0.6477,0.344386,0.939925
3,0.2563,0.312034,0.939655


TrainOutput(global_step=222, training_loss=0.48938286411869636, metrics={'train_runtime': 23.8399, 'train_samples_per_second': 74.245, 'train_steps_per_second': 9.312, 'total_flos': 349279925990400.0, 'train_loss': 0.48938286411869636, 'epoch': 3.0})

In [None]:
# %%
val_out = trainer.predict(val_ds)
val_logits = val_out.predictions
val_labels = val_out.label_ids

val_probs = torch.softmax(torch.tensor(val_logits), dim=-1).numpy()[:, 1]
val_preds = (val_probs >= 0.5).astype(int)

print("ROC-AUC:", roc_auc_score(val_labels, val_probs) if len(np.unique(val_labels)) > 1 else "NA")
print(classification_report(val_labels, val_preds, digits=3))


ROC-AUC: 0.9399245689655173
              precision    recall  f1-score   support

           0      0.972     0.914     0.942       116
           1      0.744     0.906     0.817        32

    accuracy                          0.912       148
   macro avg      0.858     0.910     0.880       148
weighted avg      0.923     0.912     0.915       148



In [None]:
# %%
# ---- Run model on full df (prediction) ----
full_ds = TextClsDataset(df["text"], labels=None, tokenizer=tokenizer, max_len=MAX_LEN)
full_out = trainer.predict(full_ds)
full_logits = full_out.predictions
full_probs = torch.softmax(torch.tensor(full_logits), dim=-1).numpy()[:, 1]

import numpy as np
from sklearn.metrics import precision_recall_curve

def pick_threshold_min_recall(y_true, y_prob, min_recall=0.85):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    # precision_recall_curve returns thresholds of length n-1
    # recall/precision have length n
    best_t = 0.0
    best_p = -1.0

    for t, p, r in zip(thresholds, precision[1:], recall[1:]):
        if r >= min_recall and p > best_p:
            best_p = p
            best_t = float(t)

    # fallback if target recall unattainable
    if best_p < 0:
        # choose threshold that maximizes recall (usually very low)
        best_t = float(thresholds[np.argmax(recall[1:])])

    return best_t

# --- use on your validation predictions ---
BEST_T = pick_threshold_min_recall(val_labels, val_probs, min_recall=0.80)
print("Chosen threshold (min_recall=0.90):", BEST_T)

df["EVENT_PRED"] = (full_probs >= BEST_T).astype(int)
df["EVENT_PRED"].value_counts()


Chosen threshold (min_recall=0.90): 0.5911701917648315


Unnamed: 0_level_0,count
EVENT_PRED,Unnamed: 1_level_1
0,7885
1,1301


In [None]:
# %%
# If needed
!pip -q install spacy


In [None]:
# %%
# ----------------------------
# Cell S2 (updated) — Parse Turkish publication date into PUB_DATE
# ----------------------------

MONTHS_TR = {
    "ocak":1, "şubat":2, "subat":2, "mart":3, "nisan":4, "mayıs":5, "mayis":5,
    "haziran":6, "temmuz":7, "ağustos":8, "agustos":8, "eylül":9, "eylul":9,
    "ekim":10, "kasım":11, "kasim":11, "aralık":12, "aralik":12
}

def parse_tr_pub_date(x):
    if pd.isna(x):
        return pd.NaT
    s = str(x).lower().strip()
    s = re.sub(r"\s+", " ", s)

    # common formats:
    # "4 aralık 2025 10:25 Güncelleme: 10:31"
    m = re.search(r"(\d{1,2})\s+([a-zçğıöşü]+)\s+(\d{4})", s)
    if not m:
        return pd.NaT
    day = int(m.group(1))
    mon = MONTHS_TR.get(m.group(2), None)
    year = int(m.group(3))
    if mon is None:
        return pd.NaT
    try:
        return pd.Timestamp(year=year, month=mon, day=day)
    except Exception:
        return pd.NaT

df["PUB_DATE"] = df["date"].apply(parse_tr_pub_date)
df[["date","PUB_DATE"]].head(10)


Unnamed: 0,date,PUB_DATE
0,2 Ocak 2024 11:02 — — Güncelleme: 10:13,2024-01-02
1,2 Ocak 2024 04:30,2024-01-02
2,1 Ocak 2024 10:36,2024-01-01
3,1 Ocak 2024 03:00,2024-01-01
4,31 Aralık 2023 23:07,2023-12-31
5,31 Aralık 2023 15:18,2023-12-31
6,31 Aralık 2023 06:34 — — Güncelleme: 1 Ocak 20...,2023-12-31
7,31 Aralık 2023 05:34 — — Güncelleme: 11:21,2023-12-31
8,30 Aralık 2023 17:58,2023-12-30
9,30 Aralık 2023 17:40,2023-12-30


In [None]:
# %%
# ----------------------------
# (kept) Phrase mining config (you already had this; kept for minimal change)
# We'll still compute ORG_KEYS_FILTERED (baseline org anchor),
# but employer anchoring will be rebuilt later (EMPLOYER_KEYS).
# ----------------------------
from collections import Counter

CONTENT_CHARS = 800  # only scan first N chars of content for speed + relevance
MIN_FREQ = 3         # keep phrases that appear at least this many times
MAX_PHRASES = 5000   # cap to avoid huge ruler
NGRAM_MIN = 2
NGRAM_MAX = 4

# Light stopwords (add more if needed)
STOP = set("""
ve veya ile için gibi üzere da de ki mi mı mu mü
işçi işçileri işçilerin işçisine işçileriyle
toplu sözleşme toplu iş sözleşmesi sözleşme tis
grev greve grevde grevin
ücret zam maaş ücretler
sendika sendikası sendikadan sendikaya direniş emekçi emekçileri
emekçilerin ama çünkü direnişi
""".split())

def ngrams(words, n):
    for i in range(len(words)-n+1):
        yield tuple(words[i:i+n])

def tokenize_for_phrases(text):
    text = text.replace("’","'").replace("`","'")
    text = re.sub(r"[^\w\sçğıöşüÇĞİÖŞÜ'-]", " ", text)
    words = [w for w in text.split() if w]
    return words

phrase_counts = Counter()

for t, c in zip(df["title"].astype(str), df["content"].astype(str)):
    text = (t + " " + c[:CONTENT_CHARS]).strip()
    words = tokenize_for_phrases(text)

    # Keep titlecase-ish tokens for phrase building
    # (kept as in your approach)
    for n in range(NGRAM_MIN, NGRAM_MAX+1):
        for ng in ngrams(words, n):
            s = " ".join(ng)
            low = s.lower()
            if any(w.lower() in STOP for w in ng):
                continue
            # heuristic: must contain at least one token starting with uppercase
            if not any(w[:1].isupper() for w in ng if w):
                continue
            phrase_counts[low] += 1

firm_phrases = [p for p,cnt in phrase_counts.items() if cnt >= MIN_FREQ]
firm_phrases = firm_phrases[:MAX_PHRASES]
len(firm_phrases), firm_phrases[:20]


(5000,
 ["kaybetti bartın'ın",
  "bartın'ın amasra",
  'amasra ilçesindeki',
  'ferdi özgün',
  'kaybetti i̇ş',
  'i̇ş cinayeti',
  "26 aralık'ta",
  'geldi saat',
  'saat 16',
  'sıkıştı mesai',
  'mesai arkadaşları',
  '112 acil',
  'acil sağlık',
  'i̇lk müdahalesinin',
  'ambulansla bartın',
  'bartın devlet',
  "devlet hastanesi'ne",
  "hastanesi'ne kaldırılan",
  'bilkent şehir',
  "şehir hastanesi'ne"])

In [None]:
# %%
# ----------------------------
# Filter ORG_KEYS: remove union/confed/general-institution anchors
# ----------------------------

BAD_ORG = set([
    "disk", "dsk", "dİsk", "türk iş", "turk is", "hak iş", "kesk",
    "genel başkanı", "genel baskani", "genel başkan", "sube baskani", "şube başkanı",
    "emek partisi", "emep", "sosyal guvenlik", "devlet hastanesi",
    "organize sanayi",
    "belediye", "bakanlık", "bakanligi", "bakan",
    "haber merkezi", "ajans", "gazetesi", "servisi"
])

def normalize_org(s: str) -> str:
    s = s.lower().strip()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"[^\w\sçğıöşü0-9]", "", s, flags=re.UNICODE)
    return s

def is_bad_org(k: str) -> bool:
    s = normalize_org(k)
    if not s:
        return True
    if s in BAD_ORG:
        return True
    # “contains” filters
    if "sendika" in s or "konfederasyon" in s or "şube" in s or "sube" in s:
        return True
    if "organize sanayi" in s:
        return True
    if "haber" in s or "servis" in s or "ajans" in s:
        return True
    return False


In [None]:
# %%
from collections import Counter
import re

# ----------------------------
# Tune knobs
# ----------------------------
CONTENT_CHARS = 1200      # scan first N chars for speed
WIN_CHARS = 220           # +/- window around triggers for mining
MAX_WINS = 8              # max windows per doc
MIN_FREQ = 3
MAX_PHRASES = 5000
NGRAM_MIN = 2
NGRAM_MAX = 5

STOP = set("""
ve veya ile için gibi üzere da de ki mi mı mu mü
işçi işçileri grev grevi direniş direnişi eylem açıklama basın
sendika sendikası işçilerden işçilerin mücadele talep sözleşme toplu iş emekçi
emekçiler emekçileri örgütlü ama fakat lakin çünkü
""".split())

# ---- your trigger regex (use your ALL_TRIG if already defined) ----
TRIG_RE = ALL_TRIG  # assumes you already defined ALL_TRIG earlier

def clean_for_phrase_mining(text: str) -> str:
    text = (text or "")
    text = text.replace("\u00A0", " ").replace("’","'").replace("`","'")
    text = re.sub(r"\s+", " ", text).strip()
    return text

def tokenize_simple(text: str):
    return re.findall(r"[A-Za-zÇĞİÖŞÜçğıöşü0-9]+", text)

def is_titlecase_like(tok: str) -> bool:
    if len(tok) < 2:
        return False
    if tok.isupper() and len(tok) >= 2:
        return True
    return tok[0].isupper() and any(c.islower() for c in tok[1:])

def trigger_windows_text(title: str, content: str, win_chars=WIN_CHARS, max_wins=MAX_WINS):
    t = clean_for_phrase_mining(title)
    c = clean_for_phrase_mining(content)[:CONTENT_CHARS]
    text = f"{t} {c}".strip()
    wins = []
    for m in TRIG_RE.finditer(text):
        a = max(0, m.start() - win_chars)
        b = min(len(text), m.end() + win_chars)
        wins.append(text[a:b])
        if len(wins) >= max_wins:
            break
    # if no trigger found, return empty => no phrase mining (prevents boilerplate mining)
    return wins

phrase_counts = Counter()

for t, c in zip(df["title"].astype(str), df["content"].astype(str)):
    wins = trigger_windows_text(t, c)
    if not wins:
        continue

    for w in wins:
        toks = tokenize_simple(w)
        flags = [is_titlecase_like(tok) for tok in toks]

        i = 0
        while i < len(toks):
            if not flags[i]:
                i += 1
                continue
            j = i
            while j < len(toks) and flags[j]:
                j += 1

            span = toks[i:j]  # consecutive titlecase-like tokens inside trigger window

            for n in range(NGRAM_MIN, NGRAM_MAX + 1):
                for k in range(0, len(span) - n + 1):
                    ng = span[k:k+n]
                    ng_l = [w.lower() for w in ng]

                    if any(w in STOP for w in ng_l):
                        continue
                    if all(w.isdigit() for w in ng):
                        continue

                    phrase = " ".join(ng)
                    phrase_counts[phrase] += 1

            i = j

candidates = [(p, cnt) for p, cnt in phrase_counts.items() if cnt >= MIN_FREQ]
candidates.sort(key=lambda x: x[1], reverse=True)

print("Candidate phrases (trigger-window, freq>=MIN_FREQ):", len(candidates))
print("Top 30:")
for p, cnt in candidates[:30]:
    print(cnt, "-", p)

firm_phrases = [p for p, cnt in candidates[:MAX_PHRASES]]
print("firm_phrases kept:", len(firm_phrases))


NameError: name 'ALL_TRIG' is not defined


[38;5;1m✘ No compatible package found for 'tr_core_news_md' (spaCy v3.8.11)[0m



OSError: [E050] Can't find model 'or tr_core_news_md'. It doesn't seem to be a Python package or a valid path to a data directory.

In [None]:
# %%
ORG_MAX_PER_DOC = 5

org_keys = []
for t, c in zip(df["title"].astype(str), df["content"].astype(str)):
    text = (t + " " + c[:CONTENT_CHARS]).strip()
    doc = nlp(text)
    ents = [normalize_org(ent.text) for ent in doc.ents if ent.label_ == "ORG"]
    ents = [e for e in ents if e and (not is_bad_org(e))]
    # de-dup preserve order
    seen = set()
    dedup = []
    for e in ents:
        if e not in seen:
            seen.add(e)
            dedup.append(e)
    org_keys.append(dedup[:ORG_MAX_PER_DOC])

df["ORG_KEYS_FILTERED"] = org_keys
df[["title","ORG_KEYS_FILTERED"]].head(10)


In [None]:
# %%
# ----------------------------
# Cell: Narrow to Collective Bargaining / Wage-related strikes
# ----------------------------

KEEP_PATTERNS = [
    r"\btoplu sözleşme\b", r"\btis\b", r"\btoplu iş sözleşmesi\b",
    r"\bücret\b", r"\bzam\b", r"\bmaaş\b", r"\bücret artış\b",
    r"\bpazarlık\b", r"\bgörüşme\b", r"\bmüzakere\b",
    r"\bgrev\b", r"\bgrevde\b", r"\bgreve çıktı\b",
    r"\bsözleşme\s*sürüyor\b", r"\banlaşma\b", r"\bdireniş\b"
]

DROP_PATTERNS = [
    r"\bişten çıkar", r"\bişten at", r"\bişten çıkarıl", r"\bkovuldu\b",
    r"\bsendikalaş", r"\bsendika üye", r"\bsendika üyeli",
    r"\biş kaz", r"\bölüm\b.*\b(ölüm|yaralı)\b", r"\bgöçük\b",
    r"\bgözalt", r"\btutuk", r"\bdava\b", r"\bmahkeme\b",
    r"\bziyaret\b", r"\bdayanışma\b", r"\banma\b", r"\bbasın açıklama\b",
    r"\bsendikalaşma\b"
]

keep_re = re.compile("|".join(KEEP_PATTERNS), flags=re.IGNORECASE)
drop_re = re.compile("|".join(DROP_PATTERNS), flags=re.IGNORECASE)

def is_cb_wage_article(title, content):
    text = f"{title} {content}"
    if drop_re.search(text):
        return False
    return bool(keep_re.search(text))

df["IS_CB_WAGE"] = df.apply(lambda r: is_cb_wage_article(r["title"], r["content"]), axis=1)

# Start from EVENT_PRED (general)
df["EVENT_PRED_CB"] = 0
mask_rel = df["EVENT_PRED"] == 1
df["EVENT_PRED_CB"] = ((df["EVENT_PRED"] == 1) & (df["IS_CB_WAGE"])).astype(int)

#df.loc[mask_pred, "EVENT_PRED_CB"] = df.loc[mask_pred, "text"].apply(lambda x: int(bool(KEEP_RE.search(str(x)))))
print("Pred-relevant (all):", int((df["EVENT_PRED"]==1).sum()))
print("Pred-relevant (CB/Wage):", int((df["EVENT_PRED_CB"]==1).sum()))
df["EVENT_PRED_CB"].value_counts()


In [None]:
# %%
from collections import Counter
import re

def title_tokens(title):
    return re.findall(r"[A-Za-zÇĞİÖŞÜçğıöşü]+", str(title).lower())

mask_cb = df["EVENT_PRED_CB"] == 1
N = int(mask_cb.sum())

dfreq = Counter()
for t in df.loc[mask_cb, "title"]:
    seen = set(w for w in title_tokens(t) if len(w) >= 4)
    for w in seen:
        dfreq[w] += 1

# keep tokens that appear in <= 5% of CB titles
RARE_MAX = max(1, int(0.05 * max(N, 1)))
rare_tokens = {w for w,cnt in dfreq.items() if cnt <= RARE_MAX}
len(rare_tokens), list(sorted(list(rare_tokens)))[:30]


In [None]:
# %% [markdown]
# ============================================================
# UPDATED HELPERS (NEW)
# These were missing in your notebook but are required for a "full code" version:
# - encode_texts()  -> embedding encoder for linking
# - connected_components() -> union-find
# ============================================================


In [None]:
# %%
from transformers import AutoModel

# Sentence embedding encoder (mean pooling + L2 norm)
# Minimal and stable: uses BERTurk backbone (same family as your classifier).
_enc_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_enc_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
_enc_model.eval()

@torch.no_grad()
def encode_texts(texts, batch_size=8, max_len=256):
    vecs = []
    for i in range(0, len(texts), batch_size):
        batch = [str(t) for t in texts[i:i+batch_size]]
        enc = _enc_tokenizer(
            batch,
            truncation=True,
            max_length=max_len,
            padding=True,
            return_tensors="pt"
        ).to(DEVICE)

        out = _enc_model(**enc)
        last = out.last_hidden_state  # (B, T, H)
        mask = enc["attention_mask"].unsqueeze(-1).float()  # (B, T, 1)

        pooled = (last * mask).sum(dim=1) / torch.clamp(mask.sum(dim=1), min=1e-6)
        pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)  # unit length

        vecs.append(pooled.detach().cpu().numpy())
    return np.vstack(vecs) if vecs else np.zeros((0, 768), dtype=np.float32)

def connected_components(n, edges):
    # Union-find
    parent = list(range(n))
    rank = [0]*n

    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    def union(a,b):
        ra, rb = find(a), find(b)
        if ra == rb:
            return
        if rank[ra] < rank[rb]:
            parent[ra] = rb
        elif rank[ra] > rank[rb]:
            parent[rb] = ra
        else:
            parent[rb] = ra
            rank[ra] += 1

    for a,b in edges:
        union(a,b)

    comps = defaultdict(list)
    for i in range(n):
        comps[find(i)].append(i)

    return list(comps.values())


In [None]:
# %% [markdown]
# ============================================================
# UPDATED LINKING (DENEME1 improvement)
# - Uses EMPLOYER_KEYS instead of ORG_KEYS_FILTERED as the primary anchor
# - Uses TIME DECAY (no hard cutoff) to allow long strikes
# - Keeps your "rare title token overlap" gate
# ============================================================


In [None]:
# %%
import math
from collections import defaultdict

def rare_title_tokens(title, min_len=4):
    toks = re.findall(r"[A-Za-zÇĞİÖŞÜçğıöşü]+", str(title).lower())
    return {t for t in toks if len(t) >= min_len and t in rare_tokens}

def _safe_days_diff(d1, d2):
    if pd.isna(d1) or pd.isna(d2):
        return None
    try:
        return abs((pd.to_datetime(d1) - pd.to_datetime(d2)).days)
    except Exception:
        return None

def _time_decay(days_diff, tau_days=25.0):
    if days_diff is None:
        return 1.0
    return math.exp(-float(days_diff) / float(tau_days))

def assign_event_ids_hybrid(
    df_in,
    rel_flag_col="EVENT_PRED_CB",
    date_col="PUB_DATE",
    employer_col="EMPLOYER_KEYS",
    sim_short=0.94,
    sim_emp=0.965,
    tau_days=25.0,
    max_rel=6000,
    EMP_MAX_BUCKET=40,
    TITLE_OVERLAP_K=2,
    # --- NEW knobs (minimal additions) ---
    MIN_CLUSTER_SIZE=2,                 # 2 kills phantom EVs; set to 1 if you want singletons
    DROP_EMPTY_EMPLOYER_CLUSTERS=True   # drops clusters where nobody has employer keys
):
    df_out = df_in.copy()

    # Normalize EVENT_ID
    df_out["EVENT_ID"] = df_out["EVENT_ID"].where(df_out["EVENT_ID"].notna(), np.nan)
    df_out["EVENT_ID"] = df_out["EVENT_ID"].apply(lambda x: str(x).strip() if not pd.isna(x) else np.nan)

    if rel_flag_col not in df_out.columns:
        raise ValueError(f"Missing column {rel_flag_col}. Did you create EVENT_PRED_CB first?")

    rel = df_out[df_out[rel_flag_col] == 1].copy()
    if len(rel) == 0:
        print(f"No rows where {rel_flag_col} == 1. Nothing to link.")
        return df_out

    if len(rel) > max_rel:
        print(f"Too many relevant rows ({len(rel)}). Increase threshold or lower corpus slice.")
        return df_out

    # Stable ordering
    rel["_has_date"] = rel[date_col].notna()
    rel = rel.sort_values(by=[date_col, "_has_date"], ascending=[True, False])

    idx = list(rel.index)
    n = len(idx)

    # Embeddings once
    texts = rel["text"].tolist()
    E = encode_texts(texts, batch_size=32 if DEVICE == "cuda" else 8, max_len=256)

    # Precompute title token sets
    title_tok_sets = [rare_title_tokens(t) for t in rel["title"].tolist()]
    rel_dates = rel[date_col].tolist()

    # Inverted index for EMPLOYER keys
    emp_to_pos = defaultdict(list)
    rel_emp_lists = rel[employer_col].tolist() if employer_col in rel.columns else [None] * n
    for pos, keys in enumerate(rel_emp_lists):
        for k in (keys or []):
            if k:
                emp_to_pos[k].append(pos)

    edges = set()

    # A) Employer-bucket edges (high precision)
    for emp, positions in emp_to_pos.items():
        if len(positions) <= 1:
            continue
        if len(positions) > EMP_MAX_BUCKET:
            continue

        for a_i in range(len(positions)):
            i = positions[a_i]
            for a_j in range(a_i + 1, len(positions)):
                j = positions[a_j]
                base_sim = float(np.dot(E[i], E[j]))
                dd = _safe_days_diff(rel_dates[i], rel_dates[j])
                sim = base_sim * _time_decay(dd, tau_days=tau_days)

                if sim >= sim_emp:
                    if len(title_tok_sets[i] & title_tok_sets[j]) >= TITLE_OVERLAP_K:
                        edges.add((min(i, j), max(i, j)))

    # B) Non-employer edges (still allowed, but harder)
    for i in range(n):
        for j in range(i + 1, n):
            base_sim = float(np.dot(E[i], E[j]))
            dd = _safe_days_diff(rel_dates[i], rel_dates[j])
            sim = base_sim * _time_decay(dd, tau_days=tau_days)

            if sim >= sim_short:
                if len(title_tok_sets[i] & title_tok_sets[j]) >= TITLE_OVERLAP_K:
                    edges.add((i, j))

    comps = connected_components(n, list(edges))
    clusters = [[idx[pos] for pos in comp] for comp in comps]

    # --- NEW: filter phantom clusters BEFORE assigning EV IDs ---
    filtered = []
    for cluster in clusters:
        # 1) must have at least MIN_CLUSTER_SIZE articles (since rel already filtered, this is cluster size)
        if len(cluster) < MIN_CLUSTER_SIZE:
            continue

        # 2) optionally drop clusters where everyone has empty employer keys
        if DROP_EMPTY_EMPLOYER_CLUSTERS and employer_col in df_out.columns:
            emp_nonempty = df_out.loc[cluster, employer_col].apply(lambda x: isinstance(x, list) and len(x) > 0).sum()
            if int(emp_nonempty) == 0:
                continue

        filtered.append(cluster)

    clusters = filtered
    # -----------------------------------------------------------

    # Assign EVENT_IDs
    new_counter = 1
    for cluster in clusters:
        existing = df_out.loc[cluster, "EVENT_ID"].dropna()
        if len(existing) > 0:
            chosen = existing.value_counts().idxmax()
        else:
            chosen = f"EV{new_counter:06d}"
            new_counter += 1
        df_out.loc[cluster, "EVENT_ID"] = chosen

    return df_out


In [None]:
# %% [markdown]
# ============================================================
# UPDATED EMPLOYER EXTRACTION (DENEME1 improvement)
# Minimal change philosophy:
# - Keep ORG_KEYS_FILTERED as a baseline
# - Rebuild EMPLOYER_KEYS with trigger-window + canonicalization + alias
# - Keep unions separately (UNION_KEYS)
# ============================================================


In [None]:
# %%
# Canonicalization + alias (grow alias dict over time)
LEGAL_RE = re.compile(
    r"\b(a\.?ş\.?|aş|anonim|şirketi|şti|ltd|limited|inc|corp|co|holding|sanayi|ticaret|ve)\b",
    flags=re.IGNORECASE
)
GENERIC_TAIL_RE = re.compile(
    r"\b(fabrika(sı|si|da|de|nda|nde)?|işletme(si|de|da|nde|nda)?|tesis(leri|de|da|nde|nda)?|işyeri(nde|ne|ni)?)\b",
    flags=re.IGNORECASE
)

def _fold_tr(s: str) -> str:
    if s is None or (isinstance(s, float) and pd.isna(s)):
        return ""
    s = str(s).strip().lower()
    s = s.replace("’", "'").replace("`", "'")
    s = re.sub(r"\s+", " ", s)
    return s

def canonical_employer(name: str) -> str:
    s = _fold_tr(name)
    if not s:
        return ""
    s = re.sub(r"[^\w\sçğıöşü0-9'-]", " ", s, flags=re.UNICODE)
    s = LEGAL_RE.sub(" ", s)
    s = GENERIC_TAIL_RE.sub(" ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

# %%
def recanon_list(xs):
    if xs is None or (isinstance(xs, float) and pd.isna(xs)):
        return []
    if not isinstance(xs, list):
        # if it came back as a string like "a; b", split
        xs = [p.strip() for p in str(xs).split(";") if p.strip()]
    out = []
    seen = set()
    for x in xs:
        k = canonical_employer(x)   # NO ALIASES
        if not k:
            continue
        if k not in seen:
            seen.add(k)
            out.append(k)
    return out

# Make ORG_KEYS_FILTERED speak the same "language" as EMPLOYER_KEYS
df["ORG_KEYS_FILTERED_CAN"] = df["ORG_KEYS_FILTERED"].apply(recanon_list)

# If you want, overwrite the old one (simplest):
df["ORG_KEYS_FILTERED"] = df["ORG_KEYS_FILTERED_CAN"]
df.drop(columns=["ORG_KEYS_FILTERED_CAN"], inplace=True)

print(df["ORG_KEYS_FILTERED"].head())


#EMPLOYER_ALIASES = {
    # "schneider elektrik": "schneider electric",
    # "purmo metal": "purmo",
    #}
def apply_employer_alias(k: str) -> str:
    #k = canonical_employer(k)
    return canonical_employer(k)

# ----------------------------
# Triggers + patterns
# IMPORTANT: join alternatives with | inside the SAME group
# ----------------------------
ALL_TRIG = re.compile(
    r"(grev(e)? çıktı|grev başladı|iş bırak(tı|ıyor)|üretim(i)? durdur|"
    r"grevde|grev sürüyor|(\d+)\.? ?gün(ü)?nde|"
    r"anlaşma sağlandı|grev bitti|grev sona erdi|protokol imzalandı|"
    r"kabul edildi|reddedildi|imza(landı)?|uzlaş(ma)?|anlaş(ma)?|"
    r"direnişi sürüyor|direniş(i)?)"
    r"kazanımla sonuçlandı",
    re.IGNORECASE
)

P_FACILITY = re.compile(
    r"(?P<name>[A-ZÇĞİÖŞÜ][\w’'-.]+(?:\s+[A-ZÇĞİÖŞÜ][\w’'-.]+){0,6})\s+"
    r"(fabrika(sı|sinda|sında|da|de|nda|nde)?|işyeri(nde|ne|ni)?|tesis(leri|de|da|nde|nda)?)",
    re.UNICODE
)
P_WORKERS = re.compile(
    r"(?P<name>[A-ZÇĞİÖŞÜ][\w’'-.]+(?:\s+[A-ZÇĞİÖŞÜ][\w’'-.]+){0,6})\s+işçi(leri|ler)?",
    re.UNICODE
)

# ----------------------------
# Union + person drop lists
# ----------------------------
UNION_TERMS = {
    "disk", "türk iş", "turk is", "hak iş", "kesk",
    "türk metal", "metal iş", "birleşik metal iş", "birlesik metal is",
    "genel iş", "petrol iş", "tek gıda iş", "tek gida is",
    "tüm bel sen", "tum bel sen", "tüm bel-sen", "tum bel-sen",
    "birtek sen",
}

# NOTE: best practice is to store *names only* (not role phrases),
# because role phrases are already handled by DROP_IF_CONTAINS.
PERSON_TERMS = {
   "hilal tok istanbul", "ramis sağlam içerik", "genel başkan özkan atar",
   "özkan atar", "bölge temsilcisi hayrettin çakmak", "izbb başkanı cemil tugay",
   "hilal tok", "hasret gültekin kozan", "sevda karaca", "iskender bayhan"
}

DROP_IF_CONTAINS = [
    # roles / bylines / org-like non-employers
    "genel başkan", "genel baskan", "başkanı", "baskani", "şube başkanı", "sube baskani",
    "belediyesi", "çalışan", "asgari", "vade",
    "partisi", "chp", "akp", "mhp", "emep",
    "servisi", "muhabir", "haber merkezi", "gazetesi", "ajans", "ücret", "zam",
    "tis", "sözleşme", "grev", "işçi"
]

# Add these near your DROP lists
STARTER_BAD = {
    "yapılan", "yapilan", "önünde", "onunde", "grevdeki",
    "sosyal", "sabah", "aynı", "ayni", "en", "daha", "gece",
    "temel", "kamu", "türkiye", "turkiye", "açıklamada", "aciklamada",
    "geçtiğimiz", "gectigimiz", "günlerde", "gunlerde", "çok"
}

# If a key is basically a sentence starter phrase, reject
def looks_like_sentence_starter(k: str) -> bool:
    kk = _fold_tr(k)
    if not kk:
        return True
    toks = kk.split()
    if not toks:
        return True
    # reject if first token is a known starter
    if toks[0] in STARTER_BAD:
        return True
    # reject very generic bigrams like "yapılan açıklamada"
    if len(toks) >= 2 and (toks[0] in STARTER_BAD or toks[1] in STARTER_BAD):
        return True
    return False

# Morphology-ish heuristics that are very common in junk:
# -deki/-daki adjectives, -inde/-ında locative phrases, etc.
def looks_like_non_entity_phrase(k: str) -> bool:
    kk = _fold_tr(k)
    if not kk:
        return True
    if kk.endswith(("deki", "daki", "teki", "taki")):
        return True
    if any(w in kk for w in ["açıklamada", "aciklamada", "saatlerinde", "günlerde", "gunlerde"]):
        return True
    return False


def looks_like_union(k: str) -> bool:
    return _fold_tr(k) in UNION_TERMS

def looks_like_person(k: str) -> bool:
    # exact name match OR contains a known name (covers "X Y içerik" cases)
    kk = _fold_tr(k)
    if not kk:
        return False
    if kk in PERSON_TERMS:
        return True
    return any(name in kk for name in PERSON_TERMS)

def should_drop_orgish(k: str) -> bool:
    kk = _fold_tr(k)
    if not kk or len(kk) < 3:
        return True
    if looks_like_union(kk) or looks_like_person(kk):
        return True
    if looks_like_sentence_starter(kk) or looks_like_non_entity_phrase(kk):
        return True
    for bad in DROP_IF_CONTAINS:
        if bad in kk:
            return True
    return False


def _as_list(x):
    if x is None or (isinstance(x, float) and pd.isna(x)):
        return []
    if isinstance(x, list):
        return [str(i).strip() for i in x if str(i).strip()]
    s = str(x).strip()
    if not s:
        return []
    return [p.strip() for p in re.split(r"[;|,]\s*", s) if p.strip()]

def trigger_windows(text: str, window_chars: int = 180, max_wins: int = 6):
    t = text or ""
    wins = []
    for m in ALL_TRIG.finditer(t):
        a = max(0, m.start() - window_chars)
        b = min(len(t), m.end() + window_chars)
        wins.append(t[a:b])
        if len(wins) >= max_wins:
            break
    return wins

def extract_employer_candidates_from_text(title: str, content: str, org_list=None):
    """
    Only keep employer candidates that:
    1) appear near strike triggers
    2) AND already exist in ORG_KEYS_FILTERED

    This prevents sentence fragments from becoming employers.
    """
    txt = (str(title or "") + " " + str(content or "")).strip()
    wins = trigger_windows(txt, window_chars=180, max_wins=6)

    emp, uni = [], []

    org_list = org_list or []
    org_can = {apply_employer_alias(o) for o in org_list}

    for w in wins:
        for pat in (P_FACILITY, P_WORKERS):
            for m in pat.finditer(w):
                raw = m.group("name").strip()
                k = apply_employer_alias(raw)
                if not k:
                    continue

                # HARD GATE: must already be an ORG
                if k not in org_can:
                    continue

                if looks_like_union(k):
                    uni.append(k)
                    continue
                if looks_like_person(k):
                    continue
                if should_drop_orgish(k):
                    continue

                emp.append(k)

    def dedup(seq):
        seen = set()
        out = []
        for x in seq:
            if x and x not in seen:
                seen.add(x)
                out.append(x)
        return out

    return dedup(emp), dedup(uni)



def build_employer_and_union_keys(df_in, org_col="ORG_KEYS_FILTERED", max_emp_per_doc=3):
    df_out = df_in.copy()
    employer_keys = []
    union_keys = []

    for t, c, orgs in zip(
        df_out["title"].astype(str),
        df_out["content"].astype(str),
        df_out[org_col] if org_col in df_out.columns else [None] * len(df_out)
    ):
        org_list = _as_list(orgs)

        emp_from_org, uni_from_org = [], []
        for o in org_list:
            k = apply_employer_alias(o)
            if not k:
                continue

            if looks_like_union(k):
                uni_from_org.append(k)
                continue
            if looks_like_person(k):
                continue
            if should_drop_orgish(k):
                continue

            emp_from_org.append(k)

        emp_mined, uni_mined = extract_employer_candidates_from_text(t, c, org_list)

        def dedup(seq):
            seen = set()
            out = []
            for x in seq:
                if x and x not in seen:
                    seen.add(x)
                    out.append(x)
            return out

        emp = dedup(emp_from_org + emp_mined)[:max_emp_per_doc]
        uni = dedup(uni_from_org + uni_mined)

        employer_keys.append(emp)
        union_keys.append(uni)

    df_out["EMPLOYER_KEYS"] = employer_keys
    df_out["UNION_KEYS"] = union_keys
    return df_out


In [None]:
# %%
# ---- Run linking FIRST (wave-level), then rebuild EMPLOYER_KEYS, then firm-splitting ----

df_linked = df.copy()

# placeholder (EMPLOYER_KEYS used by linker; build now so linker can use it)
df_linked = build_employer_and_union_keys(df_linked, org_col="ORG_KEYS_FILTERED", max_emp_per_doc=3)

df_linked = assign_event_ids_hybrid(
    df_linked,
    rel_flag_col="EVENT_PRED_CB",
    date_col="PUB_DATE",
    employer_col="EMPLOYER_KEYS",
    sim_short=0.94,
    sim_emp=0.965,
    tau_days=25.0,     # key knob: higher = more tolerant for long strikes
    EMP_MAX_BUCKET=40,
    TITLE_OVERLAP_K=2
)

mask_rel = df_linked["EVENT_PRED_CB"] == 1
print("Pred-relevant (CB/Wage):", int(mask_rel.sum()))
print("Unique EVENT_ID among predicted relevant:", int(df_linked.loc[mask_rel, "EVENT_ID"].nunique()))
df_linked.loc[mask_rel].groupby("EVENT_ID").size().describe()


In [None]:
# %%
from collections import Counter

mask_cb = df_linked["EVENT_PRED_CB"] == 1

# Top employers after filtering/mining
c_emp = Counter()
for ks in df_linked.loc[mask_cb, "EMPLOYER_KEYS"]:
    for k in (ks or []):
        c_emp[k] += 1

print("\nTop 30 EMPLOYER_KEYS:")
for k,v in c_emp.most_common(30):
    print(v, "-", k)

print("\nShare of CB/Wage docs with NO employer key:",
      float((df_linked.loc[mask_cb, "EMPLOYER_KEYS"].apply(len) == 0).mean()))


In [None]:
# %%
from collections import Counter

mask_rel = df_linked["EVENT_PRED_CB"] == 1

c = Counter()
for ks in df_linked.loc[mask_rel, "ORG_KEYS_FILTERED"]:
    if ks is None or (isinstance(ks, float) and pd.isna(ks)):
        continue
    if isinstance(ks, list):
        for k in ks:
            k = str(k).strip()
            if k:
                c[k] += 1
    else:
        # if already a string
        for k in str(ks).split(";"):
            k = k.strip()
            if k:
                c[k] += 1

print("Top 30 ORG_KEYS_FILTERED (among predicted CB events):")
for k, v in c.most_common(30):
    print(f"{v:>4} - {k}")


In [None]:
# %%
from collections import Counter
import pandas as pd

# ----------------------------
# Split each wave-level EVENT_ID into firm-level strike IDs using EMPLOYER_KEYS
# Requires df_linked["EMPLOYER_KEYS"] to exist (built above).
# ----------------------------

def split_wave_into_employers(
    df_in,
    wave_col="EVENT_ID",
    emp_col="EMPLOYER_KEYS",
    min_emp_mentions=1,
    multi_firm_mode=True
):
    out = df_in.copy()
    out["EVENT_ID_FIRM"] = None

    for wave_id, g in out.groupby(wave_col):
        # count employer keys within wave
        counter = Counter()
        for ks in g[emp_col]:
            for k in (ks or []):
                counter[k] += 1

        keep_emps = {k for k,v in counter.items() if v >= min_emp_mentions}

        for idx, row in g.iterrows():
            keys = [k for k in (row.get(emp_col) or []) if k in keep_emps]

            if len(keys) == 0:
                out.loc[idx, "EVENT_ID_FIRM"] = f"{wave_id}_UNK"
                continue

            if multi_firm_mode and len(keys) > 1:
                # keep first (or join) — simplest
                out.loc[idx, "EVENT_ID_FIRM"] = f"{wave_id}__{keys[0]}"
            else:
                out.loc[idx, "EVENT_ID_FIRM"] = f"{wave_id}__{keys[0]}"

    return out

df_linked = split_wave_into_employers(
    df_linked,
    wave_col="EVENT_ID",
    emp_col="EMPLOYER_KEYS",
    min_emp_mentions=1,
    multi_firm_mode=True
)

mask = df_linked["EVENT_PRED_CB"] == 1
print("CB/Wage articles:", int(mask.sum()))
print("Unique firm-level strike events:", df_linked.loc[mask, "EVENT_ID_FIRM"].nunique())
df_linked.loc[mask].groupby("EVENT_ID_FIRM").size().describe()


In [None]:
# %% [markdown]
# ============================================================
# UPDATED UNK ABSORPTION (DENEME1 improvement)
# You asked: no hard date cutoff. This uses a time-decay factor instead.
# ============================================================


In [None]:
# %%
def absorb_unk_into_employers(
    df_in,
    wave_col="EVENT_ID",
    firm_col="EVENT_ID_FIRM",
    date_col="PUB_DATE",
    text_col="text",
    sim_thresh=0.90,
    tau_days=25.0
):
    out = df_in.copy()
    out[date_col] = pd.to_datetime(out[date_col], errors="coerce")

    base = out[out["EVENT_PRED_CB"]==1].copy()

    for wave_id, g in base.groupby(wave_col):
        firm = g[~g[firm_col].str.endswith("_UNK", na=False)].copy()
        unk  = g[g[firm_col].str.endswith("_UNK", na=False)].copy()

        if len(firm) == 0 or len(unk) == 0:
            continue

        firm_texts = firm[text_col].tolist()
        unk_texts  = unk[text_col].tolist()

        E_firm = encode_texts(firm_texts, batch_size=32 if DEVICE=="cuda" else 8, max_len=256)
        E_unk  = encode_texts(unk_texts,  batch_size=32 if DEVICE=="cuda" else 8, max_len=256)

        firm_dates = firm[date_col].tolist()
        unk_dates  = unk[date_col].tolist()
        firm_ids   = firm[firm_col].tolist()

        for u_i, row_idx in enumerate(unk.index):
            ud = unk_dates[u_i]
            best_score = -1.0
            best_firm = None

            for f_i in range(len(firm)):
                fd = firm_dates[f_i]

                base_sim = float(np.dot(E_unk[u_i], E_firm[f_i]))
                dd = _safe_days_diff(ud, fd)
                score = base_sim * _time_decay(dd, tau_days=tau_days)

                if score > best_score:
                    best_score = score
                    best_firm = firm_ids[f_i]

            if best_firm is not None and best_score >= sim_thresh:
                out.loc[row_idx, firm_col] = best_firm

    return out

df_linked = absorb_unk_into_employers(
    df_linked,
    sim_thresh=0.90,
    tau_days=25.0
)

mask_cb = df_linked["EVENT_PRED_CB"]==1
print("UNK count after absorption:", int(df_linked.loc[mask_cb, "EVENT_ID_FIRM"].str.endswith("_UNK", na=False).sum()))
print("Unique firm events:", df_linked.loc[mask_cb, "EVENT_ID_FIRM"].nunique())
df_linked.loc[mask_cb].groupby("EVENT_ID_FIRM").size().describe()


In [None]:
# %%
# ============================================================
# REPLACEMENT EXPORT CELL — 2-sheet Excel in your format
# Sheet 1: Firm_Level_Strikes (one row per EVENT_ID_FIRM)
# Sheet 2: Articles_By_Firm_Event (article-level mapping)
# NOW includes: EMPLOYER_KEYS, UNION_KEYS in Sheet 2
# ============================================================

from openpyxl import Workbook
import pandas as pd

# pick your final dataframe variable
try:
    df_out
except NameError:
    df_out = df_linked  # fallback if you use df_linked

# choose mask if not already defined
try:
    mask
except NameError:
    mask = df_out["EVENT_PRED_CB"] == 1  # default

def list_to_str(x):
    if isinstance(x, list):
        return "; ".join([str(i) for i in x if str(i).strip()])
    if pd.isna(x):
        return ""
    return str(x)

def flatten_unique_keys(series_of_lists):
    """
    Works for list[str] cells OR already-joined strings.
    Returns semicolon-joined unique keys.
    """
    s = set()
    for ks in series_of_lists:
        if ks is None or (isinstance(ks, float) and pd.isna(ks)):
            continue
        if isinstance(ks, list):
            for k in ks:
                k = str(k).strip()
                if k:
                    s.add(k)
        else:
            for k in str(ks).split(";"):
                k = k.strip()
                if k:
                    s.add(k)
    return "; ".join(sorted(s))

# ---- Sheet 1: firm-level events ----
events_firm = (
    df_out.loc[mask]
    .groupby("EVENT_ID_FIRM", dropna=False)
    .agg(
        start=("PUB_DATE", "min"),
        end=("PUB_DATE", "max"),
        duration=("PUB_DATE", lambda x: (x.max() - x.min()).days + 1 if x.notna().any() else ""),
        n_articles=("title", "count"),
        firms=("ORG_KEYS_FILTERED", flatten_unique_keys),
        employers=("EMPLOYER_KEYS", flatten_unique_keys),
        unions=("UNION_KEYS", flatten_unique_keys),
    )
    .reset_index()
)

for c in ["start", "end"]:
    events_firm[c] = pd.to_datetime(events_firm[c], errors="coerce")

wb = Workbook()
ws1 = wb.active
ws1.title = "Firm_Level_Strikes"
ws1.append(list(events_firm.columns))

for _, row in events_firm.iterrows():
    ws1.append([list_to_str(v) for v in row.tolist()])

# ---- Sheet 2: article-level mapping ----
ws2 = wb.create_sheet("Articles_By_Firm_Event")
cols = [
    "EVENT_ID_FIRM",
    "PUB_DATE",
    "title",
    "ORG_KEYS_FILTERED",
    "EMPLOYER_KEYS",
    "UNION_KEYS",
    "EVENT_ID",
    "EVENT_RELEVANT",
]
ws2.append(cols)

tmp = df_out.loc[mask, cols].copy()
tmp["PUB_DATE"] = pd.to_datetime(tmp["PUB_DATE"], errors="coerce")
for col in ["ORG_KEYS_FILTERED", "EMPLOYER_KEYS", "UNION_KEYS"]:
    tmp[col] = tmp[col].apply(list_to_str)

for _, r in tmp.iterrows():
    ws2.append([list_to_str(v) for v in r.tolist()])

# Save
path = "firm_level_strikes_7.xlsx"
wb.save(path)
print("Saved:", path)


In [None]:
# %%
# Optional quick sanity sampling
mask_named = (df_linked["EVENT_PRED_CB"]==1) & (df_linked["EVENT_ID_FIRM"].notna()) & (~df_linked["EVENT_ID_FIRM"].str.endswith("_UNK", na=False))
df_linked.loc[mask_named].groupby("EVENT_ID_FIRM").head(2)[["EVENT_ID_FIRM","PUB_DATE","title"]].sample(20, random_state=42)
