In [None]:
# ===============================================================
# Colab Bulk Inference — LLaMA‑3.1‑8B + LoRA (hard‑merge)
# Input : CSV "ID;name;text" (UTF‑8‑SIG, semicolon)
# Output: CSV "ID;name;text;pred;pred_short;raw" (UTF‑8‑SIG, semicolon)
# Resume: Wenn Ausgabe existiert, überspringe bereits verarbeitete IDs
# Batching: Slice-basierte Chunks mit .iloc; nach jeder Batch gespeichert
# ===============================================================

# !pip -q install "transformers>=4.41.0" "accelerate>=0.30.0" sentencepiece peft huggingface_hub tqdm

import os, re, time, hashlib, warnings, csv
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

from pathlib import Path
import pandas as pd
import torch
from tqdm import tqdm
from huggingface_hub import login
from peft import PeftModel, PeftConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
transformers.logging.set_verbosity_error()
warnings.filterwarnings("ignore", message=".*generation flags are not valid.*")

# Konfiguration
INPUT_CSV   = "/content/augmentation_prepped_for_inference.csv"
OUTPUT_CSV  = "/content/augmentation_prepped_for_inference_preds_llama.csv"
BATCH_SIZE  = 2500  # ← Slice-Größe; nach jeder Batch gespeichert

HF_REPO_ID  = "YangZexi/llama-3.1-8B-Instruct-stance-lora-v2"
BASE_MODEL  = "meta-llama/Meta-Llama-3.1-8B-Instruct"

MAX_LEN_IN  = 512
MAX_NEW_TOK = 3
DO_SAMPLE   = False
TEMPERATURE = 0.0  # nur verwendet, wenn DO_SAMPLE=True

# Optionaler Login
try:
    from google.colab import userdata
    HF_TOKEN = userdata.get("HF_TOKEN")
except Exception:
    HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

# Helferfunktionen
_CLEAN_RE = re.compile(r'[\"\'\.\,\:\;\!\?\-\—\–\(\)\[\]\{\}]')
def parse_short_to_label(raw_output: str):
    if not isinstance(raw_output, str):
        return "Unklar", ""
    t = _CLEAN_RE.sub(" ", raw_output.strip().lower())
    t = re.sub(r"\s+", " ", t).strip()
    if not t:
        return "Unklar", ""
    first = t.split()[0]
    if first.startswith("zu"):  return "Zustimmung", "Zu"
    if first.startswith("ab"):  return "Ablehnung", "Ab"
    if first.startswith("ne"):  return "Neutral", "Ne"
    if "zu" in t: return "Zustimmung", "Zu"
    if "ab" in t: return "Ablehnung", "Ab"
    if "ne" in t: return "Neutral", "Ne"
    return "Unklar", ""

def pick_attn_impl():
    try:
        _ = torch.backends.cuda.sdp_kernel
        return "sdpa"
    except Exception:
        return "eager"

def tensor_md5(t):
    try:
        return hashlib.md5(t.detach().float().cpu().numpy().tobytes()).hexdigest()
    except Exception:
        return "NA"

def build_messages(name: str, text: str):
    system_msg = {
        "role": "system",
        "content": (
            "Du bist ein Stance-Klassifizierer für politische Tweets. "
            "Kategorisiere die Haltung als genau eine der drei Klassen: Zustimmung, Ablehnung oder Neutral."
        ),
    }
    user_prompt = (
        f"### Aufgabe\n"
        f"Bewerte die Haltung des folgenden Tweets gegenüber \"{name}\".\n\n"
        f"Tweet: {text}\n\n"
        "### Antwortmöglichkeiten:\n"
        "• Zustimmung: explizit/implizit positiv/unterstützend.\n"
        "• Ablehnung: explizit/implizit negativ/kritisch.\n"
        "• Neutral: sachlich/ambivalent/keine erkennbare Haltung.\n"
        "### Ausgabeformat (Kurzform):\n"
        "Gib **genau eines** der folgenden Kürzel zurück (ohne Anführungszeichen, ohne Punkt):\n"
        "Zu\nAb\nNe"
    )
    return [system_msg, {"role": "user", "content": user_prompt}]

