### VECTOR TABLE CREATION & GEN AI IMPLEMENTATION

This notebook provides an end-to-end system for farm advisory using a combination of:

- Gold-layer aggregated farm data (soil, crop, market, pest, rainfall).  
- Vector-based semantic search for user queries.  
- AI-generated recommendations using **FLAN-T5-Large**.  
- Visualizations of top crops by yield, profitability, rainfall trends, and price heatmaps.  
- Automatic audit logging of user queries and AI responses.


**Features:**  
1. Query by city, crop, or relative time (e.g., "next week", "in 3 days").  
2. Fuzzy matching for cities and crops to handle typos.  
3. AI-generated, concise, actionable recommendations for farmers.  
4. Vector similarity fallback when exact data is unavailable.  
5. Interactive visualizations:
   - Bar charts for top yield and profitability per crop.
   - Line plots for rainfall trends.
   - Heatmaps for monthly crop prices.
6. Audit logs stored in Delta tables for monitoring user queries and system responses.


**Usage:**  
- Run sequentially in a Databricks cluster with CPU/GPU memory sufficient for **FLAN-T5-Large**.  
- Update the `catalog_name`, `schema_name_gold`, and `audit_schema_name` variables to match your workspace.  
- Enter a query in the last cell (`user_query = "..."`) and run to get AI recommendations and visualizations.

In [0]:
%pip install fuzzywuzzy[speedup]

In [0]:
%pip install sentence-transformers

In [0]:
catalog_name = "databricks_free_edition"
schema_name = "databricks_gold"
schema_name_silver = "databricks_silver"
audit_schema_name = "audit_logs"
schema_name_gold = "databricks_gold"

In [0]:

# -----------------------------
# Imports
# -----------------------------
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, DoubleType
from pyspark.sql.functions import col
import pandas as pd, numpy as np, re, time, contextlib, io, traceback
from datetime import datetime, timedelta
from difflib import get_close_matches
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# -----------------------------
# CONFIG (set these before running)
# -----------------------------

vector_table = f"{catalog_name}.{schema_name_gold}.farm_vector_index"
audit_table = f"{catalog_name}.{audit_schema_name}.farm_chat_audit"

# -----------------------------
# INIT SPARK
# -----------------------------
spark = SparkSession.builder.getOrCreate()
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog_name}.{audit_schema_name}")

# -----------------------------
# LOAD VECTOR TABLE
# -----------------------------
pdf = spark.table(vector_table).toPandas()
pdf["date"] = pd.to_datetime(pdf["date"], errors="coerce")

# -----------------------------
# PREPARE EMBEDDINGS MATRIX (fast similarity)
# -----------------------------
emb_matrix = None
if "embedding" in pdf.columns and len(pdf) > 0:
    try:
        emb_matrix = np.stack(pdf["embedding"].apply(lambda x: np.array(x, dtype=np.float32)).values)
    except Exception:
        # attempt eval parse if stored as string
        emb_matrix = np.stack(pdf["embedding"].apply(lambda x: np.array(eval(x), dtype=np.float32)).values)
# If embedding column missing, emb_matrix stays None and fallback will be limited.

# -----------------------------
# LOAD MODELS
# -----------------------------
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# FLAN-T5-LARGE for stable paragraph generation (Option A)
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
gen_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

def generate_text(prompt, max_tokens=300):
    """Deterministic generation with repetition penalty to avoid loops."""
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    outputs = gen_model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        do_sample=False,
        temperature=0.0,
        num_beams=4,
        repetition_penalty=1.2,
        early_stopping=True
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# -----------------------------
# HELPERS: fuzzy, date, vector query
# -----------------------------
def resolve_fuzzy_match(query, choices, cutoff=0.6):
    if not choices:
        return None
    matches = get_close_matches(query.lower(), [c.lower() for c in choices], n=1, cutoff=cutoff)
    return matches[0] if matches else None

def resolve_city_strong(query):
    """Try substring detection (fast), then fuzzy fallback, return original-cased city."""
    q = (query or "").lower()
    cities = pdf["city"].dropna().unique().tolist()
    for c in cities:
        if c.lower() in q:
            return c
    matched_lower = resolve_fuzzy_match(query, cities, cutoff=0.5)
    if matched_lower:
        for orig in cities:
            if orig.lower() == matched_lower:
                return orig
    return None

def extract_relative_date(query):
    q = (query or "").lower()
    today = datetime.now().date()
    if "tomorrow" in q:
        return today + timedelta(days=1)
    if "next week" in q:
        return today + timedelta(days=7)
    m = re.search(r"in (\d+) days?", q)
    if m:
        return today + timedelta(days=int(m.group(1)))
    return None

