# VeriRAG: Knowledge Graph-Augmented RAG for Verilog and Assertions

**What this implements.**
- **Ingestion → Parsing → RDF Knowledge Graph (KG) → Vector Store → Retrieval → Prompting → Generation.**
- Two generation paths: **RTL** (large synthesizable Verilog) and **SVA** (Verilog with immediate assertions).
- **Interactive loops** so you can issue multiple queries in one session.

**High-level flow.**
1. Upload a CSV of training examples (instructions + Verilog code, or raw modules).
2. Parse Verilog to extract **ports, signals, parameters, and operations** with Pyverilog + heuristics.
3. Turn the parsed structure into an **RDF KG** (Turtle format) using `rdflib`.
4. Summarize each module with the LLM and compute **OpenAI embeddings**.
5. Store text + embeddings + metadata in **ChromaDB**.
6. At query time, use **vector retrieval** + **KG-derived metadata** to build **LLM prompts** for RTL or SVA code generation.

**Prerequisites.**
- This notebook is designed for Google Colab environments (mounts Google Drive under `/content/drive`).
- You must provide a valid **OpenAI API key**.
- You should have a CSV file compatible with the expected schema (see Section 7 for details).

**Outputs.**
- Per-chunk Turtle files in `.../knowledge_graphs/` (RDF representation of modules).
- A persistently stored ChromaDB collection (vector index + metadata).
- LLM-generated Verilog modules (printed in the cell output for inspection/copying).


## 1. Imports & Logging
These imports match the original scripts. They enable:
- **RDF/KG** operations (`rdflib`)
- **Vector DB** persistence (`chromadb`)
- **OpenAI** chat + embedding clients (`openai`)
- **Verilog parsing** (`pyverilog`)
- **I/O and utilities** (CSV, JSON, regex, logging)

Logging is configured once to surface key status messages and parsing diagnostics.

In [None]:
import rdflib
from rdflib import Graph, Literal, RDF, RDFS, Namespace, URIRef
import chromadb
import openai
import logging
import json
import urllib.parse
import os
import re
import csv
import uuid
import io
import tiktoken
from contextlib import redirect_stderr
from google.colab import files, drive
import pyverilog.vparser.parser as vparser
from pyverilog.vparser.parser import parse, Description, ModuleDef, Ioport, Port

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

## 2. OpenAI API
**Why:** Credentials should not be hard-coded. This cell obtains the key from `OPENAI_API_KEY` or prompts securely.

**Used by:** summarization (module overviews), embeddings (vector search), and both generation paths (RTL/SVA).

In [None]:
import getpass

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
if not OPENAI_API_KEY:
    try:
        OPENAI_API_KEY = getpass.getpass('Enter OPENAI_API_KEY (input hidden): ')
    except Exception as e:
        raise RuntimeError('OPENAI_API_KEY is required for online-only mode.')

openai_client = openai.OpenAI(api_key=OPENAI_API_KEY)
print('OpenAI client initialized.')

## 3. Storage: Google Drive Mount & Directories
**What it does:**
- Mounts Google Drive at `/content/drive` (typical for Colab) so intermediate files persist across sessions.
- Creates the working directories:
  - `PERSIST_DIRECTORY`: top-level workspace for this project.
  - `LOCAL_DB_DIR`: scratch area (kept for parity with original code).
  - `knowledge_graphs/`: per-chunk Turtle files (RDF graphs).

**Note:** paths mirror the original scripts so any existing artifacts continue to work.

In [None]:
from pathlib import Path
drive.mount('/content/drive')
PERSIST_DIRECTORY = '/content/drive/MyDrive/GRAPHRAG5'
LOCAL_DB_DIR = '/content/rag_db'
os.makedirs(PERSIST_DIRECTORY, exist_ok=True)
os.makedirs(LOCAL_DB_DIR, exist_ok=True)
os.makedirs(os.path.join(PERSIST_DIRECTORY, 'knowledge_graphs'), exist_ok=True)
print('Dirs ready:', PERSIST_DIRECTORY)

## 4. Summarization Function (Module Overviews)
**Purpose:** produce concise, 100–200 word summaries of each Verilog module for retrieval cues.

**Inputs:**
- `code`: raw Verilog text for a single module.
- `instruction`: the instruction string paired with the module (if present in the CSV).

**Output:** string summary (logged + returned). If the LLM call fails, the function returns a fallback error string.

In [None]:
def summarize_code(code, instruction):
    prompt = f"""
    Provide a detailed summary of the following Verilog module, including its functionality, inputs, outputs, parameters, and key operations. The instruction provided for the module is: "{instruction}".
    Verilog Code:
    ```verilog
    {code}
    ```
    Summary should be concise (100-200 words) and focus on:
    - Purpose of the module
    - Inputs and outputs (including widths)
    - Parameters (if any)
    - Main operations or logic
    - Any notable features (e.g., sequential, combinational, FSM)
    """
    try:
        response = openai_client.chat.completions.create(
            model='gpt-4',
            messages=[{'role': 'user', 'content': prompt}],
            max_tokens=500
        )
        summary = response.choices[0].message.content.strip()
        print(f'Generated summary: {summary[:100]}...')
        return summary
    except Exception as e:
        logging.error(f'Error generating summary: {str(e)}')
        return 'Summary generation failed.'

## 5. Verilog Parsing & Operation Classification
**Parsing strategy:**
1. Use **Pyverilog** when possible to build an AST and extract ports.
2. Fall back to **regex-based heuristics** to collect ports, signals, parameters, instantiations, and assignments.

**Operation tagging:** `classify_operation()` assigns coarse labels (e.g., `AND`, `OR`, `ADD`, `ASSIGN`, etc.) to expressions for KG edges.

