In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

device: cuda


In [None]:
import os
import glob
import re
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

import torch
from sentence_transformers import SentenceTransformer
import chromadb
from langchain_experimental.text_splitter import SemanticChunker

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

emb_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)

class TorchEmbedder:
    def embed_documents(self, docs):
        return emb_model.encode(docs, convert_to_numpy=True).tolist()

embedder = TorchEmbedder()
splitter = SemanticChunker(embedder)

client = chromadb.PersistentClient(path="./embData")
collection = client.get_or_create_collection("movies")


def _sanitize_tag(s):
    return re.sub(r"[^0-9A-Za-z_-]", "_", s)


def process_file(file_path, encode_batch_size=32):
    folder_name = os.path.basename(os.path.dirname(file_path))
    if "_" in folder_name:
        movie_name, imdb_id = folder_name.rsplit("_", 1)
    else:
        movie_name, imdb_id = folder_name, "unknown"

    character = os.path.splitext(os.path.basename(file_path))[0]

    with open(file_path, "r", encoding="utf-8") as f:
        raw_text = f.read()

    # semantic chunking
    chunks = splitter.create_documents([raw_text])

    ids, texts, metadatas = [], [], []
    for i, chunk in enumerate(chunks):
        chunk_id = f"{_sanitize_tag(folder_name)}__{_sanitize_tag(character)}__{i}"
        ids.append(chunk_id)
        texts.append(chunk.page_content)
        metadatas.append({
            "imdb_id": imdb_id,
            "movie": movie_name,
            "character": character,
            "filename": os.path.basename(file_path),
            "chunk_index": i
        })

    # embeddings
    embeddings = []
    for i in range(0, len(texts), encode_batch_size):
        sub_texts = texts[i:i + encode_batch_size]
        embs = emb_model.encode(sub_texts, show_progress_bar=False, convert_to_numpy=True)
        embeddings.extend(embs.tolist())

    return ids, texts, embeddings, metadatas


def process_all(parent_path, encode_batch_size=32, workers=None, batch_size=500):
    files = sorted(glob.glob(os.path.join(parent_path, "*", "*.txt")))
    print(f"Found {len(files)} files to process.")

    if workers is None:
        workers = min(32, multiprocessing.cpu_count())

    results_buffer = {"ids": [], "docs": [], "embs": [], "metas": []}

    with ThreadPoolExecutor(max_workers=workers) as executor:
        futures = [executor.submit(process_file, file, encode_batch_size) for file in files]

        with tqdm(total=len(futures), desc="Processing Files", unit="file") as pbar:
            for f in as_completed(futures):
                try:
                    ids, texts, embeddings, metadatas = f.result()

                    results_buffer["ids"].extend(ids)
                    results_buffer["docs"].extend(texts)
                    results_buffer["embs"].extend(embeddings)
                    results_buffer["metas"].extend(metadatas)

                    if len(results_buffer["ids"]) >= batch_size:
                        collection.add(
                            ids=results_buffer["ids"],
                            documents=results_buffer["docs"],
                            embeddings=results_buffer["embs"],
                            metadatas=results_buffer["metas"]
                        )
                        results_buffer = {"ids": [], "docs": [], "embs": [], "metas": []}

                except Exception as e:
                    print("Error:", e)
                finally:
                    pbar.update(1)

    if results_buffer["ids"]:
        collection.add(
            ids=results_buffer["ids"],
            documents=results_buffer["docs"],
            embeddings=results_buffer["embs"],
            metadatas=results_buffer["metas"]
        )

    print("âœ… Completed embedding + storage.")


if __name__ == "__main__":
    parent = r"C:\Users\anura\Desktop\NPN Hackathon\movie_characters\data\movie_character_texts\movie_character_texts"
    process_all(parent, encode_batch_size=256, workers=None, batch_size=1000)


In [3]:
!pip install langchain_experimental

