<a href="https://colab.research.google.com/github/AnuradhaWatane/pocAI/blob/main/ai_app_poc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# prompt: Create a project to fix Java vulnerabilities using a Retrieval Augmented Generation (RAG) system and a Large Language Model (LLM).

# This project outlines the core components for building a RAG system
# to help fix Java vulnerabilities using an LLM.
# Note: This is a foundational structure. A complete system would require
# a comprehensive vulnerability knowledge base, robust Java code parsing,
# and potentially fine-tuning of the LLM or embeddings.

# Install necessary libraries
# We'll use langchain for orchestration, sentence-transformers for embeddings,
# chromadb as a vector store, and transformers for the LLM (or integrate with an API).
!pip install -q langchain sentence-transformers chromadb transformers accelerate bitsandbytes
!pip install -U langchain-community

# Import necessary libraries
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

# --- Configuration ---
# Define where to store the vector database (optional, can be in-memory)
VECTOR_DB_PATH = "./chroma_db"
# Define the embedding model to use
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
# Define the LLM model to use (choose a model suitable for text generation, potentially fine-tuned)
# Using a smaller model for demonstration, you might need a larger one or an API key for better results.
LLM_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Example small model

# --- Step 1: Prepare the Vulnerability Knowledge Base ---
# In a real application, this would involve collecting data from:
# - Public vulnerability databases (NVD, OSV)
# - Secure coding documentation (OWASP Cheatsheets)
# - Examples of vulnerable and fixed code snippets
# - Security advisories

# For this example, let's create a few dummy text files representing knowledge base entries.
knowledge_base_data = {
    "sql_injection_fix.txt": """
    SQL Injection Prevention:
    Always use Prepared Statements or parameterized queries when interacting with databases in Java.
    Do NOT concatenate user input directly into SQL queries.
    Example vulnerable code:
    String query = "SELECT * FROM users WHERE username = '" + userInput + "'";
    Statement stmt = conn.createStatement();
    ResultSet rs = stmt.executeQuery(query);

    Example fixed code using PreparedStatement:
    String query = "SELECT * FROM users WHERE username = ?";
    PreparedStatement pstmt = conn.prepareStatement(query);
    pstmt.setString(1, userInput);
    ResultSet rs = pstmt.executeQuery();
    """,
    "xss_fix.txt": """
    Cross-Site Scripting (XSS) Prevention in Java Web Applications:
    Sanitize or escape all user-generated content before rendering it in HTML.
    Use libraries like OWASP Java Encoder Project.
    Never directly output raw user input into HTML pages.
    Example vulnerable JSP/Servlet output:
    out.println("<h1>Welcome, " + userName + "</h1>");

    Example fixed output using encoding (in JSP):
    out.println("<h1>Welcome, " + fn:escapeXml(userName) + "</h1>");
    """,
    "path_traversal_fix.txt": """
    Path Traversal Prevention in Java:
    Sanitize user input that specifies file paths.
    Use canonical paths to check if the requested path is within an allowed directory.
    Avoid directly using user input in file system operations without validation.
    Example vulnerable code:
    File file = new File("/var/www/app/data/" + userSuppliedFilename);
    readFile(file);

    Example fixed approach:
    File baseDir = new File("/var/www/app/data/");
    File requestedFile = new File(baseDir, userSuppliedFilename);
    if (requestedFile.getCanonicalPath().startsWith(baseDir.getCanonicalPath())) {
        readFile(requestedFile);
    } else {
        # Handle invalid access attempt
    }
    """
}

# Write these to dummy files
if not os.path.exists("knowledge_base"):
    os.makedirs("knowledge_base")

for filename, content in knowledge_base_data.items():
    with open(os.path.join("knowledge_base", filename), "w") as f:
        f.write(content.strip())

print("Created dummy knowledge base files.")

# --- Step 2: Load and Process Knowledge Base Documents ---
# Load documents from the directory
loader = TextLoader(os.path.join("knowledge_base", list(knowledge_base_data.keys())[0]))
documents = loader.load()

# Load other documents
for filename in list(knowledge_base_data.keys())[1:]:
    loader = TextLoader(os.path.join("knowledge_base", filename))
    documents.extend(loader.load())

# Split documents into smaller chunks (optional but often helps retrieval)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
split_documents = text_splitter.split_documents(documents)

print(f"Loaded {len(documents)} documents and split into {len(split_documents)} chunks.")

# --- Step 3: Create Embeddings and Vector Store ---
# Initialize the embedding model
print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL_NAME)
print("Embedding model loaded.")

# Create the Chroma vector store from the documents and embeddings
# This will embed the documents and store them.
print(f"Creating vector store in {VECTOR_DB_PATH}...")
db = Chroma.from_documents(split_documents, embeddings, persist_directory=VECTOR_DB_PATH)
print("Vector store created and populated.")

# To load an existing database instead:
# db = Chroma(persist_directory=VECTOR_DB_PATH, embedding_function=embeddings)

# --- Step 4: Set up the LLM and Retrieval Chain ---
# Initialize the LLM
# We use HuggingFacePipeline to wrap a model from the Hugging Face Hub.
# Ensure you have enough RAM/GPU memory for the chosen model.
# You might need to configure quantization (load_in_8bit/load_in_4bit) for larger models.

