In [None]:
"""
phase3_moral_scoring.py

Full Phase 3 pipeline:
- Load Master Moral Vectors (pickle)
- Load cleaned Tamil texts (thirukkural_cleaned.csv, aathichoodi_cleaned.csv)
- Embed texts with IndicSBERT (batched, GPU-aware)
- Compute cosine similarity between each text embedding and each master moral vector
- Expand/save scores, compute dominant moral per line
- Optional: PCA/UMAP 2D visualization and summary stats

Assumptions:
- Master moral vectors are stored in a pickle file as a dict:
    { "care.virtue": np.array([...]), "care.vice": np.array([...]), ... }
- CSVs have a text column: "couplet" for Thirukkural and "text" for Aathichudi.
  If your column names differ, change the constants below.
"""

import os
import pickle
from tqdm import tqdm
import numpy as np
import pandas as pd

# embedding model
from sentence_transformers import SentenceTransformer

# cosine
from numpy.linalg import norm
from numpy import dot

# optional plotting / dimensionality reduction
try:
    from sklearn.decomposition import PCA
    import matplotlib.pyplot as plt
    import umap
    PLOTTING_AVAILABLE = True
except Exception:
    PLOTTING_AVAILABLE = False

# -------------------------
# CONFIG — edit paths if needed
# -------------------------
# Use the master_vectors_all_languages.pkl file from validation notebook
MASTER_VECTOR_PICKLE = "master_vectors_all_languages.pkl"
# CSV paths - now in processedDataTamil folder
THIRUKKURAL_CSV = "processedDataTamil/thirukkural_cleaned.csv"
AATHICHUDI_CSV = "processedDataTamil/aathichoodi_cleaned.csv"
OUTPUT_DIR = "phase3_outputs"
EMBED_BATCH_SIZE = 64   # reduce if memory limited
# Fixed model name to match what we've been using
MODEL_NAME = "l3cube-pune/indic-sentence-similarity-sbert"

# CSV column names (adjust if different)
THIRUKKURAL_TEXT_COL = "couplet"
AATHICHUDI_TEXT_COL = "text"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -------------------------
# Helpers
# -------------------------
def load_master_vectors(pkl_path):
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)
    
    # Check if it's the multi-language format
    if 'tamil' in data and isinstance(data['tamil'], dict):
        print(f"Detected multi-language pickle. Extracting Tamil vectors...")
        data = data['tamil']
    
    # ensure numpy arrays
    for k, v in data.items():
        data[k] = np.array(v, dtype=np.float32)
    print(f"Loaded {len(data)} master moral vectors from {pkl_path}")
    return data

def load_texts(csv_path, text_col):
    df = pd.read_csv(csv_path, encoding="utf-8")
    if text_col not in df.columns:
        raise ValueError(f"Column '{text_col}' not found in {csv_path}. Columns: {df.columns.tolist()}")
    texts = df[text_col].astype(str).tolist()
    return df, texts

def load_model(model_name=MODEL_NAME):
    print("Loading model:", model_name)
    model = SentenceTransformer(model_name)
    print("Model loaded.")
    return model

def embed_texts(model, texts, batch_size=EMBED_BATCH_SIZE):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Embedding batches"):
        batch = texts[i:i+batch_size]
        emb = model.encode(batch, convert_to_numpy=True, show_progress_bar=False)
        embeddings.append(emb)
    embeddings = np.vstack(embeddings)
    return embeddings

def cosine_similarity_matrix(matA, vecsB):
    """
    matA: (n, d) numpy array of embeddings
    vecsB: dict of {label: (d,) }
    returns: DataFrame of shape (n, len(vecsB)) with cosine similarities
    """
    labels = list(vecsB.keys())
    B = np.vstack([vecsB[l] for l in labels])  # (m, d)
    # normalize
    A_norm = matA / np.linalg.norm(matA, axis=1, keepdims=True)
    B_norm = B / np.linalg.norm(B, axis=1, keepdims=True)
    sims = A_norm.dot(B_norm.T)  # (n, m)
    df = pd.DataFrame(sims, columns=labels)
    return df

# -------------------------
# Main processing function
# -------------------------
def process_file(csv_path, text_col, master_vectors, model, output_prefix):
    print(f"\nProcessing {csv_path} ...")
    df, texts = load_texts(csv_path, text_col)
    print(f"Loaded {len(texts)} texts.")

    # 1) Embedding
    embeddings = embed_texts(model, texts)
    print("Embeddings shape:", embeddings.shape)

    # 2) Compute similarity scores
    scores_df = cosine_similarity_matrix(embeddings, master_vectors)
    # attach to original dataframe
    result_df = pd.concat([df.reset_index(drop=True), scores_df.reset_index(drop=True)], axis=1)

    # 3) Dominant moral (argmax)
    moral_cols = scores_df.columns.tolist()
    result_df["dominant_moral"] = scores_df.idxmax(axis=1)
    result_df["dominant_score"] = scores_df.max(axis=1)

    # 4) Save embeddings (optional) — save as .npy to avoid huge csvs
    emb_path = os.path.join(OUTPUT_DIR, f"{output_prefix}_embeddings.npy")
    np.save(emb_path, embeddings)
    print("Saved embeddings to", emb_path)

    # 5) Save result CSV
    csv_out = os.path.join(OUTPUT_DIR, f"{output_prefix}_moral_scores.csv")
    result_df.to_csv(csv_out, index=False, encoding="utf-8")
    print("Saved moral scores to", csv_out)

    # 6) Summary stats
    summary = result_df[moral_cols].mean().sort_values(ascending=False)
    summary_path = os.path.join(OUTPUT_DIR, f"{output_prefix}_moral_summary.csv")
    summary.to_csv(summary_path)
    print("Saved moral summary (means) to", summary_path)
    print("Top foundations by mean score:\n", summary.head(10))

    # 7) Optional: Visualization (PCA or UMAP)
    if PLOTTING_AVAILABLE:
        try:
            # reduce embeddings to 2D (UMAP if available, else PCA)
            reducer = umap.UMAP(n_components=2, random_state=42)
            emb2d = reducer.fit_transform(embeddings)
            plt.figure(figsize=(8,6))
            # color by dominant moral (sample a few if too many categories)
            labels = result_df["dominant_moral"].astype(str)
            unique_labels = labels.unique()
            # map labels to ints
            label2int = {l:i for i,l in enumerate(sorted(unique_labels))}
            colors = [label2int[l] for l in labels]
            plt.scatter(emb2d[:,0], emb2d[:,1], c=colors, s=10, cmap="tab20")
            plt.title(f"UMAP 2D: {output_prefix}")
            # legend (only top 12 to avoid crowd)
            handles = []
            import matplotlib.patches as mpatches
            for l in sorted(unique_labels)[:12]:
                handles.append(mpatches.Patch(color=plt.cm.tab20(label2int[l]%20), label=l))
            plt.legend(handles=handles, bbox_to_anchor=(1.05,1), loc="upper left")
            plt.tight_layout()
            figpath = os.path.join(OUTPUT_DIR, f"{output_prefix}_umap.png")
            plt.savefig(figpath, dpi=300)
            plt.close()
            print("Saved UMAP plot to", figpath)
        except Exception as e:
            print("Visualization failed:", e)
    else:
        print("Plotting libraries not available; skipping visualization.")

    return result_df, embeddings

