## Setting Up Dependencies

In [23]:
import os
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import pyarrow.parquet as pq
import spacy
from torch.utils.data import DataLoader, Dataset

# Load NLP pipeline for query analysis
nlp = spacy.load("en_core_web_lg")

# Legal dictionary terms - expand as needed
LEGAL_TERMS = {
    "habeas corpus", "mens rea", "actus reus", "stare decisis", 
    "prima facie", "de novo", "res judicata", "certiorari",
    "statutory", "U.S.C.", "CFR", "jurisdiction", "adjudicate"
}

# Regex patterns for legal citations
CITATION_PATTERNS = [
    r'\d+\s+U\.S\.C\.\s+§*\s*\d+',  # US Code
    r'\d+\s+C\.F\.R\.\s+§*\s*\d+',   # Code of Federal Regulations
    r'[A-Za-z]+\s+v\.\s+[A-Za-z]+',  # Case names
]


## Loading Embeddings from Parquet

In [24]:
class LegalEmbeddingLoader:
    """Loads embeddings from parquet files for both models and all granularity levels."""
    
    def __init__(self, base_path):
        self.base_path = base_path
        self.gemini_embeddings = {}
        self.voyager_embeddings = {}
        self.metadata = {}
        
    def load_embeddings(self):
        """Load the six specified embedding files."""
        file_mappings = {
            "gemini_chapters": "embeddings_gemini_text-005_chapters_semchunk.parquet",
            "voyager_chapters": "embeddings_voyage_per_chapter_semchunked.parquet",
            "gemini_pages": "embeddings_gemini_text-005_pages_semchunk.parquet",
            "voyager_pages": "embeddings_voyage_per_pages_semchunked.parquet",
            "gemini_sections": "embeddings_gemini_text-005.parquet",
            "voyager_sections": "embeddings_voyage.parquet",
        }

        for key, file_name in file_mappings.items():
            print(self.base_path)
            file_path = os.path.join(self.base_path, key.split("_")[-1], file_name)
            print(file_path)
            if not os.path.exists(file_path):
                print(f"File {file_name} not found. Skipping...")
                continue

            # Read parquet file
            table = pq.read_table(file_path)
            df = table.to_pandas()
            print(f"\nColumns in {file_name}: {df.columns.tolist()}")
            # Extract embeddings and metadata
            embeddings = np.stack(df["Embedding"].values)

            # Determine model and granularity
            model, granularity = key.split("_")

            # Store embeddings
            if model == "gemini":
                self.gemini_embeddings[granularity] = torch.tensor(embeddings, dtype=torch.float32)
            else:
                self.voyager_embeddings[granularity] = torch.tensor(embeddings, dtype=torch.float32)

            # Store metadata
            self.metadata[key] = df.drop('Embedding', axis=1)

            print(f"Loaded {file_name} ({model} - {granularity})")
        return self.gemini_embeddings, self.voyager_embeddings, self.metadata

    def get_embedding_dimensions(self):
        """Return the dimensions of embeddings for both models."""
        gemini_dim = {k: v.shape[1] for k, v in self.gemini_embeddings.items()}
        voyager_dim = {k: v.shape[1] for k, v in self.voyager_embeddings.items()}
        return gemini_dim, voyager_dim

In [25]:
# # Assuming LegalEmbeddingLoader class has been defined as in the code you provided
# loader = LegalEmbeddingLoader(base_path="New_Embeddings_2025")
# gemini_embeddings, voyager_embeddings, metadata = loader.load_embeddings()

# # Get the embedding dimensions
# gemini_dim, voyager_dim = loader.get_embedding_dimensions()

# # Print the dimensions for both models
# print("Gemini Embedding Dimensions:")
# print(gemini_dim)

# print("\nVoyager Embedding Dimensions:")
# print(voyager_dim)


## Query Analysis and Intent Recognition