def query_vector_db_fast(user_query, top_k=5, df=pdf, emb_matrix_local=emb_matrix):
    """Vector similarity fallback using precomputed embeddings matrix for speed."""
    if emb_matrix_local is None or df.empty:
        return df.head(0)
    qv = embed_model.encode(user_query)
    qv = np.array(qv, dtype=np.float32)
    emb_norm = emb_matrix_local / np.linalg.norm(emb_matrix_local, axis=1, keepdims=True)
    qv_norm = qv / np.linalg.norm(qv)
    sims = emb_norm.dot(qv_norm)
    top_idx = np.argsort(-sims)[:top_k]
    return df.iloc[top_idx].copy()

# -----------------------------
# AUDIT TABLE (init + append)
# -----------------------------
def init_audit_table():
    if not spark.catalog.tableExists(audit_table):
        schema = StructType([
            StructField("timestamp", TimestampType(), True),
            StructField("user_query", StringType(), True),
            StructField("user_location", StringType(), True),
            StructField("matched_city", StringType(), True),
            StructField("matched_crop", StringType(), True),
            StructField("answer", StringType(), True),
            StructField("response_time", DoubleType(), True)
        ])
        spark.createDataFrame([], schema=schema).write.saveAsTable(audit_table)

def append_audit_safe(entry: dict):
    row = {
        "timestamp": entry.get("timestamp", datetime.now()),
        "user_query": str(entry.get("user_query") or ""),
        "user_location": str(entry.get("user_location") or ""),
        "matched_city": str(entry.get("matched_city") or ""),
        "matched_crop": str(entry.get("matched_crop") or ""),
        "answer": str(entry.get("answer") or ""),
        "response_time": float(entry.get("response_time") or 0.0)
    }
    sdf = spark.createDataFrame([row])
    sdf.write.mode("append").saveAsTable(audit_table)

init_audit_table()

# -----------------------------
# SENTENCE CLEANER (robust)
# -----------------------------
def clean_and_build_paragraph(raw_text, min_words=30, max_sentences=5):
    raw = (raw_text or "").strip()
    # split by lines first (preserve multi-line)
    lines = [l.strip() for l in raw.split("\n") if l.strip()]
    sentence_candidates = []
    for line in lines:
        parts = re.split(r'(?<=[.!?])\s+', line)
        for p in parts:
            if p.strip():
                sentence_candidates.append(p.strip())
    # dedupe
    seen = set(); final = []
    for s in sentence_candidates:
        k = s.lower()
        if k not in seen:
            final.append(s)
            seen.add(k)
        if len(final) >= max_sentences:
            break
    # fallback to raw if too short
    answer = " ".join(final)
    if len(answer.split()) < min_words:
        # try taking raw but truncating to reasonable length
        if len(raw.split()) > min_words:
            answer = " ".join(raw.split()[:max( min_words, 60 )])  # provide something substantive
        else:
            answer = raw
    return answer

