In [None]:
import gradio as gr
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
import networkx as nx
import hashlib
from datetime import datetime
from groq import Groq
import warnings
import random
from io import BytesIO
from PIL import Image
import plotly.express as px
import logging
import time
import json
import re
from diffusers import StableDiffusionPipeline
from torch import nn
import torch.nn.functional as F
from gtts import gTTS
import os

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

warnings.filterwarnings('ignore')

# Configuration
MODEL_NAME = "bert-base-uncased"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASS_NAMES = ["Safe", "Suspicious", "Malicious"]

# Light theme colors
SIMCI_COLORS = {
    "bg_light": "#F5F5F5",
    "text": "#1F2937",
    "accent": "#2DD4BF",
    "warning": "#EF4444",
    "risk_high": "#EF4444",
    "risk_medium": "#F59E0B",
    "comment": "#6B7280",
    "string": "#10B981",
    "number": "#8B5CF6",
    "comparison": "#3B82F6"
}

# Initialize Groq API Client
groq_client = Groq(api_key="")

# Load BERT model for SQLi detection
def load_models():
    bert_tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
    bert_model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
    bert_model.to(DEVICE)
    return bert_tokenizer, bert_model

bert_tokenizer, bert_model = load_models()

# SQL Injection Detection
def predict_sqli(query):
    if not query or not query.strip():
        return "Safe", {"Safe": 1.0, "Suspicious": 0.0, "Malicious": 0.0}

    inputs = bert_tokenizer(query, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
    with torch.no_grad():
        outputs = bert_model(**inputs)
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=-1)
    pred_class = torch.argmax(probs).item()
    return CLASS_NAMES[pred_class], {CLASS_NAMES[i]: float(probs[0][i]) for i in range(3)}

# Community Pattern Library
# Community Pattern Library
COMMUNITY_PATTERNS = {
    "hex_encoded": {"pattern": r"0x[0-9a-fA-F]+", "desc": "Hex-encoded payload", "severity": "medium"},
    "nested_comment": {"pattern": r"/\*\*/", "desc": "Nested comment bypass", "severity": "high"},
    "base64_payload": {"pattern": r"[A-Za-z0-9+/=]{20,}", "desc": "Base64-encoded injection", "severity": "high"},
    "inline_comment": {"pattern": r"#", "desc": "Inline comment to truncate query", "severity": "medium"},
    "boolean_injection": {"pattern": r"1=1", "desc": "Boolean-based tautology", "severity": "high"},
    "time_delay": {"pattern": r"(SLEEP|BENCHMARK|WAITFOR)", "desc": "Time-based delay injection", "severity": "high"},
    "out_of_band": {"pattern": r"(DNSLOG|HTTPLOG)", "desc": "Out-of-band data exfiltration", "severity": "critical"},
    "error_based": {"pattern": r"(CONVERT|CAST)", "desc": "Error-based injection", "severity": "high"},
    "stacked_query": {"pattern": r";[ ]*(SELECT|INSERT|UPDATE|DELETE)", "desc": "Stacked query injection", "severity": "critical"},
    "encoded_space": {"pattern": r"%20", "desc": "URL-encoded space bypass", "severity": "medium"}
}

COMMUNITY_PATTERN_DB = list(COMMUNITY_PATTERNS.items())


# GAN Implementation for Adversarial Query Generation
class Generator(nn.Module):
    def __init__(self, input_dim=100, hidden_dim=256, output_dim=50):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_dim=50, hidden_dim=256):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Initialize GAN
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
criterion = nn.BCELoss()


