# Fine-Tuning Language Models for Text Classification: A Deep Practical Guide

Welcome to this comprehensive guide on fine-tuning large language models (LLMs) for text classification tasks! 

## 🎯 What You'll Learn

This notebook provides a complete, hands-on approach to:

1. **Understanding the Why**: Learn why fine-tuning pre-trained LLMs is crucial for text classification
2. **Complete Implementation**: Step-by-step code for the entire fine-tuning pipeline
3. **Thai Language Focus**: Special considerations for Thai text processing
4. **Best Practices**: Production-ready techniques and optimization strategies
5. **Practical Examples**: Real-world applications with sentiment analysis and topic classification

## 🔧 Key Features

- **Model Selection**: Choose the right pre-trained model for your task
- **Data Preparation**: Handle Thai text preprocessing and tokenization
- **Training Pipeline**: Complete fine-tuning with Hugging Face Transformers
- **Evaluation Metrics**: Comprehensive model assessment
- **Hyperparameter Optimization**: Systematic tuning with Optuna
- **Deployment Ready**: Save and load models for production use

## 📋 Prerequisites

- Basic Python programming knowledge
- Understanding of machine learning concepts
- Familiarity with PyTorch (helpful but not required)
- GPU access recommended for faster training

Let's start building your text classification expertise! 🚀

## 1. Import Required Libraries and Set Up Environment

First, let's import all the necessary libraries and set up our environment for reproducible results.

In [None]:
# Core ML and deep learning libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
import random
import os
from typing import List, Dict, Tuple, Optional, Union

# Hugging Face Transformers
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    TrainingArguments, 
    Trainer,
    DataCollatorWithPadding,
    EarlyStoppingCallback
)

# Dataset handling
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, 
    f1_score, 
    precision_recall_fscore_support, 
    confusion_matrix,
    classification_report
)

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Progress tracking
from tqdm.auto import tqdm

# Thai language processing
try:
    from pythainlp import word_tokenize
    from attacut import tokenize as attacut_tokenize
    print("✅ Thai NLP libraries loaded successfully")
except ImportError:
    print("⚠️ Thai NLP libraries not found. Install with: pip install pythainlp attacut")

# Hyperparameter optimization (optional)
try:
    import optuna
    print("✅ Optuna loaded for hyperparameter optimization")
except ImportError:
    print("⚠️ Optuna not found. Install with: pip install optuna")