**Outputs:**
- `module_name`
- `input_ports`, `output_ports`, `signals`, `parameters`
- `operations` with IDs, types, operands, and context (combinational/sequential)
- `ast` (when Pyverilog succeeds)

In [None]:
def preprocess_verilog(code):
    def replace_idx(match):
        idx = match.group(1)
        return f'IDX{idx}'
    code = re.sub(r'`IDX\((\d+)\)', replace_idx, code)
    return code

def classify_operation(expr):
    expr = expr.strip()
    if '&' in expr and not '&&' in expr:
        op_type = 'AND'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr.replace('&', ' '))
    elif '|' in expr and not '||' in expr:
        op_type = 'OR'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr.replace('|', ' '))
    elif '^' in expr:
        op_type = 'XOR'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr.replace('^', ' '))
    elif '+' in expr:
        op_type = 'ADD'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr.replace('+', ' '))
    elif '-' in expr and not '->' in expr:
        op_type = 'SUBTRACT'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr.replace('-', ' '))
    elif '<<' in expr:
        op_type = 'LSHIFT'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr.replace('<<', ' '))
    elif '>>' in expr:
        op_type = 'RSHIFT'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr.replace('>>', ' '))
    elif '~' in expr:
        op_type = 'NOT'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr.replace('~', ' '))
    elif '<=' in expr:
        op_type = 'NON_BLOCKING_ASSIGN'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr)
    elif '=' in expr and '==' not in expr:
        op_type = 'ASSIGN'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr)
    else:
        op_type = 'UNKNOWN'
        operands = re.findall(r'\b[\w\[\]:]+\b', expr)
    return op_type, operands

def parse_verilog_code(code, temp_file='temp.v'):
    module_name = None
    input_ports = []
    output_ports = []
    signals = []
    parameters = []
    operations = []
    ast = None
    header_ports = []
    port_directions = {}

    code = preprocess_verilog(code)
    with open(temp_file, 'w') as f:
        f.write(code)

    try:
        f = io.StringIO()
        with redirect_stderr(f):
            ast, _ = parse([temp_file], preprocess_include=['/content/verilog/'], debug=False)
        if isinstance(ast.description, Description):
            for node in ast.description.definitions:
                if isinstance(node, ModuleDef):
                    module_name = node.name
                    if node.portlist:
                        for port in node.portlist.ports:
                            if isinstance(port, Ioport) and hasattr(port.first, 'name'):
                                port_name = port.first.name
                                width = '1'
                                if hasattr(port.first, 'width') and port.first.width:
                                    width = f'[{port.first.width.msb}:{port.first.width.lsb}]'
                                if isinstance(port.first, vparser.Input):
                                    input_ports.append((port_name, width))
                                elif isinstance(port.first, vparser.Output):
                                    output_ports.append((port_name, width))
                            elif isinstance(port, Port) and hasattr(port, 'name'):
                                header_ports.append((port.name, '1'))
    except Exception as e:
        print(f'Pyverilog parsing failed: {str(e)}')

    try:
        lines = code.splitlines()
        module_found = False
        in_module_decl = False
        port_section = []
        i = 0
        always_context = None

        while i < len(lines):
            line = lines[i].strip()
            if line.startswith('//') or not line:
                i += 1
                continue
            if line.startswith('module'):
                candidate_name = line.split()[1].split('(')[0].strip()
                module_name = candidate_name
                module_found = True
                if '(' in line:
                    in_module_decl = True
                    start_idx = line.index('(') + 1
                    if ')' in line:
                        port_section.append(line[start_idx:line.index(')')])
                        in_module_decl = False
                    else:
                        port_section.append(line[start_idx:])
                i += 1
                continue
            if in_module_decl:
                if ')' in line:
                    port_section.append(line[:line.index(')')])
                    in_module_decl = False
                else:
                    port_section.append(line)
                i += 1
                continue
            if module_found:
                port_match = re.match(r'^(input|output|inout)\s*(wire|reg)?\s*(\[[\w\-`]+:[0-9]+\])?\s*([\w,\s]+)\s*[,;]', line)
                if port_match:
                    direction = port_match.group(1)
                    width = port_match.group(3) if port_match.group(3) else '1'
                    port_names = [p.strip() for p in port_match.group(4).split(',') if p.strip()]
                    for port_name in port_names:
                        port_directions[port_name] = direction
                        if direction == 'input' and port_name not in [p[0] for p in input_ports]:
                            input_ports.append((port_name, width))
                        elif direction == 'output' and port_name not in [p[0] for p in output_ports]:
                            output_ports.append((port_name, width))
                signal_match = re.match(r'^(wire|reg)\s*(\[[\w\-`]+:[0-9]+\])?\s*([\w,\s]+)\s*;', line)
                if signal_match:
                    signal_type = signal_match.group(1)
                    width = signal_match.group(2) if signal_match.group(2) else '1'
                    signal_names = [s.strip() for s in signal_match.group(3).split(',')]
                    for signal_name in signal_names:
                        if signal_name not in [p[0] for p in input_ports + output_ports]:
                            signals.append((signal_name, signal_type, width))
                param_match = re.match(r'^parameter\s+(.+?);', line)
                if param_match:
                    param_str = param_match.group(1).strip()
                    param_pairs = re.split(r',\s*(?=\w+\s*=)', param_str)
                    for pair in param_pairs:
                        pair_match = re.match(r'(\w+)\s*=\s*([^,\s]+(?:\s*[^,\s]+)*)', pair.strip())
                        if pair_match:
                            param_name = pair_match.group(1).strip()
                            param_value = pair_match.group(2).strip()
                            print(f'Parsed parameter: name={param_name}, value={param_value}')
                            parameters.append((param_name, param_value))
                inst_match = re.match(r'^(\w+)\s+(\w+)\s*\(([^)]+)\);', line)
                if inst_match:
                    module_type = inst_match.group(1)
                    instance_name = inst_match.group(2)
                    ports = [p.strip() for p in inst_match.group(3).split(',')]
                    operations.append({
                        'id': str(uuid.uuid4()),
                        'type': 'INSTANTIATION',
                        'target': instance_name,
                        'expression': f"{module_type}({', '.join(ports)})",
                        'operands': ports,
                        'context': 'structural'
                    })
                assign_match = re.match(r'^assign\s+([\w\[\]:]+)\s*=\s*([^;]+);', line)
                if assign_match:
                    target = assign_match.group(1)
                    expr = assign_match.group(2).strip()
                    op_type, operands = classify_operation(expr)
                    operations.append({
                        'id': str(uuid.uuid4()),
                        'type': op_type,
                        'target': target,
                        'expression': expr,
                        'operands': operands,
                        'context': 'combinational'
                    })
                if line.startswith('always @'):
                    if '@(*)' in line or '@(' in line and 'posedge' not in line:
                        always_context = 'combinational'
                    elif 'posedge' in line:
                        always_context = 'sequential'
                    i += 1
                    while i < len(lines) and not lines[i].strip().startswith('endmodule'):
                        stmt = lines[i].strip()
                        if stmt and not stmt.startswith('//'):
                            nb_assign_match = re.match(r'^([\w\[\]:]+)\s*<\=\s*([^;]+);', stmt)
                            if nb_assign_match:
                                target = nb_assign_match.group(1).strip()
                                expr = nb_assign_match.group(2).strip()
                                op_type, operands = classify_operation(expr)
                                operations.append({
                                    'id': str(uuid.uuid4()),
                                    'type': op_type,
                                    'target': target,
                                    'expression': expr,
                                    'operands': operands,
                                    'context': always_context
                                })
                            block_assign_match = re.match(r'^([\w\[\]:]+)\s*=\s*([^;]+);', stmt)
                            if block_assign_match:
                                target = block_assign_match.group(1).strip()
                                expr = block_assign_match.group(2).strip()
                                op_type, operands = classify_operation(expr)
                                operations.append({
                                    'id': str(uuid.uuid4()),
                                    'type': op_type,
                                    'target': target,
                                    'expression': expr,
                                    'operands': operands,
                                    'context': always_context
                                })
                        i += 1
                    continue
            i += 1

        if port_section:
            port_text = ' '.join(port_section).replace(';', ',')
            port_list = [p.strip() for p in port_text.split(',') if p.strip() and not p.strip().startswith('//')]
            for port in port_list:
                match = re.match(r'^(input|output|inout)?\s*(wire|reg)?\s*(\[[\w\-`]+:[0-9]+\])?\s*(\w+)', port)
                if match and match.group(1):
                    width = match.group(3) if match.group(3) else '1'
                    port_name = match.group(4)
                    direction = match.group(1)
                    port_directions[port_name] = direction
                    if direction == 'input' and port_name not in [p[0] for p in input_ports]:
                        input_ports.append((port_name, width))
                    elif direction == 'output' and port_name not in [p[0] for p in output_ports]:
                        output_ports.append((port_name, width))
                else:
                    header_ports.append((port, '1'))

        for port_name, width in header_ports:
            direction = port_directions.get(port_name, 'input')
            if direction == 'input' and port_name not in [p[0] for p in input_ports]:
                input_ports.append((port_name, width))
            elif direction == 'output' and port_name not in [p[0] for p in output_ports]:
                output_ports.append((port_name, width))

        if not module_found:
            print('No valid module found in code')
        else:
            print(f'Parsed module: {module_name}')

    except Exception as e:
        print(f'Heuristic parsing failed: {str(e)}')

    if os.path.exists(temp_file):
        os.remove(temp_file)

    return module_name, input_ports, output_ports, signals, parameters, operations, ast

