# 🦜 BERTweet Embeddings for Antique Dataset

This notebook generates document embeddings using the BERTweet model from Hugging Face for the cleaned Antique dataset.

**Prerequisites:**
1. Upload the `antique_cleaned_for_embeddings.json` file from your local text cleaning service
2. The text has been pre-cleaned using the enhanced text cleaning service

**Output:**
- High-quality BERTweet embeddings for document representation
- Embeddings saved in multiple formats for easy download and use


## 📋 Setup and Installation

In [None]:
# Install required packages
!pip install -q transformers torch numpy pandas scikit-learn tqdm
!pip install -q datasets accelerate

print("✅ Packages installed successfully!")

In [None]:
# Import required libraries
import json
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.auto import tqdm
import gc
import os
from datetime import datetime
import pickle
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("✅ Libraries imported successfully!")

## 📁 Upload and Load Cleaned Data

In [None]:
# Upload the cleaned dataset file
from google.colab import files

print("📁 Please upload your 'antique_cleaned_for_embeddings.json' file:")
uploaded = files.upload()

# Get the uploaded filename
data_file = list(uploaded.keys())[0]
print(f"✅ Uploaded file: {data_file}")

In [None]:
# Load the cleaned dataset
print("📖 Loading cleaned dataset...")

with open(data_file, 'r', encoding='utf-8') as f:
    data = json.load(f)

print(f"📊 Dataset Information:")
print(f"   Dataset: {data.get('dataset', 'Unknown')}")
print(f"   Total documents: {data.get('total_documents', 0):,}")
print(f"   Export timestamp: {data.get('export_timestamp', 'Unknown')}")
print(f"   Cleaning method: {data.get('cleaning_method', 'Unknown')}")

# Extract documents
documents = data['documents']
print(f"\n🔍 Sample document:")
print(f"   ID: {documents[0]['id']}")
print(f"   Cleaned text: {documents[0]['text'][:200]}...")
print(f"   Original text: {documents[0]['original_text'][:200]}...")

print(f"\n✅ Loaded {len(documents):,} documents successfully!")

## 🤖 Load BERTweet Model

In [None]:
# Load BERTweet model and tokenizer
print("🤖 Loading BERTweet model...")

model_name = "vinai/bertweet-base"