# Pattern Injection Visualization (3D Scatter Plot)
def visualize_pattern_bar(query):
    if not query or not query.strip():
        return None

    tokens = []
    current_token = ""
    in_string = False

    for char in query:
        if char in ["'", '"']:
            in_string = not in_string
            current_token += char
        elif char.isspace() and not in_string:
            if current_token:
                tokens.append(current_token)
                current_token = ""
        else:
            current_token += char

    if current_token:
        tokens.append(current_token)

    token_types = []
    colors = []
    for token in tokens:
        upper_token = token.upper()
        if upper_token in ["SELECT", "FROM", "WHERE", "JOIN", "UNION", "GROUP BY", "ORDER BY"]:
            token_types.append("Keyword")
            colors.append("#00B7EB")  # Cyberpunk blue
        elif upper_token in ["AND", "OR", "NOT"]:
            token_types.append("Operator")
            colors.append("#FF007A")  # Cyberpunk pink
        elif any(c in token for c in ["'", '"']):
            token_types.append("String")
            colors.append("#00FF9F")  # Cyberpunk green
        elif any(c in token for c in ["--", "/*", "*/"]):
            token_types.append("Comment")
            colors.append("#6B7280")
        elif any(c in token for c in ["=", "<", ">", "!="]):
            token_types.append("Comparison")
            colors.append("#FFB800")  # Cyberpunk yellow
        elif token.isdigit():
            token_types.append("Number")
            colors.append("#9B00FF")  # Cyberpunk purple
        else:
            token_types.append("Identifier")
            colors.append("#FFFFFF")

    frames = []
    for i in range(len(tokens)):
        frame_data = []
        for j, (token, color, ttype) in enumerate(zip(tokens[:i+1], colors[:i+1], token_types[:i+1])):
            frame_data.append(go.Scatter3d(
                x=[j], y=[0], z=[random.uniform(-0.5, 0.5)],
                mode="markers+text",
                text=[token],
                marker=dict(size=12, color=color, symbol="circle", line=dict(width=2, color="#1E1E1E"), opacity=0.9),
                textfont=dict(size=14, color=color, family="Inter"),
                hoverinfo="text",
                hovertext=f"<b>Token</b>: {token}<br><b>Type</b>: {ttype}<br><b>Position</b>: {j+1}",
                legendgroup=ttype,
                showlegend=True,
                name=ttype
            ))
        frames.append(go.Frame(data=frame_data, name=f"frame{i}"))

    injection_points = [i for i, token in enumerate(tokens)
                       if any(x in token.upper() for x in ["'", "--", ";", "UNION", "OR 1=1", "XP_", "EXEC"])]

    fig = go.Figure(data=[
        go.Scatter3d(
            x=[0], y=[0], z=[0],
            mode="markers+text",
            text=[tokens[0]] if tokens else [],
            marker=dict(size=12, color=colors[0] if colors else "#FFFFFF", symbol="circle", line=dict(width=2, color="#1E1E1E")),
            textfont=dict(size=14, color=colors[0] if colors else "#FFFFFF", family="Inter"),
            hoverinfo="text",
            hovertext=f"<b>Token</b>: {tokens[0]}<br><b>Type</b>: {token_types[0]}<br><b>Position</b>: 1" if tokens else "",
            legendgroup=token_types[0] if token_types else "",
            showlegend=True,
            name=token_types[0] if token_types else ""
        )
    ], frames=frames)

    for point in injection_points:
        fig.add_trace(go.Scatter3d(
            x=[point], y=[0], z=[0],
            mode="markers",
            marker=dict(size=15, color="#FF0000", symbol="x", opacity=0.8, line=dict(width=2, color="#1E1E1E")),
            hoverinfo="text",
            hovertext="Injection Risk",
            showlegend=False
        ))

    fig.update_layout(
        title=dict(text="SQL Query Token Analysis (3D)", font=dict(size=22, color="#FFFFFF", family="Inter"), x=0.5, xanchor="center"),
        scene=dict(
            xaxis=dict(showticklabels=False, zeroline=False, backgroundcolor="rgba(0,0,0,0)"),
            yaxis=dict(showticklabels=False, zeroline=False, backgroundcolor="rgba(0,0,0,0)"),
            zaxis=dict(showticklabels=False, zeroline=False, backgroundcolor="rgba(0,0,0,0)"),
            bgcolor="#1E1E1E"
        ),
        plot_bgcolor="rgba(30,30,30,0.95)",
        paper_bgcolor="rgba(30,30,30,0.95)",
        height=600,
        margin=dict(l=40, r=40, t=80, b=40),
        showlegend=True,
        legend=dict(x=0.85, y=0.95, bgcolor="rgba(30,30,30,0.8)", bordercolor="#FFFFFF", borderwidth=1, font=dict(color="#FFFFFF")),
        updatemenus=[{
            "buttons": [
                {
                    "args": [None, {"frame": {"duration": 300, "redraw": True}, "fromcurrent": True}],
                    "label": "Play",
                    "method": "animate"
                },
                {
                    "args": [[None], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
                    "label": "Pause",
                    "method": "animate"
                }
            ],
            "direction": "left",
            "pad": {"r": 10, "t": 87},
            "showactive": False,
            "type": "buttons",
            "x": 0.1,
            "xanchor": "right",
            "y": 0,
            "yanchor": "top"
        }],
        scene_camera=dict(eye=dict(x=1.5, y=1.5, z=0.5))
    )
    return fig

# Query Mutation for DNA Analysis
def visualize_mutation_sunburst(query):
    logger.debug(f"Starting visualize_mutation_sunburst with query: {query}")

    # Hardcoded data to ensure plot displays
    labels = ["Original", "Mutation 1", "Mutation 2", "Mutation 3", "Mutation 4", "Mutation 5"]
    risk_scores = [10, 20, 30, 40, 50, 60]
    logger.debug(f"Using placeholder data - Labels: {labels}, Risk Scores: {risk_scores}")

    # Create simple line plot
    fig = go.Figure(data=[
        go.Scatter(
            x=labels,
            y=risk_scores,
            mode="lines+markers",
            marker=dict(size=8, color="red"),  # Simplified color
            line=dict(color="blue", width=2),  # Simplified color
            name="Risk Scores"
        )
    ])

    # Minimal layout to avoid rendering issues
    fig.update_layout(
        title="Mutation Risk Scores",
        xaxis=dict(title="Query", tickangle=45, tickfont=dict(size=10)),
        yaxis=dict(title="Risk Score (%)", range=[0, 100], tickfont=dict(size=10)),
        plot_bgcolor="white",
        paper_bgcolor="white",
        font=dict(size=10),
        height=300,  # Reduced height for faster rendering
        showlegend=False,
        margin=dict(l=20, r=20, t=50, b=50)
    )

    # Explicitly set Plotly renderer for Gradio compatibility
    fig.update_layout(template="plotly_white")
    logger.debug("Plot configured successfully")
    return fig

# Simplified Threat Impact Visualization (Bar Plot)
def visualize_threat_impact(query):
    if not query or not query.strip():
        return None

    patterns = {k: v for k, v in COMMUNITY_PATTERNS.items()}
    severity_counts = {"low": 0.2, "medium": 0.3, "high": 0.5, "critical": 0.6}
    detected_patterns = []

    for pat_name, pat_info in patterns.items():
        try:
            if re.search(pat_info["pattern"], query, re.IGNORECASE):
                detected_patterns.append(pat_name)
                severity_counts[pat_info["severity"]] += 1
        except re.error as e:
            logger.error(f"Invalid regex pattern '{pat_info['pattern']}' for {pat_name}: {str(e)}")

    labels = list(severity_counts.keys())
    values = list(severity_counts.values())
    colors = [
        SIMCI_COLORS["accent"],
        SIMCI_COLORS["risk_medium"],
        SIMCI_COLORS["risk_high"],
        SIMCI_COLORS["warning"]
    ]

    fig = go.Figure(data=[
        go.Scatter(
            x=labels,
            y=values,
            mode="markers+text",
            marker=dict(size=12, color=colors, symbol="circle"),
            text=values,
            textposition="top center",
            hovertemplate="<b>%{x}</b><br>Count: %{y}<extra></extra>"
        )
    ])

    fig.update_layout(
        title=dict(
            text="Threat Severity Distribution",
            font=dict(size=22, color=SIMCI_COLORS["text"], family="Asap"),
            x=0.5,
            xanchor="center"
        ),
        xaxis=dict(
            title="Severity",
            tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])
        ),
        yaxis=dict(
            title="Count",
            tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])
        ),
        plot_bgcolor=SIMCI_COLORS["bg_light"],
        paper_bgcolor=SIMCI_COLORS["bg_light"],
        font=dict(family="Asap", color=SIMCI_COLORS["text"]),
        height=400,
        margin=dict(t=80, l=40, r=40, b=40),
        annotations=[
            dict(
                text=f"Detected {len(detected_patterns)} threats",
                xref="paper",
                yref="paper",
                x=0.5,
                y=-0.1,
                showarrow=False,
                font=dict(size=14, color=SIMCI_COLORS["text"])
            )
        ]
    )
    return fig

