# Model Training for Data Cleaning Pipeline

This notebook covers the training and fine-tuning of models used in our data cleaning pipeline. We'll focus on:
1. Loading and preparing the data
2. Setting up the models
3. Training the development status classifier
4. Model evaluation and validation

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from datasets import Dataset
from sklearn.model_selection import train_test_split
import yaml
import logging
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# Load configuration
def load_config():
    with open('../configs/model_config.yaml', 'r') as f:
        config = yaml.safe_load(f)
    return config

config = load_config()
print("Configuration loaded successfully!")

## 1. Data Loading and Preparation

In [None]:
def load_and_prepare_data(file_path):
    """Load and prepare the dataset for model training"""
    # Load data
    df = pd.read_csv(file_path)
    
    # Basic preprocessing
    df['text'] = df['text'].fillna('')
    df['country_code'] = df['country_code'].fillna('UNKNOWN')
    
    # Create binary labels for development status
    df['label'] = df['development_status'].map({
        'Developed': 1,
        'Developing': 0
    })
    
    return df

# Load the dataset
df = load_and_prepare_data('../data/raw/input_dataset.csv')
print(f"Loaded {len(df)} records")
print("\nSample data:")
display(df.head())

## 2. Model Setup and Data Preparation

In [None]:
class ModelTrainer:
    def __init__(self, config):
        self.config = config
        self.device = torch.device(config['model_params']['device'])
        self.setup_models()
        
    def setup_models(self):
        """Initialize models and tokenizers"""
        # Development status model
        model_config = self.config['model_params']['development_status']
        self.dev_status_tokenizer = AutoTokenizer.from_pretrained(
            model_config['model_name'],
            cache_dir=model_config['cache_dir']
        )
        self.dev_status_model = AutoModelForSequenceClassification.from_pretrained(
            model_config['model_name'],
            num_labels=model_config['num_labels'],
            cache_dir=model_config['cache_dir']
        ).to(self.device)
        
    def prepare_dataset(self, df, text_col, label_col):
        """Prepare dataset for training"""
        # Create features
        features = [{
            'text': str(row[text_col]),
            'label': int(row[label_col])
        } for _, row in df.iterrows()]
        
        # Convert to Dataset format
        dataset = Dataset.from_list(features)
        
        # Tokenize function
        def tokenize_function(examples):
            return self.dev_status_tokenizer(
                examples['text'],
                padding='max_length',
                truncation=True,
                max_length=self.config['model_params']['max_length']
            )
        
        # Tokenize dataset
        tokenized_dataset = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=dataset.column_names
        )
        
        return tokenized_dataset

# Initialize trainer
trainer = ModelTrainer(config)
print("Model trainer initialized successfully!")

## 3. Data Splitting and Tokenization

In [None]:
# Split data
train_df, eval_df = train_test_split(
    df,
    test_size=0.2,
    random_state=config['model_params']['seed']
)

# Prepare datasets
train_dataset = trainer.prepare_dataset(train_df, 'text', 'label')
eval_dataset = trainer.prepare_dataset(eval_df, 'text', 'label')

print(f"Training samples: {len(train_dataset)}")
print(f"Evaluation samples: {len(eval_dataset)}")

## 4. Training Setup

In [None]:
def setup_training(config, train_dataset, eval_dataset):
    """Setup training arguments and trainer"""
    training_args = TrainingArguments(
        output_dir="../models/dev_status_model",
        num_train_epochs=config['model_params']['num_epochs'],
        per_device_train_batch_size=config['model_params']['batch_size'],
        per_device_eval_batch_size=config['model_params']['batch_size'],
        learning_rate=config['model_params']['learning_rate'],
        warmup_steps=config['model_params']['warmup_steps'],
        weight_decay=config['model_params']['weight_decay'],
        logging_dir="../logs",
        logging_steps=config['logging']['logging_steps'],
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
    )
    
    # Setup trainer
    trainer = Trainer(
        model=trainer.dev_status_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=trainer.dev_status_tokenizer,
        data_collator=DataCollatorWithPadding(trainer.dev_status_tokenizer)
    )
    
    return trainer

# Setup training
model_trainer = setup_training(config, train_dataset, eval_dataset)
print("Training setup completed!")

## 5. Data Analysis and Validation

In [None]:
def analyze_data_distribution(df):
    """Analyze the distribution of labels and data characteristics"""
    # Plot label distribution
    plt.figure(figsize=(10, 6))
    sns.countplot(data=df, x='development_status')
    plt.title('Distribution of Development Status')
    plt.xticks(rotation=45)
    plt.show()
    
    # Text length distribution
    plt.figure(figsize=(10, 6))
    df['text_length'] = df['text'].str.len()
    sns.histplot(data=df, x='text_length', bins=50)
    plt.title('Distribution of Text Length')
    plt.show()
    
    # Print statistics
    print("\nLabel Distribution:")
    print(df['development_status'].value_counts(normalize=True))
    
    print("\nText Length Statistics:")
    print(df['text_length'].describe())

# Analyze data
analyze_data_distribution(df)