In [1]:
!pip install chromadb

Collecting chromadb
  Downloading chromadb-1.0.20-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.3 kB)
Collecting pybase64>=1.4.1 (from chromadb)
  Downloading pybase64-1.4.2-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl.metadata (8.7 kB)
Collecting posthog<6.0.0,>=2.4.0 (from chromadb)
  Downloading posthog-5.4.0-py3-none-any.whl.metadata (5.7 kB)
Collecting onnxruntime>=1.14.1 (from chromadb)
  Downloading onnxruntime-1.22.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)
Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb)
  Downloading opentelemetry_exporter_otlp_proto_grpc-1.36.0-py3-none-any.whl.metadata (2.4 kB)
Collecting pypika>=0.48.9 (from chromadb)
  Downloading PyPika-0.48.9.tar.gz (67 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [

In [2]:
import json
import chromadb
import logging
from pydantic import BaseModel, Field
from typing import Optional, List
import uuid
from enum import Enum

In [3]:
class DocType(str, Enum):
    SUMMARY = "summary"
    PRECEDENT = "precedent"
    DEFINITION = "definition"
    LEGAL_TEXT = "legal_text"

In [4]:
chroma_client = chromadb.PersistentClient(path="./chroma_db")
logger = logging.getLogger(__name__)
collection = chroma_client.get_or_create_collection(name="compliance_rules")

class ChunkModel(BaseModel):
    # REMOVE THE __init__ METHOD COMPLETELY
    # Just define the fields as class variables like this:

    id: str = Field(
        default_factory=lambda: f"chunk_{uuid.uuid4().hex[:8]}",
        description="Auto-generated unique identifier for the chunk"
    )
    content: str = Field(..., description="The main text content of the chunk")
    regulation: str = Field(..., description="Regulation this chunk pertains to")
    jurisdiction: str = Field(..., description="Geographic jurisdiction")
    doc_type: DocType = Field(..., description="Type of document")
    keywords: List[str] = Field(default_factory=list, description="List of keywords for retrieval")
    source: Optional[str] = Field(None, description="Original source of the content")

    def to_dict(self) -> dict:
        return {
            "id": self.id,
            "content": self.content,
            "regulation": self.regulation,
            "jurisdiction": self.jurisdiction,
            "doc_type": self.doc_type.value,
            "keywords": self.keywords,
            "source": self.source
        }

    def generateFullText(self) -> str:
        lines = [
            f"REGULATION: {self.regulation}",
            f"JURISDICTION: {self.jurisdiction}",
            f"DOCUMENT TYPE: {self.doc_type.value}",
            f"CONTENT: {self.content}"
        ]
        if self.keywords:
            keywords_str = ", ".join(self.keywords)
            lines.append(f"KEYWORDS: {keywords_str}")
        if self.source:
            lines.append(f"SOURCE: {self.source}")
        lines.append(f"ID: {self.id}")
        return ". ".join(lines)
    @classmethod
    def from_json_dict(cls, json_data: dict) -> "ChunkModel":
        # Use existing ID from JSON or generate new one
        chunk_id = json_data.get('id', f"chunk_{uuid.uuid4().hex[:8]}")
        return cls(
            id=chunk_id,
            content=json_data['content'],
            regulation=json_data['regulation'],
            jurisdiction=json_data['jurisdiction'],
            doc_type=DocType(json_data['doc_type']),
            keywords=json_data.get('keywords', []),
            source=json_data.get('source')
        )

In [5]:
from sentence_transformers import SentenceTransformer
import numpy as np
from typing import List
class EmbeddingGenerator:
    def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)
        self.model_name = model_name
    def generate_embedding(self, chunk: ChunkModel):
      return self.model.encode(chunk.generateFullText())
    def generate_embedding_query(self, query: str):
      return self.model.encode(query)


In [6]:
import json
from typing import List, Dict, Any
from pydantic import BaseModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
class RegulationParser:
    def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
        self.splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )

    def parse(self, file_path: str) -> List[ChunkModel]:
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        documents = data.get("documents", [])
        chunks: List[ChunkModel] = []

        for doc in documents:
            # Split content
            for idx, content_chunk in enumerate(self.splitter.split_text(doc["content"])):
                chunk_data = {
                    "id": f"{doc['id']}_content_{idx}",
                    "content": content_chunk,
                    "regulation": doc["regulation"],
                    "jurisdiction": doc["jurisdiction"],
                    "doc_type": doc["doc_type"],
                    "keywords": doc.get("trigger_keywords", []),
                    "source": doc["id"]
                }
                chunks.append(ChunkModel.from_json_dict(chunk_data))

            # Split key_obligations
            for i, obligation in enumerate(doc.get("key_obligations", [])):
                for j, chunk_text in enumerate(self.splitter.split_text(obligation)):
                    chunk_data = {
                        "content": chunk_text,
                        "regulation": doc["regulation"],
                        "jurisdiction": doc["jurisdiction"],
                        "doc_type": doc["doc_type"],
                        "keywords": doc.get("trigger_keywords", []),
                        "source": doc["id"]
                    }
                    chunks.append(ChunkModel.from_json_dict(chunk_data))

        return chunks