# Query Analysis Visualization (Heatmap)
def visualize_sqli_heatmap(query):
    if not query or not query.strip():
        return None

    tokens = query.split()
    risk_scores = []
    for token in tokens:
        if any(p in token.upper() for p in ["'", "--", ";", "UNION", "OR 1=1", "EXEC", "XP_"]):
            risk_scores.append(1.0)
        elif token.upper() in ["SELECT", "FROM", "WHERE"]:
            risk_scores.append(0.3)
        else:
            risk_scores.append(0.1)

    frames = []
    for i in range(len(tokens)):
        frame_scores = risk_scores[:i+1] + [0] * (len(tokens) - i - 1)
        frame_data = [
            go.Heatmap(
                z=[frame_scores],
                x=tokens,
                colorscale=[[0, SIMCI_COLORS["accent"]], [0.3, SIMCI_COLORS["risk_medium"]], [1, SIMCI_COLORS["risk_high"]]],
                showscale=True,
                colorbar=dict(title="Risk Level", titleside="right", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"]))
            )
        ]
        for j, score in enumerate(frame_scores):
            if score == 1.0:
                frame_data.append(go.Scatter(
                    x=[j], y=[0.5], mode="text",
                    text=["⚠️"], textfont=dict(size=16, color=SIMCI_COLORS["risk_high"], family="Asap"),
                    hoverinfo="text", hovertext=f"High Risk: {tokens[j]}",
                    showlegend=False
                ))
        frames.append(go.Frame(data=frame_data, name=f"frame{i}"))

    fig = go.Figure(
        data=[
            go.Heatmap(
                z=[risk_scores],
                x=tokens,
                colorscale=[[0, SIMCI_COLORS["accent"]], [0.3, SIMCI_COLORS["risk_medium"]], [1, SIMCI_COLORS["risk_high"]]],
                showscale=True,
                colorbar=dict(title="Risk Level", titleside="right", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"]))
            )
        ],
        frames=frames
    )

    for j, score in enumerate(risk_scores):
        if score == 1.0:
            fig.add_trace(go.Scatter(
                x=[j], y=[0.5], mode="text",
                text=["⚠️"], textfont=dict(size=16, color=SIMCI_COLORS["risk_high"], family="Asap"),
                hoverinfo="text", hovertext=f"High Risk: {tokens[j]}",
                showlegend=False
            ))

    fig.update_layout(
        title=dict(text="SQL Query Heatmap (Injection Risk)", font=dict(size=22, color=SIMCI_COLORS["text"], family="Asap"), x=0.5, xanchor="center"),
        height=450,
        plot_bgcolor=SIMCI_COLORS["bg_light"],
        paper_bgcolor=SIMCI_COLORS["bg_light"],
        xaxis=dict(tickangle=45, tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])),
        yaxis=dict(showticklabels=False),
        updatemenus=[{
            "buttons": [
                {
                    "args": [None, {"frame": {"duration": 300, "redraw": True}, "fromcurrent": True}],
                    "label": "Play",
                    "method": "animate"
                },
                {
                    "args": [[None], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
                    "label": "Pause",
                    "method": "animate"
                }
            ],
            "direction": "left",
            "pad": {"r": 10, "t": 87},
            "showactive": False,
            "type": "buttons",
            "x": 0.1,
            "xanchor": "right",
            "y": 0,
            "yanchor": "top",
            "font": dict(family="Asap", color=SIMCI_COLORS["text"])
        }],
        sliders=[{
            "steps": [
                {"args": [[f"frame{i}"], {"frame": {"duration": 300, "redraw": True}, "mode": "immediate"}],
                 "label": f"Token {i+1}", "method": "animate"} for i in range(len(tokens))
            ],
            "x": 0.1, "len": 0.9, "currentvalue": {"prefix": "Progress: ", "font": dict(family="Asap", color=SIMCI_COLORS["text"])},
            "pad": {"b": 10, "t": 50},
            "font": dict(family="Asap", color=SIMCI_COLORS["text"])
        }]
    )
    return fig

# Query Sanitization
def sanitize_query(query, db_type="generic"):
    if not query or not query.strip():
        return "", "No query provided", ""

    dangerous_patterns = {
        "'": {"replacement": "''", "desc": "Single quote - escaped"},
        "--": {"replacement": "", "desc": "SQL comment - removed"},
        ";": {"replacement": "", "desc": "Query separator - removed"},
        "/*": {"replacement": "", "desc": "Comment start - removed"},
        "*/": {"replacement": "", "desc": "Comment end - removed"},
        "UNION": {"replacement": "/*UNION*/", "desc": "UNION operator - commented out"},
        "OR 1=1": {"replacement": "/*OR 1=1*/", "desc": "Tautology - commented out"},
        "EXEC": {"replacement": "/*EXEC*/", "desc": "EXEC command - commented out"},
        "XP_": {"replacement": "/*XP_*/", "desc": "Extended procedure - commented out"}
    }
    dangerous_patterns.update(COMMUNITY_PATTERNS)

    sanitized = query
    replacements = []
    for pattern, info in dangerous_patterns.items():
        if isinstance(pattern, str) and pattern.upper() in sanitized.upper():
            sanitized = sanitized.replace(pattern, info["replacement"])
            replacements.append((pattern, info["desc"]))
        elif isinstance(pattern, dict) and re.search(info["pattern"], sanitized, re.IGNORECASE):
            sanitized = re.sub(info["pattern"], info["replacement"], sanitized, flags=re.IGNORECASE)
            replacements.append((pattern, info["desc"]))

    param_query = generate_parameterized_query(query, db_type)

    explanation = f"**Sanitization Report**\nOriginal: {query[:100]}{'...' if len(query) > 100 else ''}\n\n"
    if replacements:
        explanation += "Sanitized patterns:\n" + "\n".join(f"- {p}: {d}" for p, d in replacements)
        explanation += f"\n\n**Sanitized Query**: {sanitized[:100]}{'...' if len(sanitized) > 100 else ''}"
    else:
        explanation += "No dangerous patterns found."

    return sanitized, explanation, param_query

# Parameterized Query
def generate_parameterized_query(query, db_type="generic"):
    tokens = query.split()
    param_query = []
    params = []
    for token in tokens:
        if any(p in token for p in ["'", '"']):
            param_query.append("?")
            params.append(token.strip("'\""))
        else:
            param_query.append(token)
    parameterized = " ".join(param_query)
    explanation = f"**Parameterized Query**: {parameterized}\n**Parameters**: {params}"
    if db_type == "mysql":
        parameterized = parameterized.replace("?", "%s")
    elif db_type == "postgresql":
        parameterized = parameterized.replace("?", "$1")
    return explanation

# Simplified Attack DNA Visualization (Network Graph)
def extract_attack_dna(query):
    if not query or not query.strip():
        return None, None, None

    signature = hashlib.sha256(query.encode()).hexdigest()[:16]
    patterns = {
        "quote_escape": {"active": "'" in query, "severity": "medium"},
        "comment": {"active": "--" in query or "/*" in query, "severity": "medium"},
        "union": {"active": "UNION" in query.upper(), "severity": "high"},
        "stacked": {"active": "; " in query, "severity": "high"},
        "tautology": {"active": "OR 1=1" in query.upper(), "severity": "high"},
        "command_exec": {"active": "EXEC" in query.upper(), "severity": "critical"},
        "xp_cmdshell": {"active": "XP_" in query.upper(), "severity": "critical"},
        "time_delay": {"active": "WAITFOR" in query.upper(), "severity": "high"}
    }
    patterns.update({k: {"active": re.search(v["pattern"], query, re.IGNORECASE) is not None, "severity": v["severity"]} for k, v in COMMUNITY_PATTERNS.items()})

    active_patterns = [p for p, info in patterns.items() if info["active"]]
    severity_colors = {
        "low": SIMCI_COLORS["accent"],
        "medium": SIMCI_COLORS["risk_medium"],
        "high": SIMCI_COLORS["risk_high"],
        "critical": SIMCI_COLORS["warning"]
    }

    # Create network graph
    G = nx.DiGraph()
    G.add_node("Query Signature")
    for pattern in active_patterns:
        G.add_node(pattern)
        G.add_edge("Query Signature", pattern)

    pos = nx.spring_layout(G)
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=1.5, color=SIMCI_COLORS["text"]),
        hoverinfo="none",
        mode="lines"
    )

    node_x = []
    node_y = []
    node_colors = []
    node_text = []
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(node)
        if node == "Query Signature":
            node_colors.append(SIMCI_COLORS["risk_high"])
        else:
            node_colors.append(severity_colors[patterns[node]["severity"]])

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode="markers+text",
        text=node_text,
        textposition="bottom center",
        hovertemplate="<b>%{text}</b><br>Severity: %{customdata}<extra></extra>",
        customdata=["Signature" if n == "Query Signature" else patterns[n]["severity"].title() for n in G.nodes()],
        marker=dict(
            size=15,
            color=node_colors,
            line=dict(width=2, color=SIMCI_COLORS["text"])
        )
    )

    fig = go.Figure(data=[edge_trace, node_trace])
    fig.update_layout(
        title=dict(text="Attack DNA Network Graph", font=dict(size=22, color=SIMCI_COLORS["text"], family="Asap"), x=0.5, xanchor="center"),
        showlegend=False,
        plot_bgcolor=SIMCI_COLORS["bg_light"],
        paper_bgcolor=SIMCI_COLORS["bg_light"],
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        height=400,
        margin=dict(t=80, l=40, r=40, b=40),
        annotations=[
            dict(
                text=f"Patterns Detected: {', '.join(active_patterns) if active_patterns else 'None'}",
                xref="paper",
                yref="paper",
                x=0.5,
                y=-0.1,
                showarrow=False,
                font=dict(family="Asap", size=12, color=SIMCI_COLORS["text"])
            )
        ]
    )
    return signature, active_patterns, fig