## 6. Entity Extraction & RDF Knowledge Graph (RDF/Turtle)
**Why RDF:** graph structure lets us express relationships between **modules, parameters, signals, and operations**.

**Entities:** `Module`, `Signal`, `Parameter`, `Operation`.
**Relations:** `hasInput`, `hasOutput`, `hasInternalSignal`, `hasParameter`, `performsOperation`, `usesSignal`, `producesSignal`, `dependsOnParameter`, `usesParameter`, `instantiates`.

**Output:** a `.ttl` file per chunk in `knowledge_graphs/`, which can be inspected with standard RDF tools. 

In [None]:
def extract_entities(module_name, input_ports, output_ports, signals, parameters, operations, ast=None):
    modules = [{
        'name': module_name,
        'input_ports': [{'name': name, 'direction': 'input', 'width': width} for name, width in input_ports],
        'output_ports': [{'name': name, 'direction': 'output', 'width': width} for name, width in output_ports],
        'signals': [{'name': name, 'type': s_type, 'width': width} for name, s_type, width in signals],
        'parameters': [{'name': name, 'value': value} for name, value in parameters],
        'operations': operations
    }]

    signal_dict = {}
    for port in modules[0]['input_ports'] + modules[0]['output_ports'] + modules[0]['signals']:
        signal_dict[port['name']] = {
            'width': port['width'],
            'type': port.get('type', 'wire' if port.get('direction') in ['input', 'output'] else port.get('type')),
            'module': module_name,
            'direction': port.get('direction', 'internal')
        }

    param_dict = {param['name']: {
        'value': param['value'],
        'module': module_name
    } for param in modules[0]['parameters']}

    operation_dict = {op['id']: {
        'type': op['type'],
        'target': op['target'],
        'expression': op['expression'],
        'operands': op['operands'],
        'context': op['context'],
        'module': module_name
    } for op in operations}

    relationships = []
    for op in operations:
        op_id = op['id']
        for operand in op['operands']:
            operand_clean = re.sub(r'\[\d+:\d+\]', '', operand)
            if operand_clean in signal_dict:
                relationships.append({
                    'source': f"operation_{urllib.parse.quote(op_id)}",
                    'target': f"signal_{urllib.parse.quote(operand_clean)}",
                    'type': 'uses_signal'
                })
        target_clean = re.sub(r'\[\d+:\d+\]', '', op['target'])
        if target_clean in signal_dict:
            relationships.append({
                'source': f"operation_{urllib.parse.quote(op_id)}",
                'target': f"signal_{urllib.parse.quote(target_clean)}",
                'type': 'produces_signal'
            })
        for param_name in param_dict:
            if param_name in op['expression']:
                relationships.append({
                    'source': f"operation_{urllib.parse.quote(op_id)}",
                    'target': f"param_{urllib.parse.quote(param_name)}",
                    'type': 'depends_on_parameter'
                })
        if op['type'] == 'INSTANTIATION':
            module_type = op['expression'].split('(')[0]
            relationships.append({
                'source': f"module_{urllib.parse.quote(module_name)}",
                'target': f"module_{urllib.parse.quote(module_type)}",
                'type': 'instantiates'
            })

    for signal_name, signal_info in signal_dict.items():
        for param_name in param_dict:
            if param_name in signal_info['width']:
                relationships.append({
                    'source': f"signal_{urllib.parse.quote(signal_name)}",
                    'target': f"param_{urllib.parse.quote(param_name)}",
                    'type': 'uses_parameter'
                })

    return modules, signal_dict, param_dict, operation_dict, relationships

