In [None]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel

class InLegalBERTEmbeddings:
    def __init__(self, model_name="law-ai/InLegalBERT"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def embed_documents(self, texts):
        embeddings = []
        for text in texts:
            # Tokenize the input text
            inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
            
            # Get the model outputs
            with torch.no_grad():
                outputs = self.model(**inputs)
            
            # Use the last hidden state
            last_hidden_state = outputs.last_hidden_state
            
            # Apply mean pooling to get a single vector for the entire text
            sentence_embedding = last_hidden_state.mean(dim=1).squeeze().numpy()
            embeddings.append(sentence_embedding)
        
        return np.array(embeddings)

In [None]:
embeddings = InLegalBERTEmbeddings(model_name="law-ai/InLegalBERT")

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
import pandas as pd
# Load the CSV data
csv_path = r"D:\Work\projects\Major Project\BharatLAW\data\FIR_DATASET.csv"
df = pd.read_csv(csv_path)

# Combine relevant columns into a single text string for embedding
def create_document(row):
    return f"Section: {row['URL'].split('/')[-1]} | Description: {row['Description']} | Offense: {row['Offense']} | Punishment: {row['Punishment']}"

df['document'] = df.apply(create_document, axis=1)

# Split the documents into chunks (if necessary)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
docs = text_splitter.create_documents(df['document'].tolist())

# Generate embeddings and add them to the vector store
vectorstore = FAISS.from_documents(docs, embeddings)

# Save the updated vector store
vectorstore.save_local("ipc_embed_db_updated")