# Stable Diffusion
sd_model = None

def load_stable_diffusion_model():
    global sd_model
    if sd_model is None:
        model_id = "CompVis/stable-diffusion-v1-4"
        try:
            logger.info(f"Loading Stable Diffusion model: {model_id}")
            sd_model = StableDiffusionPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )
            sd_model = sd_model.to(DEVICE)
            if torch.cuda.is_available():
                sd_model.enable_attention_slicing()
            logger.info("Stable Diffusion model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load Stable Diffusion model: {str(e)}")
            raise
    return sd_model

def generate_attack_story_image(query, style="comic-strip"):
    if not query or not query.strip():
        logger.warning("No query provided for image generation")
        return None

    pred, _ = predict_sqli(query)
    styles = {
        "comic-strip": "A comic-strip style illustration",
        "cyberpunk": "A cyberpunk-themed digital art",
        "realistic": "A realistic depiction"
    }
    prompt = (
        f"{styles.get(style, styles['realistic'])} of a cybersecurity dashboard detecting a SQL injection attempt with the query '{query[:50]}...'. "
        f"The scene shows a glowing red warning for '{pred}' classification, with charts and alerts on a futuristic interface."
    )
    negative_prompt = "text artifacts,offensive"

    try:
        pipe = load_stable_diffusion_model()
        logger.debug(f"Generating image with prompt: {prompt[:100]}...")
        image = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=30,
            guidance_scale=7.5,
            height=512,
            width=512
        ).images[0]
        logger.info("Image generated successfully")
        return image
    except Exception as e:
        logger.error(f"Error generating image: {str(e)}")
        placeholder = Image.new('RGB', (512, 512), color='black')
        return placeholder

# Attack Story
def generate_attack_story(query):
    if not query or not query.strip():
        return None

    image = generate_attack_story_image(query)
    return image

# Auto-explanation Engine
def generate_auto_explanation(query):
    pred, conf = predict_sqli(query)
    patterns = {
        "tautology": "OR 1=1" in query.upper(),
        "union": "UNION" in query.upper(),
        "comment": "--" in query or "/*" in query,
        "hex": "0x" in query.lower(),
        "delay": "SLEEP" in query.upper() or "WAITFOR" in query.upper()
    }
    detected = [k for k, v in patterns.items() if v]

    prompt = f"""
    Explain why the query '{query}' was classified as {pred}.
    Detected patterns: {', '.join(detected) or 'None'}.
    Confidence: {conf}.
    Provide a concise, human-readable explanation in markdown format.
    Include mitigation suggestions.
    """
    try:
        response = groq_client.chat.completions.create(
            model="llama-3.3-70b-versatile",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7,
            max_tokens=300
        )
        explanation = response.choices[0].message.content.strip()
        return explanation
    except Exception as e:
        logger.error(f"Error generating explanation: {str(e)}")
        return "Unable to generate explanation."

# Text-to-Speech for Explanations
def text_to_speech(text):
    try:
        clean_text = re.sub(r'[#*_-]', '', text)
        tts = gTTS(text=clean_text, lang='en')
        audio_file = "explanation.mp3"
        tts.save(audio_file)
        return audio_file
    except Exception as e:
        logger.error(f"Error generating speech: {str(e)}")
        return None

# Natural Language to Secure SQL
def nl_to_secure_sql(nl_query):
    if not nl_query or not nl_query.strip():
        return "", None, "No input provided"

    try:
        prompt = f"""
        You are a cybersecurity expert. Convert the following natural language query into a secure, parameterized SQL query:
        "{nl_query}"
        Return the result in the following format, separated by '###':
        1. The secure SQL query as plain text.
        2. A markdown explanation of the sanitization logic and why the query is secure.
        Example:
        ```
        SELECT * FROM customers WHERE signup_date > ?
        ###
        **Sanitization Logic**
        - Converted natural language to SQL with parameterized query.
        - Used `?` placeholder to prevent injection.
        - Ensured no dangerous keywords like UNION or ; were introduced.
        ```
        """
        response = groq_client.chat.completions.create(
            model="llama-3.3-70b-versatile",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7,
            max_tokens=300
        )
        result = response.choices[0].message.content.strip()

        parts = result.split("###")
        secure_query = parts[0].strip() if len(parts) > 0 else ""
        explanation = parts[1].strip() if len(parts) > 1 else "No explanation provided"

        detected_patterns = []
        for pat_name, pat_info in COMMUNITY_PATTERNS.items():
            try:
                if re.search(pat_info["pattern"], nl_query, re.IGNORECASE):
                    detected_patterns.append(pat_name)
            except re.error as e:
                logger.error(f"Invalid regex pattern '{pat_info['pattern']}' for {pat_name}: {str(e)}")

        nodes = ["Input", "Tokenization", "Parameterization", "Secure Query"]
        node_colors = [SIMCI_COLORS["accent"], SIMCI_COLORS["risk_medium"], SIMCI_COLORS["comparison"], SIMCI_COLORS["risk_high"]]
        if detected_patterns:
            nodes.append("Detected Patterns")
            node_colors.append(SIMCI_COLORS["risk_high"])

        links = [
            {"source": 0, "target": 1, "value": 1, "color": SIMCI_COLORS["accent"]},
            {"source": 1, "target": 2, "value": 1, "color": SIMCI_COLORS["risk_medium"]},
            {"source": 2, "target": 3, "value": 1, "color": SIMCI_COLORS["comparison"]}
        ]
        if detected_patterns:
            links.append({"source": 0, "target": len(nodes)-1, "value": len(detected_patterns), "color": SIMCI_COLORS["risk_high"]})

        fig = go.Figure(go.Sankey(
            node=dict(
                pad=15,
                thickness=20,
                line=dict(color=SIMCI_COLORS["text"], width=0.5),
                label=nodes,
                color=node_colors
            ),
            link=dict(
                source=[link["source"] for link in links],
                target=[link["target"] for link in links],
                value=[link["value"] for link in links],
                color=[link["color"] for link in links]
            )
        ))
        fig.update_layout(
            title=dict(
                text="Query Sanitization Flow",
                font=dict(size=22, color=SIMCI_COLORS["text"], family="Asap"),
                x=0.5,
                xanchor="center"
            ),
            height=400,
            plot_bgcolor=SIMCI_COLORS["bg_light"],
            paper_bgcolor=SIMCI_COLORS["bg_light"],
            font=dict(family="Asap", color=SIMCI_COLORS["text"]),
            annotations=[
                dict(
                    text=f"Patterns Detected: {', '.join(detected_patterns) if detected_patterns else 'None'}",
                    xref="paper",
                    yref="paper",
                    x=0.5,
                    y=-0.1,
                    showarrow=False,
                    font=dict(size=12, color=SIMCI_COLORS["text"])
                )
            ]
        )
        return secure_query, fig, explanation
    except Exception as e:
        return "", None, f"Error: {str(e)}"