# Modell laden & LoRA hard‑mergen
peft_cfg = PeftConfig.from_pretrained(HF_REPO_ID)
if peft_cfg.base_model_name_or_path and peft_cfg.base_model_name_or_path != BASE_MODEL:
    print(f"Adapter base={peft_cfg.base_model_name_or_path} -> überschreibe BASE_MODEL.")
    BASE_MODEL = peft_cfg.base_model_name_or_path

dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    torch_dtype=dtype,
    trust_remote_code=True,
    attn_implementation=pick_attn_impl(),
)

probe_name = next((k for k in base.state_dict().keys() if k.endswith(".weight")), None)
pre_hash = tensor_md5(base.state_dict()[probe_name]) if probe_name else "NA"

model = PeftModel.from_pretrained(base, HF_REPO_ID)
model = model.merge_and_unload(safe_merge=True)
model.eval()

post_hash = tensor_md5(model.state_dict()[probe_name]) if probe_name else "NA"
print("Merge OK" if pre_hash != post_hash else "Merge-Probe unverändert — Adapter prüfen.")

# Daten laden & per ID fortsetzen
in_path  = Path(INPUT_CSV)
out_path = Path(OUTPUT_CSV)

df = pd.read_csv(in_path, sep=";", encoding="utf-8-sig", dtype=str)
df = df[["ID", "name", "text"]].dropna(subset=["ID", "text"]).copy()

processed_ids = set()
if out_path.exists():
    print(f"Fortsetzen: bestehende Ausgabe gefunden -> {out_path}")
    try:
        done = pd.read_csv(out_path, sep=";", encoding="utf-8-sig", dtype=str, usecols=["ID"])
        processed_ids = set(done["ID"].astype(str).tolist())
    except Exception as e:
        print(f"Konnte bestehende Ausgabe nicht lesen: {e}")

remain = df[~df["ID"].astype(str).isin(processed_ids)].reset_index(drop=True)
total  = len(df)
todo   = len(remain)
print(f"Gesamtzeilen: {total} | Bereits erledigt: {total - todo} | Zu erledigen: {todo}")

write_header = not out_path.exists()

# Einzelvorhersage
def predict_one(name: str, text: str):
    messages = build_messages(name, text)
    chat_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    enc = tokenizer(chat_text, return_tensors="pt", truncation=True, max_length=MAX_LEN_IN).to(model.device)

    gen_kwargs = dict(max_new_tokens=MAX_NEW_TOK, do_sample=DO_SAMPLE, pad_token_id=tokenizer.eos_token_id)
    if DO_SAMPLE:
        gen_kwargs["temperature"] = TEMPERATURE

    with torch.no_grad():
        out_ids = model.generate(**enc, **gen_kwargs)

    prompt_len = enc["input_ids"].shape[1]
    raw = tokenizer.decode(out_ids[0][prompt_len:], skip_special_tokens=True).strip()
    label, short = parse_short_to_label(raw)
    return label, short, raw

# Slice-basierte Batchverarbeitung mit .iloc
start_ts = time.time()
for start in range(0, todo, BATCH_SIZE):
    batch = remain.iloc[start:start + BATCH_SIZE].copy()
    if batch.empty:
        continue

    rows = []
    pbar = tqdm(total=len(batch), desc=f"Batch {start//BATCH_SIZE + 1}", dynamic_ncols=True)
    for _, r in batch.iterrows():
        try:
            name = (r["name"] or "").strip()
            text = (r["text"] or "").strip()
            pred, pred_short, raw = predict_one(name, text)
        except Exception as e:
            pred, pred_short, raw = "Unklar", "", f"ERROR: {e}"

        rows.append({
            "ID": str(r["ID"]),
            "name": name,
            "text": text,
            "pred": pred,
            "pred_short": pred_short,
            "raw": raw
        })
        pbar.update(1)
    pbar.close()

    # Speichern dieser Batch sofort (Anhängemodus nach dem ersten Schreiben)
    pd.DataFrame(rows).to_csv(
        out_path,
        sep=";",
        index=False,
        encoding="utf-8-sig",
        mode=("w" if write_header else "a"),
        header=write_header,
        quoting=csv.QUOTE_MINIMAL
    )
    write_header = False

    elapsed = time.time() - start_ts
    done_now = min(start + BATCH_SIZE, todo)
    print(f"Batch {start//BATCH_SIZE + 1} gespeichert -> {out_path.name} | Erledigt: {done_now}/{todo} | {elapsed/60:.1f} min")

