In [None]:
!pip install sentence-transformers faiss-cpu scikit-learn gradio pandas



In [None]:
import pandas as pd

data = [
    {"query": "What is a patent?", "doc": "A patent is an exclusive right granted for an invention."},
    {"query": "How do I apply for a patent?", "doc": "To apply for a patent, submit an application to the national patent office."},
    {"query": "What does a trademark protect?", "doc": "A trademark protects brand names and logos used on goods and services."},
    {"query": "What is copyright?", "doc": "Copyright protects original works of authorship like books and music."},
    {"query": "What is the duration of a patent?", "doc": "Patents generally last for 20 years from the filing date."},
    {"query": "Can a patent be renewed?", "doc": "Patents cannot usually be renewed beyond their maximum term."},
    {"query": "What happens after a patent expires?", "doc": "After a patent expires, the invention becomes public domain."},
    {"query": "How to challenge a trademark?", "doc": "Trademark opposition is a legal process to prevent registration of a similar mark."},
    {"query": "Who grants patents in the USA?", "doc": "The United States Patent and Trademark Office (USPTO) grants patents."},
    {"query": "What is infringement of copyright?", "doc": "Using copyrighted work without permission may be considered infringement."}
]

df = pd.DataFrame(data)
df.head()


Unnamed: 0,query,doc
0,What is a patent?,A patent is an exclusive right granted for an ...
1,How do I apply for a patent?,"To apply for a patent, submit an application t..."
2,What does a trademark protect?,A trademark protects brand names and logos use...
3,What is copyright?,Copyright protects original works of authorshi...
4,What is the duration of a patent?,Patents generally last for 20 years from the f...


In [None]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

In [None]:
model = SentenceTransformer('all-MiniLM-L6-v2') #original model

In [None]:
doc_embeddings = model.encode(df['doc'].tolist(), show_progress_bar=True)
dimension = doc_embeddings.shape[1]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
dimension

384

In [None]:
index = faiss.IndexFlatL2(dimension)
index.add(doc_embeddings)

In [None]:
def semantic_search(query, top_k=3):
    query_embedding = model.encode([query])
    D, I = index.search(np.array(query_embedding), k=top_k)
    return df.iloc[I[0]][['doc']]

In [None]:
semantic_search("Tell me about patent")

Unnamed: 0,doc
0,A patent is an exclusive right granted for an ...
8,The United States Patent and Trademark Office ...
6,"After a patent expires, the invention becomes ..."


In [None]:
#fine tuned model
from sentence_transformers import InputExample, losses
from torch.utils.data import DataLoader

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./output",
    report_to="none",  # disables wandb, tensorboard, etc.
)

In [None]:
import os

# Make sure no wandb logs start
os.environ["WANDB_DISABLED"] = "true"

In [None]:
train_examples = [
    InputExample(texts=["What is a patent?", "A patent is an exclusive right granted for an invention."], label=1),
    InputExample(texts=["How to apply for a patent?", "To apply for a patent, submit an application to the national patent office."], label=1),
    InputExample(texts=["What is copyright?", "Photosynthesis is a process used by plants."], label=0),  # Negative
    InputExample(texts=["What is trademark?", "Trademark protects brand identity"], label=1),
]

In [None]:
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
train_loss = losses.CosineSimilarityLoss(model)

In [None]:
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, show_progress_bar=True)
model.save("fine-tuned-legal-sbert")

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss


In [None]:
model = SentenceTransformer("fine-tuned-legal-sbert")

In [None]:
doc_embeddings = model.encode(df['doc'].tolist(), show_progress_bar=True)
index = faiss.IndexFlatL2(doc_embeddings.shape[1])
index.add(doc_embeddings)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
semantic_search("What to do after my patent expired?")

Unnamed: 0,doc
6,"After a patent expires, the invention becomes ..."
4,Patents generally last for 20 years from the f...
0,A patent is an exclusive right granted for an ...


In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(df['doc'].tolist())

def get_tfidf_score(query):
    query_tfidf = vectorizer.transform([query])
    return tfidf_matrix.dot(query_tfidf.T).toarray().flatten()

In [None]:
def hybrid_search(query, alpha=0.5, top_k=3):
    query_embedding = model.encode([query])
    D, I = index.search(np.array(query_embedding), k=top_k)

    sem_scores = [1 - D[0][i] for i in range(top_k)]  # 1 - L2 distance
    keyword_scores = get_tfidf_score(query)

    results = []
    for rank, idx in enumerate(I[0]):
        final_score = alpha * sem_scores[rank] + (1 - alpha) * keyword_scores[idx]
        results.append((df.iloc[idx]['doc'], final_score))

    return sorted(results, key=lambda x: -x[1])


In [None]:
for doc, score in hybrid_search("What to do after my patent expired?"):
    print(f"{round(score, 3)} -> {doc}")


0.256 -> After a patent expires, the invention becomes public domain.
0.008 -> A patent is an exclusive right granted for an invention.
-0.019 -> Patents generally last for 20 years from the filing date.


In [None]:
import gradio as gr

def search_interface(query, alpha=0.5):
    results = hybrid_search(query, alpha)
    return "\n\n".join([f"{doc} (Score: {round(score, 3)})" for doc, score in results])

gr.Interface(
    fn=search_interface,
    inputs=[gr.Textbox(label="Search Query"), gr.Slider(0, 1, value=0.5, label="Semantic vs Keyword Weight")],
    outputs="text",
    title="Domain-Aware Hybrid Search (Legal)"
).launch()


It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://1ea9ba196e77b68228.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