# Query Execution Timeline
def query_execution_timeline(query):
    if not query or not query.strip():
        return None

    stages = [
        {"stage": "Parsing", "start": 0, "duration": 2, "desc": "Tokenizing query", "color": SIMCI_COLORS["accent"]},
        {"stage": "Binding", "start": 2, "duration": 1.5, "desc": "Parameter substitution", "color": SIMCI_COLORS["risk_medium"]},
        {"stage": "Execution", "start": 3.5, "duration": 3, "desc": "Running query", "color": SIMCI_COLORS["string"]},
        {"stage": "Output", "start": 6.5, "duration": 1, "desc": "Returning results", "color": SIMCI_COLORS["risk_high"]}
    ]

    fig = go.Figure()
    for stage in stages:
        fig.add_trace(go.Bar(
            x=[stage["duration"]],
            y=[stage["stage"]],
            base=[stage["start"]],
            orientation="h",
            marker=dict(color=stage["color"]),
            text=stage["stage"],
            textposition="auto",
            hovertemplate="<b>%{y}</b><br>Start: %{base}s<br>Duration: %{x}s<br>%{customdata}<extra></extra>",
            customdata=[stage["desc"]]
        ))

    fig.update_layout(
        title=dict(text="Query Execution Timeline", font=dict(size=22, color=SIMCI_COLORS["text"], family="Asap"), x=0.5, xanchor="center"),
        xaxis=dict(title="Time (seconds)", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])),
        yaxis=dict(title="Stage", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])),
        plot_bgcolor=SIMCI_COLORS["bg_light"],
        paper_bgcolor=SIMCI_COLORS["bg_light"],
        font=dict(family="Asap", color=SIMCI_COLORS["text"]),
        barmode="stack",
        height=400
    )
    return fig

# Threat Persona
def generate_threat_persona(query):
    if not query or not query.strip():
        return "Unknown", 0.0, "No query provided"

    patterns = {
        "simple": ["'", "--"],
        "moderate": ["UNION", "OR 1=1"],
        "advanced": ["EXEC", "XP_", ";", "WAITFOR", "CONVERT", "CAST"]
    }
    patterns.update({k: [v["pattern"]] for k, v in COMMUNITY_PATTERNS.items()})

    complexity_score = 0
    detected_patterns = []
    for level, pats in patterns.items():
        for p in pats:
            try:
                if re.search(p, query, re.IGNORECASE):
                    detected_patterns.append(p)
                    if level == "simple":
                        complexity_score += 0.2
                    elif level == "moderate":
                        complexity_score += 0.5
                    else:
                        complexity_score += 0.8
            except re.error as e:
                logger.error(f"Invalid regex pattern '{p}' in threat persona: {str(e)}")
                continue

    complexity_score = min(complexity_score, 1.0)
    if complexity_score < 0.3:
        persona = "Script Kiddie"
    elif complexity_score < 0.7:
        persona = "Botnet Scanner"
    else:
        persona = "Expert Exploiter"

    explanation = f"**Threat Persona Analysis**\n- Patterns: {', '.join(detected_patterns) or 'None'}\n- Complexity Score: {complexity_score:.2f}\n- Persona: {persona}"
    return persona, complexity_score, explanation

# Browser-Based Honeypot
def honeypot_browser_simulation(input_query):
    if not input_query or not input_query.strip():
        return None, "No input provided"

    pred, conf = predict_sqli(input_query)

    pie_fig = go.Figure(data=[
        go.Pie(
            labels=list(conf.keys()),
            values=list(conf.values()),
            textinfo="label+percent",
            marker=dict(colors=[SIMCI_COLORS["accent"], SIMCI_COLORS["risk_medium"], SIMCI_COLORS["risk_high"]], line=dict(color=SIMCI_COLORS["text"], width=2)),
            hoverinfo="label+value",
            hovertemplate="%{label}: %{value:.3f}<extra></extra>",
            pull=[0.1 if k == pred else 0 for k in conf.keys()],
            textfont=dict(family="Asap", size=14, color=SIMCI_COLORS["text"])
        )
    ])
    pie_fig.update_layout(
        title=dict(text="Honeypot Confidence Scores", font=dict(size=16, color=SIMCI_COLORS["text"], family="Asap"), x=0.5, xanchor="center"),
        plot_bgcolor=SIMCI_COLORS["bg_light"],
        paper_bgcolor=SIMCI_COLORS["bg_light"],
        height=300,
        margin=dict(l=20, r=20, t=50, b=20),
        showlegend=True,
        legend=dict(x=0.85, y=0.95, bgcolor=SIMCI_COLORS["bg_light"], bordercolor=SIMCI_COLORS["text"], borderwidth=1, font=dict(family="Asap", color=SIMCI_COLORS["text"])),
        annotations=[
            dict(
                text=f"Predicted: {pred}",
                x=0.5, y=-0.1, xref="paper", yref="paper",
                showarrow=False, font=dict(family="Asap", size=12, color=SIMCI_COLORS["text"])
            )
        ]
    )

    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    timeline_fig = go.Figure()
    timeline_fig.add_trace(go.Scatter(
        x=[timestamp],
        y=[pred],
        mode="markers+text",
        marker=dict(size=15, color=SIMCI_COLORS["risk_high"] if pred == "Malicious" else SIMCI_COLORS["accent"]),
        text=[pred],
        textposition="top center",
        hovertemplate="<b>%{y}</b><br>Time: %{x}<br>Query: %{customdata}<extra></extra>",
        customdata=[input_query]
    ))
    timeline_fig.update_layout(
        title=dict(text="Honeypot Attack Timeline", font=dict(size=16, color=SIMCI_COLORS["text"], family="Asap"), x=0.5, xanchor="center"),
        xaxis=dict(title="Timestamp", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])),
        yaxis=dict(title="Classification", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])),
        plot_bgcolor=SIMCI_COLORS["bg_light"],
        paper_bgcolor=SIMCI_COLORS["bg_light"],
        height=300
    )

    fig = sp.make_subplots(
        rows=2, cols=1,
        subplot_titles=("Confidence Scores", "Attack Timeline"),
        vertical_spacing=0.15,
        specs=[[{"type": "domain"}], [{"type": "xy"}]]
    )
    for trace in pie_fig.data:
        fig.add_trace(trace, row=1, col=1)
    for trace in timeline_fig.data:
        fig.add_trace(trace, row=2, col=1)

    fig.update_layout(
        height=600,
        plot_bgcolor=SIMCI_COLORS["bg_light"],
        paper_bgcolor=SIMCI_COLORS["bg_light"],
        showlegend=True,
        font=dict(family="Asap", color=SIMCI_COLORS["text"]),
        title=dict(
            text="Honeypot Simulation Results",
            font=dict(size=22, color=SIMCI_COLORS["text"], family="Asap"),
            x=0.5,
            xanchor="center"
        )
    )
    fig.update_xaxes(title_text="Timestamp", row=2, col=1, tickfont=dict(family="Asap", color=SIMCI_COLORS["text"]))
    fig.update_yaxes(title_text="Classification", row=2, col=1, tickfont=dict(family="Asap", color=SIMCI_COLORS["text"]))

    return fig, f"**Honeypot Result**: Classified as {pred}"

