In [None]:
import os
import re
import time
import pandas as pd
from tqdm import tqdm
from dotenv import load_dotenv
import google.genai as genai

# Load all keys
load_dotenv(override=True)
API_KEYS = [
    os.getenv("GEMINI_KEY_5"),
    os.getenv("GEMINI_KEY_6"),
]
MODEL_ID = "models/gemini-2.5-flash"

# Initialize first client
key_index = 0
client = genai.Client(api_key=API_KEYS[key_index])

# Utility
def switch_api_key():
    """Switch to next available key. Return False if none left."""
    global key_index, client
    if key_index + 1 < len(API_KEYS):
        key_index += 1
        new_key = API_KEYS[key_index]
        client = genai.Client(api_key=new_key)
        print(f"Switched to API key #{key_index+1}")
        return True
    else:
        print("All API keys exhausted.")
        return False

In [None]:
# Batch classify function
def batch_classify_inflation(texts, max_retries=3):
    joined = "\n\n".join([f"{i+1}. {t[:1000]}" for i, t in enumerate(texts)])
    prompt = f"""
    Kamu adalah analis ekonomi makro yang menilai berita tentang inflasi di Indonesia khususnya pada index Indeks Harga Konsumen (CPI) di Indonesia.

    Untuk setiap teks berita berikut, berikan:
    - Label: Inflation, Deflation, atau Neutral
    - Alasan singkat (1 kalimat) mengapa kamu memilih label itu.

    Format jawaban (gunakan tanda | sebagai pemisah label dan alasan):
    1. Inflation | Kenaikan harga minyak mendorong tekanan inflasi.
    2. Deflation | Penurunan permintaan menyebabkan harga turun.
    3. Neutral | Hanya laporan data tanpa indikasi tekanan harga.

    Teks berita:
    {joined}
    """

    for attempt in range(max_retries):
        try:
            response = client.models.generate_content(
                model=MODEL_ID,
                contents=prompt,
            )

            raw_output = getattr(response, "text", None)
            if raw_output is None and hasattr(response, "candidates"):
                raw_output = response.candidates[0].content.parts[0].text
            if raw_output is None:
                raw_output = str(response)

            matches = re.findall(
                r'^\s*\d+\.\s*([A-Za-z]+)\s*\|\s*(.+)$',
                raw_output,
                flags=re.MULTILINE,
            )

            labels, reasons = [], []
            for label, reason in matches:
                label = label.capitalize()
                if label not in ["Inflation", "Deflation", "Neutral"]:
                    label = "Neutral"
                labels.append(label)
                reasons.append(reason.strip())

            # Ensure batch length consistency
            n = len(texts)
            while len(labels) < n:
                labels.append("Neutral")
                reasons.append("Tidak ada alasan diberikan.")
            if len(labels) > n:
                labels, reasons = labels[:n], reasons[:n]
            return labels, reasons

        except Exception as e:
            err = str(e)
            print(f"Exception in batch_classify_inflation: {type(e).__name__} → {err}")

            if "RESOURCE_EXHAUSTED" in err:
                print(f"API key #{key_index+1} quota exhausted.")
                if not switch_api_key():
                    raise RuntimeError("All keys exhausted – stop job now.")
                else:
                    time.sleep(10)
                    return batch_classify_inflation(texts)  # retry with new key

            elif "503" in err or "UNAVAILABLE" in err:
                wait = 60 * (attempt + 1)
                print(f"Gemini overloaded, waiting {wait}s before retry...")
                time.sleep(wait)
                continue

            else:
                print("Unhandled error, retrying after short delay...")
                time.sleep(5)
                continue

    print("All retries failed. Defaulting to Neutral.")
    n = len(texts)
    return ["Neutral"] * n, ["Server overloaded or quota limit."] * n

In [3]:
def safe_batch_label(df, batch_size=5, checkpoint_path="checkpoint_tempo.csv",
                     delay_sec=6, retry_wait=60, max_retries=3):

    if "label" not in df.columns:
        df["label"] = [None] * len(df)
    if "label_reason" not in df.columns:
        df["label_reason"] = [None] * len(df)

    start_i = 0

    # Resume safely
    if os.path.exists(checkpoint_path):
        print(f"Resuming from checkpoint: {checkpoint_path}")
        df_ckpt = pd.read_csv(checkpoint_path)
        if len(df_ckpt) == len(df):
            for col in ["label", "label_reason"]:
                if col in df_ckpt.columns:
                    df[col] = df_ckpt[col]
        unlabeled_idx = df["label"].isna()
        if unlabeled_idx.any():
            start_i = unlabeled_idx.idxmax()
        else:
            print("All rows already labeled — nothing to resume.")
            return df
        print(f"Resuming from row {start_i}")

    # Main loop
    for i in tqdm(range(start_i, len(df), batch_size)):
        batch = df["clean_text"].iloc[i:i+batch_size].tolist()
        n_batch = len(batch)

        try:
            labels, reasons = batch_classify_inflation(batch)
        except RuntimeError as e:
            print(f"{e}. Saving checkpoint and stopping.")
            tmp_path = checkpoint_path + ".tmp"
            df.to_csv(tmp_path, index=False)
            os.replace(tmp_path, checkpoint_path)
            print(f"Saved progress before stopping at row {i}.")
            return df

        if not labels:
            labels = ["Neutral"] * n_batch
            reasons = ["Failed request"] * n_batch

        # Match batch size
        if len(labels) != n_batch:
            labels = (labels + ["Neutral"] * n_batch)[:n_batch]
        if len(reasons) != n_batch:
            reasons = (reasons + ["Unknown"] * n_batch)[:n_batch]

        df.loc[i:i+batch_size-1, "label"] = labels
        df.loc[i:i+batch_size-1, "label_reason"] = reasons

        # Save checkpoint every 100 rows or at end
        if (i % 100 == 0) or (i + batch_size >= len(df)):
            tmp_path = checkpoint_path + ".tmp"
            df.to_csv(tmp_path, index=False)
            os.replace(tmp_path, checkpoint_path)
            print(f"Checkpoint saved safely at row {i}")

        time.sleep(delay_sec)

    os.makedirs("result", exist_ok=True)
    out_path = "result/df_tempo_labeled.csv"
    df.to_csv(out_path, index=False)
    print(f"Labeling complete → saved to {out_path}")
    return df

In [None]:
# RUN
df_tempo = pd.read_csv("result/df_tempo.csv")
df_tempo_labeled = safe_batch_label(df_tempo)

  0%|          | 0/8 [00:00<?, ?it/s]

Checkpoint saved safely at row 0


 88%|████████▊ | 7/8 [02:31<00:20, 20.40s/it]

Checkpoint saved safely at row 35


100%|██████████| 8/8 [02:45<00:00, 20.64s/it]

Labeling complete → saved to result/df_tempo_labeled.csv



