In [None]:
import pandas as pd
import numpy as np
import faiss
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModel
from sentence_transformers import SentenceTransformer
import torch

In [None]:
# Load Dataset
train = pd.read_csv("/teamspace/studios/this_studio/LegalTech-Palak/train.csv")
test = pd.read_csv("/teamspace/studios/this_studio/LegalTech-Palak/test.csv")

df = pd.concat([train, test])

In [None]:
# Load QWen3-Embedding-8B Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embed_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-8B")
embed_model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-8B").to(device)
embed_model.eval()

In [None]:
llama_model_id = "/teamspace/studios/this_studio/meta-llama/Meta-Llama-3-8B-Instruct"
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)

In [None]:
# Token-Based Chunking
def tokenize_and_chunk(text, chunk_size=512, overlap=128):
    input_ids = llama_tokenizer.encode(text, truncation=False)
    chunks = []
    for i in range(0, len(input_ids), chunk_size - overlap):
        chunk_ids = input_ids[i:i + chunk_size]
        chunk_text = llama_tokenizer.decode(chunk_ids, skip_special_tokens=True)
        chunks.append(chunk_text)
    return chunks

In [None]:
# Generate Chunks
chunked_texts = []
chunk_metadata = []

for doc_id, row in tqdm(df.iterrows(), total=len(df)):
    input_text = str(row["Input"]) if pd.notna(row["Input"]) else ""
    chunks = tokenize_and_chunk(input_text)
    for i, chunk in enumerate(chunks):
        chunked_texts.append(chunk)
        chunk_metadata.append({
            "doc_id": doc_id,
            "chunk_index": i,
            "title": row.get("Title", f"Case {doc_id}"),
            "original_text": input_text
        })

In [None]:
# Load Precomputed Embeddings
chunk_embeddings = np.load("chunk_embeddings.npy")

# Load FAISS Index
index = faiss.read_index("faiss_index.index")

In [None]:
def get_chunk_cls_embedding(text):
    inputs = embed_tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = embed_model(**inputs)
        cls_embedding = outputs.last_hidden_state[:, 0, :]
    return cls_embedding.squeeze(0).cpu().numpy()

In [None]:
# DO NOT RUN AGAIN
# Qwen Embedding (CLS Token per Chunk)
def get_chunk_cls_embedding(text):
    inputs = embed_tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = embed_model(**inputs)
        cls_embedding = outputs.last_hidden_state[:, 0, :]
    return cls_embedding.squeeze(0).cpu().numpy()

chunk_embeddings = []
for chunk in tqdm(chunked_texts, desc="Embedding chunks"):
    chunk_embeddings.append(get_chunk_cls_embedding(chunk))

chunk_embeddings = np.vstack(chunk_embeddings).astype("float32")

In [None]:
# DO NOT RUN AGAIN
# ========== 7. Build FAISS Index ==========
embedding_dim = chunk_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(chunk_embeddings)

In [None]:
# DO NOT RUN AGAIN
faiss.write_index(index, "faiss_index.index")
index = faiss.read_index("faiss_index.index")

In [None]:
# Retrieve Chunks
def retrieve_chunks(query, top_k=6):
    query_vec = get_chunk_cls_embedding(query).reshape(1, -1)
    _, indices = index.search(query_vec, top_k)
    results = []
    for i in indices[0]:
        results.append({
            "chunk": chunked_texts[i],
            "doc_id": chunk_metadata[i]["doc_id"],
            "title": chunk_metadata[i]["title"],
            "chunk_index": chunk_metadata[i]["chunk_index"],
            "original_text": chunk_metadata[i]["original_text"]
        })
    return results

In [None]:
def truncate_chunks_to_fit_prompt(query_text, retrieved_chunks, max_tokens=4000, reserved_tokens=512):
    static_prompt = "You are a legal assistant AI trained to analyze legal documents.\n\nContext:\n"
    token_budget = max_tokens - len(llama_tokenizer.encode(static_prompt + query_text)) - reserved_tokens

    selected_chunks = []
    total_tokens = 0
    excluded_chunks = []

    for c in retrieved_chunks:
        tokens = llama_tokenizer.encode(c["chunk"])
        if total_tokens + len(tokens) > token_budget:
            excluded_chunks.append(c)
            continue
        selected_chunks.append(c)
        total_tokens += len(tokens)

    return selected_chunks, excluded_chunks


In [None]:
# Build RAG Prompt
def create_rag_prompt(query_text, retrieved_chunks):
    context = "\n\n".join([f"Chunk {i+1} (from {c['title']}):\n{c['chunk']}" for i, c in enumerate(retrieved_chunks)])
    return (
        f"You are a legal assistant AI trained to analyze legal documents.\n\n"
        f"Context:\n{context}\n\n"
        f"New Case:\n{query_text}\n\n"
        f"Task:\n"
        f"1. Predict whether the appeal will be accepted (1) or rejected (0).\n"
        f"2. Identify the most relevant sentence(s) from the chunks.\n"
        f"3. Explain your reasoning briefly (max 2 lines).\n\n"
        f"Output format:\n"
        f"Label: <0 or 1>\n"
        f"Explanation: <brief explanation>"
    )

In [None]:
import pandas as pd
from tqdm import tqdm
import os
import csv

OUTPUT_CSV_PATH = "Qwen_prompts_rag_LLM.csv"

# Prepare a list of inputs (queries) and metadata from df
records = []

print("Generating prompts from truncated chunks...")

for doc_id, row in tqdm(df.iterrows(), total=len(df)):
    query_text = str(row["Input"]) if pd.notna(row["Input"]) else ""
    if not query_text.strip():
        continue

    # Retrieve top-k chunks
    retrieved = retrieve_chunks(query_text, top_k=6)

    # Truncate to fit token budget
    selected_chunks, excluded_chunks = truncate_chunks_to_fit_prompt(query_text, retrieved)

    # Build prompt
    prompt = create_rag_prompt(query_text, selected_chunks)

    # Create record
    records.append({
        "doc_id": doc_id,
        "title": row.get("Title", f"Case {doc_id}"),
        "query": query_text,
        "num_retrieved_chunks": len(retrieved),
        "num_selected_chunks": len(selected_chunks),
        "excluded_chunks": "; ".join([c["chunk"][:100] for c in excluded_chunks]),  # optional preview
        "prompt": prompt,
        "label": row.get("Label", "")
    })

# Save to CSV
df_prompts = pd.DataFrame(records)
df_prompts.to_csv(OUTPUT_CSV_PATH, index=False, quoting=csv.QUOTE_NONNUMERIC)

print(f"Saved {len(df_prompts)} prompts to {OUTPUT_CSV_PATH}")

In [None]:
import pandas as pd

df_prompts = pd.read_csv('/Users/gunjananand/Desktop/Lightning AI-Palak/prompts_rag_LLM.csv')