# Honeypot Simulation
def honeypot_simulation():
    safe_patterns = [
        "SELECT * FROM users WHERE username = ?",
        "INSERT INTO logs (event) VALUES (?)",
        "UPDATE settings SET value = ? WHERE key = ?"
    ]
    malicious_patterns = [
        "admin'--",
        "1' OR '1'='1",
        "1' UNION SELECT username, password FROM users--",
        "1; DROP TABLE users--",
        "0x414243; SLEEP(5)--",
        "1' AND 1=(SELECT COUNT(*) FROM information_schema.tables)--"
    ]

    results = []
    for _ in range(10):
        is_malicious = random.random() < 0.4
        query = random.choice(malicious_patterns if is_malicious else safe_patterns)
        pred, conf = predict_sqli(query)
        results.append({
            "Query": query,
            "Classification": pred,
            "Confidence": f"{max(conf.values()):.2f}",
            "Timestamp": (datetime.now() - pd.Timedelta(minutes=random.randint(0, 60))).strftime("%Y-%m-%d %H:%M:%S"),
            "IP": f"192.168.{random.randint(0, 255)}.{random.randint(0, 255)}"
        })
    df = pd.DataFrame(results)

    classification_counts = df["Classification"].value_counts()
    fig = go.Figure(data=[
        go.Bar(
            x=classification_counts.index,
            y=classification_counts.values,
            marker_color=[SIMCI_COLORS["accent"], SIMCI_COLORS["risk_medium"], SIMCI_COLORS["risk_high"]][:len(classification_counts)],
            text=classification_counts.values,
            textposition="auto",
            hovertemplate="<b>%{x}</b><br>Count: %{y}<extra></extra>"
        )
    ])
    fig.update_layout(
        title=dict(text="Honeypot Simulation Classifications", font=dict(size=22, color=SIMCI_COLORS["text"], family="Asap"), x=0.5, xanchor="center"),
        xaxis=dict(title="Classification", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])),
        yaxis=dict(title="Count", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])),
        plot_bgcolor=SIMCI_COLORS["bg_light"],
        paper_bgcolor=SIMCI_COLORS["bg_light"],
        font=dict(family="Asap", color=SIMCI_COLORS["text"]),
        height=400
    )

    return df, fig

# Enhanced Red Team Simulator
def generate_adversarial_query(attack_type="basic"):
    try:
        prompts = {
            "basic": "Generate a realistic SQL injection attack query with common patterns like tautologies or UNION attacks.",
            "blind": "Generate a blind SQL injection query using conditional responses or time-based delays.",
            "obfuscated": "Generate an obfuscated SQL injection query using hex encoding or nested comments.",
            "error_based": "Generate an error-based SQL injection query exploiting database error messages.",
            "stacked": "Generate a stacked query SQL injection with multiple statements.",
            "advanced_obfuscation": "Generate a highly obfuscated SQL injection query combining multiple encoding techniques."
        }
        prompt = f"""
        You are a cybersecurity red team AI. {prompts.get(attack_type, prompts['basic'])}
        Provide only the query, no explanation.
        """
        if attack_type == "advanced_obfuscation":
            noise = torch.randn(1, 100).to(DEVICE)
            with torch.no_grad():
                gen_output = generator(noise)
            tokens = ["SELECT", "FROM", "WHERE", "users", "OR", "1=1", "--"]
            query = " ".join(random.choices(tokens, k=10))
        else:
            response = groq_client.chat.completions.create(
                model="llama-3.3-70b-versatile",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.9,
                max_tokens=100
            )
            query = response.choices[0].message.content.strip()

        patterns = {
            "tautology": "OR 1=1" in query.upper(),
            "union": "UNION" in query.upper(),
            "comment": "--" in query or "/*" in query,
            "hex": "0x" in query.lower(),
            "delay": "SLEEP" in query.upper() or "WAITFOR" in query.upper(),
            "error": "(CONVERT|CAST)" in query.upper(),
            "stacked": ";" in query,
            "encoded": "%20" in query.lower()
        }
        pattern_names = list(patterns.keys())
        severities = [0.8 if v else 0.2 for v in patterns.values()]
        colors = [SIMCI_COLORS["risk_high"] if v else SIMCI_COLORS["accent"] for v in patterns.values()]

        fig = go.Figure(go.Bar(
            x=pattern_names,
            y=severities,
            marker_color=colors,
            text=[f"{s:.2f}" for s in severities],
            textposition="auto",
            hovertemplate="<b>%{x}</b><br>Severity: %{y:.2f}<extra></extra>"
        ))

        fig.update_layout(
            title=dict(text="Adversarial Query Pattern Complexity", font=dict(size=22, color=SIMCI_COLORS["text"], family="Asap"), x=0.5, xanchor="center"),
            xaxis=dict(title="Pattern", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])),
            yaxis=dict(title="Severity", tickfont=dict(family="Asap", color=SIMCI_COLORS["text"])),
            plot_bgcolor=SIMCI_COLORS["bg_light"],
            paper_bgcolor=SIMCI_COLORS["bg_light"],
            font=dict(family="Asap", color=SIMCI_COLORS["text"]),
            height=400
        )

        return query, fig
    except Exception as e:
        return f"Error generating adversarial query: {str(e)}", None