print(f"Abgeschlossen. Ausgabe -> {out_path}")

adapter_config.json:   0%|          | 0.00/949 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/336M [00:00<?, ?B/s]

❗ Merge probe unchanged — check adapter.
🔁 Resume: found existing output → /content/augmentation_prepped_for_inference_preds_llama.csv
📦 Total rows: 493704 | Already done: 95000 | To do: 398704


🔍 Batch 1: 100%|██████████| 2500/2500 [03:19<00:00, 12.52it/s]


💾 Saved batch 1 → augmentation_prepped_for_inference_preds_llama.csv | Done: 2500/398704 | 3.3 min


🔍 Batch 2: 100%|██████████| 2500/2500 [03:19<00:00, 12.56it/s]


💾 Saved batch 2 → augmentation_prepped_for_inference_preds_llama.csv | Done: 5000/398704 | 6.6 min


🔍 Batch 3: 100%|██████████| 2500/2500 [03:24<00:00, 12.23it/s]


💾 Saved batch 3 → augmentation_prepped_for_inference_preds_llama.csv | Done: 7500/398704 | 10.1 min


🔍 Batch 4: 100%|██████████| 2500/2500 [03:20<00:00, 12.50it/s]


💾 Saved batch 4 → augmentation_prepped_for_inference_preds_llama.csv | Done: 10000/398704 | 13.4 min


🔍 Batch 5: 100%|██████████| 2500/2500 [03:16<00:00, 12.71it/s]


💾 Saved batch 5 → augmentation_prepped_for_inference_preds_llama.csv | Done: 12500/398704 | 16.7 min


🔍 Batch 6: 100%|██████████| 2500/2500 [03:18<00:00, 12.63it/s]


💾 Saved batch 6 → augmentation_prepped_for_inference_preds_llama.csv | Done: 15000/398704 | 20.0 min


🔍 Batch 7: 100%|██████████| 2500/2500 [03:22<00:00, 12.36it/s]


💾 Saved batch 7 → augmentation_prepped_for_inference_preds_llama.csv | Done: 17500/398704 | 23.3 min


🔍 Batch 8: 100%|██████████| 2500/2500 [03:21<00:00, 12.40it/s]


💾 Saved batch 8 → augmentation_prepped_for_inference_preds_llama.csv | Done: 20000/398704 | 26.7 min


🔍 Batch 9: 100%|██████████| 2500/2500 [03:22<00:00, 12.32it/s]


💾 Saved batch 9 → augmentation_prepped_for_inference_preds_llama.csv | Done: 22500/398704 | 30.1 min


🔍 Batch 10: 100%|██████████| 2500/2500 [03:24<00:00, 12.21it/s]


💾 Saved batch 10 → augmentation_prepped_for_inference_preds_llama.csv | Done: 25000/398704 | 33.5 min


🔍 Batch 11: 100%|██████████| 2500/2500 [03:17<00:00, 12.66it/s]


💾 Saved batch 11 → augmentation_prepped_for_inference_preds_llama.csv | Done: 27500/398704 | 36.8 min


🔍 Batch 12: 100%|██████████| 2500/2500 [03:15<00:00, 12.77it/s]


💾 Saved batch 12 → augmentation_prepped_for_inference_preds_llama.csv | Done: 30000/398704 | 40.0 min


🔍 Batch 13: 100%|██████████| 2500/2500 [03:16<00:00, 12.71it/s]


💾 Saved batch 13 → augmentation_prepped_for_inference_preds_llama.csv | Done: 32500/398704 | 43.3 min


🔍 Batch 14: 100%|██████████| 2500/2500 [03:16<00:00, 12.72it/s]


💾 Saved batch 14 → augmentation_prepped_for_inference_preds_llama.csv | Done: 35000/398704 | 46.6 min


🔍 Batch 15: 100%|██████████| 2500/2500 [03:18<00:00, 12.59it/s]