print(f"Loading LLM model: {LLM_MODEL_NAME}")
try:
    tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME,
        torch_dtype=torch.float16, # Use float16 for efficiency if possible
        load_in_8bit=True, # Or load_in_4bit=True for larger models
        device_map="auto" # Automatically determine where to put model layers (GPU/CPU)
    )

    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=512, # Limit the generated output length
        temperature=0.7,
        top_p=0.95,
        repetition_penalty=1.15,
        device_map="auto"
    )

    llm = HuggingFacePipeline(pipeline=pipe)
    print("LLM model loaded and pipeline created.")

except Exception as e:
    print(f"Error loading LLM model: {e}")
    print("You might need to install `accelerate` and `bitsandbytes` for quantization.")
    print(f"Attempting to load without quantization (may fail if model is too large)...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
        model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME, device_map="auto")
        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.95,
            repetition_penalty=1.15,
            device_map="auto"
        )
        llm = HuggingFacePipeline(pipeline=pipe)
        print("LLM model loaded without quantization.")
    except Exception as e2:
        print(f"Failed to load LLM without quantization: {e2}")
        llm = None # Set LLM to None if loading fails completely

if llm:
    # Create a retriever from the vector store
    retriever = db.as_retriever(search_kwargs={"k": 3}) # Retrieve top 3 relevant documents

    # Create the Retrieval-Augmented Generation chain
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff", # 'stuff' combines all retrieved docs into one prompt
        retriever=retriever,
        return_source_documents=True # Optional: return the documents used by the LLM
    )

    print("RAG chain created.")

    # --- Step 5: Query the System ---

    # Example Query: Ask for help fixing a SQL Injection vulnerability
    query = "How do I fix a SQL Injection vulnerability in Java?"

    print(f"\nQuery: {query}")
    response = qa_chain({"query": query})

    print("\n--- Response ---")
    print(response['result'])

    if 'source_documents' in response:
        print("\n--- Source Documents ---")
        for i, doc in enumerate(response['source_documents']):
            print(f"Document {i+1}:")
            print(f"  Source: {doc.metadata.get('source', 'N/A')}")
            # print(f"  Content: {doc.page_content[:200]}...") # Print snippet of content
            print("-" * 20)

    # Example Query 2: Ask for help with a specific code snippet (requires Java parsing)
    # Note: The current setup doesn't parse Java code directly. You would typically
    # feed a description of the vulnerability found in the code or the code itself
    # as the query, and the RAG system would retrieve relevant fixes from the KB.
    # Integrating a static analysis tool here would be beneficial.

    # query_code_snippet = """
    # public class Insecure {
    #     public void processRequest(HttpServletRequest request, HttpServletResponse response) {
    #         String username = request.getParameter("username");
    #         String sql = "SELECT * FROM users WHERE username = '" + username + "'"; # Vulnerable line
    #         # ... execute sql ...
    #     }
    # }
    # """
    # query_code_desc = "Help me fix the SQL Injection vulnerability in the provided Java code snippet."
    # print(f"\nQuery: {query_code_desc}")
    # # In a real system, you might analyze the code, identify the vulnerability type (SQL Injection),
    # # and formulate a query like "How to fix SQL Injection in Java?"
    # # or combine the analysis results with the query.
    # response_code = qa_chain({"query": "How to fix SQL Injection in Java code? Provide a code example."})
    # print("\n--- Response (Code Fix) ---")
    # print(response_code['result'])


else:
    print("\nLLM did not load successfully. Cannot run the RAG chain.")

# --- Step 6: Further Development Ideas ---
# - Integrate a Java static analysis tool (e.g., using pylspclient to interface with an LSP server like Eclipse JDT Language Server, or using external tools).
# - Automatically extract vulnerability type and location from code analysis results.
# - Improve the knowledge base with more detailed examples, specific framework vulnerabilities, etc.
# - Use a larger, more capable LLM (consider using paid APIs like OpenAI GPT, Anthropic Claude, or Google Gemini if budget allows and terms are acceptable).
# - Fine-tune the embedding model or LLM on Java security datasets.
# - Implement a user interface (e.g., Gradio or a simple web app) for interaction.
# - Add logic to apply fixes or suggest code modifications directly.
# - Include validation steps to check if the suggested fix actually addresses the vulnerability.

Created dummy knowledge base files.
Loaded 3 documents and split into 3 chunks.
Loading embedding model: sentence-transformers/all-MiniLM-L6-v2


  embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL_NAME)
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Embedding model loaded.
Creating vector store in ./chroma_db...
Vector store created and populated.
Loading LLM model: TinyLlama/TinyLlama-1.1B-Chat-v1.0


tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Error loading LLM model: Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` requires `accelerate`. You can install it with `pip install accelerate`
You might need to install `accelerate` and `bitsandbytes` for quantization.
Attempting to load without quantization (may fail if model is too large)...
Failed to load LLM without quantization: Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` requires `accelerate`. You can install it with `pip install accelerate`

LLM did not load successfully. Cannot run the RAG chain.