# Set random seeds for reproducibility
def set_seed(seed: int = 42):
    """Set random seeds for reproducible results."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    # For deterministic behavior (may reduce performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Check device availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️ Using device: {device}")

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name()}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("   Note: GPU not available. Training will be slower on CPU.")

# Configure visualization settings
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("\n🚀 Environment setup complete!")

## 2. Select and Load a Pre-trained Model

Choosing the right pre-trained model is crucial for success. For Thai text classification, we'll use WangchanBERTa, a BERT-based model specifically trained on Thai text.

### Model Selection Criteria:

1. **Language Support**: Thai-specific models perform better than multilingual ones
2. **Task Compatibility**: Models designed for sequence classification
3. **Resource Requirements**: Balance between performance and computational cost
4. **Domain Relevance**: Models trained on similar domains when available

### Popular Thai Language Models:

- **WangchanBERTa**: Thai BERT variant (recommended for Thai text)
- **multilingual-BERT**: Supports Thai but less specialized
- **XLM-RoBERTa**: Cross-lingual model with Thai support

In [None]:
# Model configuration
MODEL_NAME = "airesearch/wangchanberta-base-att-spm-uncased"
MAX_LENGTH = 512
NUM_LABELS = 2  # Will be updated based on actual dataset

print(f"📋 Model Configuration:")
print(f"   Model: {MODEL_NAME}")
print(f"   Max Length: {MAX_LENGTH}")
print(f"   Initial Num Labels: {NUM_LABELS}")

# Load tokenizer
print("\n🔤 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Add padding token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("   Added padding token")

print(f"✅ Tokenizer loaded successfully")
print(f"   Vocabulary size: {tokenizer.vocab_size:,}")
print(f"   Special tokens: {len(tokenizer.special_tokens_map)}")

# Test tokenization with Thai text
sample_thai_text = "ผมชอบภาพยนตร์เรื่องนี้มาก สนุกและน่าตื่นเต้น"
sample_english_text = "I love this movie, it's exciting and entertaining"

print(f"\n🔍 Tokenization Example:")
print(f"Thai text: {sample_thai_text}")

# Tokenize and show results
thai_tokens = tokenizer.tokenize(sample_thai_text)
thai_encoded = tokenizer.encode(sample_thai_text, add_special_tokens=True)

print(f"Tokens: {thai_tokens}")
print(f"Token IDs: {thai_encoded}")
print(f"Decoded: {tokenizer.decode(thai_encoded)}")

# Function to load model (we'll do this after preparing the dataset to know exact num_labels)
def load_model(num_labels: int):
    """Load pre-trained model for sequence classification."""
    print(f"\n🤖 Loading model with {num_labels} labels...")
    
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=num_labels,
        output_attentions=False,
        output_hidden_states=False,
    )
    
    # Move model to device
    model.to(device)
    
    print(f"✅ Model loaded successfully")
    print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"   Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    return model

print("\n📝 Note: Model will be loaded after dataset preparation to determine the correct number of labels.")

## 3. Prepare and Explore the Dataset

Data quality is the foundation of successful fine-tuning. Let's create sample datasets and explore their characteristics.

### Key Dataset Considerations:

1. **Quality over Quantity**: Clean, relevant data is more valuable than large, noisy datasets
2. **Class Balance**: Check for imbalanced classes and plan mitigation strategies
3. **Representative Samples**: Ensure training data represents real-world usage
4. **Proper Splits**: Maintain consistent distributions across train/validation/test sets

In [None]:
def create_sample_datasets():
    """Create sample Thai text classification datasets for demonstration."""
    
    # Thai Sentiment Analysis Dataset
    sentiment_data = {
        'text': [
            'ผมชอบภาพยนตร์เรื่องนี้มาก สนุกและน่าตื่นเต้น',
            'ร้านอาหารนี้แย่มาก อาหารไม่อร่อย บริการไม่ดี',
            'สินค้าคุณภาพดีมาก ราคาไม่แพง คุ้มค่าเงิน แนะนำเลย',
            'บริการแย่มาก พนักงานไม่เป็นมิตร ไม่อยากมาอีก',
            'โรงแรมนี้สะอาด สะดวกสบาย ราคาดี วิวสวย',
            'หนังสือเล่มนี้น่าเบื่อมาก เนื้อหาไม่น่าสนใจ ไม่แนะนำ',
            'อาหารอร่อยมาก บรรยากาศดี ราคาถูก จะมาอีก',
            'สินค้าไม่ตรงตามที่โฆษณา คุณภาพต่ำ เสียเงิน',
            'การบริการดีมาก เจ้าหน้าที่ใส่ใจลูกค้า พอใจมาก',
            'ของเก่าแล้ว ไม่คุ้มค่า ผิดหวังมาก',
            'คุณภาพเยี่ยม ส่งเร็ว บริการดี ประทับใจมาก',
            'ไม่ดีเลย แย่มาก ไม่คุ้มค่าเงิน จะไม่ซื้ออีก',
            'สวยงาม ใช้งานง่าย คุ้มค่ามาก ชอบมาก',
            'ไม่พอใจ ไม่ตรงตามความต้องการ จะคืนสินค้า',
            'ยอดเยี่ยม ทุกอย่างดีมาก จะแนะนำเพื่อน',
            'แย่มาก ผิดหวัง ไม่ควรซื้อ เสียเงินเปล่า',
            'น่าใช้มาก ดีไซน์สวย คุณภาพดี ราคาเหมาะสม',
            'ไม่ได้เรื่อง ไม่น่าเชื่อถือ บริการแย่',
            'ประทับใจมาก บริการดีเยี่ยม จะใช้บริการอีก',
            'ไม่ดี คุณภาพต่ำ ไม่คุ้มค่า จะไม่ซื้ออีก'
        ],
        'label': [
            'positive', 'negative', 'positive', 'negative', 'positive',
            'negative', 'positive', 'negative', 'positive', 'negative',
            'positive', 'negative', 'positive', 'negative', 'positive',
            'negative', 'positive', 'negative', 'positive', 'negative'
        ]
    }
    
    # Thai Topic Classification Dataset  
    topic_data = {
        'text': [
            'การแพร่ระบาดของโควิด-19 ส่งผลกระทบต่อเศรษฐกิจโลก ทำให้หลายประเทศเข้าสู่ภาวะถดถอย',
            'นักฟุตบอลทีมชาติไทยเตรียมความพร้อมสำหรับการแข่งขันฟุตบอลโลกรอบคัดเลือก',
            'เทคโนโลยี AI และ Machine Learning กำลังเปลี่ยนแปลงวิธีการทำงานในหลายอุตสาหกรรม',
            'ราคาน้ำมันปรับตัวสูงขึ้นอย่างต่อเนื่อง ส่งผลต่อต้นทุนการผลิตและการขนส่ง',
            'การท่องเที่ยวในประเทศไทยเริ่มฟื้นตัวหลังจากสถานการณ์โควิด-19 คลี่คลาย',
            'การพัฒนาแอปพลิเคชันมือถือสำหรับธุรกิจอีคอมเมิร์ซเป็นที่นิยมมากขึ้น',
            'นโยบายสาธารณสุขใหม่เพื่อรองรับการแพร่ระบาดของโรคติดเชื้อในอนาคต',
            'การแข่งขันกีฬาโอลิมปิกเป็นเวทีสำคัญสำหรับนักกีฬาจากทั่วโลก',
            'นวัตกรรมในด้าน Blockchain และ Cryptocurrency กำลังได้รับความสนใจ',
            'ภาวะเงินเฟ้อส่งผลต่อการใช้จ่าย บริโภค และการลงทุนของประชาชน',
            'เทศกาลท่องเที่ยวและวัฒนธรรมท้องถิ่นช่วยกระตุ้นเศรษฐกิจชุมชน',
            'การใช้ Big Data ในการวิเคราะห์พฤติกรรมผู้บริโภคช่วยเพิ่มประสิทธิภาพทางธุรกิจ'
        ],
        'label': [
            'health', 'sports', 'technology', 'economy', 'travel', 'technology',
            'health', 'sports', 'technology', 'economy', 'travel', 'technology'
        ]
    }
    
    return {
        'sentiment': pd.DataFrame(sentiment_data),
        'topic': pd.DataFrame(topic_data)
    }

# Create sample datasets
datasets = create_sample_datasets()

print("📊 Sample Datasets Created:")
print("\n1. Sentiment Analysis Dataset:")
print(f"   Total samples: {len(datasets['sentiment'])}")
print(f"   Classes: {datasets['sentiment']['label'].unique()}")
print(f"   Class distribution:")
sentiment_counts = datasets['sentiment']['label'].value_counts()
for label, count in sentiment_counts.items():
    print(f"     {label}: {count} ({count/len(datasets['sentiment']):.1%})")

print("\n2. Topic Classification Dataset:")
print(f"   Total samples: {len(datasets['topic'])}")
print(f"   Classes: {datasets['topic']['label'].unique()}")
print(f"   Class distribution:")
topic_counts = datasets['topic']['label'].value_counts()
for label, count in topic_counts.items():
    print(f"     {label}: {count} ({count/len(datasets['topic']):.1%})")

# Let's use sentiment analysis for our main example
df = datasets['sentiment'].copy()
print(f"\n🎯 Using sentiment analysis dataset for fine-tuning example")

# Display sample data
print("\n📋 Sample Data:")
for i in range(3):
    print(f"{i+1}. Text: {df.iloc[i]['text']}")
    print(f"   Label: {df.iloc[i]['label']}")
    print()

In [None]:
# Visualize class distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Sentiment analysis distribution
sentiment_counts.plot(kind='bar', ax=ax1, color=['skyblue', 'lightcoral'])
ax1.set_title('Sentiment Analysis - Class Distribution')
ax1.set_xlabel('Sentiment')
ax1.set_ylabel('Count')
ax1.tick_params(axis='x', rotation=0)

# Topic classification distribution
topic_counts.plot(kind='bar', ax=ax2, color=['lightgreen', 'orange', 'purple', 'yellow'])
ax2.set_title('Topic Classification - Class Distribution')
ax2.set_xlabel('Topic')
ax2.set_ylabel('Count')
ax2.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

# Text length analysis
df['text_length'] = df['text'].str.len()
df['word_count'] = df['text'].str.split().str.len()

print("📏 Text Statistics:")
print(f"Average text length: {df['text_length'].mean():.1f} characters")
print(f"Average word count: {df['word_count'].mean():.1f} words")
print(f"Max text length: {df['text_length'].max()} characters")
print(f"Min text length: {df['text_length'].min()} characters")

# Plot text length distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.hist(df['text_length'], bins=15, alpha=0.7, color='skyblue', edgecolor='black')
plt.xlabel('Text Length (characters)')
plt.ylabel('Frequency')
plt.title('Distribution of Text Lengths')

plt.subplot(1, 2, 2)
plt.hist(df['word_count'], bins=15, alpha=0.7, color='lightgreen', edgecolor='black')
plt.xlabel('Word Count')
plt.ylabel('Frequency')
plt.title('Distribution of Word Counts')

plt.tight_layout()
plt.show()

## 4. Preprocess and Tokenize Data

Proper preprocessing and tokenization are crucial for model performance. We'll handle Thai-specific text processing, create label mappings, and prepare data for training.

In [None]:
def preprocess_thai_text(text: str) -> str:
    """
    Preprocess Thai text for better model performance.
    
    Args:
        text: Input Thai text
        
    Returns:
        Preprocessed text
    """
    # Remove excessive whitespace
    text = ' '.join(text.split())
    
    # Optional: Apply Thai word segmentation (uncomment if needed)
    # text = ' '.join(word_tokenize(text, engine='attacut'))
    
    return text

# Create label mappings
unique_labels = df['label'].unique()
label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}

print(f"🏷️ Label Mappings:")
print(f"   Labels: {unique_labels}")
print(f"   label2id: {label2id}")
print(f"   id2label: {id2label}")

# Update NUM_LABELS
NUM_LABELS = len(unique_labels)
print(f"   Number of labels: {NUM_LABELS}")

# Preprocess text and create numeric labels
df['text_processed'] = df['text'].apply(preprocess_thai_text)
df['labels'] = df['label'].map(label2id)

print(f"\n✅ Data preprocessing completed")
print(f"   Original text column: 'text'")
print(f"   Processed text column: 'text_processed'")
print(f"   Numeric labels column: 'labels'")

# Split data into train, validation, and test sets
print(f"\n🔀 Splitting data...")

# First split: train+val vs test (80-20)
train_val_df, test_df = train_test_split(
    df, 
    test_size=0.2, 
    random_state=42, 
    stratify=df['labels']
)

# Second split: train vs val (75-25 of train_val, resulting in 60-20-20 overall)
train_df, val_df = train_test_split(
    train_val_df, 
    test_size=0.25, 
    random_state=42, 
    stratify=train_val_df['labels']
)

print(f"   Train set: {len(train_df)} samples ({len(train_df)/len(df):.1%})")
print(f"   Validation set: {len(val_df)} samples ({len(val_df)/len(df):.1%})")
print(f"   Test set: {len(test_df)} samples ({len(test_df)/len(df):.1%})")

# Verify class distribution in splits
print(f"\n📊 Class distribution in splits:")
for split_name, split_df in [('Train', train_df), ('Validation', val_df), ('Test', test_df)]:
    dist = split_df['labels'].value_counts().sort_index()
    dist_pct = (dist / len(split_df) * 100).round(1)
    print(f"   {split_name}: {dict(zip([id2label[i] for i in dist.index], dist_pct))}")

# Tokenization function
def tokenize_texts(texts, max_length=MAX_LENGTH):
    """Tokenize a list of texts."""
    return tokenizer(
        texts,
        truncation=True,
        padding=True,
        max_length=max_length,
        return_tensors='pt'
    )

# Test tokenization
print(f"\n🔤 Testing tokenization:")
sample_text = df['text_processed'].iloc[0]
sample_encoding = tokenizer(
    sample_text,
    truncation=True,
    padding=True,
    max_length=MAX_LENGTH,
    return_tensors='pt'
)

print(f"   Sample text: {sample_text}")
print(f"   Input IDs shape: {sample_encoding['input_ids'].shape}")
print(f"   Attention mask shape: {sample_encoding['attention_mask'].shape}")
print(f"   Token count: {sample_encoding['input_ids'].shape[1]}")

# Show tokenized version
tokens = tokenizer.convert_ids_to_tokens(sample_encoding['input_ids'][0])
print(f"   First 10 tokens: {tokens[:10]}")
print(f"   Last 10 tokens: {tokens[-10:]}")

## 5. Create Data Loaders

Now we'll convert our preprocessed data into Hugging Face Dataset objects and create efficient data loaders for training.

In [None]:
# Create Dataset objects from DataFrames
def create_dataset(df, text_col='text_processed', label_col='labels'):
    """Create a Hugging Face Dataset from DataFrame."""
    return Dataset.from_dict({
        'text': df[text_col].tolist(),
        'labels': df[label_col].tolist()
    })

# Create datasets for each split
train_dataset = create_dataset(train_df)
val_dataset = create_dataset(val_df)
test_dataset = create_dataset(test_df)

print(f"📦 Created Dataset objects:")
print(f"   Train dataset: {len(train_dataset)} samples")
print(f"   Validation dataset: {len(val_dataset)} samples")
print(f"   Test dataset: {len(test_dataset)} samples")

# Tokenization function for datasets
def tokenize_function(examples):
    """Function to tokenize examples in a dataset."""
    return tokenizer(
        examples['text'],
        truncation=True,
        padding=True,
        max_length=MAX_LENGTH,
        return_tensors='pt'
    )

# Apply tokenization to all datasets
print(f"\n⚙️ Tokenizing datasets...")
train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

print(f"✅ Tokenization completed")

# Create DatasetDict for convenient handling
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

print(f"\n📊 Dataset summary:")
print(dataset_dict)

# Inspect a sample from the training dataset
print(f"\n🔍 Sample from training dataset:")
sample = train_dataset[0]
print(f"   Text: {sample['text'][:100]}...")
print(f"   Label: {sample['labels']} ({id2label[sample['labels']]})")
print(f"   Input IDs shape: {len(sample['input_ids'])}")
print(f"   Attention mask shape: {len(sample['attention_mask'])}")

# Create data collator for dynamic padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

print(f"\n🎯 Data collator created for dynamic padding")
print(f"   This will handle variable-length sequences efficiently during training")

## 6. Configure Training Hyperparameters

Hyperparameter selection significantly impacts training success. Let's set up a comprehensive configuration with explanations for each parameter.

In [None]:
# Training hyperparameters
training_config = {
    # Core training parameters
    'learning_rate': 2e-5,          # Common starting point for BERT models
    'batch_size': 8,                # Adjust based on GPU memory (16 or 32 if possible)
    'num_epochs': 3,                # Usually 2-5 epochs for fine-tuning
    'warmup_ratio': 0.1,            # 10% of training steps for warmup
    'weight_decay': 0.01,           # L2 regularization
    
    # Training stability
    'max_grad_norm': 1.0,           # Gradient clipping
    'gradient_accumulation_steps': 2, # Simulate larger batch size
    
    # Evaluation and saving
    'evaluation_strategy': 'steps',  # Evaluate during training
    'eval_steps': 50,               # Steps between evaluations
    'save_steps': 100,              # Steps between checkpoints
    'logging_steps': 10,            # Steps between logging
    
    # Model selection
    'load_best_model_at_end': True,
    'metric_for_best_model': 'f1',
    'greater_is_better': True,
    
    # Performance optimization
    'fp16': torch.cuda.is_available(),  # Mixed precision training
    'dataloader_num_workers': 0,       # Number of data loading workers
    'disable_tqdm': False,             # Show progress bars
    
    # Early stopping
    'early_stopping_patience': 3,      # Stop if no improvement for 3 evaluations
    
    # Output configuration
    'output_dir': './results',
    'logging_dir': './logs',
    'report_to': None,                 # Disable wandb/tensorboard for now
}

print("⚙️ Training Configuration:")
for key, value in training_config.items():
    print(f"   {key}: {value}")

# Calculate effective batch size
effective_batch_size = training_config['batch_size'] * training_config['gradient_accumulation_steps']
print(f"\n📊 Effective batch size: {effective_batch_size}")

# Estimate training steps
steps_per_epoch = len(train_dataset) // effective_batch_size
total_steps = steps_per_epoch * training_config['num_epochs']
warmup_steps = int(total_steps * training_config['warmup_ratio'])

print(f"   Steps per epoch: {steps_per_epoch}")
print(f"   Total training steps: {total_steps}")
print(f"   Warmup steps: {warmup_steps}")

# Define compute metrics function
def compute_metrics(eval_pred):
    """Compute evaluation metrics."""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    # Calculate various metrics
    accuracy = accuracy_score(labels, predictions)
    f1_macro = f1_score(labels, predictions, average='macro')
    f1_weighted = f1_score(labels, predictions, average='weighted')
    
    precision, recall, f1_micro, _ = precision_recall_fscore_support(
        labels, predictions, average='micro'
    )
    
    return {
        'accuracy': accuracy,
        'f1': f1_weighted,  # Primary metric for model selection
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'precision': precision,
        'recall': recall
    }

print(f"\n📈 Evaluation metrics configured:")
print("   - Accuracy: Overall correctness")
print("   - F1-weighted: Handles class imbalance")
print("   - F1-macro: Equal weight to all classes")
print("   - F1-micro: Overall micro-averaged performance")
print("   - Precision & Recall: Additional detailed metrics")

## 7. Fine-Tune the Model

Now we'll load the model with the correct number of labels and start the fine-tuning process.

In [None]:
# Load the model with correct number of labels
print(f"🤖 Loading model for {NUM_LABELS} classes...")
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_LABELS,
    id2label=id2label,
    label2id=label2id,
    output_attentions=False,
    output_hidden_states=False,
)

# Move model to device
model.to(device)

print(f"✅ Model loaded successfully")
print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Create training arguments
training_args = TrainingArguments(
    output_dir=training_config['output_dir'],
    learning_rate=training_config['learning_rate'],
    per_device_train_batch_size=training_config['batch_size'],
    per_device_eval_batch_size=training_config['batch_size'],
    num_train_epochs=training_config['num_epochs'],
    warmup_ratio=training_config['warmup_ratio'],
    weight_decay=training_config['weight_decay'],
    max_grad_norm=training_config['max_grad_norm'],
    gradient_accumulation_steps=training_config['gradient_accumulation_steps'],
    evaluation_strategy=training_config['evaluation_strategy'],
    eval_steps=training_config['eval_steps'],
    save_steps=training_config['save_steps'],
    logging_steps=training_config['logging_steps'],
    logging_dir=training_config['logging_dir'],
    load_best_model_at_end=training_config['load_best_model_at_end'],
    metric_for_best_model=training_config['metric_for_best_model'],
    greater_is_better=training_config['greater_is_better'],
    fp16=training_config['fp16'],
    dataloader_num_workers=training_config['dataloader_num_workers'],
    disable_tqdm=training_config['disable_tqdm'],
    report_to=training_config['report_to'],
    save_strategy='steps',
    seed=42,
)

# Early stopping callback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=training_config['early_stopping_patience']
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],
)

print(f"\n🎯 Trainer configured successfully")
print(f"   Training dataset: {len(train_dataset)} samples")
print(f"   Validation dataset: {len(val_dataset)} samples")
print(f"   Early stopping patience: {training_config['early_stopping_patience']} evaluations")

# Start training
print(f"\n🚀 Starting fine-tuning...")
print(f"   This may take several minutes depending on your hardware")
print(f"   Monitor the progress below:")

# Train the model
train_result = trainer.train()

print(f"\n✅ Training completed!")
print(f"   Training time: {train_result.metrics['train_runtime']:.2f} seconds")
print(f"   Training samples per second: {train_result.metrics['train_samples_per_second']:.2f}")
print(f"   Final training loss: {train_result.metrics['train_loss']:.4f}")

# Save the model
trainer.save_model()
tokenizer.save_pretrained(training_config['output_dir'])

print(f"\n💾 Model saved to: {training_config['output_dir']}")
print("   Files saved:")
print("   - pytorch_model.bin (model weights)")
print("   - config.json (model configuration)")
print("   - tokenizer files")

## 8. Evaluate Model Performance

Let's comprehensively evaluate our fine-tuned model using various metrics and visualizations.

In [None]:
# Evaluate on test set
print("📊 Evaluating on test set...")
test_results = trainer.evaluate(test_dataset)

print(f"🎯 Test Set Results:")
for metric, value in test_results.items():
    if metric.startswith('eval_'):
        metric_name = metric.replace('eval_', '')
        print(f"   {metric_name}: {value:.4f}")

# Get predictions for detailed analysis
print(f"\n🔍 Getting predictions for detailed analysis...")
predictions = trainer.predict(test_dataset)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = predictions.label_ids

# Generate classification report
print(f"\n📋 Classification Report:")
report = classification_report(
    y_true, y_pred, 
    target_names=list(id2label.values()),
    digits=4
)
print(report)

# Create confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=list(id2label.values()),
    yticklabels=list(id2label.values())
)
plt.title('Confusion Matrix - Test Set')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.tight_layout()
plt.show()

# Calculate per-class metrics
print(f"\n📈 Per-Class Performance:")
for i, class_name in id2label.items():
    class_mask = (y_true == i)
    if class_mask.sum() > 0:
        class_acc = (y_pred[class_mask] == y_true[class_mask]).mean()
        print(f"   {class_name}: {class_acc:.4f} accuracy ({class_mask.sum()} samples)")

# Analyze misclassifications
print(f"\n🔍 Misclassification Analysis:")
misclassified = (y_pred != y_true)
misclassified_indices = np.where(misclassified)[0]

print(f"   Total misclassifications: {misclassified.sum()}")
print(f"   Misclassification rate: {misclassified.mean():.4f}")

if len(misclassified_indices) > 0:
    print(f"\n   Sample misclassifications:")
    for i, idx in enumerate(misclassified_indices[:3]):  # Show first 3
        true_label = id2label[y_true[idx]]
        pred_label = id2label[y_pred[idx]]
        text = test_dataset[idx]['text']
        print(f"   {i+1}. Text: {text[:100]}...")
        print(f"      True: {true_label} | Predicted: {pred_label}")
        print()

# Get prediction probabilities for confidence analysis
prediction_probs = torch.softmax(torch.tensor(predictions.predictions), dim=1)
max_probs = prediction_probs.max(dim=1).values
confidence_scores = max_probs.numpy()

print(f"📊 Prediction Confidence Analysis:")
print(f"   Mean confidence: {confidence_scores.mean():.4f}")
print(f"   Median confidence: {np.median(confidence_scores):.4f}")
print(f"   Min confidence: {confidence_scores.min():.4f}")
print(f"   Max confidence: {confidence_scores.max():.4f}")

# Plot confidence distribution
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.hist(confidence_scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
plt.xlabel('Confidence Score')
plt.ylabel('Frequency')
plt.title('Distribution of Prediction Confidence')

plt.subplot(1, 2, 2)
# Confidence vs accuracy
correct_predictions = (y_pred == y_true)
plt.scatter(confidence_scores[correct_predictions], [1]*correct_predictions.sum(), 
           alpha=0.6, color='green', label='Correct', s=10)
plt.scatter(confidence_scores[~correct_predictions], [0]*misclassified.sum(), 
           alpha=0.6, color='red', label='Incorrect', s=10)
plt.xlabel('Confidence Score')
plt.ylabel('Correctness')
plt.title('Confidence vs Correctness')
plt.legend()
plt.yticks([0, 1], ['Incorrect', 'Correct'])

plt.tight_layout()
plt.show()

# Create summary metrics dictionary
final_metrics = {
    'test_accuracy': test_results['eval_accuracy'],
    'test_f1_weighted': test_results['eval_f1'],
    'test_f1_macro': test_results['eval_f1_macro'],
    'test_precision': test_results['eval_precision'],
    'test_recall': test_results['eval_recall'],
    'misclassification_rate': misclassified.mean(),
    'mean_confidence': confidence_scores.mean(),
    'total_samples': len(y_true)
}

print(f"\n📊 Final Performance Summary:")
for metric, value in final_metrics.items():
    print(f"   {metric}: {value:.4f}")

# Save results
results_df = pd.DataFrame({
    'text': [test_dataset[i]['text'] for i in range(len(test_dataset))],
    'true_label': [id2label[label] for label in y_true],
    'predicted_label': [id2label[label] for label in y_pred],
    'confidence': confidence_scores,
    'correct': y_pred == y_true
})

results_df.to_csv('test_results.csv', index=False)
print(f"\n💾 Detailed results saved to 'test_results.csv'")

## 9. Optimize and Tune Hyperparameters

Now let's explore hyperparameter optimization to potentially improve our model's performance using Optuna.

In [None]:
# Define hyperparameter search space
def create_hyperparameter_search_space():
    """Define the search space for hyperparameter optimization."""
    return {
        'learning_rate': (1e-6, 1e-4),     # Learning rate range
        'batch_size': [8, 16],             # Batch size options
        'num_epochs': (2, 4),              # Number of epochs
        'warmup_ratio': (0.0, 0.2),       # Warmup ratio
        'weight_decay': (0.0, 0.3),       # Weight decay
    }

# Hyperparameter optimization function
def hyperparameter_optimization(n_trials=10):
    """
    Perform hyperparameter optimization using Optuna.
    
    Args:
        n_trials: Number of optimization trials
    
    Returns:
        Best hyperparameters and study object
    """
    try:
        import optuna
    except ImportError:
        print("❌ Optuna not installed. Skipping hyperparameter optimization.")
        print("   Install with: pip install optuna")
        return None, None
    
    def objective(trial):
        """Objective function for Optuna optimization."""
        
        # Sample hyperparameters
        learning_rate = trial.suggest_float('learning_rate', 1e-6, 1e-4, log=True)
        batch_size = trial.suggest_categorical('batch_size', [8, 16])
        num_epochs = trial.suggest_int('num_epochs', 2, 4)
        warmup_ratio = trial.suggest_float('warmup_ratio', 0.0, 0.2)
        weight_decay = trial.suggest_float('weight_decay', 0.0, 0.3)
        
        # Create model for this trial
        trial_model = AutoModelForSequenceClassification.from_pretrained(
            MODEL_NAME,
            num_labels=NUM_LABELS,
            id2label=id2label,
            label2id=label2id,
        )
        trial_model.to(device)
        
        # Configure training arguments
        trial_args = TrainingArguments(
            output_dir=f'./hp_search_trial_{trial.number}',
            learning_rate=learning_rate,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=num_epochs,
            warmup_ratio=warmup_ratio,
            weight_decay=weight_decay,
            evaluation_strategy='epoch',
            save_strategy='epoch',
            logging_steps=50,
            load_best_model_at_end=True,
            metric_for_best_model='f1',
            greater_is_better=True,
            report_to=None,
            disable_tqdm=True,  # Disable progress bars for cleaner output
            fp16=torch.cuda.is_available(),
        )
        
        # Create trainer
        trial_trainer = Trainer(
            model=trial_model,
            args=trial_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )
        
        # Train and evaluate
        trial_trainer.train()
        eval_results = trial_trainer.evaluate()
        
        # Clean up
        del trial_model
        del trial_trainer
        torch.cuda.empty_cache()
        
        return eval_results['eval_f1']
    
    # Create study
    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=n_trials)
    
    return study.best_params, study

# Run hyperparameter optimization (optional - can be time-consuming)
run_hp_optimization = False  # Set to True to run optimization

if run_hp_optimization:
    print("🔧 Starting hyperparameter optimization...")
    print("   This may take significant time depending on n_trials")
    
    best_params, study = hyperparameter_optimization(n_trials=5)  # Reduced for demo
    
    if best_params:
        print(f"\n✅ Hyperparameter optimization completed!")
        print(f"   Best parameters: {best_params}")
        print(f"   Best F1 score: {study.best_value:.4f}")
        
        # Visualize optimization results
        if hasattr(optuna.visualization, 'plot_optimization_history'):
            fig = optuna.visualization.plot_optimization_history(study)
            fig.show()
        
        print(f"\n💡 Recommendations:")
        print(f"   Use these parameters for your final model training")
        print(f"   Consider running more trials for better results")
    
else:
    print("🔧 Hyperparameter optimization skipped")
    print("   Set run_hp_optimization = True to run optimization")
    print("   Recommended search space:")
    search_space = create_hyperparameter_search_space()
    for param, space in search_space.items():
        print(f"     {param}: {space}")

# Manual hyperparameter experimentation
print(f"\n🧪 Manual Hyperparameter Experimentation:")
print("   Here are some common hyperparameter combinations to try:")

hp_combinations = [
    {'learning_rate': 2e-5, 'batch_size': 16, 'num_epochs': 3, 'warmup_ratio': 0.1},
    {'learning_rate': 3e-5, 'batch_size': 8, 'num_epochs': 4, 'warmup_ratio': 0.06},
    {'learning_rate': 1e-5, 'batch_size': 32, 'num_epochs': 2, 'warmup_ratio': 0.15},
    {'learning_rate': 5e-5, 'batch_size': 16, 'num_epochs': 3, 'warmup_ratio': 0.0},
]

for i, combo in enumerate(hp_combinations, 1):
    print(f"   Combination {i}: {combo}")

print(f"\n💡 Hyperparameter Tuning Tips:")
print("   1. Start with learning rates between 1e-5 and 5e-5")
print("   2. Use smaller batch sizes if you have limited GPU memory")
print("   3. Try 2-4 epochs to avoid overfitting")
print("   4. Use warmup for training stability")
print("   5. Apply weight decay (0.01-0.1) for regularization")
print("   6. Monitor validation metrics to detect overfitting")

# Regularization techniques
print(f"\n🛡️ Additional Regularization Techniques:")
print("   - Dropout: Already included in the pre-trained model")
print("   - Weight decay: Configurable in training arguments")
print("   - Early stopping: Implemented in our training")
print("   - Data augmentation: Consider for small datasets")
print("   - Learning rate scheduling: Can be added to training arguments")

## 10. Save and Deploy the Fine-Tuned Model

Let's save our model properly and demonstrate how to load it for inference in production environments.

In [None]:
# Save the final model
model_save_path = "./final_model"
print(f"💾 Saving final model to: {model_save_path}")

# Save model and tokenizer
trainer.save_model(model_save_path)
tokenizer.save_pretrained(model_save_path)

# Save additional metadata
import json

model_metadata = {
    'model_name': MODEL_NAME,
    'num_labels': NUM_LABELS,
    'label2id': label2id,
    'id2label': {str(k): v for k, v in id2label.items()},  # JSON requires string keys
    'max_length': MAX_LENGTH,
    'training_config': training_config,
    'final_metrics': final_metrics,
    'classes': list(label2id.keys())
}

with open(f"{model_save_path}/model_metadata.json", 'w', encoding='utf-8') as f:
    json.dump(model_metadata, f, indent=2, ensure_ascii=False)

print("✅ Model saved successfully!")
print("   Files saved:")
print("   - pytorch_model.bin (model weights)")
print("   - config.json (model configuration)")
print("   - tokenizer.json (tokenizer)")
print("   - model_metadata.json (custom metadata)")

# Demonstrate loading the saved model
print(f"\n🔄 Demonstrating model loading...")

# Load model and tokenizer
loaded_model = AutoModelForSequenceClassification.from_pretrained(model_save_path)
loaded_tokenizer = AutoTokenizer.from_pretrained(model_save_path)
loaded_model.to(device)

# Load metadata
with open(f"{model_save_path}/model_metadata.json", 'r', encoding='utf-8') as f:
    loaded_metadata = json.load(f)

print("✅ Model loaded successfully!")
print(f"   Model: {loaded_metadata['model_name']}")
print(f"   Classes: {loaded_metadata['classes']}")

# Create inference pipeline
from transformers import pipeline

# Create text classification pipeline
classifier = pipeline(
    "text-classification",
    model=loaded_model,
    tokenizer=loaded_tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

print(f"\n🚀 Inference pipeline created!")

# Test inference with new examples
test_texts = [
    "สินค้านี้ดีมาก คุณภาพเยี่ยม แนะนำเลย",  # Should be positive
    "ไม่พอใจ บริการแย่มาก ไม่คุ้มค่า",        # Should be negative
    "ราคาเหมาะสม ใช้งานได้ปกติ",              # Neutral/positive
]

print(f"\n🔍 Testing inference:")
for i, text in enumerate(test_texts, 1):
    result = classifier(text)
    print(f"{i}. Text: {text}")
    print(f"   Prediction: {result[0]['label']} (confidence: {result[0]['score']:.4f})")
    print()

# Custom inference function for more control
def predict_sentiment(text, return_probabilities=False):
    """
    Custom inference function with preprocessing.
    
    Args:
        text: Input text to classify
        return_probabilities: Whether to return probabilities for all classes
        
    Returns:
        Prediction result
    """
    # Preprocess text
    processed_text = preprocess_thai_text(text)
    
    # Tokenize
    inputs = loaded_tokenizer(
        processed_text,
        truncation=True,
        padding=True,
        max_length=MAX_LENGTH,
        return_tensors='pt'
    )
    
    # Move to device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Get predictions
    loaded_model.eval()
    with torch.no_grad():
        outputs = loaded_model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)
        
        if return_probabilities:
            return {
                'probabilities': probabilities.cpu().numpy()[0],
                'labels': list(loaded_metadata['classes'])
            }
        else:
            predicted_class_id = torch.argmax(probabilities, dim=-1).item()
            predicted_label = loaded_metadata['id2label'][str(predicted_class_id)]
            confidence = probabilities.max().item()
            
            return {
                'label': predicted_label,
                'confidence': confidence
            }

print(f"🧪 Testing custom inference function:")
for i, text in enumerate(test_texts, 1):
    result = predict_sentiment(text)
    prob_result = predict_sentiment(text, return_probabilities=True)
    
    print(f"{i}. Text: {text}")
    print(f"   Prediction: {result['label']} (confidence: {result['confidence']:.4f})")
    print(f"   All probabilities: {dict(zip(prob_result['labels'], prob_result['probabilities'].round(4)))}")
    print()

# Create a simple API endpoint example (using FastAPI)
api_code = '''
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

app = FastAPI()

# Load model at startup
model = AutoModelForSequenceClassification.from_pretrained("./final_model")
tokenizer = AutoTokenizer.from_pretrained("./final_model")

class TextInput(BaseModel):
    text: str

@app.post("/predict")
async def predict(input_data: TextInput):
    # Tokenize and predict
    inputs = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.softmax(outputs.logits, dim=-1)
        predicted_class = torch.argmax(probabilities, dim=-1).item()
        confidence = probabilities.max().item()
    
    return {
        "prediction": model.config.id2label[predicted_class],
        "confidence": float(confidence)
    }

# Run with: uvicorn api:app --reload
'''

print(f"📡 API Deployment Example:")
print("   Save the following code as 'api.py' and run with:")
print("   uvicorn api:app --reload")
print(f"   Then test with: curl -X POST http://localhost:8000/predict -H 'Content-Type: application/json' -d '{{\"text\": \"Your Thai text here\"}}'")

# Save API code
with open("api_example.py", "w", encoding="utf-8") as f:
    f.write(api_code)

print(f"\n💾 API example saved as 'api_example.py'")

# Performance optimization tips
print(f"\n⚡ Performance Optimization Tips:")
print("   1. Model Quantization: Reduce model size and inference time")
print("   2. ONNX Conversion: Convert to ONNX for faster inference")
print("   3. TensorRT: Use NVIDIA TensorRT for GPU optimization")
print("   4. Batch Processing: Process multiple texts together")
print("   5. Caching: Cache frequent predictions")
print("   6. Model Distillation: Create smaller student models")

print(f"\n🎯 Deployment Checklist:")
print("   ✅ Model saved with all necessary files")
print("   ✅ Inference pipeline tested")
print("   ✅ Custom prediction function created")
print("   ✅ API example provided")
print("   ✅ Performance optimization tips listed")
print("\n🚀 Your model is ready for production deployment!")

## 11. Monitor Model and Handle Data Drift

Production deployment requires ongoing monitoring to ensure model performance remains stable over time. Let's implement monitoring strategies and data drift detection.

In [None]:
import datetime
from collections import defaultdict
import warnings

class ModelMonitor:
    """
    A simple monitoring system for text classification models in production.
    """
    
    def __init__(self, model_name: str = "Thai-Sentiment-Classifier"):
        self.model_name = model_name
        self.predictions_log = []
        self.performance_metrics = defaultdict(list)
        self.confidence_thresholds = {
            'high': 0.8,    # High confidence predictions
            'medium': 0.6,  # Medium confidence predictions
            'low': 0.4      # Low confidence predictions (need review)
        }
        
    def log_prediction(self, text: str, prediction: str, confidence: float, timestamp: datetime.datetime = None):
        """Log a single prediction with metadata."""
        if timestamp is None:
            timestamp = datetime.datetime.now()
            
        log_entry = {
            'timestamp': timestamp,
            'text': text,
            'prediction': prediction,
            'confidence': confidence,
            'text_length': len(text),
            'word_count': len(text.split())
        }
        
        self.predictions_log.append(log_entry)
        
    def analyze_confidence_distribution(self, recent_days: int = 7):
        """Analyze confidence distribution over recent predictions."""
        cutoff_date = datetime.datetime.now() - datetime.timedelta(days=recent_days)
        recent_predictions = [
            p for p in self.predictions_log 
            if p['timestamp'] >= cutoff_date
        ]
        
        if not recent_predictions:
            print(f"⚠️ No predictions found in the last {recent_days} days")
            return
            
        confidences = [p['confidence'] for p in recent_predictions]
        
        high_conf = sum(1 for c in confidences if c >= self.confidence_thresholds['high'])
        medium_conf = sum(1 for c in confidences if self.confidence_thresholds['medium'] <= c < self.confidence_thresholds['high'])
        low_conf = sum(1 for c in confidences if c < self.confidence_thresholds['medium'])
        
        total = len(confidences)
        
        print(f"📊 Confidence Distribution (Last {recent_days} days):")
        print(f"   Total predictions: {total}")
        print(f"   High confidence (≥{self.confidence_thresholds['high']}): {high_conf} ({high_conf/total:.1%})")
        print(f"   Medium confidence ({self.confidence_thresholds['medium']}-{self.confidence_thresholds['high']}): {medium_conf} ({medium_conf/total:.1%})")
        print(f"   Low confidence (<{self.confidence_thresholds['medium']}): {low_conf} ({low_conf/total:.1%})")
        
        if low_conf / total > 0.2:  # More than 20% low confidence
            print("   ⚠️ Warning: High proportion of low-confidence predictions!")
            print("   Consider model retraining or additional data collection.")
            
        return {
            'high': high_conf / total,
            'medium': medium_conf / total,
            'low': low_conf / total
        }
    
    def detect_input_drift(self, reference_stats: dict = None, recent_days: int = 7):
        """Detect potential input drift by comparing text characteristics."""
        cutoff_date = datetime.datetime.now() - datetime.timedelta(days=recent_days)
        recent_predictions = [
            p for p in self.predictions_log 
            if p['timestamp'] >= cutoff_date
        ]
        
        if not recent_predictions:
            print(f"⚠️ No recent predictions to analyze")
            return
            
        # Calculate current statistics
        current_stats = {
            'avg_text_length': np.mean([p['text_length'] for p in recent_predictions]),
            'avg_word_count': np.mean([p['word_count'] for p in recent_predictions]),
            'max_text_length': max([p['text_length'] for p in recent_predictions]),
            'min_text_length': min([p['text_length'] for p in recent_predictions])
        }
        
        print(f"📏 Input Characteristics (Last {recent_days} days):")
        for stat, value in current_stats.items():
            print(f"   {stat}: {value:.2f}")
            
        # Compare with reference if provided
        if reference_stats:
            print(f"\n🔍 Drift Detection:")
            drift_detected = False
            
            for stat in ['avg_text_length', 'avg_word_count']:
                if stat in reference_stats:
                    reference_value = reference_stats[stat]
                    current_value = current_stats[stat]
                    change_pct = abs(current_value - reference_value) / reference_value * 100
                    
                    print(f"   {stat}: {current_value:.2f} vs {reference_value:.2f} (Δ{change_pct:.1f}%)")
                    
                    if change_pct > 20:  # More than 20% change
                        print(f"     ⚠️ Significant drift detected!")
                        drift_detected = True
                        
            if drift_detected:
                print(f"\n🚨 Input drift detected! Consider:")
                print("   - Reviewing recent data quality")
                print("   - Collecting new training data")
                print("   - Retraining the model")
        
        return current_stats
    
    def generate_monitoring_report(self):
        """Generate a comprehensive monitoring report."""
        if not self.predictions_log:
            print("📋 No predictions to report")
            return
            
        # Overall statistics
        total_predictions = len(self.predictions_log)
        date_range = (min(p['timestamp'] for p in self.predictions_log),
                     max(p['timestamp'] for p in self.predictions_log))
        
        print(f"📋 Model Monitoring Report - {self.model_name}")
        print(f"   Report generated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"   Total predictions: {total_predictions}")
        print(f"   Date range: {date_range[0].strftime('%Y-%m-%d')} to {date_range[1].strftime('%Y-%m-%d')}")
        
        # Confidence analysis
        confidences = [p['confidence'] for p in self.predictions_log]
        print(f"\n📊 Overall Confidence Statistics:")
        print(f"   Mean: {np.mean(confidences):.4f}")
        print(f"   Median: {np.median(confidences):.4f}")
        print(f"   Std: {np.std(confidences):.4f}")
        print(f"   Min: {np.min(confidences):.4f}")
        print(f"   Max: {np.max(confidences):.4f}")
        
        # Prediction distribution
        predictions = [p['prediction'] for p in self.predictions_log]
        pred_counts = pd.Series(predictions).value_counts()
        print(f"\n📈 Prediction Distribution:")
        for pred, count in pred_counts.items():
            print(f"   {pred}: {count} ({count/total_predictions:.1%})")

# Initialize monitoring system
monitor = ModelMonitor("Thai-Sentiment-Classifier")

print("🔍 Model Monitoring System Initialized")

# Simulate some predictions for demonstration
print("\n📝 Simulating predictions for monitoring demo...")

# Generate sample monitoring data
sample_monitoring_data = [
    ("สินค้าดีมาก คุณภาพเยี่ยม", "positive", 0.95),
    ("ไม่พอใจ บริการแย่", "negative", 0.88),
    ("ราคาเหมาะสม ใช้ได้", "positive", 0.72),
    ("ไม่ดี คุณภาพต่ำ", "negative", 0.83),
    ("ปกติ ไม่มีปัญหา", "positive", 0.65),
    ("แย่มาก ไม่แนะนำ", "negative", 0.91),
    ("โอเค พอใช้ได้", "positive", 0.55),  # Low confidence
    ("ไม่แน่ใจ อาจจะดี", "positive", 0.52),  # Low confidence
]

# Log predictions
base_time = datetime.datetime.now() - datetime.timedelta(days=5)
for i, (text, pred, conf) in enumerate(sample_monitoring_data):
    timestamp = base_time + datetime.timedelta(hours=i*3)
    monitor.log_prediction(text, pred, conf, timestamp)

print(f"✅ Logged {len(sample_monitoring_data)} sample predictions")

# Analyze confidence distribution
print("\n" + "="*50)
confidence_dist = monitor.analyze_confidence_distribution(recent_days=7)

# Reference statistics for drift detection (from training data)
training_reference_stats = {
    'avg_text_length': df['text_length'].mean(),
    'avg_word_count': df['word_count'].mean(),
}

print("\n" + "="*50)
current_stats = monitor.detect_input_drift(
    reference_stats=training_reference_stats, 
    recent_days=7
)

# Generate full monitoring report
print("\n" + "="*50)
monitor.generate_monitoring_report()

# Create monitoring dashboard visualization
def create_monitoring_dashboard(monitor):
    """Create visualizations for monitoring dashboard."""
    
    if not monitor.predictions_log:
        print("No data to visualize")
        return
    
    # Extract data
    timestamps = [p['timestamp'] for p in monitor.predictions_log]
    confidences = [p['confidence'] for p in monitor.predictions_log]
    predictions = [p['prediction'] for p in monitor.predictions_log]
    
    # Create plots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    # Confidence over time
    ax1.plot(timestamps, confidences, marker='o', linestyle='-', alpha=0.7)
    ax1.axhline(y=monitor.confidence_thresholds['high'], color='g', linestyle='--', label='High threshold')
    ax1.axhline(y=monitor.confidence_thresholds['medium'], color='orange', linestyle='--', label='Medium threshold')
    ax1.set_title('Prediction Confidence Over Time')
    ax1.set_xlabel('Time')
    ax1.set_ylabel('Confidence')
    ax1.legend()
    ax1.tick_params(axis='x', rotation=45)
    
    # Confidence distribution
    ax2.hist(confidences, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    ax2.axvline(x=monitor.confidence_thresholds['high'], color='g', linestyle='--', label='High threshold')
    ax2.axvline(x=monitor.confidence_thresholds['medium'], color='orange', linestyle='--', label='Medium threshold')
    ax2.set_title('Confidence Distribution')
    ax2.set_xlabel('Confidence')
    ax2.set_ylabel('Frequency')
    ax2.legend()
    
    # Prediction distribution
    pred_counts = pd.Series(predictions).value_counts()
    pred_counts.plot(kind='bar', ax=ax3, color=['lightcoral', 'lightblue'])
    ax3.set_title('Prediction Distribution')
    ax3.set_xlabel('Prediction')
    ax3.set_ylabel('Count')
    ax3.tick_params(axis='x', rotation=0)
    
    # Confidence by prediction
    df_monitor = pd.DataFrame(monitor.predictions_log)
    for pred in df_monitor['prediction'].unique():
        subset = df_monitor[df_monitor['prediction'] == pred]
        ax4.scatter(subset.index, subset['confidence'], label=pred, alpha=0.7)
    ax4.set_title('Confidence by Prediction Type')
    ax4.set_xlabel('Prediction Index')
    ax4.set_ylabel('Confidence')
    ax4.legend()
    
    plt.tight_layout()
    plt.show()

print("\n📊 Creating monitoring dashboard...")
create_monitoring_dashboard(monitor)

# Production monitoring recommendations
print(f"\n🎯 Production Monitoring Recommendations:")
print("1. **Real-time Monitoring**:")
print("   - Set up alerts for low confidence predictions")
print("   - Monitor prediction distribution shifts")
print("   - Track response times and system performance")

print("\n2. **Data Quality Checks**:")
print("   - Validate input text format and encoding")
print("   - Check for unusual characters or patterns")
print("   - Monitor text length distributions")

print("\n3. **Performance Monitoring**:")
print("   - Track model accuracy on labeled validation sets")
print("   - Monitor user feedback and corrections")
print("   - A/B test model versions")

print("\n4. **Automated Actions**:")
print("   - Flag low-confidence predictions for human review")
print("   - Trigger retraining when drift is detected")
print("   - Scale infrastructure based on prediction volume")

print("\n5. **Logging and Storage**:")
print("   - Store predictions with metadata")
print("   - Log model versions and configurations")
print("   - Maintain audit trails for compliance")

# Save monitoring code as a standalone module
monitoring_code = '''
"""
Production monitoring module for text classification models.
"""

import datetime
import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional

class ProductionModelMonitor:
    """Production-ready model monitoring system."""
    
    def __init__(self, model_name: str, log_file: str = "model_predictions.log"):
        self.model_name = model_name
        self.log_file = log_file
        self.setup_logging()
    
    def setup_logging(self):
        """Set up logging configuration."""
        logging.basicConfig(
            filename=self.log_file,
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)
    
    def log_prediction(self, 
                      text: str, 
                      prediction: str, 
                      confidence: float,
                      user_id: Optional[str] = None,
                      session_id: Optional[str] = None):
        """Log prediction with full metadata."""
        
        log_data = {
            'model_name': self.model_name,
            'timestamp': datetime.datetime.now().isoformat(),
            'text': text,
            'prediction': prediction,
            'confidence': confidence,
            'text_length': len(text),
            'word_count': len(text.split()),
            'user_id': user_id,
            'session_id': session_id
        }
        
        self.logger.info(json.dumps(log_data, ensure_ascii=False))
    
    def should_flag_for_review(self, confidence: float, threshold: float = 0.6) -> bool:
        """Determine if prediction should be flagged for human review."""
        return confidence < threshold

# Usage example:
# monitor = ProductionModelMonitor("thai-sentiment-v1")
# monitor.log_prediction("sample text", "positive", 0.85, user_id="user123")
'''

with open("production_monitoring.py", "w", encoding="utf-8") as f:
    f.write(monitoring_code)

print(f"\n💾 Production monitoring code saved as 'production_monitoring.py'")
print("\n🎉 Monitoring setup complete! Your model is ready for production deployment with comprehensive monitoring.")

## 🎉 Conclusion and Next Steps

Congratulations! You've completed the comprehensive guide to fine-tuning language models for text classification. 

### 🏆 What You've Accomplished

1. ✅ **Environment Setup**: Configured all necessary libraries and dependencies
2. ✅ **Model Selection**: Chose appropriate pre-trained models for Thai text
3. ✅ **Data Preparation**: Created, explored, and preprocessed datasets
4. ✅ **Model Fine-tuning**: Implemented complete training pipeline
5. ✅ **Evaluation**: Comprehensive performance assessment with multiple metrics
6. ✅ **Hyperparameter Optimization**: Systematic tuning strategies
7. ✅ **Deployment**: Production-ready model saving and inference
8. ✅ **Monitoring**: Ongoing performance tracking and drift detection

### 🚀 Key Takeaways

- **Quality Data Matters**: Clean, representative datasets are crucial for success
- **Hyperparameter Tuning**: Systematic optimization significantly improves performance
- **Evaluation is Critical**: Use multiple metrics to assess model performance
- **Production Readiness**: Proper saving, loading, and monitoring are essential
- **Continuous Improvement**: Regular monitoring and retraining maintain performance

### 📈 Next Steps for Advanced Users

1. **Advanced Techniques**:
   - Implement multi-task learning
   - Explore few-shot learning approaches
   - Try ensemble methods for better performance

2. **Scale Up**:
   - Use larger datasets for better generalization
   - Experiment with larger pre-trained models
   - Implement distributed training

3. **Domain-Specific Improvements**:
   - Create domain-specific vocabularies
   - Implement custom preprocessing for your use case
   - Explore task-specific architectures

4. **Production Enhancement**:
   - Implement A/B testing frameworks
   - Add comprehensive logging and alerting
   - Optimize for low-latency inference

### 🛠️ Tools and Resources for Further Learning

- **Hugging Face Transformers**: Official documentation and tutorials
- **Papers with Code**: Latest research and implementations
- **Thai NLP Resources**: PyThaiNLP, AI Research Thailand
- **Monitoring Tools**: Weights & Biases, MLflow, TensorBoard

### 📚 Recommended Reading

- "Natural Language Processing with Transformers" by Lewis Tunstall
- "Hands-On Machine Learning" by Aurélien Géron
- Research papers on Thai NLP and text classification

### 🤝 Community and Support

- Join Thai NLP communities and forums
- Contribute to open-source projects
- Share your experiences and learn from others

Thank you for following this comprehensive guide! Your journey in fine-tuning language models for text classification has just begun. Keep experimenting, learning, and building amazing NLP applications! 🎯