# Voice-to-SQL with CSV Export
def voice_to_sql(audio_input, language="en"):
    if not audio_input:
        return "", None, "No audio input provided", None

    try:
        model_id = "distil-whisper-large-v3-en" if language == "en" else "whisper-large-v3-turbo"
        with open(audio_input, "rb") as file:
            transcription = groq_client.audio.transcriptions.create(
                file=file,
                model=model_id,
                language=language
            )
        text_input = transcription.text
        secure_query, fig, explanation = nl_to_secure_sql(text_input)

        # Export to CSV
        csv_data = pd.DataFrame({
            "Timestamp": [datetime.now().strftime("%Y-%m-d %H:%M:%S")],
            "Transcription": [text_input],
            "Secure_Query": [secure_query],
            "Explanation": [explanation]
        })
        csv_file = f"voice_to_sql_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
        csv_data.to_csv(csv_file, index=False)

        return secure_query, fig, f"**Transcription**: {text_input}\n\n{explanation}", csv_file
    except Exception as e:
        return "", None, f"Error: {str(e)}. Please ensure audio input is valid.", None

# SQLiChatBot
class SQLiChatBot:
    def __init__(self):
        self.conversation_history = []

    def respond(self, message):
        if not message or not message.strip():
            return "Please enter a question about SQL injection security."

        try:
            system_prompt = """
            You are AskSQLiBot, a cybersecurity expert specializing in SQL injection prevention.
            Provide technical yet clear explanations about SQL injection vulnerabilities, mitigation strategies, or query analysis.
            Use analogies or visuals when helpful, suggest secure alternatives, and allow feedback to improve detection.
            Keep responses concise and professional.
            """
            self.conversation_history.append({"role": "user", "content": message})
            response = groq_client.chat.completions.create(
                model="llama-3.3-70b-versatile",
                messages=[
                    {"role": "system", "content": system_prompt},
                    *self.conversation_history[-6:]
                ],
                temperature=0.7,
                max_tokens=500
            )
            bot_response = response.choices[0].message.content
            self.conversation_history.append({"role": "assistant", "content": bot_response})
            return bot_response
        except Exception as e:
            return f"Error: {str(e)}. Please ensure your Groq API key is valid and check https://console.groq.com for details."

sqlibot = SQLiChatBot()

# Attack Simulation Playground
def attack_simulation_playground(query):
    pred, conf = predict_sqli(query)
    sanitized, explanation, param_query = sanitize_query(query)
    return f"**Simulation Result**\n- Prediction: {pred}\n- Confidence: {conf}\n- Sanitized Query: {sanitized}\n- Parameterized: {param_query}\n\n{explanation}"

# Community Pattern Submission
def submit_community_pattern(pattern, desc, severity):
    if not pattern or not desc or not severity:
        return "Please provide pattern, description, and severity.", None
    try:
        re.compile(pattern)
        COMMUNITY_PATTERNS[pattern] = {"pattern": pattern, "desc": desc, "severity": severity.lower()}
        COMMUNITY_PATTERN_DB.append((pattern, {"pattern": pattern, "desc": desc, "severity": severity.lower()}))
        pattern_df = pd.DataFrame(
            [(k, v["desc"], v["severity"]) for k, v in COMMUNITY_PATTERNS.items()],
            columns=["Pattern", "Description", "Severity"]
        )
        return f"Pattern '{pattern}' submitted successfully.", pattern_df
    except re.error as e:
        return f"Invalid regex pattern: {str(e)}", None

# Main Analysis Function
def analyze_query(query, db_type="generic", img_style="comic-strip"):
    if not query or not query.strip():
        return ["Safe", {"Safe": 1.0, "Suspicious": 0.0, "Malicious": 0.0}, None, "", "No query provided", "", "Unknown", None,
                None, "Unknown", 0.0, "No query provided", None, None, "No warnings"]

    warnings_list = []
    try:
        pred_class, probs = predict_sqli(query)
        pattern_bar = visualize_pattern_bar(query)
        heatmap = visualize_sqli_heatmap(query)
        sanitized, sanitize_explanation, param_query = sanitize_query(query, db_type)
        dna_sig, dna_patterns, dna_viz = extract_attack_dna(query)
        persona, threat_score, persona_explanation = generate_threat_persona(query)
        story_image = generate_attack_story(query)
        threat_impact = visualize_threat_impact(query)
        explanation = generate_auto_explanation(query)
        audio_explanation = text_to_speech(explanation)

        return [pred_class, probs, pattern_bar, sanitized, sanitize_explanation, param_query, dna_sig, dna_viz, heatmap, persona,
                threat_score, persona_explanation, threat_impact, story_image, explanation, audio_explanation, "\n".join(warnings_list)]
    except Exception as e:
        warnings_list.append(f"Error in analysis: {str(e)}")
        return ["Safe", {"Safe": 1.0, "Suspicious": 0.0, "Malicious": 0.0}, None, "", "Analysis failed", "", "Unknown", None,
                None, "Unknown", 0.0, "Analysis failed", None, None, "Analysis failed", None, "\n".join(warnings_list)]

