In [None]:
!pip install datasets pandas pymongo sentence-transformers

In [None]:
!pip install -U transformers accelerate

### Loading the dataset using `datasets` library directly from hugging face

In [None]:
from datasets import load_dataset

dataset = load_dataset("MongoDB/embedded_movies", split="train")

In [None]:
import pandas as pd

In [None]:
data = pd.DataFrame(dataset)

In [None]:
data.head()

In [None]:
data.columns

In [None]:
data["plot"][0]

In [None]:
data["fullplot"][0]

In [None]:
data["fullplot"].isnull().sum()

In [None]:
data.shape

In [None]:
data.isnull().sum()

In [None]:
dataset_df = data.dropna(subset=["fullplot"])

In [None]:
dataset_df["fullplot"].isnull().sum()

In [None]:
dataset_df = dataset_df.drop(columns=["plot_embedding"])

In [None]:
dataset_df.head(2)

In [None]:
from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("thenlper/gte-large")

In [None]:
def get_embedding(text: str):
  if not text.strip():
    print("Attempted to get embedding for empty string")
    return []
  embedding = embedding_model.encode(text)
  return embedding.tolist()

In [None]:
dataset_df["plot_embedding"] = dataset_df["fullplot"].apply(get_embedding)

### Connecting to MongoDB

In [None]:
from google.colab import userdata
mongo_db_uri = userdata.get('mongo_db_uri')

In [None]:
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

uri = mongo_db_uri

client = MongoClient(uri, server_api=ServerApi('1'))

try:
    client.admin.command('ping')
    print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
    print(e)

In [None]:
db = client["movie_db"]

In [None]:
collection = db["collection01"]

In [None]:
document = dataset_df.to_dict("records")

In [None]:
collection.insert_many(document)

### Data Retrieval

In [None]:
user_query = "Which one is a good horror movie based on story and title to watch and why?"

In [None]:
def vector_search(user_query, collection):
  query_embedding = get_embedding(user_query)
  if query_embedding is None:
    return "Invalid query or embedding is failed!"
  pipeline = [{
                "$vectorSearch": {
                "index": "vector_index",
                "queryVector": query_embedding,
                "path": "plot_embedding",
                "numCandidates": 150,
                "limit": 4,
                }

            },
              {
                "$project": {
                "fullplot": 1,
                "title": 1,
                "genres": 1,
                "score": {"$meta": "vectorSearchScore"},
                 }
            }]
  result = collection.aggregate(pipeline)
  return list(result)

In [None]:
vector_search(user_query, collection)

In [None]:
def get_search_result(user_query, collection):
  knowledge = vector_search(user_query, collection)
  search_result = ""
  for result in knowledge:
        search_result += f"Title: { result.get('title', 'N/A') }, Plot: { result.get('fullplot', 'N/A') }\n"
  return search_result

In [None]:
retrieved_info = get_search_result(user_query, collection)
print(retrieved_info)

### Response Generation

In [None]:
!pip install -U huggingface_hub

In [None]:
from google.colab import userdata
hugging_face_token = userdata.get('hugging_face_access_token')

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "google/gemma-2b-it"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
prompt = f"Query: {user_query}\nContinue to answer the query by using the Search Results:\n{retrieved_info}."

# Tokenize input
model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Generate output (use max_new_tokens to control length)
response = model.generate(**model_inputs, max_new_tokens=500, do_sample=True)

# Decode tokens back to text
output_text = tokenizer.batch_decode(response, skip_special_tokens=True)[0]

print(output_text)