# Thai Sentence Embedding Model Training

This notebook fine-tunes a multilingual SentenceTransformer model for Thai sentence embeddings using the XNLI dataset and evaluates with the STSB benchmark.

## Features
- Loads and preprocesses Thai XNLI data for similarity learning
- Uses `bert-base-multilingual-cased` as the base model
- Evaluates with STSB (semantic textual similarity benchmark)
- Saves the trained model for downstream Thai NLP tasks

## Setup and Installation

In [1]:
# Install required packages
!pip install datasets sentence-transformers python-dotenv

# Import libraries
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.trainer import SentenceTransformerTrainer
import os
import sys
import logging
import getpass

Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting sentence-transformers
Collecting sentence-transformers
  Downloading sentence_transformers-4.1.0-py3-none-any.whl.metadata (13 kB)
Collecting python-dotenv
  Downloading python_dotenv-1.1.1-py3-none-any.whl.metadata (24 kB)
  Downloading sentence_transformers-4.1.0-py3-none-any.whl.metadata (13 kB)
Collecting python-dotenv
  Downloading python_dotenv-1.1.1-py3-none-any.whl.metadata (24 kB)
Collecting pyarrow>=15.0.0 (from datasets)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting tqdm>=4.66.3 (from datasets)
  D

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

## Authentication Setup

Enter your Hugging Face token to download datasets and models. You can get your token from https://huggingface.co/settings/tokens

In [1]:
# Set up Hugging Face authentication
hf_token = getpass.getpass("Enter your Hugging Face token: ")
os.environ["HF_HUB_TOKEN"] = hf_token

# Set up cache directories
os.environ["HF_HOME"] = "/content/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/content/hf_cache/datasets"
os.environ["HF_TRANSFORMERS_CACHE"] = "/content/hf_cache/transformers"

print("Authentication and cache setup complete!")

NameError: name 'getpass' is not defined

## Logging Setup

In [None]:
def setup_logging():
    """Set up logger to output info to stdout."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[logging.StreamHandler(sys.stdout)]
    )
    return logging.getLogger(__name__)

logger = setup_logging()
logger.info("Logging setup complete!")

## Data Loading and Preprocessing

In [None]:
def load_thai_data():
    """Load and preprocess Thai XNLI dataset as 0-1 similarity labels."""
    logger.info("Loading Thai XNLI dataset...")
    try:
        xnli = load_dataset("xnli", "th")
        raw_train = xnli["train"]

        # Map: entailment -> 1.0; neutral & contradiction -> 0.0
        mapping = {0: 1.0, 1: 0.0, 2: 0.0}

        train_dataset = Dataset.from_dict(
            {
                "sentence1": raw_train["premise"],
                "sentence2": raw_train["hypothesis"],
                "label": [mapping[label] for label in raw_train["label"]],
            }
        )
        logger.info(f"Loaded {len(train_dataset)} Thai training examples")
        return train_dataset
    except Exception as e:
        logger.error(f"Error loading Thai data: {e}")
        raise

# Load Thai training data
train_dataset = load_thai_data()

In [None]:
def load_validation_data():
    """Load STS-B validation set normalized to 0-1 range for cosine-similarity evaluator."""
    logger.info("Loading STSB validation data...")
    try:
        val_sts = load_dataset("glue", "stsb", split="validation")
        normalized_scores = [score / 5.0 for score in val_sts["label"]]

        evaluator = EmbeddingSimilarityEvaluator(
            sentences1=val_sts["sentence1"],
            sentences2=val_sts["sentence2"],
            scores=normalized_scores,
            main_similarity="cosine",
            name="stsb_eval"
        )
        logger.info(f"Created evaluator with {len(val_sts)} validation examples")
        return evaluator
    except Exception as e:
        logger.error(f"Error loading validation data: {e}")
        raise

# Load validation data
evaluator = load_validation_data()

## Model Setup

In [None]:
def create_model_and_loss():
    """Create multilingual sentence transformer and cosine-similarity loss."""
    model_name = "bert-base-multilingual-cased"
    logger.info(f"Loading model: {model_name}")
    embedding_model = SentenceTransformer(model_name)
    train_loss = losses.CosineSimilarityLoss(model=embedding_model)
    return embedding_model, train_loss

# Create model and loss
embedding_model, train_loss = create_model_and_loss()

## Training Configuration

In [None]:
def setup_training_args():
    """Configure training arguments for fine-tuning the model."""
    return SentenceTransformerTrainingArguments(
        output_dir="/content/thai_embedding_model",
        num_train_epochs=3,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        warmup_steps=500,
        fp16=True,
        eval_steps=500,
        logging_steps=100,
        save_steps=1000,
        load_best_model_at_end=False,  # Disable automatic best model loading
        metric_for_best_model="eval_stsb_eval_spearman_cosine",
        greater_is_better=True,
        dataloader_drop_last=False,
        learning_rate=2e-5,
    )

# Setup training arguments
args = setup_training_args()
logger.info("Training arguments configured!")

## Training Process

In [None]:
# Create trainer
trainer = SentenceTransformerTrainer(
    model=embedding_model,
    args=args,
    train_dataset=train_dataset,
    loss=train_loss,
    evaluator=evaluator
)

logger.info("Trainer created successfully!")

In [None]:
# Start training
logger.info("==== Training started ====")
trainer.train()
logger.info("==== Training finished ====")

## Model Evaluation

In [None]:
# Evaluate final model
logger.info("Running final evaluation...")
result = evaluator(embedding_model)
if isinstance(result, dict):
    for metric, value in result.items():
        if isinstance(value, float):
            logger.info(f"{metric}: {value:.4f}")
        else:
            logger.info(f"{metric}: {value}")
else:
    logger.info(f"Evaluation result: {result}")

## Save Model

In [None]:
# Save model
final_model_path = "/content/thai_sentence_transformer_final"
embedding_model.save(final_model_path)
logger.info(f"Model saved to: {final_model_path}")

# Create a zip file for easy download
!zip -r /content/thai_sentence_transformer_final.zip /content/thai_sentence_transformer_final
print("Model packaged for download at: /content/thai_sentence_transformer_final.zip")

## Usage Example

In [None]:
# Test the trained model
from sentence_transformers import SentenceTransformer

# Load the trained model
model = SentenceTransformer(final_model_path)

# Test with Thai sentences
thai_sentences = [
    "สวัสดีครับ",
    "ขอบคุณมากครับ",
    "ประโยคภาษาไทย",
    "การเรียนรู้ของเครื่อง"
]

# Generate embeddings
embeddings = model.encode(thai_sentences)

print(f"Generated embeddings for {len(thai_sentences)} Thai sentences")
print(f"Embedding shape: {embeddings.shape}")
print("\nSample sentences:")
for i, sentence in enumerate(thai_sentences):
    print(f"{i+1}. {sentence}")

# Calculate similarity between first two sentences
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

similarity = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0]
print(f"\nSimilarity between '{thai_sentences[0]}' and '{thai_sentences[1]}': {similarity:.4f}")

## Download Results

You can download the trained model and training logs from the Files panel on the left. Look for:
- `/content/thai_sentence_transformer_final.zip` - The complete trained model
- `/content/thai_embedding_model/` - Training checkpoints and logs

## Next Steps

1. Download the model to your local machine
2. Use it for Thai sentence similarity tasks
3. Fine-tune further on your specific domain data
4. Deploy for production use

## License

This project is for research and educational use. Please check the original model licenses before commercial use.