# Load tokenizer
print("   Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

# Load model
print("   Loading model...")
model = AutoModel.from_pretrained(model_name)
model.to(device)
model.eval()

print(f"✅ BERTweet model loaded successfully!")
print(f"   Model: {model_name}")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Embedding dimension: {model.config.hidden_size}")
print(f"   Max sequence length: {tokenizer.model_max_length}")

## 🔄 Text Preprocessing for BERTweet

In [None]:
def preprocess_text_for_bertweet(text, max_length=256):
    """
    Preprocess text specifically for BERTweet model.
    
    Args:
        text: Input text string
        max_length: Maximum sequence length
        
    Returns:
        Preprocessed text string
    """
    if not text or not isinstance(text, str):
        return ""
    
    # Basic cleaning for BERTweet
    text = text.strip()
    
    # Remove extra whitespace
    text = ' '.join(text.split())
    
    # Truncate if too long (leave room for special tokens)
    words = text.split()
    if len(words) > max_length - 10:  # Reserve space for [CLS], [SEP], etc.
        text = ' '.join(words[:max_length - 10])
    
    return text

# Test preprocessing
sample_text = documents[0]['text']
processed_text = preprocess_text_for_bertweet(sample_text)

print("🔄 Text preprocessing example:")
print(f"   Original: {sample_text[:150]}...")
print(f"   Processed: {processed_text[:150]}...")
print(f"   Length: {len(processed_text.split())} words")

print("✅ Text preprocessing function ready!")

## 🚀 Generate Embeddings

In [None]:
def generate_embeddings_batch(texts, batch_size=16, max_length=256):
    """
    Generate BERTweet embeddings for a batch of texts.
    
    Args:
        texts: List of text strings
        batch_size: Batch size for processing
        max_length: Maximum sequence length
        
    Returns:
        numpy array of embeddings
    """
    all_embeddings = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch_texts = texts[i:i + batch_size]
        
        # Preprocess texts
        processed_texts = [preprocess_text_for_bertweet(text, max_length) for text in batch_texts]
        
        # Tokenize
        inputs = tokenizer(
            processed_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        
        # Move to device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate embeddings
        with torch.no_grad():
            outputs = model(**inputs)
            
            # Use [CLS] token embedding (first token)
            embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            
            all_embeddings.append(embeddings)
        
        # Clear GPU memory
        del inputs, outputs
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return np.vstack(all_embeddings)

print("✅ Embedding generation function ready!")

In [None]:
# Generate embeddings for all documents
print("🚀 Starting embedding generation...")
print(f"   Processing {len(documents):,} documents")
print(f"   Using device: {device}")

# Extract texts
texts = [doc['text'] for doc in documents]
doc_ids = [doc['id'] for doc in documents]

# Set batch size based on available memory
batch_size = 32 if torch.cuda.is_available() else 8
print(f"   Batch size: {batch_size}")

# Generate embeddings
start_time = datetime.now()
embeddings = generate_embeddings_batch(texts, batch_size=batch_size)
end_time = datetime.now()

processing_time = (end_time - start_time).total_seconds()

print(f"\n✅ Embedding generation completed!")
print(f"   Embeddings shape: {embeddings.shape}")
print(f"   Processing time: {processing_time:.2f} seconds")
print(f"   Average time per document: {processing_time/len(documents):.4f} seconds")
print(f"   Embedding dimension: {embeddings.shape[1]}")

# Basic statistics
print(f"\n📊 Embedding Statistics:")
print(f"   Mean: {embeddings.mean():.6f}")
print(f"   Std: {embeddings.std():.6f}")
print(f"   Min: {embeddings.min():.6f}")
print(f"   Max: {embeddings.max():.6f}")

## 🔍 Quality Check and Validation

In [None]:
# Quality check: Test similarity between related documents
print("🔍 Quality Check: Testing embedding similarities...")

# Sample a few documents for similarity testing
sample_indices = np.random.choice(len(embeddings), min(10, len(embeddings)), replace=False)
sample_embeddings = embeddings[sample_indices]
sample_texts = [texts[i] for i in sample_indices]
sample_ids = [doc_ids[i] for i in sample_indices]

# Calculate pairwise similarities
similarities = cosine_similarity(sample_embeddings)

print(f"\n📊 Similarity Matrix for {len(sample_indices)} sample documents:")
print(f"   Average similarity: {similarities.mean():.4f}")
print(f"   Std similarity: {similarities.std():.4f}")

# Find most similar pairs (excluding self-similarity)
similarities_no_diag = similarities.copy()
np.fill_diagonal(similarities_no_diag, -1)

max_sim_idx = np.unravel_index(np.argmax(similarities_no_diag), similarities_no_diag.shape)
max_similarity = similarities_no_diag[max_sim_idx]

print(f"\n🔗 Most similar document pair:")
print(f"   Similarity: {max_similarity:.4f}")
print(f"   Doc 1 ({sample_ids[max_sim_idx[0]]}): {sample_texts[max_sim_idx[0]][:100]}...")
print(f"   Doc 2 ({sample_ids[max_sim_idx[1]]}): {sample_texts[max_sim_idx[1]][:100]}...")

print("\n✅ Quality check completed!")

## 💾 Save Embeddings

In [None]:
# Prepare data for saving
print("💾 Preparing data for saving...")

# Create metadata
metadata = {
    'model_name': model_name,
    'embedding_dimension': embeddings.shape[1],
    'total_documents': len(documents),
    'generation_timestamp': datetime.now().isoformat(),
    'processing_time_seconds': processing_time,
    'device_used': str(device),
    'batch_size': batch_size,
    'max_sequence_length': 256,
    'dataset_info': {
        'dataset': data.get('dataset', 'antique'),
        'cleaning_method': data.get('cleaning_method', 'enhanced'),
        'export_timestamp': data.get('export_timestamp')
    }
}

print(f"✅ Metadata prepared")
print(f"   Model: {metadata['model_name']}")
print(f"   Embedding dimension: {metadata['embedding_dimension']}")
print(f"   Total documents: {metadata['total_documents']:,}")

In [None]:
# Save as NumPy format
print("💾 Saving embeddings in NumPy format...")

# Save embeddings
np.save('antique_bertweet_embeddings.npy', embeddings)

# Save document IDs
np.save('antique_bertweet_doc_ids.npy', np.array(doc_ids))

# Save metadata
with open('antique_bertweet_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print("✅ NumPy format saved:")
print("   - antique_bertweet_embeddings.npy")
print("   - antique_bertweet_doc_ids.npy")
print("   - antique_bertweet_metadata.json")

In [None]:
# Save as Pickle format (includes everything in one file)
print("💾 Saving complete dataset in Pickle format...")

complete_data = {
    'embeddings': embeddings,
    'doc_ids': doc_ids,
    'texts': texts,
    'metadata': metadata,
    'original_documents': documents[:1000]  # Save first 1000 for reference
}

with open('antique_bertweet_complete.pkl', 'wb') as f:
    pickle.dump(complete_data, f)

print("✅ Pickle format saved:")
print("   - antique_bertweet_complete.pkl")
print("   - Contains embeddings, IDs, texts, and metadata")

In [None]:
# Save as CSV format for easy inspection
print("💾 Saving document mapping in CSV format...")

# Create DataFrame with document info
df = pd.DataFrame({
    'doc_id': doc_ids,
    'text': texts,
    'text_length': [len(text) for text in texts],
    'word_count': [len(text.split()) for text in texts]
})

# Add embedding norms for quality check
df['embedding_norm'] = np.linalg.norm(embeddings, axis=1)

# Save to CSV
df.to_csv('antique_bertweet_documents.csv', index=False)

print("✅ CSV format saved:")
print("   - antique_bertweet_documents.csv")
print(f"   - {len(df):,} document records")

# Show summary statistics
print(f"\n📊 Document Statistics:")
print(f"   Average text length: {df['text_length'].mean():.1f} characters")
print(f"   Average word count: {df['word_count'].mean():.1f} words")
print(f"   Average embedding norm: {df['embedding_norm'].mean():.4f}")

## 📥 Download Files

In [None]:
# Download all generated files
print("📥 Downloading generated files...")

# List of files to download
files_to_download = [
    'antique_bertweet_embeddings.npy',
    'antique_bertweet_doc_ids.npy',
    'antique_bertweet_metadata.json',
    'antique_bertweet_complete.pkl',
    'antique_bertweet_documents.csv'
]

# Check file sizes
print("📋 File Information:")
for filename in files_to_download:
    if os.path.exists(filename):
        size_mb = os.path.getsize(filename) / (1024 * 1024)
        print(f"   {filename}: {size_mb:.2f} MB")
    else:
        print(f"   {filename}: File not found")

print("\n🚀 Ready to download! Run the next cell to download files.")

In [None]:
# Download files one by one
from google.colab import files

for filename in files_to_download:
    if os.path.exists(filename):
        print(f"📥 Downloading {filename}...")
        files.download(filename)
    else:
        print(f"❌ {filename} not found")

print("\n✅ All files downloaded successfully!")

## 📊 Final Summary

In [None]:
# Final summary
print("=" * 80)
print("🎉 BERTWEET EMBEDDING GENERATION COMPLETED")
print("=" * 80)
print()
print(f"📊 Summary:")
print(f"   Model: {model_name}")
print(f"   Documents processed: {len(documents):,}")
print(f"   Embedding dimension: {embeddings.shape[1]}")
print(f"   Total processing time: {processing_time:.2f} seconds")
print(f"   Average time per document: {processing_time/len(documents):.4f} seconds")
print(f"   Device used: {device}")
print()
print(f"📁 Generated Files:")
for filename in files_to_download:
    if os.path.exists(filename):
        size_mb = os.path.getsize(filename) / (1024 * 1024)
        print(f"   ✅ {filename} ({size_mb:.2f} MB)")
print()
print(f"📋 Next Steps:")
print(f"   1. Download all generated files to your local machine")
print(f"   2. Load the embeddings in your search engine backend")
print(f"   3. Use embeddings for document retrieval and similarity search")
print(f"   4. Evaluate retrieval performance using your evaluation metrics")
print()
print(f"💡 Usage Examples:")
print(f"   # Load embeddings in Python:")
print(f"   embeddings = np.load('antique_bertweet_embeddings.npy')")
print(f"   doc_ids = np.load('antique_bertweet_doc_ids.npy')")
print(f"   ")
print(f"   # Or load complete dataset:")
print(f"   with open('antique_bertweet_complete.pkl', 'rb') as f:")
print(f"       data = pickle.load(f)")
print()
print("=" * 80)
print("✅ EMBEDDING GENERATION SUCCESSFUL!")
print("=" * 80)

## 🔧 Memory Cleanup

In [None]:
# Clean up memory
print("🧹 Cleaning up memory...")

# Delete large variables
del embeddings
del model
del tokenizer
del documents
del texts
del data

# Force garbage collection
gc.collect()

# Clear GPU memory if available
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("🧹 GPU memory cleared")

print("✅ Memory cleanup completed!")