In [None]:
## Tahap 4

# =========================================
# TAHAP 4 - SOLUTION REUSE (Prediksi Amar Putusan)
# Menggunakan hasil IndoBERT + Voting
# =========================================

import pandas as pd
import numpy as np
import json
import torch
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict

# === Load Dataset (yang sudah dibersihkan)
df = pd.read_csv('/content/drive/MyDrive/Penalaran Komputer/CSV4/putusan_fidusia_cleaned_FINAL_BERSIH_FIX.csv')
documents = df['text_pdf_cleaned'].astype(str).tolist()
case_ids = df['nomor'].astype(str).tolist()

# === Load IndoBERT
tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1")
model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1")

# === Fungsi Ambil Embedding BERT
@torch.no_grad()
def get_bert_embedding(text):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
    outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

# === Buat embedding semua dokumen (lama, tapi hanya sekali)
doc_embeddings = np.array([get_bert_embedding(doc) for doc in documents])

# === Load queries dari queries.json
query_path = "/content/drive/MyDrive/Penalaran Komputer/data/eval/queries_fidusia.json"
with open(query_path, "r", encoding="utf-8") as f:
    queries_data = json.load(f)

queries = [item["query"] for item in queries_data]

# === Simpan mapping case_id -> amar (solusi)
case_solutions = {}
for i, row in df.iterrows():
    amar_text = str(row.get("amar", "")).lower()
    if not amar_text or amar_text.strip() == "":
        # alternatif: ambil dari text_pdf_cleaned jika kosong
        text = str(row["text_pdf_cleaned"]).lower()
        idx = text.find("m e n g a d i l i")
        case_solutions[row["nomor"]] = text[idx:] if idx != -1 else "amar tidak ditemukan"
    else:
        case_solutions[row["nomor"]] = amar_text

# === Fungsi retrieve dengan IndoBERT
def retrieve_bert(query, k=5):
    q_vec = get_bert_embedding(query).reshape(1, -1)
    sims = cosine_similarity(q_vec, doc_embeddings).flatten()
    top_k_idx = sims.argsort()[-k:][::-1]
    return [(case_ids[i], sims[i]) for i in top_k_idx]

# === Fungsi prediksi amar
def predict_amar(query, use_weighting=True):
    top_k = retrieve_bert(query, k=5)
    votes = defaultdict(float)
    for case_id, sim in top_k:
        amar = case_solutions.get(case_id, "amar tidak ditemukan")
        if use_weighting:
            votes[amar] += sim
        else:
            votes[amar] += 1
    predicted = max(votes.items(), key=lambda x: x[1])[0]
    return predicted, top_k

# === Prediksi untuk semua query dari queries.json
results = []
for i, q in enumerate(queries):
    pred, top_k = predict_amar(q)
    results.append({
        "query_id": i + 1,
        "query": q,
        "predicted_amar": pred,
        "top_5_case_ids": [cid for cid, _ in top_k]
    })

# === Simpan ke CSV
output_path = '/content/drive/MyDrive/Penalaran Komputer/data/results/prediksi_amar_fidusia.csv'
pd.DataFrame(results).to_csv(output_path, index=False)

print("✅ Prediksi amar disimpan di:", output_path)


display(output_path)

display(pd.DataFrame(results))