💾 Saved batch 15 → augmentation_prepped_for_inference_preds_llama.csv | Done: 37500/398704 | 49.9 min


🔍 Batch 16: 100%|██████████| 2500/2500 [03:15<00:00, 12.81it/s]


💾 Saved batch 16 → augmentation_prepped_for_inference_preds_llama.csv | Done: 40000/398704 | 53.2 min


🔍 Batch 17: 100%|██████████| 2500/2500 [03:20<00:00, 12.45it/s]


💾 Saved batch 17 → augmentation_prepped_for_inference_preds_llama.csv | Done: 42500/398704 | 56.5 min


🔍 Batch 18: 100%|██████████| 2500/2500 [03:18<00:00, 12.60it/s]


💾 Saved batch 18 → augmentation_prepped_for_inference_preds_llama.csv | Done: 45000/398704 | 59.8 min


🔍 Batch 19: 100%|██████████| 2500/2500 [03:24<00:00, 12.21it/s]


💾 Saved batch 19 → augmentation_prepped_for_inference_preds_llama.csv | Done: 47500/398704 | 63.2 min


🔍 Batch 20: 100%|██████████| 2500/2500 [03:23<00:00, 12.31it/s]


💾 Saved batch 20 → augmentation_prepped_for_inference_preds_llama.csv | Done: 50000/398704 | 66.6 min


🔍 Batch 21: 100%|██████████| 2500/2500 [03:21<00:00, 12.42it/s]


💾 Saved batch 21 → augmentation_prepped_for_inference_preds_llama.csv | Done: 52500/398704 | 70.0 min


🔍 Batch 22: 100%|██████████| 2500/2500 [03:19<00:00, 12.53it/s]


💾 Saved batch 22 → augmentation_prepped_for_inference_preds_llama.csv | Done: 55000/398704 | 73.3 min


🔍 Batch 23: 100%|██████████| 2500/2500 [03:19<00:00, 12.55it/s]


💾 Saved batch 23 → augmentation_prepped_for_inference_preds_llama.csv | Done: 57500/398704 | 76.6 min


🔍 Batch 24: 100%|██████████| 2500/2500 [03:30<00:00, 11.90it/s]


💾 Saved batch 24 → augmentation_prepped_for_inference_preds_llama.csv | Done: 60000/398704 | 80.1 min


🔍 Batch 25: 100%|██████████| 2500/2500 [03:27<00:00, 12.05it/s]


💾 Saved batch 25 → augmentation_prepped_for_inference_preds_llama.csv | Done: 62500/398704 | 83.6 min


🔍 Batch 26: 100%|██████████| 2500/2500 [03:19<00:00, 12.50it/s]


💾 Saved batch 26 → augmentation_prepped_for_inference_preds_llama.csv | Done: 65000/398704 | 86.9 min


🔍 Batch 27: 100%|██████████| 2500/2500 [03:21<00:00, 12.43it/s]


💾 Saved batch 27 → augmentation_prepped_for_inference_preds_llama.csv | Done: 67500/398704 | 90.3 min


🔍 Batch 28: 100%|██████████| 2500/2500 [03:17<00:00, 12.66it/s]


💾 Saved batch 28 → augmentation_prepped_for_inference_preds_llama.csv | Done: 70000/398704 | 93.6 min


🔍 Batch 29: 100%|██████████| 2500/2500 [03:19<00:00, 12.50it/s]


💾 Saved batch 29 → augmentation_prepped_for_inference_preds_llama.csv | Done: 72500/398704 | 96.9 min


🔍 Batch 30: 100%|██████████| 2500/2500 [03:16<00:00, 12.70it/s]


💾 Saved batch 30 → augmentation_prepped_for_inference_preds_llama.csv | Done: 75000/398704 | 100.2 min


🔍 Batch 31: 100%|██████████| 2500/2500 [03:17<00:00, 12.68it/s]


💾 Saved batch 31 → augmentation_prepped_for_inference_preds_llama.csv | Done: 77500/398704 | 103.5 min


🔍 Batch 32:  60%|██████    | 1506/2500 [02:00<01:18, 12.61it/s]

KeyboardInterrupt: 