# -----------------------------
# MAIN: ai_response (final corrected)
# -----------------------------
def ai_response(user_query, user_location=None, top_k=5):
    start_t = time.time()

    # detect city & crop
    matched_city = resolve_city_strong(user_query)
    crops = pdf["market_cropName"].dropna().unique().tolist()
    matched_crop = None
    for c in crops:
        if c.lower() in (user_query or "").lower():
            matched_crop = c
            break
    if not matched_crop:
        matched_crop = resolve_fuzzy_match(user_query, crops, cutoff=0.5)

    # initial filter
    df = pdf.copy()
    if matched_city:
        df = df[df["city"].str.lower() == matched_city.lower()]
    if matched_crop:
        df = df[df["market_cropName"].str.lower() == matched_crop.lower()]

    # relative date with +/- window
    qd = extract_relative_date(user_query)
    if qd is not None and not df.empty:
        avail = df["date"].dropna().sort_values()
        if not avail.empty:
            closest = min(avail, key=lambda x: abs(pd.Timestamp(x).date() - qd))
            window = timedelta(days=15)
            df = df[(df["date"] >= (closest - window)) & (df["date"] <= (closest + window))]

    # vector fallback preferring matched city
    if df.empty:
        vec = query_vector_db_fast(user_query, top_k=50, df=pdf, emb_matrix_local=emb_matrix)
        if matched_city and "city" in vec.columns:
            vec_city = vec[vec["city"].str.lower() == matched_city.lower()]
            if not vec_city.empty:
                df = vec_city.head(top_k)
            else:
                df = vec.head(top_k)
        else:
            df = vec.head(top_k)

    # final enforce city filter for plotting clarity
    if matched_city and "city" in df.columns:
        df = df[df["city"].str.lower() == matched_city.lower()]

    top = df.sort_values("yieldPredictionScore", ascending=False).head(top_k)

    # build summary and prompt
    if top.empty:
        answer = "I couldn't find relevant data for that location or crop. Try a different city or crop name."
    else:
        grp = top.groupby("market_cropName").agg({
            "yieldPredictionScore": "mean",
            "profitabilityIndex": "mean",
            "rainfall_rainfallMm": "mean"
        }).reset_index()

        best_crop = grp.loc[grp["profitabilityIndex"].idxmax(), "market_cropName"]
        city_text = matched_city or (user_location or "your area")

        avg_yield = float(grp["yieldPredictionScore"].mean())
        avg_profit = float(grp["profitabilityIndex"].mean())
        avg_rain = float(grp["rainfall_rainfallMm"].mean())

        yield_trend = "increasing" if avg_yield > 30 else "stable" if avg_yield > 20 else "declining"
        rain_comment = ("favorable for most field crops" if avg_rain > 20 else
                        "slightly below ideal levels" if avg_rain > 5 else
                        "very low; irrigation may be needed")

        summary = (
            f"In {city_text}, {best_crop} looks most promising. Average yield ≈ {avg_yield:.2f}, "
            f"profitability ≈ {avg_profit:.2f}, and recent rainfall ≈ {avg_rain:.1f} mm ({rain_comment}). "
            f"Yield trend appears {yield_trend}. Recommend timely sowing, balanced nutrition, and pest monitoring."
        )

        context_lines = []
        for _, r in top.iterrows():
            context_lines.append(
                f"{r['market_cropName']} (city:{r['city']}): Yield={r['yieldPredictionScore']:.2f}, "
                f"Profit={r['profitabilityIndex']:.2f}, Rain={r.get('rainfall_rainfallMm',0):.1f}mm"
            )
        context = "\n".join(context_lines)

        prompt = (
            "You are an expert agricultural advisor. Using the numeric summary and context below, "
            "write a clear 4-sentence recommendation paragraph for farmers in simple language. "
            f"Start by naming the best crop for {city_text}, explain why (mention yield, profit, rainfall), "
            "and end with one practical risk or alternate option.\n\n"
            f"Summary:\n{summary}\n\nContext:\n{context}\n\nQuestion: {user_query}"
        )

        raw = generate_text(prompt, max_tokens=260)
        # clean and build paragraph robustly
        answer = clean_and_build_paragraph(raw, min_words=30, max_sentences=5)

    # audit
    rt = round(time.time() - start_t, 2)
    append_audit_safe({
        "timestamp": datetime.now(),
        "user_query": user_query,
        "user_location": user_location,
        "matched_city": matched_city,
        "matched_crop": matched_crop,
        "answer": answer,
        "response_time": rt
    })

    print(answer)
    return answer, top

# -----------------------------
# PLOTTING HELPERS
# -----------------------------
def plot_top_matches(top_matches):
    if top_matches is None or top_matches.empty:
        print("No data available for plotting.")
        return

    cities = top_matches["city"].dropna().unique()
    if len(cities) > 1:
        main_city = top_matches["city"].mode().iloc[0]
        top_matches = top_matches[top_matches["city"] == main_city]

    plt.figure(figsize=(10,5))
    sns.barplot(x="market_cropName", y="yieldPredictionScore", data=top_matches)
    plt.title(f"Yield Score — {top_matches['city'].iloc[0]}")
    plt.xticks(rotation=30)
    plt.tight_layout()
    display(plt.gcf())

    plt.figure(figsize=(10,5))
    sns.barplot(x="market_cropName", y="profitabilityIndex", data=top_matches)
    plt.title(f"Profitability — {top_matches['city'].iloc[0]}")
    plt.xticks(rotation=30)
    plt.tight_layout()
    display(plt.gcf())

    if "date" in top_matches.columns and top_matches["date"].notna().any():
        plt.figure(figsize=(12,5))
        sns.lineplot(x="date", y="rainfall_rainfallMm", hue="market_cropName", data=top_matches)
        plt.title(f"Rainfall Trend — {top_matches['city'].iloc[0]}")
        plt.xticks(rotation=30)
        plt.tight_layout()
        display(plt.gcf())

def plot_price_heatmap(df):
    if df.empty:
        print("No data for heatmap.")
        return
    d = df.copy()
    d["date"] = pd.to_datetime(d["date"])
    d["month"] = d["date"].dt.strftime("%b")
    heat = (d.groupby(["market_cropName","month"])["crop_cropPricePerQuintal"]
            .mean().reset_index().pivot(index="market_cropName", columns="month", values="crop_cropPricePerQuintal"))
    month_order = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
    cols = [m for m in month_order if m in heat.columns]
    heat = heat.reindex(columns=cols)
    plt.figure(figsize=(14,8))
    sns.heatmap(heat, annot=True, fmt=".1f", linewidths=.5, cmap="YlGnBu")
    plt.title("Average Crop Price per Quintal — By Month")
    plt.tight_layout()
    display(plt.gcf())





In [0]:
user_query = "wich crop shuld i grow in California next week"
answer, top = ai_response(user_query)