# -------------------------
# ENTRY POINT
# -------------------------
def main():
    # 0) load master vectors
    if not os.path.exists(MASTER_VECTOR_PICKLE):
        raise FileNotFoundError(f"Master vector pickle not found: {MASTER_VECTOR_PICKLE}")
    master_vectors = load_master_vectors(MASTER_VECTOR_PICKLE)

    # 1) load model
    model = load_model(MODEL_NAME)

    # 2) process Thirukkural
    if os.path.exists(THIRUKKURAL_CSV):
        thiru_df, thiru_emb = process_file(THIRUKKURAL_CSV, THIRUKKURAL_TEXT_COL, master_vectors, model, "thirukkural")
    else:
        print("Thirukkural CSV not found, skipping:", THIRUKKURAL_CSV)
        thiru_df, thiru_emb = None, None

    # 3) process Aathichudi
    if os.path.exists(AATHICHUDI_CSV):
        aath_df, aath_emb = process_file(AATHICHUDI_CSV, AATHICHUDI_TEXT_COL, master_vectors, model, "aathichudi")
    else:
        print("Aathichudi CSV not found, skipping:", AATHICHUDI_CSV)
        aath_df, aath_emb = None, None

    # 4) cross-text comparison: combine means and save
    if thiru_df is not None and aath_df is not None:
        # compute mean per foundation for each text
        moral_cols = list(master_vectors.keys())
        th_mean = thiru_df[moral_cols].mean().rename("thirukkural_mean")
        aa_mean = aath_df[moral_cols].mean().rename("aathichudi_mean")
        compare = pd.concat([th_mean, aa_mean], axis=1)
        compare_path = os.path.join(OUTPUT_DIR, "comparison_means.csv")
        compare.to_csv(compare_path)
        print("Saved cross-text comparison means to", compare_path)

    print("\nAll done. Outputs are in:", OUTPUT_DIR)

if __name__ == "__main__":
    main()


Detected multi-language pickle. Extracting Tamil vectors...
Loaded 10 master moral vectors from master_vectors_all_languages.pkl
Loading model: l3cube-pune/indic-sentence-similarity-sbert
Model loaded.

Processing processedDataTamil/thirukkural_cleaned.csv ...
Loaded 1334 texts.


Embedding batches: 100%|██████████| 21/21 [00:03<00:00,  6.26it/s]
  warn(


Embeddings shape: (1334, 768)
Saved embeddings to phase3_outputs/thirukkural_embeddings.npy
Saved moral scores to phase3_outputs/thirukkural_moral_scores.csv
Saved moral summary (means) to phase3_outputs/thirukkural_moral_summary.csv
Top foundations by mean score:
 sanctity.vice       0.543980
sanctity.virtue     0.531287
authority.virtue    0.502406
loyalty.virtue      0.499667
care.vice           0.490899
care.virtue         0.490188
fairness.vice       0.483029
authority.vice      0.475291
loyalty.vice        0.467730
fairness.virtue     0.466591
dtype: float32
Saved UMAP plot to phase3_outputs/thirukkural_umap.png

Processing processedDataTamil/aathichoodi_cleaned.csv ...
Loaded 111 texts.


Embedding batches: 100%|██████████| 2/2 [00:00<00:00, 13.14it/s]

Embeddings shape: (111, 768)
Saved embeddings to phase3_outputs/aathichudi_embeddings.npy
Saved moral scores to phase3_outputs/aathichudi_moral_scores.csv
Saved moral summary (means) to phase3_outputs/aathichudi_moral_summary.csv
Top foundations by mean score:
 sanctity.virtue     0.452064
sanctity.vice       0.432831
care.virtue         0.428321
authority.virtue    0.424513
loyalty.virtue      0.411359
fairness.virtue     0.407207
fairness.vice       0.402675
care.vice           0.397296
authority.vice      0.396465
loyalty.vice        0.393054
dtype: float32



  warn(


Saved UMAP plot to phase3_outputs/aathichudi_umap.png
Saved cross-text comparison means to phase3_outputs/comparison_means.csv

All done. Outputs are in: phase3_outputs
