In [None]:
# Import necessary libraries
import os
import sys
import argparse
import logging
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
import numpy as np
from src.model.scalable_gmm import ScalableGMM, GMMConfig
from src.utils.metrics import MetricsCollector, PerformanceMetrics
from src.utils.validations import DataValidator
from src.data.generator import DataGenerator
from src.model.transformer import DataTransformer, TransformerConfig

# Set up logging
def setup_logging(log_level=logging.INFO):
    logging.basicConfig(
        level=log_level,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    return logging.getLogger(__name__)

logger = setup_logging()

# Parse arguments
def parse_arguments():
    parser = argparse.ArgumentParser(description='VGM Processing Job')
    parser.add_argument('--bucket_name', required=True, help='GCS bucket name')
    parser.add_argument('--input_path', required=True, help='Input data path')
    parser.add_argument('--output_path', required=True, help='Output path for transformed data')
    parser.add_argument('--batch_size', type=int, default=1000, help='Batch size for processing')
    parser.add_argument('--sample_size', type=int, default=100000, help='Sample size for GMM fitting')
    parser.add_argument('--validation_threshold', type=float, default=0.05, help='Threshold for validation tests')
    parser.add_argument('--validation_sample_size', type=int, default=10000, help='Sample size for validation')
    return parser.parse_args()

args = parse_arguments()

# Create Spark session
def create_spark_session():
    return SparkSession.builder \
        .appName("VGM_Processing") \
        .config("spark.executor.memory", "8g") \
        .config("spark.driver.memory", "4g") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.dynamicAllocation.enabled", "true") \
        .config("spark.shuffle.service.enabled", "true") \
        .config("spark.sql.shuffle.partitions", "1000") \
        .getOrCreate()

spark = create_spark_session()

# Load and prepare data
def load_and_prepare_data():
    logger.info(f"Loading data from {args.input_path}")
    data = spark.read.parquet(args.input_path) \
        .select("Amount") \
        .repartition(1000)
    data.cache()
    total_records = data.count()
    logger.info(f"Loaded {total_records:,} records")
    return data

data = load_and_prepare_data()

# Fit VGM model
def fit_vgm_model():
    gmm_config = GMMConfig(
        n_components=10,
        batch_size=args.batch_size,
        eps=0.005
    )
    vgm = ScalableGMM(gmm_config)
    vgm.fit(data)
    return vgm

vgm = fit_vgm_model()

# Transform data
def transform_data():
    logger.info("Transforming full dataset")
    transformed_data = vgm.transform(data)
    return transformed_data

transformed_data = transform_data()

# Validate transformation
def validate_transformation(validator, original_data, transformed_data, inverse_transformed_data):
    # Perform validation
    accuracy = validate_transformation(
        validator,
        original_data,
        None,
        inverse_transformed_data,
        logger
    )
    return accuracy

# Sample data for validation
sample_size = min(args.validation_sample_size, data.count())
sample_original = data.limit(sample_size).toPandas()["Amount"].values
sample_transformed = transformed_data.limit(sample_size)
sample_inverse = vgm.inverse_transform(sample_transformed).toPandas()["Amount"].values

# Validate transformation
validator = DataValidator(threshold=args.validation_threshold)
accuracy = validate_transformation(validator, sample_original, None, sample_inverse)
logger.info(f"Transformation accuracy: {accuracy:.2f}")

# Save transformed data
logger.info(f"Saving transformed data to {args.output_path}")
transformed_data.write \
    .mode("overwrite") \
    .parquet(args.output_path)