# Gradio Interface
with gr.Blocks(
    title="SQLI Guard - Advanced Protection Suite",
    theme='SebastianBravo/simci_css@0.0.2',
    css="""
        .gradio-container {max-width: 1200px !important; font-family: 'Asap', 'ui-sans-serif', sans-serif;}
        @media (max-width: 768px) { .gradio-container {padding: 10px;} }
    """
) as demo:
    gr.Markdown(
        """
        # 🛡️ SQLI Guard - Advanced AI-Powered SQL Injection Protection
        Detect, visualize, simulate, and prevent SQL injection attacks with state-of-the-art AI.
        """
    )

    db_type_selector = gr.Dropdown(label="Database Type", choices=["generic", "mysql", "postgresql"], value="generic")
    img_style_selector = gr.Dropdown(label="Image Style", choices=["comic-strip", "cyberpunk", "realistic"], value="comic-strip")

    with gr.Tabs():
        with gr.TabItem("🔍 Query Analysis"):
            with gr.Row():
                with gr.Column(scale=2):
                    query_input = gr.Textbox(label="Enter SQL Query", lines=3, placeholder="SELECT * FROM users WHERE id = '1'")
                    analyze_btn = gr.Button("Analyze Query")
                    gr.Examples(
                        examples=[
                            "SELECT * FROM users WHERE username = 'admin' AND password = 'password'",
                            "SELECT * FROM products WHERE id = 1; DROP TABLE users--",
                            "admin'--",
                            "1' OR '1'='1",
                            "1' UNION SELECT username, password FROM users"
                        ],
                        inputs=query_input,
                        label="Example SQL Queries"
                    )
                with gr.Column(scale=3):
                    prediction_output = gr.Label(label="Prediction")
                    confidence_output = gr.JSON(label="Confidence Scores")
                    pattern_bar_output = gr.Plot(label="Detected Injection Patterns (3D Scatter)")
                    heatmap_output = gr.Plot(label="Query Heatmap (Injection Risk)")
                    warnings_output = gr.Textbox(label="Warnings", interactive=False, lines=3)

        with gr.TabItem("🛡️ Protection Tools"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Query Sanitization")
                    sanitized_output = gr.Textbox(label="Sanitized Query", interactive=False)
                    param_query_output = gr.Textbox(label="Parameterized Query", interactive=False)
                    sanitize_explanation = gr.Textbox(label="Sanitization Report", interactive=False, lines=5)
                with gr.Column():
                    gr.Markdown("### Attack DNA Analysis")
                    dna_sig_output = gr.Textbox(label="Attack Signature", interactive=False)
                    dna_viz_output = gr.Plot(label="DNA Network Graph")


        with gr.TabItem("🎮 Honeypot Simulation"):
            honeypot_btn = gr.Button("Run Honeypot Simulation")
            honeypot_output = gr.DataFrame(label="Honeypot Results", headers=["Query", "Classification", "Confidence", "Timestamp", "IP"])
            honeypot_viz_output = gr.Plot(label="Classification Summary")

        with gr.TabItem("💬 Security Assistant"):
            gr.ChatInterface(
                fn=lambda msg, history: sqlibot.respond(msg),
                examples=[
                    "What is SQL injection?",
                    "How to prevent SQL injection in Python?",
                    "Analyze this query: SELECT * FROM users WHERE id = 1",
                    "Explain blind SQL injection"
                ],
                title="AskSQLiBot",
                description="Ask about SQL injection prevention, analysis, or best practices."
            )

        with gr.TabItem("🧠 Advanced Features"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Natural Language to Secure SQL")
                    nl_input = gr.Textbox(label="Enter Natural Language Query", placeholder="Get user name and email from customers where signup date is after 2023")
                    nl_btn = gr.Button("Convert to Secure SQL")
                    nl_output = gr.Textbox(label="Secure SQL Query", interactive=False)
                    nl_viz_output = gr.Plot(label="Sanitization Flow")
                    nl_explanation = gr.Textbox(label="Explanation", interactive=False, lines=5)
                with gr.Column():
                    gr.Markdown("### Red Team Simulator")
                    attack_type_selector = gr.Dropdown(label="Attack Type", choices=["basic", "blind", "obfuscated", "error_based", "stacked", "advanced_obfuscation"], value="basic")
                    red_team_btn = gr.Button("Generate Adversarial Query")
                    red_team_output = gr.Textbox(label="Adversarial Query", interactive=False)
                    red_team_viz = gr.Plot(label="Pattern Complexity")
                    red_team_feedback = gr.Textbox(label="Feedback (Success/Fail)", placeholder="Was the query successful? Enter feedback...")

        with gr.TabItem("🔐 Threat Analysis"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Threat Persona")
                    persona_output = gr.Textbox(label="Persona", interactive=False)
                    threat_score_output = gr.Textbox(label="Threat Score", interactive=False)
                    persona_explanation = gr.Textbox(label="Persona Analysis", interactive=False, lines=5)
                with gr.Column():
                    gr.Markdown("### Attack Story")
                    threat_impact_output = gr.Plot(label="Threat Severity Distribution")
                    story_image_output = gr.Image(label="Attack Story Comic (Stable Diffusion)")

        with gr.TabItem("🔍 Auto-explanation"):
            gr.Markdown("### AI-Generated Explanation")
            explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=5)
            audio_explanation_output = gr.Audio(label="Listen to Explanation", interactive=False)

        with gr.TabItem("⏱️ Query Execution Timeline"):
            xray_btn = gr.Button("Analyze Execution Stages")
            xray_output = gr.Plot(label="Query Execution Timeline")

        with gr.TabItem("🌐 Browser Honeypot"):
            honeypot_input = gr.Textbox(label="Simulate Form Input", placeholder="Enter username or malicious input")
            honeypot_browser_btn = gr.Button("Simulate Attack")
            honeypot_browser_viz = gr.Plot(label="Attack Flow and Timeline")
            honeypot_browser_result = gr.Textbox(label="Result", interactive=False)

        with gr.TabItem("🎙️ Voice-to-SQL"):
            voice_input = gr.Audio(label="Record Voice Input", type="filepath")
            voice_language = gr.Dropdown(label="Language", choices=["en", "es", "fr", "de"], value="en")
            voice_btn = gr.Button("Convert to SQL")
            voice_output = gr.Textbox(label="Secure SQL Query", interactive=False)
            voice_viz_output = gr.Plot(label="Sanitization Flow")
            voice_explanation = gr.Textbox(label="Explanation", interactive=False, lines=5)
            voice_csv_output = gr.File(label="Download CSV Report", interactive=False)

        with gr.TabItem("🤝 Community Patterns"):
            gr.Markdown("### Contribute to the Pattern Library")
            pattern_input = gr.Textbox(label="New Pattern", placeholder="Enter pattern (e.g., 0x[0-9a-fA-F]+)")
            pattern_desc = gr.Textbox(label="Description", placeholder="Describe the pattern")
            pattern_severity = gr.Dropdown(label="Severity", choices=["low", "medium", "high", "critical"], value="medium")
            pattern_submit_btn = gr.Button("Submit Pattern")
            pattern_output = gr.Textbox(label="Submission Status", interactive=False)
            pattern_table = gr.DataFrame(label="Existing Patterns", value=pd.DataFrame(
                [(k, v["desc"], v["severity"]) for k, v in COMMUNITY_PATTERNS.items()],
                columns=["Pattern", "Description", "Severity"]
            ))

        with gr.TabItem("🎮 Attack Simulation Playground"):
            playground_input = gr.Textbox(label="Enter Query to Simulate", lines=3, placeholder="SELECT * FROM users WHERE id = '1'")
            playground_btn = gr.Button("Run Simulation")
            playground_output = gr.Textbox(label="Simulation Result", interactive=False, lines=5)

    # Event Handlers
    analyze_btn.click(
        fn=analyze_query,
        inputs=[query_input, db_type_selector, img_style_selector],
        outputs=[
            prediction_output, confidence_output, pattern_bar_output, sanitized_output, sanitize_explanation,
            param_query_output, dna_sig_output, dna_viz_output, heatmap_output, persona_output, threat_score_output,
            persona_explanation, threat_impact_output, story_image_output, explanation_output, audio_explanation_output, warnings_output
        ]
    )

    honeypot_btn.click(
        fn=honeypot_simulation,
        inputs=None,
        outputs=[honeypot_output, honeypot_viz_output]
    )

    nl_btn.click(fn=nl_to_secure_sql, inputs=[nl_input], outputs=[nl_output, nl_viz_output, nl_explanation])

    red_team_btn.click(fn=generate_adversarial_query, inputs=[attack_type_selector], outputs=[red_team_output, red_team_viz])

    xray_btn.click(fn=query_execution_timeline, inputs=[query_input], outputs=xray_output)

    honeypot_browser_btn.click(fn=honeypot_browser_simulation, inputs=[honeypot_input], outputs=[honeypot_browser_viz, honeypot_browser_result])

    voice_btn.click(fn=voice_to_sql, inputs=[voice_input, voice_language], outputs=[voice_output, voice_viz_output, voice_explanation, voice_csv_output])

    playground_btn.click(fn=attack_simulation_playground, inputs=playground_input, outputs=playground_output)

    pattern_submit_btn.click(
        fn=submit_community_pattern,
        inputs=[pattern_input, pattern_desc, pattern_severity],
        outputs=[pattern_output, pattern_table]
    )

try:
    demo.launch(share=True, debug=True)
except Exception as e:
    print(f"Error launching Gradio app: {str(e)}")