# Export Vietnamese Legal Embeddings

This notebook generates embeddings using Colab's free T4 GPU and exports them to a file.
Then you can import the file into your local Qdrant without needing GPU.

## Steps:
1. Run all cells to generate embeddings (~15-20 min)
2. Download the output file `embeddings.jsonl`
3. Run `python import_embeddings_local.py` on your machine

In [None]:
# Cell 1: Install dependencies
!pip install -q datasets sentence-transformers tqdm

In [None]:
# Cell 2: Configuration
MODEL_NAME = "minhquan6203/paraphrase-vietnamese-law"
BATCH_SIZE = 64  # Larger batch for GPU
OUTPUT_FILE = "embeddings.jsonl"

# Chunking config
CHUNK_SIZE = 512  # Characters per chunk
CHUNK_OVERLAP = 100  # Overlap between chunks

print(f"Model: {MODEL_NAME}")
print(f"Output: {OUTPUT_FILE}")

In [None]:
# Cell 3: Check GPU & Import libraries
import torch
import re
import json
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tqdm

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    device = "cuda"
else:
    device = "cpu"
    print("WARNING: No GPU detected! This will be slow.")
print(f"Using device: {device}")

In [None]:
# Cell 4: Load dataset
print("Loading dataset: vietnamese-legal-corpus-20k-raw...")
print("(This may take a few minutes for 2.6GB download)")
dataset = load_dataset("52100303-TranPhuocSang/vietnamese-legal-corpus-20k-raw")
print(f"Dataset loaded: {len(dataset['train'])} documents")

In [None]:
# Cell 5: Load embedding model
print(f"Loading model: {MODEL_NAME}...")
model = SentenceTransformer(MODEL_NAME, device=device)
print(f"Model loaded! Embedding dimension: {model.get_sentence_embedding_dimension()}")

In [None]:
# Cell 6: Text processing functions
def clean_text(text):
    """Clean and normalize text"""
    if not text:
        return ""
    text = re.sub(r'<[^>]+>', '', text)  # Remove HTML tags
    text = re.sub(r'\s+', ' ', text)  # Remove extra whitespace
    return text.strip()

def chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
    """Split text into overlapping chunks"""
    if not text or len(text) < 100:
        return [text] if text else []

    chunks = []
    start = 0
    while start < len(text):
        end = start + chunk_size
        chunk = text[start:end]

        # Try to cut at sentence boundary
        if end < len(text):
            last_period = chunk.rfind('.')
            last_newline = chunk.rfind('\n')
            cut_point = max(last_period, last_newline)
            if cut_point > chunk_size // 2:
                chunk = chunk[:cut_point + 1]
                end = start + cut_point + 1

        if chunk.strip():
            chunks.append(chunk.strip())

        start = end - overlap
        if start >= len(text):
            break

    return chunks

print("Text processing functions loaded!")

In [None]:
# Cell 7: Process and chunk all documents
print("Processing and chunking all documents...")

all_chunks = []
for idx, doc in enumerate(tqdm(dataset['train'], desc="Chunking")):
    full_text = clean_text(doc.get('full_text', ''))
    title = clean_text(doc.get('title', ''))

    metadata = {
        'title': title,
        'official_number': doc.get('official_number', ''),
        'document_type': doc.get('document_type', ''),
        'document_field': doc.get('document_field', ''),
        'issued_date': doc.get('issued_date', ''),
        'effective_date': doc.get('effective_date', ''),
        'place_issue': doc.get('place_issue', ''),
        'signer': doc.get('signer', ''),
        'url': doc.get('url', ''),
        'source_id': doc.get('source_id', idx),
    }

    if full_text:
        chunks = chunk_text(full_text)
        for chunk_idx, chunk in enumerate(chunks):
            all_chunks.append({
                'text': chunk,
                'chunk_idx': chunk_idx,
                'total_chunks': len(chunks),
                **metadata
            })
    elif title:
        all_chunks.append({
            'text': title,
            'chunk_idx': 0,
            'total_chunks': 1,
            **metadata
        })

print(f"Total chunks created: {len(all_chunks)}")

In [None]:
# Cell 8: Generate embeddings and save to file
print(f"Generating embeddings and saving to {OUTPUT_FILE}...")
print(f"This will process {len(all_chunks)} chunks in batches of {BATCH_SIZE}")

with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    for i in tqdm(range(0, len(all_chunks), BATCH_SIZE), desc="Embedding"):
        batch = all_chunks[i:i+BATCH_SIZE]
        
        # Create embeddings
        texts = [chunk['text'] for chunk in batch]
        embeddings = model.encode(
            texts,
            batch_size=BATCH_SIZE,
            normalize_embeddings=True,
            show_progress_bar=False
        )
        
        # Write each point as JSON line
        for j, (chunk, embedding) in enumerate(zip(batch, embeddings)):
            point = {
                'id': i + j,
                'vector': embedding.tolist(),
                'payload': chunk
            }
            f.write(json.dumps(point, ensure_ascii=False) + '\n')
        
        # Clear GPU memory periodically
        if torch.cuda.is_available() and (i // BATCH_SIZE) % 50 == 0:
            torch.cuda.empty_cache()

print(f"\nDone! Embeddings saved to {OUTPUT_FILE}")

In [None]:
# Cell 9: Check output file
import os

file_size = os.path.getsize(OUTPUT_FILE) / (1024 * 1024 * 1024)  # GB
line_count = sum(1 for _ in open(OUTPUT_FILE, 'r', encoding='utf-8'))

print(f"{'='*50}")
print(f"EXPORT COMPLETE!")
print(f"{'='*50}")
print(f"Output file: {OUTPUT_FILE}")
print(f"File size: {file_size:.2f} GB")
print(f"Total vectors: {line_count}")
print(f"{'='*50}")
print(f"\nNext steps:")
print(f"1. Download {OUTPUT_FILE} from Colab")
print(f"2. Place it in your project's colab/ folder")
print(f"3. Run: python colab/import_embeddings_local.py")

In [None]:
# Cell 10: Download file (run this to trigger download in Colab)
from google.colab import files
files.download(OUTPUT_FILE)