In [7]:
class ChunkManager:
    def __init__(self, db_path: str = "./chroma_db", collection_name: str = "regulations"):
        self.client = chromadb.PersistentClient(path=db_path)
        self.collection = self.client.get_or_create_collection(name=collection_name)

    def add_chunks(self, chunks: List[ChunkModel]):
        self.collection.upsert(
            documents=[c.content for c in chunks],
            ids=[c.id for c in chunks],
            embeddings=[EmbeddingGenerator().generate_embedding(c) for c in chunks],
            metadatas=[{
                "regulation": c.regulation,
                "jurisdiction": c.jurisdiction,
                "doc_type": str(c.doc_type),
                "keywords": ", ".join(c.keywords) if c.keywords else "",
                "source": c.source
            } for c in chunks]
        )

    def query_chunks(self, query_embedding: List[float], k: int = 5):
        try:
            results = self.collection.query(
                query_embeddings=[query_embedding],
                n_results=k,
                include=["documents", "metadatas", "distances", "embeddings"]
            )
            if not results["metadatas"] or not results["metadatas"][0]:
                return {"results": []}

            num_results = len(results["metadatas"][0])

            return {
                "results": [
                    {
                        "metadata": results["metadatas"][0][i],
                        "document": results["documents"][0][i],
                        "distance": results["distances"][0][i],
                        "embedding": results["embeddings"][0][i].tolist() if results["embeddings"] else None
                    }
                    for i in range(num_results)
                ]
            }

        except Exception as e:
            print(f"Error querying chunks: {e}")
            return {"results": []}



    def get_all_chunks(self) -> List[ChunkModel]:
      """Retrieve all chunks from the collection"""
      results = self.collection.get(
          include=["documents", "metadatas", "ids", "embeddings"]
      )
      all_chunks = []

      for doc, meta, cid, embedding in zip(results['documents'], results['metadatas'], results['ids'], results["embeddings"]):
          chunk_data = {
              "id": cid,
              "content": doc,
              **meta,  # spread metadata fields directly
              "embedding": embedding.tolist() if embedding is not None else None
          }
          all_chunks.append(ChunkModel.from_json_dict(chunk_data))

      return all_chunks


In [8]:
import numpy as np

def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """Compute cosine similarity between two vectors."""
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

In [9]:
from typing import List, Optional
import numpy as np