def create_knowledge_graph(modules, signals, parameters, operations, relationships, output_file):
    g = Graph()
    EX = Namespace('http://example.org/hw#')
    g.bind('ex', EX)

    g.add((EX.Module, RDF.type, RDFS.Class))
    g.add((EX.Signal, RDF.type, RDFS.Class))
    g.add((EX.Parameter, RDF.type, RDFS.Class))
    g.add((EX.Operation, RDF.type, RDFS.Class))
    g.add((EX.hasInput, RDF.type, RDF.Property))
    g.add((EX.hasOutput, RDF.type, RDF.Property))
    g.add((EX.hasInternalSignal, RDF.type, RDF.Property))
    g.add((EX.hasParameter, RDF.type, RDF.Property))
    g.add((EX.performsOperation, RDF.type, RDF.Property))
    g.add((EX.hasExpression, RDF.type, RDF.Property))
    g.add((EX.usesSignal, RDF.type, RDF.Property))
    g.add((EX.producesSignal, RDF.type, RDF.Property))
    g.add((EX.dependsOnParameter, RDF.type, RDF.Property))
    g.add((EX.usesParameter, RDF.type, RDF.Property))
    g.add((EX.instantiates, RDF.type, RDF.Property))

    for module in modules:
        module_uri = EX[f"module_{urllib.parse.quote(module['name'])}"]
        g.add((module_uri, RDF.type, EX.Module))
        g.add((module_uri, RDFS.label, Literal(module['name'])))

        for port in module['input_ports']:
            signal_uri = EX[f"signal_{urllib.parse.quote(port['name'])}"]
            g.add((signal_uri, RDF.type, EX.Signal))
            g.add((signal_uri, RDFS.label, Literal(port['name'])))
            g.add((signal_uri, EX.width, Literal(port['width'])))
            g.add((signal_uri, EX.direction, Literal('input')))
            g.add((module_uri, EX.hasInput, signal_uri))

        for port in module['output_ports']:
            signal_uri = EX[f"signal_{urllib.parse.quote(port['name'])}"]
            g.add((signal_uri, RDF.type, EX.Signal))
            g.add((signal_uri, RDFS.label, Literal(port['name'])))
            g.add((signal_uri, EX.width, Literal(port['width'])))
            g.add((signal_uri, EX.direction, Literal('output')))
            g.add((module_uri, EX.hasOutput, signal_uri))

        for signal_name, signal_info in signals.items():
            if signal_info['module'] == module['name'] and signal_info['direction'] == 'internal':
                signal_uri = EX[f"signal_{urllib.parse.quote(signal_name)}"]
                g.add((signal_uri, RDF.type, EX.Signal))
                g.add((signal_uri, RDFS.label, Literal(signal_name)))
                g.add((signal_uri, EX.width, Literal(signal_info['width'])))
                g.add((signal_uri, EX.signalType, Literal(signal_info['type'])))
                g.add((signal_uri, EX.direction, Literal('internal')))
                g.add((module_uri, EX.hasInternalSignal, signal_uri))

        for param_name, param_info in parameters.items():
            if param_info['module'] == module['name']:
                param_uri = EX[f"param_{urllib.parse.quote(param_name)}"]
                g.add((param_uri, RDF.type, EX.Parameter))
                g.add((param_uri, RDFS.label, Literal(param_name)))
                g.add((param_uri, EX.value, Literal(param_info['value'])))
                g.add((module_uri, EX.hasParameter, param_uri))

        for op_id, op_info in operations.items():
            if op_info['module'] == module['name']:
                op_uri = EX[f"operation_{urllib.parse.quote(op_id)}"]
                g.add((op_uri, RDF.type, EX.Operation))
                g.add((op_uri, RDFS.label, Literal(op_info['type'])))
                g.add((op_uri, EX.target, Literal(op_info['target'])))
                g.add((op_uri, EX.hasExpression, Literal(op_info['expression'])))
                g.add((op_uri, EX.context, Literal(op_info['context'])))
                for operand in op_info['operands']:
                    operand_clean = re.sub(r'\[\d+:\d+\]', '', operand)
                    if operand_clean in signals:
                        signal_uri = EX[f"signal_{urllib.parse.quote(operand_clean)}"]
                        g.add((op_uri, EX.usesSignal, signal_uri))
                g.add((module_uri, EX.performsOperation, op_uri))

    for rel in relationships:
        source_uri = EX[rel['source']]
        target_uri = EX[rel['target']]
        rel_type = EX[rel['type'].replace('_', '')]
        g.add((source_uri, rel_type, target_uri))

    g.serialize(destination=output_file, format='turtle')
    print(f'Knowledge graph saved to {output_file}')