In [26]:
class LegalQueryAnalyzer:
    """Analyzes legal queries to determine intent and model weights."""
    
    def __init__(self, legal_terms=LEGAL_TERMS, citation_patterns=CITATION_PATTERNS, model_name='bert-base-uncased'):
        self.legal_terms = legal_terms
        self.citation_patterns = citation_patterns
        self.model_name = "sentence-transformers/all-mpnet-base-v2"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModel.from_pretrained(self.model_name)
        
    def analyze_query(self, query):
        """
        Analyze query characteristics to determine model weights.
        Returns a dictionary of features and recommended weights.
        """
        # Process with spaCy
        doc = nlp(query)
        
        # Feature extraction
        features = {
            'legal_term_density': self._calculate_legal_term_density(query),
            'citation_count': self._count_citations(query),
            'structural_complexity': self._assess_complexity(doc),
            'query_length': len(doc),
            'jurisdiction_signals': self._detect_jurisdiction(doc)
        }
        
        # Calculate recommended weights
        weights = self._determine_weights(features)
        
        return {
            'features': features,
            'weights': weights,
            'query_embedding': self._get_query_embedding(query)
        }
    
    def _calculate_legal_term_density(self, query):
        """Calculate the density of legal terminology in the query."""
        # Normalize and tokenize query
        query_lower = query.lower()
        total_tokens = len(query_lower.split())
        
        # Count legal terms
        legal_term_count = sum(1 for term in self.legal_terms if term.lower() in query_lower)
        
        # Calculate density
        if total_tokens > 0:
            return (legal_term_count / total_tokens) * 100
        return 0
    
    def _count_citations(self, query):
        """Count legal citations in the query."""
        citation_count = 0
        for pattern in self.citation_patterns:
            citation_count += len(re.findall(pattern, query))
        return citation_count
    
    def _assess_complexity(self, doc):
        """
        Assess the structural complexity of the query.
        Returns a score from 0-1 based on:
        - Number of clauses
        - Presence of legal conditionals
        - Sentence structure complexity
        """
        # Count clauses
        clause_markers = ["if", "when", "whether", "notwithstanding", "provided that"]
        clause_count = sum(1 for token in doc if token.text.lower() in clause_markers)
        
        # Check for complex legal conditionals
        has_conditionals = any(cm in doc.text.lower() for cm in clause_markers)
        
        # Assess syntactic complexity (simplified)
        depth = max((token.dep_.count('_') for token in doc), default=0)
        
        # Calculate complexity score (0-1)
        complexity = min(1.0, (clause_count * 0.2) + (0.3 if has_conditionals else 0) + (depth * 0.1))
        
        return complexity
    
    def _detect_jurisdiction(self, doc):
        """
        Detect jurisdictional signals in the query.
        Returns a dictionary of jurisdictional features.
        """
        # Look for jurisdictional entities
        jurisdictions = {
            'federal': 0,
            'state': 0,
            'international': 0,
            'specific_court': None
        }
        
        # Check for federal signals
        federal_terms = ["federal", "U.S.", "United States", "SCOTUS", "Supreme Court"]
        jurisdictions['federal'] = any(term.lower() in doc.text.lower() for term in federal_terms)
        
        # Check for state signals
        state_names = [
    "Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut", "Delaware", 
    "Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa", "Kansas", "Kentucky", 
    "Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan", "Minnesota", "Mississippi", 
    "Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire", "New Jersey", "New Mexico", 
    "New York", "North Carolina", "North Dakota", "Ohio", "Oklahoma", "Oregon", "Pennsylvania", 
    "Rhode Island", "South Carolina", "South Dakota", "Tennessee", "Texas", "Utah", "Vermont", 
    "Virginia", "Washington", "West Virginia", "Wisconsin", "Wyoming"
]
  # Add all states
        jurisdictions['state'] = any(state in doc.text for state in state_names)
        
        # Check for international signals
        international_terms = ["international", "foreign", "treaty", "convention"]
        jurisdictions['international'] = any(term.lower() in doc.text.lower() for term in international_terms)
        
        # Look for specific courts
        court_patterns = ["Circuit", "District Court", "Supreme Court"]
        for pattern in court_patterns:
            if pattern in doc.text:
                jurisdictions['specific_court'] = pattern
                break
                
        return jurisdictions
    
    def _determine_weights(self, features):
        """
        Determine the optimal weights for each model based on features.
        Uses a rule-based approach initially, could be replaced with ML model.
        """
        # Default weights slightly favor specialized model
        gemini_weight = 0.4
        voyager_weight = 0.6
        
        # Adjust for legal density and citations
        if features['legal_term_density'] > 5 or features['citation_count'] > 0:
            # Increase weight for legal model
            voyager_weight += 0.15
            gemini_weight -= 0.15
        
        # Adjust for complexity
        if features['structural_complexity'] > 0.7:
            voyager_weight += 0.1
            gemini_weight -= 0.1
        
        # Adjust for jurisdictional specificity
        if features['jurisdiction_signals']['specific_court']:
            voyager_weight += 0.1
            gemini_weight -= 0.1
        
        # Ensure weights are valid
        voyager_weight = min(max(voyager_weight, 0.1), 0.9)
        gemini_weight = 1.0 - voyager_weight
        
        return {
            'gemini': gemini_weight,
            'voyager': voyager_weight
        }
    
    def _get_query_embedding(self, query):
        """Generate embedding for the query using the pretrained model."""
        inputs = self.tokenizer(query, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state.mean(dim=1)


## Attention-Based Fusion Implementation

In [27]:
class MultiLevelAttention(nn.Module):
    """
    Enhanced attention mechanism with dimension fixes
    """
    
    def __init__(self, output_dim=768):
        super(MultiLevelAttention, self).__init__()
        
        # Model dimensions
        self.gemini_dim = 768
        self.voyager_dim = 1024
        self.output_dim = output_dim
        self.granularities = ['sections', 'chapters', 'pages']
        self.cross_attention_weights = {}
        # Query projection
        self.query_projector = nn.Linear(768, output_dim)
        
        # Document projectors
        self.gemini_projector = nn.Linear(self.gemini_dim, output_dim)
        self.voyager_projector = nn.Linear(self.voyager_dim, output_dim)
        
        # Aggregation
        self.aggregation_layer = nn.Linear(output_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, gemini_embeddings, voyager_embeddings, query_embedding, weights):
        # Project query
        query_projected = self.query_projector(query_embedding)
        
        granularity_results = {}
        diagnostics = {}
        
        for granularity in self.granularities:
            print(f"Processing granularity: {granularity}")
            gemini_emb = gemini_embeddings[granularity]
            voyager_emb = voyager_embeddings[granularity]
            
            # Project document embeddings
            gemini_proj = self.gemini_projector(gemini_emb)
            voyager_proj = self.voyager_projector(voyager_emb)
            

            gemini_proj = gemini_proj[:, :self.output_dim]  # Ensure [num_docs, 768]
            voyager_proj = voyager_proj[:, :self.output_dim]
            # Calculate similarities (fixed dimensions)
            gemini_similarity = F.cosine_similarity(
                gemini_proj,
                query_projected.unsqueeze(0),
                dim=1
            ).unsqueeze(0)  # Add batch dimension for softmax
            
            voyager_similarity = F.cosine_similarity(
                voyager_proj,
                query_projected.unsqueeze(0),
                dim=1
            ).unsqueeze(0)
            
            # Normalize similarities
            gemini_weights = F.softmax(gemini_similarity, dim=-1)
            voyager_weights = F.softmax(voyager_similarity, dim=-1)
            print(f"Gemini weights shape: {gemini_weights.shape}")
            print(f"Gemini proj shape: {gemini_proj.shape}")
            # Weighted sum with proper dimensions
            gemini_weighted = torch.matmul(
                gemini_weights,
                gemini_proj.T
            ).squeeze(0)  # Remove batch dimension
            
            voyager_weighted = torch.matmul(
                voyager_weights,
                voyager_proj.T
            ).squeeze(0)
            
            # Combine model embeddings
            new_size = (weights['gemini'] * gemini_weighted).shape[-1] - (weights['voyager'] * voyager_weighted).shape[-1]  # Padding size for smaller model
            print(f"New size: {new_size}")
            granularity_combined = (
                weights['gemini'] * gemini_weighted +
                F.pad(weights['voyager'] * voyager_weighted, (0, new_size), "constant", 0)
                
            )
            
            granularity_results[granularity] = granularity_combined
            print(f"Granularity Completed: {granularity}")
        # Aggregate results
        # Unified dimension handling
        max_dim = max(granularity_results[gran].shape[-1] for gran in self.granularities)

        combined_embedding = torch.stack([
            F.pad(gran_res, (0, max_dim - gran_res.shape[-1])) 
            for gran_res in granularity_results.values()
        ], dim=0).mean(dim=0)

        # combined_embedding = torch.stack([
        #             granularity_results[gran] for gran in self.granularities
        #         ], dim=0).mean(dim=0)
        self.cross_attention_weights = {
            'gemini': gemini_weights.detach().cpu(),
            'voyager': voyager_weights.detach().cpu(),
            'combined': granularity_combined.detach().cpu()
        }

        # Final processing
        # Modified final processing with dimension enforcement
        fused_embedding = self.layer_norm(
            self.aggregation_layer(
                nn.Linear(2176, 768)(combined_embedding)  # Project to correct dimension
            )
        )
        
        return fused_embedding, {
            'processed_embeddings': granularity_results,
            'diagnostics': diagnostics,
            'cross_attention_weights': self.cross_attention_weights 
        }


## Complete Legal RAG System Implementation

In [31]:
class LegalAttentionRAG:
    """
    Complete Legal RAG system using attention-based fusion of embeddings.
    """
    
    def __init__(self, embedding_path, query_embedding_dim=768, 
                 device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        print(f"Initializing LegalAttentionRAG on device: {device}")
        
        # Initialize embedding loader
        self.embedding_loader = LegalEmbeddingLoader(embedding_path)
        self.gemini_embeddings, self.voyager_embeddings, self.metadata = self.embedding_loader.load_embeddings()
        # Add to __init__
        self.gemini_projector = nn.Linear(768, 768)  # Maintain original dimension
        self.voyager_projector = nn.Linear(1024, 768)  # Project to match Gemini

        # Get embedding dimensions and verify
        gemini_dims, voyager_dims = self.embedding_loader.get_embedding_dimensions()
        print(f"Gemini Embedding Dimensions:\n{gemini_dims}")
        print(f"Voyager Embedding Dimensions:\n{voyager_dims}")
        
        # Initialize query analyzer
        self.query_analyzer = LegalQueryAnalyzer()
        
        # Initialize attention fusion model - using fixed version that doesn't rely on dimension dicts
        self.attention_model = MultiLevelAttention(output_dim=query_embedding_dim).to(device)
        
        # Move embeddings to device
        for key in self.gemini_embeddings:
            self.gemini_embeddings[key] = self.gemini_embeddings[key].to(device)
            self.voyager_embeddings[key] = self.voyager_embeddings[key].to(device)
            
        print("LegalAttentionRAG system initialized successfully")

            
    def process_query(self, query, top_k=5):
        """Process a legal query using the simplified attention fusion."""
        # 1. Analyze query to determine intent and weights
        query_analysis = self.query_analyzer.analyze_query(query)
        weights = query_analysis['weights']
        features = query_analysis['features']
        query_embedding = query_analysis['query_embedding'].to(self.device)
        
        print(f"Query embedding shape: {query_embedding.shape}")
        print(f"Model weights: Gemini={weights['gemini']:.2f}, Voyager={weights['voyager']:.2f}")
        
        # 2. Apply attention-based fusion
        try:
            fused_embedding, attention_info = self.attention_model(
                self.gemini_embeddings,
                self.voyager_embeddings,
                query_embedding,
                weights
            )
            
            # Log dimensions for debugging
            if 'diagnostics' in attention_info:
                print("Embedding shapes:")
                for key, value in attention_info['diagnostics'].items():
                    print(f"  {key}: {value}")
                    
        except Exception as e:
            print(f"Error in attention model: {e}")
            # Log shapes to help diagnose the issue
            print(f"Query shape: {query_embedding.shape}")
            for granularity in self.gemini_embeddings.keys():
                print(f"Gemini {granularity}: {self.gemini_embeddings[granularity].shape}")
                print(f"Voyager {granularity}: {self.voyager_embeddings[granularity].shape}")
            raise


        
        # 3. Retrieve most relevant documents using the fused embedding
        results = {}
        # Modified similarity calculation with dimension alignment
        for granularity in ['sections', 'chapters', 'pages']:
            gemini_emb = self.gemini_projector(self.gemini_embeddings[granularity])  # [n_docs, 768]
            voyager_emb = self.voyager_projector(self.voyager_embeddings[granularity])  # [n_docs, 768]
            
            # Calculate similarity scores with proper dimension alignment
            gemini_scores = F.cosine_similarity(
                fused_embedding.unsqueeze(0).unsqueeze(1),  # [1, 1, 768]
                gemini_emb.unsqueeze(0).unsqueeze(0),       # [1, 1, n_docs, 768]
                dim=3
            ).squeeze()
            
            voyager_scores = F.cosine_similarity(
                fused_embedding.unsqueeze(0).unsqueeze(1),  # [1, 1, 768]
                voyager_emb.unsqueeze(0).unsqueeze(0),      # [1, 1, n_docs, 768]
                dim=3
            ).squeeze()
    
    # Rest of your code remains the same

            
            # Weighted combination of scores
            new_size = (weights['gemini'] * gemini_scores).shape[-1] - (weights['voyager'] * voyager_scores).shape[-1]  # Padding size for smaller model
            print(f"New size: {new_size}")
            combined_scores = (
                weights['gemini'] * gemini_scores + 
                F.pad(weights['voyager'] * voyager_scores, (0, new_size), "constant", 0)
            )
            
            # Get top-k results
            top_indices = combined_scores.argsort(descending=True)[:top_k]
            
            # Retrieve metadata for these indices
            gemini_meta = self.metadata[f"gemini_{granularity}"].iloc[top_indices.cpu().numpy()]
            voyager_meta = self.metadata[f"voyager_{granularity}"].iloc[top_indices.cpu().numpy()]
            
            # Store results
            results[granularity] = {
                'indices': top_indices.cpu().detach().numpy(),
                'scores': combined_scores[top_indices].cpu().detach().numpy(),
                'gemini_metadata': gemini_meta,
                'voyager_metadata': voyager_meta
            }
        # 4. Return results with diagnostic information
        return {
            'query_analysis': {
                'features': features,
                'weights': weights
            },
            'attention_weights': attention_info.get('cross_attention_weights'),
            'results': results
        }
    
    def explain_weights(self, query):
        """Generate an explanation of why particular weights were chosen."""
        analysis = self.query_analyzer.analyze_query(query)
        features = analysis['features']
        weights = analysis['weights']
        
        explanation = {
            'query': query,
            'gemini_weight': weights['gemini'],
            'voyager_weight': weights['voyager'],
            'reasoning': []
        }
        
        # Explain each feature's contribution
        if features['legal_term_density'] > 5:
            explanation['reasoning'].append(
                f"High legal terminology density ({features['legal_term_density']:.1f}%) "
                f"increased Voyager Law 2 weight by 15%"
            )
            
        if features['citation_count'] > 0:
            explanation['reasoning'].append(
                f"Presence of {features['citation_count']} legal citations "
                f"increased Voyager Law 2 weight by 15%"
            )
            
        if features['structural_complexity'] > 0.7:
            explanation['reasoning'].append(
                f"High query complexity score ({features['structural_complexity']:.2f}) "
                f"increased Voyager Law 2 weight by 10%"
            )
            
        if features['jurisdiction_signals']['specific_court']:
            explanation['reasoning'].append(
                f"Specific court reference ({features['jurisdiction_signals']['specific_court']}) "
                f"increased Voyager Law 2 weight by 10%"
            )
            
        if not explanation['reasoning']:
            explanation['reasoning'].append(
                "Used default weights slightly favoring legal expertise (Voyager: 60%, Gemini: 40%)"
            )
            
        return explanation


## Direct Fusion RAG

In [29]:
class DirectFusionRAG:
    """
    Simple weighted combination of retrieval results from both models.
    This is a fallback approach if attention-based fusion is problematic.
    """
    def __init__(self, embedding_path, query_embedding_dim=768, 
                 device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        print(f"Initializing LegalSimpleRAG on device: {device}")
        
        # Initialize embedding loader
        self.embedding_loader = LegalEmbeddingLoader(embedding_path)
        self.gemini_embeddings, self.voyager_embeddings, self.metadata = self.embedding_loader.load_embeddings()
        
        self.query_projection = nn.Linear(768, 1024).to(self.device)

        # Get embedding dimensions and verify
        gemini_dims, voyager_dims = self.embedding_loader.get_embedding_dimensions()
        print(f"Gemini Embedding Dimensions:\n{gemini_dims}")
        print(f"Voyager Embedding Dimensions:\n{voyager_dims}")
        
        # Initialize query analyzer
        self.query_analyzer = LegalQueryAnalyzer()
        
        # Initialize attention fusion model - using fixed version that doesn't rely on dimension dicts
        self.attention_model = MultiLevelAttention(output_dim=query_embedding_dim).to(device)
        
        # Move embeddings to device
        for key in self.gemini_embeddings:
            self.gemini_embeddings[key] = self.gemini_embeddings[key].to(device)
            self.voyager_embeddings[key] = self.voyager_embeddings[key].to(device)
            
        print("LegalSimpleRAG system initialized successfully")

        
    def process_query(self, query, top_k=5):
        self.query_analyzer = LegalQueryAnalyzer()
        # 1. Analyze query for weights
        query_analysis = self.query_analyzer.analyze_query(query)
        weights = query_analysis['weights']
        query_embedding = query_analysis['query_embedding'].to(self.device)
        
        results = {}
        
        # 2. For each granularity, retrieve from both models separately
        for granularity in self.gemini_embeddings.keys():
            # Get embeddings
            gemini_emb = self.gemini_embeddings[granularity]
            linear = nn.Linear(768, 1024)  # Define a linear layer to project gemini_emb to 1024 dimensions
            gemini_emb = linear(gemini_emb)
            print(f"Projected gemini_emb shape: {gemini_emb.shape}") 
            voyager_emb = self.voyager_embeddings[granularity]
            print(f"voyager_emb shape: {voyager_emb.shape}")
            
            if query_embedding.shape[-1] != 768:
                print(f"Warning: Unexpected query_embedding shape {query_embedding.shape}, expected (1, 768)")
            
            input_dim = query_embedding.shape[-1]
            linear = nn.Linear(input_dim, 1024).to(self.device)
            query_embedding = linear(query_embedding)
            print(f"Projected query_embedding shape: {query_embedding.shape}")
            print(f"Query Embeddings: {query_embedding.shape}")
            # Calculate similarity scores directly
            gemini_scores = F.cosine_similarity(
                query_embedding.unsqueeze(0),
                gemini_emb
            )
            print("gemin Done")
            voyager_scores = F.cosine_similarity(
                query_embedding.unsqueeze(0),
                voyager_emb
            )
            print("Voyage Done")
            # Weight scores based on query analysis
            combined_scores = (
                weights['gemini'] * gemini_scores +
                weights['voyager'] * voyager_scores
            )
            print(f"Shape of combined_scores: {combined_scores.shape}")
            # Flatten combined_scores to 1D
            combined_scores = combined_scores.flatten()
            print(f"Shape of combined_scores after Flatten: {combined_scores.shape}")
            # Get top results
            top_indices = combined_scores.argsort(descending=True)[:top_k]
            # Ensure top_indices are within bounds of combined_scores
            top_indices = top_indices[top_indices < combined_scores.shape[0]]

            # Now safely access top_results
            top_results = combined_scores[top_indices]
            # Store results
            combined_scores = combined_scores.detach().numpy()
            results[granularity] = {
                'indices': top_indices,
                'scores': combined_scores[top_indices]
            }
        
        return results


## Usage Example: Processing Legal Queries

In [39]:
# Example usage of the complete system
def example_usage():
    # Initialize the system
    legal_rag = LegalAttentionRAG(embedding_path="New_Embeddings_2025")
    
    # Example queries representing different types of legal questions
    queries = [
        "What are the elements of wire fraud under 18 U.S.C. § 1343?",
        "Explain the concept of mens rea in criminal law",
        "What is the difference between Title 18 and Title 26?",
        "Has the Supreme Court ruled on the constitutionality of 18 U.S.C. § 1512(c)(2)?"
    ]
    
    for query in queries:
        print(f"\nProcessing query: '{query}'")
        
        # Get explanation of weight determination
        explanation = legal_rag.explain_weights(query)
        print("\nQuery Weight Analysis:")
        print(f"Gemini weight: {explanation['gemini_weight']:.2f}")
        print(f"Voyager weight: {explanation['voyager_weight']:.2f}")
        print("Reasoning:")
        for reason in explanation['reasoning']:
            print(f"- {reason}")
            
        # Process query and get results using Attention and Direct fusion RAG
        technique= "Attention"
        import pandas as pd

        def get_dict_skeleton_with_df(d):
            if isinstance(d, dict):
                return {k: get_dict_skeleton_with_df(v) for k, v in d.items()}
            elif isinstance(d, list):
                return [get_dict_skeleton_with_df(d[0]) if d else []]  # Handle lists
            elif isinstance(d, pd.DataFrame):
                return {"DataFrame": list(d.columns)}  # Extract DataFrame column names
            else:
                return type(d).__name__  # Get the type of the value
        if technique == "Direct":

            direct_fusion_rag = DirectFusionRAG(embedding_path="New_Embeddings_2025")  # Initialize DirectFusionRAG
            results = direct_fusion_rag.process_query(query)
            print(results)
            print("Direct RAG DONE")
            print("\nTop Results:")
            for granularity, data in results.items():
                print(f"\n{granularity.upper()}:")
                for i, index in enumerate(data['indices'].tolist()):  # Convert tensor to list
                    score = data['scores'][i]
                    print(f"  {i+1}. Index: {index} (Score: {score:.4f})")

        if technique=="Attention":
            results = legal_rag.process_query(query, top_k=3)
            print("\n#########Top Results#######:")
            print(get_dict_skeleton_with_df(results))
            print("----------END OF RESULTS---------")
        # Display top results from different granularities
            for granularity in results['results']:
                print(f"\n{granularity.upper()}:")
                for i in range(len(results['results'][granularity]['indices'])):
                    score = results['results'][granularity]['scores'][i]
                    # Get the actual index from indices array
                    idx = results['results'][granularity]['indices'][i]
                    
                    # Access metadata using proper keys from DataFrame structure
                    metadata = results['results'][granularity]['voyager_metadata'].iloc[i]
                    
                    # Use .get() with fallback values for missing keys
                    print(f"  {i+1}. Section {metadata.get('Section', metadata.get('chunk', 'N/A'))[:50]}... (Score: {score:.4f})")
                    print(f"     URL: {metadata.get('Url', 'No URL available')}")
                    print(f"     Content: {metadata.get('Content', metadata.get('chunk', 'No text available'))[:150]}...")



example_usage()


Initializing LegalAttentionRAG on device: cpu
New_Embeddings_2025
New_Embeddings_2025\chapters\embeddings_gemini_text-005_chapters_semchunk.parquet

Columns in embeddings_gemini_text-005_chapters_semchunk.parquet: ['chunk', 'Embedding']
Loaded embeddings_gemini_text-005_chapters_semchunk.parquet (gemini - chapters)
New_Embeddings_2025
New_Embeddings_2025\chapters\embeddings_voyage_per_chapter_semchunked.parquet

Columns in embeddings_voyage_per_chapter_semchunked.parquet: ['chunk', 'Embedding']
Loaded embeddings_voyage_per_chapter_semchunked.parquet (voyager - chapters)
New_Embeddings_2025
New_Embeddings_2025\pages\embeddings_gemini_text-005_pages_semchunk.parquet

Columns in embeddings_gemini_text-005_pages_semchunk.parquet: ['chunk', 'Embedding']
Loaded embeddings_gemini_text-005_pages_semchunk.parquet (gemini - pages)
New_Embeddings_2025
New_Embeddings_2025\pages\embeddings_voyage_per_pages_semchunked.parquet

Columns in embeddings_voyage_per_pages_semchunked.parquet: ['chunk', 'Emb