class ChunkComparator:

    def __init__(self, chunk_manager: ChunkManager):
        self.chunk_manager = chunk_manager
        self.embedding_generator = EmbeddingGenerator()

    def query_by_text(self, query_text: str, top_k: int = 5, metadata_filter: Optional[dict] = None) -> List[ChunkModel]:
        query_embedding = self.embedding_generator.model.encode(query_text)
        return self.query_by_embedding(query_embedding, top_k=top_k, metadata_filter=metadata_filter)


    def query_by_embedding(self, query_embedding: List[float], top_k: int = 5, metadata_filter: Optional[dict] = None) -> List[ChunkModel]:
        all_chunks = self.chunk_manager.get_all_chunks()
        if metadata_filter:
            for key, value in metadata_filter.items():
                all_chunks = [c for c in all_chunks if getattr(c, key) == value]

        # Compute cosine similarity for each chunk
        similarities = []
        query_vec = np.array(query_embedding)
        for chunk in all_chunks:
            chunk_vec = np.array(chunk.embedding)  # assumes embedding stored in chunkModel
            sim = cosine_similarity(query_vec, chunk_vec)
            similarities.append((sim, chunk))

        # Sort by similarity descending and return top_k
        similarities.sort(key=lambda x: x[0], reverse=True)
        top_chunks = [chunk for _, chunk in similarities[:top_k]]
        return top_chunks


In [10]:
class RAGEngine:
  def __init__(self, embedding_generator: EmbeddingGenerator, chromaVectorStore: ChunkManager, chunkGenerator: RegulationParser) -> None:
    self.embedding_generator = EmbeddingGenerator()
    self.chromaVectorStore = ChunkManager()
    self.chunkGenerator = RegulationParser()

  def initialize_database(self, filepath: str):
    chunks = self.chunkGenerator.parse(filepath)
    self.chromaVectorStore.add_chunks(chunks)
    print("Chunks added to database...")


  def query_with_context(self, feature_description: str) -> str:
    """
    Simple RAG query that retrieves relevant compliance context and formats a prompt for LLM.
    """
    query_embedding = self.embedding_generator.generate_embedding_query(feature_description)
    query_result = self.chromaVectorStore.query_chunks(query_embedding, k=5)
    relevant_chunks = query_result["results"]
    context = self._build_context(relevant_chunks)
    prompt = f"""
    Analyze this feature description for geo-compliance requirements.
    FEATURE DESCRIPTION:
    {feature_description}
    RELEVANT COMPLIANCE CONTEXT:
    {context}
    Answer these questions:
    1. Does this feature require geo-specific compliance logic? (Yes/No/Maybe)
    2. Why or why not? Provide clear reasoning based on the context.
    3. Which specific regulations apply, if any?

    Format your response as:
    Requires Geo Logic: [Yes/No/Maybe]
    Reasoning: [Your reasoning here]
    Related Regulations: [Comma-separated list or None]
    """
    return prompt

  def _build_context(self, chunks: list) -> str:
    context_lines = []
    #print(type(chunks))
    for i, chunk in enumerate(chunks):
      #print(chunk)
      #print(f"DEBUG: chunk {i} type = {type(chunk)}")
      #print(f"DEBUG: chunk {i} value = {chunk}")
      context_lines.append(f"--- Chunk {i+1} ---")
      context_lines.append(chunk['document'])
      return "\n".join(context_lines)

In [11]:
!pip install transformers accelerate trl[sentencepiece] huggingface_hub --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m544.8/544.8 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [12]:
# Load dataset
DATA_PATH = "compliance_feature_analysis_with_desc.csv"

import pandas as pd
from datasets import load_dataset, Dataset

df = pd.read_csv(DATA_PATH)

df['input'] = df['feature_name'] + ". " + df['feature_description']

df['output'] = (
    "Flag: " + df['flag'] + "\n" +
    "Reasoning: " + df['reasoning'] + "\n")
    # "Related Regulations: "+ df['related_regulations'] + "\n")

# Combine input + output into a single "text" column for SFTTrainer
df['text'] = df['input'] + "\n" + df['output']