## 7. CSV Processing & Embeddings
**Expected CSV schema:**
- If your data uses the Alpaca-style format, the `text` column may contain segments like `<s>[INST]...[/INST] ...</s>`.
  - The instruction goes to `instruction` and the following Verilog block goes to `code`.
- If your CSV contains standalone Verilog modules, the function detects `module ... endmodule` blocks directly.

**Embeddings:**
- For each chunk, build a text payload: *Instruction + Code + Summary*.
- Create embeddings with `text-embedding-3-small` and write to ChromaDB with the full document and metadata.

**Chroma metadata fields:** `id`, `instruction`, `summary`, `row_index`, `knowledge_graph` (path to TTL), `module_name`.

In [None]:
def process_csv(csv_file):
    chunks = []
    with open(csv_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        print(f'CSV Headers: {reader.fieldnames}')
        for row in reader:
            if 'text' in row:
                chunk = row['text']
            elif 'code' in row:
                chunk = row['code']
            elif reader.fieldnames:
                chunk = row[reader.fieldnames[0]]
            else:
                print(f'Row {reader.line_num}: No valid column found')
                continue

            inst_match = re.match(r'<s>\[INST\](.*?)\[/INST\](.*)</s>', chunk, re.DOTALL)
            if inst_match:
                instruction = inst_match.group(1).strip()
                code = inst_match.group(2).strip()
                chunks.append({
                    'id': str(uuid.uuid4()),
                    'instruction': instruction,
                    'code': code,
                    'row_index': reader.line_num - 1
                })
            else:
                module_match = re.search(r'module\s+(.*?)\s+endmodule', chunk, re.DOTALL)
                if module_match:
                    code = module_match.group(0).strip()
                    chunks.append({
                        'id': str(uuid.uuid4()),
                        'instruction': 'Process module code',
                        'code': code,
                        'row_index': reader.line_num - 1
                    })
                else:
                    print(f'Invalid chunk format in row {reader.line_num}: {chunk[:50]}...')
    return chunks

def get_code_embedding(code):
    try:
        response = openai_client.embeddings.create(
            input=code,
            model='text-embedding-3-small'
        )
        return response.data[0].embedding
    except Exception as e:
        logging.error(f'Error generating embedding: {str(e)}')
        return None

def generate_code_embeddings(code_chunks):
    embeddings = []
    for chunk in code_chunks:
        text = f"Instruction: {chunk['instruction']}\nCode:\n{chunk['code']}\nSummary:\n{chunk['summary']}"
        embedding = get_code_embedding(text)
        if embedding:
            embeddings.append(embedding)
        else:
            embeddings.append([0] * 1536)
            print(f"Failed to generate embedding for chunk {chunk['id']}")
    return embeddings

def store_in_chroma(chunks, embeddings, chroma_path, collection_name='verilog_modules'):
    client_ch = chromadb.PersistentClient(path=chroma_path)
    try:
        client_ch.delete_collection(collection_name)
    except:
        pass
    collection = client_ch.create_collection(collection_name)

    valid_chunks = []
    valid_embeddings = []
    for chunk, emb in zip(chunks, embeddings):
        if not all(x == 0 for x in emb):
            valid_chunks.append(chunk)
            valid_embeddings.append(emb)

    if valid_chunks:
        collection.add(
            embeddings=[emb for emb in valid_embeddings],
            documents=[f"Instruction: {chunk['instruction']}\nCode:\n{chunk['code']}\nSummary:\n{chunk['summary']}" for chunk in valid_chunks],
            metadatas=[{
                'id': chunk['id'],
                'instruction': chunk['instruction'],
                'summary': chunk['summary'],
                'row_index': chunk['row_index'],
                'knowledge_graph': chunk['knowledge_graph'],
                'module_name': chunk.get('module_name', '')
            } for chunk in valid_chunks],
            ids=[chunk['id'] for chunk in valid_chunks]
        )
    return collection

## 8. Build KG + Vector Store (End-to-End for Ingestion)
**What happens in this section:**
1. Parse each CSV row into chunks.
2. Extract Verilog structure.
3. Summarize with LLM (short textual description per module).
4. Write RDF graphs to Turtle files.
5. Compute embeddings and persist a Chroma collection.

**Artifacts written:**
- `chunk_metadata.json`: consolidated, per-chunk metadata for downstream retrieval.
- `knowledge_graphs/kg_<id>.ttl`: RDF graph files.
- `verilog_chroma_db/`: a persistent ChromaDB collection.

In [None]:
def main_kg(csv_file='extracted_texts1.csv'):
    chroma_path = os.path.join(PERSIST_DIRECTORY, 'verilog_chroma_db')
    metadata_file = os.path.join(PERSIST_DIRECTORY, 'chunk_metadata.json')

    chunks = process_csv(csv_file)
    print(f'Processed {len(chunks)} chunks from {csv_file}')

    metadata = []
    for chunk in chunks:
        module_name, input_ports, output_ports, signals, parameters, operations, ast = parse_verilog_code(chunk['code'])
        if module_name:
            chunk['module_name'] = module_name
            chunk['summary'] = summarize_code(chunk['code'], chunk['instruction'])
            modules, signals_dict, param_dict, operation_dict, relationships = extract_entities(
                module_name, input_ports, output_ports, signals, parameters, operations, ast
            )
            kg_file = os.path.join(PERSIST_DIRECTORY, 'knowledge_graphs', f"kg_{chunk['id']}.ttl")
            create_knowledge_graph(modules, signals_dict, param_dict, operation_dict, relationships, kg_file)
            chunk['knowledge_graph'] = kg_file
            metadata.append({
                'chunk_id': chunk['id'],
                'module_name': module_name,
                'knowledge_graph_path': kg_file,
                'row_index': chunk['row_index'],
                'instruction': chunk['instruction'],
                'summary': chunk['summary'],
                'code': chunk['code']
            })
        else:
            print(f"Failed to parse module for chunk {chunk['id']}")

    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2)
    print(f'Metadata saved to {metadata_file}')

    embeddings = generate_code_embeddings(chunks)
    print(f'Generated embeddings for {len(embeddings)} chunks')

    collection = store_in_chroma(chunks, embeddings, chroma_path)
    print(f'Chroma DB saved to {chroma_path}')

## 9. RTL Generator (Vector Retrieval → Prompt → LLM)
**Goal:** produce large, synthesizable Verilog based on the query. The prompt includes:
- Retrieved context (documents and KG metadata) for signal/operation awareness.
- Explicit requirements (e.g., pipelining, FSMs) matching the provided script.

**Retrieval fields:**
- `documents`: full text used for semantic similarity
- `metadatas`: `module_name`, `knowledge_graph` path, `id`, and summary

**Output:** printed RTL code block for manual review and copy.

In [None]:
from rdflib import Graph, URIRef

def count_tokens(text, model='gpt-4o'):
    enc = tiktoken.encoding_for_model(model)
    return len(enc.encode(text))

def load_chunks_and_chroma(metadata_file, chroma_path, collection_name='verilog_modules'):
    with open(metadata_file, 'r') as f:
        chunks = json.load(f)
    client = chromadb.PersistentClient(path=chroma_path)
    collection = client.get_collection(collection_name)
    return chunks, collection

def get_code_embedding(code):
    response = openai_client.embeddings.create(input=code, model='text-embedding-3-small')
    return response.data[0].embedding

def query_vector_db(query_text, collection, n_results=10):
    query_embedding = get_code_embedding(query_text)
    return collection.query(query_embeddings=[query_embedding], n_results=n_results)

def get_module_info(g, module_name):
    uri = URIRef(f"http://example.org/hw#module_{urllib.parse.quote(module_name)}")
    if not any(s == uri for s, _, _ in g):
        return None
    def run(q):
        results = g.query(q)
        return [{var: str(row[var]) for var in results.vars} for row in results]
    qstr = lambda x: x % urllib.parse.quote(module_name)
    return {
        'inputs': run(qstr("PREFIX ex: <http://example.org/hw#> PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> SELECT ?signal ?label ?width WHERE { ex:module_%s ex:hasInput ?signal . ?signal rdfs:label ?label . OPTIONAL { ?signal ex:width ?width . } }")),
        'outputs': run(qstr("PREFIX ex: <http://example.org/hw#> PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> SELECT ?signal ?label ?width WHERE { ex:module_%s ex:hasOutput ?signal . ?signal rdfs:label ?label . OPTIONAL { ?signal ex:width ?width . } }")),
        'signals': run(qstr("PREFIX ex: <http://example.org/hw#> PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> SELECT ?signal ?label ?width ?type WHERE { ex:module_%s ex:hasInternalSignal ?signal . ?signal rdfs:label ?label . OPTIONAL { ?signal ex:width ?width . } OPTIONAL { ?signal ex:signalType ?type . } }"))[:5],
        'parameters': run(qstr("PREFIX ex: <http://example.org/hw#> PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> SELECT ?param ?label ?value WHERE { ex:module_%s ex:hasParameter ?param . ?param rdfs:label ?label . ?param ex:value ?value . }")),
        'operations': run(qstr("PREFIX ex: <http://example.org/hw#> PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> SELECT ?operation ?label ?target ?expression WHERE { ex:module_%s ex:performsOperation ?operation . ?operation rdfs:label ?label . ?operation ex:target ?target . ?operation ex:hasExpression ?expression . }"))[:5]
    }

def construct_llm_prompt(query, results, chunks):
    prompt_template = """
# Verilog Module Generation
**Query**: {query}

**Context**: {module_details}

**Task**:
Generate a **full synthesizable Verilog RTL module** with **at least 500 lines**. You MUST:
- Use deep pipelining with at least 3 pipeline stages for every major block.
- Include multiple FSMs with at least 5 states each, documented inline.
- Expand register files to multiple arrays with hundreds of signals.
- Unroll loops for RAM initialization and complex data paths.
- Add extra case statements and conditionals to reach line count.
- Repeat non-critical always blocks with slight variations to pad size.
- Use meaningful signal names and thorough comments for each section.
Output only a single code block: `Verilog Module 500+ Lines`.
"""
    module_details = ''
    for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
        chunk = next((c for c in chunks if c['chunk_id'] == meta['id']), None)
        if not chunk:
            continue
        print(f"Processing module: {meta['module_name']} | Chunk ID: {meta['id']} | KG: {meta['knowledge_graph']}")
        g = Graph()
        g.parse(meta['knowledge_graph'], format='turtle')
        module_info = get_module_info(g, meta['module_name'])
        module_details += f"Module: {meta['module_name']}\nInstruction: {chunk.get('instruction','')}\nSummary: {chunk.get('summary','')}\nMetadata: {json.dumps(module_info, indent=2)}\n"
    prompt = prompt_template.format(query=query, module_details=module_details)
    return prompt