# Convert to HuggingFace Dataset and split train/test
dataset = Dataset.from_pandas(df[['text']])
dataset = dataset.train_test_split(test_size=0.2)

In [14]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTTrainer
import torch
from huggingface_hub import login

# Hugging Face login
login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…



tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

In [None]:
# Load LLaMA 3.2 1B Instruct
model_name = "meta-llama/Llama-3.2-1B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    use_auth_token=True,
    device_map="auto",  # or "cpu" if you don't have GPU
)

In [15]:
# Configure LoRA
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "v_proj"],  # Common attention layers
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
# Prepare Trainer
training_args = TrainingArguments(
    output_dir="./llama-geo-ft",
    per_device_train_batch_size=1,      # adjust based on Colab GPU memory
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=True,
    num_train_epochs=3,
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    report_to=[],
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    peft_config=lora_config,
    args=training_args
)

model.train()

# Start Training
trainer.train()




Adding EOS to train dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

Step,Training Loss
10,4.772


TrainOutput(global_step=15, training_loss=4.6529277801513675, metrics={'train_runtime': 9.5405, 'train_samples_per_second': 6.289, 'train_steps_per_second': 1.572, 'total_flos': 23846781874176.0, 'train_loss': 4.6529277801513675, 'entropy': 4.409536123275757, 'num_tokens': 4077.0, 'mean_token_accuracy': 0.2541874282062054, 'epoch': 3.0})

In [None]:

def main():
    """Main function to run the Compliance RAG System."""

    # Initialize components
    print("Initializing RAG System Components...")

    embedding_generator = EmbeddingGenerator()
    chroma_vector_store = ChunkManager()
    regulation_parser = RegulationParser()

    # Create RAG engine
    rag_engine = RAGEngine(
        embedding_generator=embedding_generator,
        chromaVectorStore=chroma_vector_store,
        chunkGenerator=regulation_parser
    )

    # Initialize database with compliance knowledge
    print("Loading compliance knowledge base...")
    rag_engine.initialize_database("./sample_data/compliance_knowledge_base.json")
    print("Database initialized successfully!")

    # Test queries
    test_queries = [
        "We need to add a one-click report button for illegal videos",
        "Add autoplay feature to video feed for all users",
        "Create age verification system for new user signups",
        "Implement content download blocking feature",
        "Add parental controls for video viewing"
    ]

    print("\n" + "="*50)
    print("RUNNING COMPLIANCE CHECKS...")
    print("="*50)

    while True:
        # Get user input
        query = input("\nEnter a feature description: ").strip()
        if query.lower() in {"exit", "quit"}:
            print("Exiting system. Goodbye!")
            break

        print(f"\nChecking compliance for: '{query}'")

        # Generate the prompt with context
        prompt = rag_engine.query_with_context(query)

        # Prepare input for LLaMA
        llama_input = f"{prompt}\nAnswer:"

        # Tokenize
        inputs = tokenizer(llama_input, return_tensors="pt").to(model.device)

        # Generate output
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )

        # Only decode newly generated tokens (exclude the prompt echo)
        generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
        answer = tokenizer.decode(generated_tokens, skip_special_tokens=True)

        cleaned_lines = []
        for line in answer.splitlines():
            stripped = line.strip()
            if stripped:  # keep non-empty lines only
                cleaned_lines.append(stripped)

        answer = "\n".join(cleaned_lines)

        print("\n=== LLaMA Output ===")
        print(answer.strip())


    # print("\n" + "="*50)
    # print("SYSTEM READY FOR PRODUCTION USE!")
    # print("="*50)

    # # Interactive mode example
    # print("\nTo use interactively, you would:")
    # print("1. Call rag_engine.query_with_context('your feature description')")
    # print("2. Send the resulting prompt to your LLM (GPT-4, Claude, etc.)")
    # print("3. Parse the LLM response for compliance decisions")



if __name__ == "__main__":
    main()

Initializing RAG System Components...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Loading compliance knowledge base...