def call_llm(prompt):
    response = openai_client.chat.completions.create(
        model='gpt-4o',
        messages=[{'role': 'user', 'content': prompt}],
        max_tokens=16000
    )
    return response.choices[0].message.content

def query_llm_rtl(query, chunks, collection):
    results = query_vector_db(query, collection)
    if results:
        prompt = construct_llm_prompt(query, results, chunks)
        output = call_llm(prompt)
        print('\nGenerated Large RTL Verilog:\n', output if output else 'No response.')
    else:
        print('No DB results.')

def main_rtl():
    chunks, collection = load_chunks_and_chroma(
        os.path.join(PERSIST_DIRECTORY, 'chunk_metadata.json'),
        os.path.join(PERSIST_DIRECTORY, 'verilog_chroma_db'))
    if not chunks or not collection:
        print('Load failed.')
        return
    print(f'Loaded {len(chunks)} chunks.')
    while True:
        q = input("Query (or 'exit'): ")
        if q.strip().lower() == 'exit':
            break
        query_llm_rtl(q, chunks, collection)

## 10. SVA Generator (Vector Retrieval → Prompt → LLM)
**Goal:** generate a standalone, synthesizable Verilog module and embed **SystemVerilog immediate assertions**
that check operation correctness, bounds, input validity, and signal relations.

**Prompt contents:** retrieved context (documents + KG facts), a compact code example, and precise requirements.

**Output:** printed Verilog with assertions for manual review.

In [None]:
def load_chunks_and_chroma_sva(metadata_file, chroma_path, collection_name='verilog_modules'):
    try:
        with open(metadata_file, 'r') as f:
            chunks = json.load(f)
        client = chromadb.PersistentClient(path=chroma_path)
        collection = client.get_collection(collection_name)
        return chunks, collection
    except Exception as e:
        print(f'Failed to load chunks or Chroma collection: {e}')
        return None, None

def get_text_embedding_sva(code):
    try:
        response = openai_client.embeddings.create(
            input=code,
            model='text-embedding-3-small'
        )
        return response.data[0].embedding
    except Exception as e:
        logging.error(f'Error generating embedding: {str(e)}')
        return None

def query_vector_db_sva(query_text, collection, n_results=1):
    try:
        query_embedding = get_text_embedding_sva(query_text)
        if query_embedding:
            results = collection.query(
                query_embeddings=[query_embedding],
                n_results=n_results
            )
            return results
        else:
            print('Failed to generate query embedding')
            return None
    except Exception as e:
        print(f'Query error: {e}')
        return None

def get_module_info_sva(g, module_name):
    module_uri = URIRef(f"http://example.org/hw#module_{urllib.parse.quote(module_name)}")
    if not any(s == module_uri for s, _, _ in g):
        print(f'No triples found for module {module_name}')
        return None

    input_query = """
    PREFIX ex: <http://example.org/hw#>
    PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
    SELECT ?signal ?label ?width
    WHERE {
        ex:module_%s ex:hasInput ?signal .
        ?signal rdfs:label ?label .
        OPTIONAL { ?signal ex:width ?width . }
    }
    """ % urllib.parse.quote(module_name)

    output_query = """
    PREFIX ex: <http://example.org/hw#>
    PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
    SELECT ?signal ?label ?width
    WHERE {
        ex:module_%s ex:hasOutput ?signal .
        ?signal rdfs:label ?label .
        OPTIONAL { ?signal ex:width ?width . }
    }
    """ % urllib.parse.quote(module_name)

    signal_query = """
    PREFIX ex: <http://example.org/hw#>
    PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
    SELECT ?signal ?label ?width ?type
    WHERE {
        ex:module_%s ex:hasInternalSignal ?signal .
        ?signal rdfs:label ?label .
        OPTIONAL { ?signal ex:width ?width . }
        OPTIONAL { ?signal ex:signalType ?type . }
    }
    """ % urllib.parse.quote(module_name)

    parameter_query = """
    PREFIX ex: <http://example.org/hw#>
    PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
    SELECT ?param ?label ?value
    WHERE {
        ex:module_%s ex:hasParameter ?param .
        ?param rdfs:label ?label .
        ?param ex:value ?value .
    }
    """ % urllib.parse.quote(module_name)

    operation_query = """
    PREFIX ex: <http://example.org/hw#>
    PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
    SELECT ?operation ?label ?target ?expression
    WHERE {
        ex:module_%s ex:performsOperation ?operation .
        ?operation rdfs:label ?label .
        ?operation ex:target ?target .
        ?operation ex:hasExpression ?expression .
    }
    """ % urllib.parse.quote(module_name)

    def execute_query(query):
        results = g.query(query)
        output = []
        for row in results:
            row_dict = {}
            for var in results.vars:
                value = row[var]
                row_dict[str(var)] = str(value) if value is not None else None
            output.append(row_dict)
        return output

    try:
        inputs = execute_query(input_query)
        outputs = execute_query(output_query)
        signals = execute_query(signal_query)[:5]
        parameters = execute_query(parameter_query)
        operations = execute_query(operation_query)[:5]
    except Exception as e:
        print(f'Query error: {e}')
        return None

    module_info = {
        'module': module_name,
        'inputs': [{'name': row['label'], 'width': row.get('width', 'unknown')} for row in inputs],
        'outputs': [{'name': row['label'], 'width': row.get('width', 'unknown')} for row in outputs],
        'signals': [{'name': row['label'], 'width': row.get('width', 'unknown'), 'type': row.get('type', 'unknown')} for row in signals],
        'parameters': [{'name': row['label'], 'value': row['value']} for row in parameters],
        'operations': [{'type': row['label'], 'target': row['target'], 'expression': row['expression']} for row in operations]
    }
    return module_info

def construct_llm_prompt_sva(query, results, chunks):
    prompt_template = """
# Verilog Module Generation with Comprehensive Immediate Assertions
**Query**: {query}

**Context**: Relevant Verilog module details to guide the design.

{module_details}

**Task**:
- Generate a standalone, synthesizable Verilog module strictly adhering to the query.
- Ensure 100% syntactically correct Verilog code, compatible with EDA tools.
- Implement all logic directly, no external module instantiations.
- For case statements, define operation codes as `localparam` constants with numeric literals and use these constants in case items.
- Include concise comments for functionality, inputs, outputs, and logic.
- Include comprehensive SystemVerilog immediate assertions (`assert`) covering correctness, bounds, validity, and relationships.
- Output a single code block labeled `Verilog Module with Comprehensive Immediate Assertions`.
"""

    module_details = ''
    max_code_len = 300
    max_summary_len = 100
    for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
        chunk_id = meta['id']
        chunk = next((c for c in chunks if c['chunk_id'] == chunk_id), None)
        if not chunk:
            continue
        kg_file = meta['knowledge_graph']
        g = Graph()
        try:
            g.parse(kg_file, format='turtle')
        except Exception as e:
            print(f'Failed to parse {kg_file}: {e}')
            continue
        module_name = meta['module_name']
        module_info = get_module_info_sva(g, module_name)
        instruction = chunk.get('instruction', meta.get('instruction', 'Unknown'))
        code = chunk.get('code', 'Code not available')
        summary = chunk.get('summary', meta.get('summary', 'Summary not available'))
        code = code[:max_code_len] + ('...' if len(code) > max_code_len else '')
        summary = summary[:max_summary_len] + ('...' if len(summary) > max_summary_len else '')
        module_details += f"Module: {module_name}\nInstruction: {instruction}\nSummary: {summary}\nCode:\n```verilog\n{code}\n```\n"
        if module_info:
            module_details += (
                f"Metadata:\n  Inputs: {json.dumps(module_info['inputs'], indent=2)}\n"
                f"  Outputs: {json.dumps(module_info['outputs'], indent=2)}\n"
                f"  Signals: {json.dumps(module_info['signals'], indent=2)}\n"
                f"  Parameters: {json.dumps(module_info['parameters'], indent=2)}\n"
                f"  Operations: {json.dumps(module_info['operations'], indent=2)}\n"
            )
        module_details += '\n---\n'
    return prompt_template.format(query=query, module_details=module_details)

def call_llm_sva(prompt):
    try:
        response = openai_client.chat.completions.create(
            model='gpt-4',
            messages=[{'role': 'user', 'content': prompt}],
            max_tokens=6000
        )
        return response.choices[0].message.content
    except Exception as e:
        logging.error(f'Error calling LLM: {str(e)}')
        return None

def query_llm_sva(query, chunks, collection, n_results=1):
    results = query_vector_db_sva(query, collection, n_results)
    if results:
        prompt = construct_llm_prompt_sva(query, results, chunks)
        llm_response = call_llm_sva(prompt)
        if llm_response:
            print('\nGenerated Verilog Module with Assertions:\n')
            print(llm_response)
        else:
            print('Failed to get LLM response')
    else:
        print('Failed to query vector database')

def main_sva():
    metadata_file = os.path.join(PERSIST_DIRECTORY, 'chunk_metadata.json')
    chroma_path = os.path.join(PERSIST_DIRECTORY, 'verilog_chroma_db')
    chunks, collection = load_chunks_and_chroma_sva(metadata_file, chroma_path)
    if not chunks or not collection:
        print('Failed to load chunks or Chroma collection')
        return
    print(f'Loaded {len(chunks)} chunks and Chroma collection')
    while True:
        q = input("Query (or 'exit'): ")
        if q.strip().lower() == 'exit':
            break
        query_llm_sva(q, chunks, collection)


## 11. Interactive Entry (`while True` loop) & Usage Notes
**How to use this entrypoint:**
1. Run the cell.
2. Enter a mode:
   - `1`: Upload a CSV and build the KG + Vector store (one-time per dataset).
   - `2`: Start the **RTL** interactive loop and issue queries.
   - `3`: Start the **SVA** interactive loop and issue queries.

**Troubleshooting tips:**
- If `Load failed` appears in modes 2/3, ensure you completed mode 1 and that `chunk_metadata.json` and the Chroma collection exist.
- If embeddings fail for a row, the code logs it and continues; those rows are skipped when adding to Chroma.
- If Pyverilog cannot parse a module, the regex fallback tries to recover ports/signals.

**Reproducibility:**
- The LLM outputs are non-deterministic across runs. To reduce variance, you can set model parameters at call sites.
- All intermediate artifacts (TTL, JSON, Chroma) are written under the persistent directory so you can re-use them in later sessions.

In [None]:
def main():
    print('Select mode:')
    print('  1) Build KG + Chroma from CSV (main_kg)')
    print('  2) RTL generation (main_rtl)')
    print('  3) SVA generation (main_sva)')
    mode = input('Enter 1/2/3: ').strip()
    if mode == '1':
        uploaded = files.upload()
        csv_file = list(uploaded.keys())[0]
        main_kg(csv_file)
    elif mode == '2':
        main_rtl()
    elif mode == '3':
        main_sva()
    else:
        print('Invalid choice.')

if __name__ == '__main__':
    main()