In [None]:
from datetime import datetime
import psutil
import os
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.functions import (
    col, lit, datediff, round, sum, when, 
    current_timestamp, broadcast, expr
)
from pyspark.storagelevel import StorageLevel

# Création de la session Spark
spark = (SparkSession.builder
    .appName("FCT_Orders_Test")
    .config("spark.sql.shuffle.partitions", "10")
    .config("spark.default.parallelism", "10")
    .getOrCreate()
)

def get_table_from_db(table_name: str, spark: SparkSession):
    """Fonction pour lire les données depuis PostgreSQL"""
    host = "upstream_data"  # ou votre host PostgreSQL
    port = "5432"      # ou votre port PostgreSQL
    db = "tpchdb"      # ou votre nom de base de données
    jdbc_url = f'jdbc:postgresql://{host}:{port}/{db}'
    connection_properties = {
        'user': 'tpchuser',          # votre username
        'password': 'tpchpass',      # votre password
        'driver': 'org.postgresql.Driver',
    }
    return spark.read.jdbc(
        url=jdbc_url, table=table_name, properties=connection_properties
    )

def log_dataframe_info(df, step_name: str):
    """Utilitaire pour logger les infos d'un DataFrame"""
    count = df.count()
    partitions = df.rdd.getNumPartitions()
    memory = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024  # MB
    
    print(f"""
    ====== {step_name} ======
    Count: {count}
    Partitions: {partitions}
    Memory Usage (MB): {memory:.2f}
    Schema: {df.schema.simpleString()}
    """)

def process_batch(batch_df, orders_df, spark, batch_id):
    """Process a single batch of LineItem data"""
    print(f"Processing batch {batch_id}")
    
    try:
        # Jointure pour ce lot
        joined_data = (
            batch_df
            .join(
                orders_df,
                batch_df.l_orderkey == orders_df.o_orderkey,
                'inner'
            )
            .select([
                col('o_orderkey').alias('order_key'),
                col('l_linenumber').alias('line_number'),
                col('o_custkey').alias('customer_key'),
                col('l_partkey').alias('part_key'),
                col('l_suppkey').alias('supplier_key'),
                
                # Dates
                col('o_orderdate').alias('order_date'),
                col('l_shipdate').alias('ship_date'),
                col('l_commitdate').alias('commit_date'),
                col('l_receiptdate').alias('receipt_date'),
                
                # Métriques
                col('o_totalprice').alias('order_total_amount'),
                col('o_orderpriority').alias('order_priority'),
                col('o_orderstatus').alias('order_status'),
                col('l_quantity').alias('quantity'),
                col('l_extendedprice').alias('extended_price'),
                col('l_discount').alias('discount_percentage'),
                col('l_tax').alias('tax_percentage'),
                
                # Calculs optimisés
                expr('ROUND(l_extendedprice * (1 - l_discount) * (1 + l_tax), 2)').alias('net_amount'),
                expr('ROUND(l_extendedprice * l_discount, 2)').alias('discount_amount'),
                expr('ROUND(l_extendedprice * (1 - l_discount) * l_tax, 2)').alias('tax_amount'),
                
                # Métriques de livraison
                col('l_shipmode').alias('ship_mode'),
                col('l_returnflag').alias('return_flag'),
                col('l_linestatus').alias('line_status'),
                expr('datediff(l_shipdate, o_orderdate)').alias('shipping_delay_days'),
                expr('datediff(l_receiptdate, l_shipdate)').alias('delivery_delay_days'),
                expr('l_receiptdate > l_commitdate').alias('is_late_delivery'),
                
                # Métadonnées
                lit(current_timestamp()).alias('etl_inserted')
            ])
        )
        
        # Optimisation du résultat
        result = (
            joined_data
            .repartition(20, 'order_key')
            .persist(StorageLevel.MEMORY_AND_DISK_SER)
        )
        
        count = result.count()
        print(f"Batch {batch_id} produced {count} rows")
        
        return result

    except Exception as e:
        print(f"Error processing batch {batch_id}: {str(e)}")
        raise

# Lecture des données sources
orders_data = get_table_from_db('public.orders', spark)
lineitem_data = get_table_from_db('public.lineitem', spark)

# Log des volumes initiaux
orders_count = orders_data.count()
lineitem_count = lineitem_data.count()
print(f"Initial counts - Orders: {orders_count}, LineItem: {lineitem_count}")

# Optimisation des orders (plus petite table)
orders_subset = (orders_data
    .select([
        'o_orderkey', 'o_custkey', 'o_orderdate',
        'o_totalprice', 'o_orderpriority', 'o_orderstatus'
    ])
    .repartition(50, 'o_orderkey')
    .persist(StorageLevel.MEMORY_AND_DISK_SER)
)

# Traitement par lots de LineItem
batch_size = 1000000  # 1 million de lignes par lot
num_batches = (lineitem_count + batch_size - 1) // batch_size
print(f"Processing LineItem in {num_batches} batches of {batch_size} rows")

# Création d'un DataFrame vide pour accumuler les résultats
result_df = None

for batch_id in range(num_batches):
    print(f"Processing batch {batch_id + 1}/{num_batches}")
    
    # Sélection du lot
    batch_df = (lineitem_data
        .orderBy('l_orderkey')
        .limit(batch_size)
        .offset(batch_id * batch_size)
    )

    # Traitement du lot
    batch_result = process_batch(
        batch_df, 
        orders_subset, 
        spark, 
        batch_id
    )

    # Accumulation des résultats
    if result_df is None:
        result_df = batch_result
    else:
        result_df = result_df.unionByName(batch_result)

    # Nettoyage explicite
    batch_df.unpersist()
    batch_result.unpersist()

# Nettoyage final
orders_subset.unpersist()

# Afficher quelques résultats
print("\nFinal Results Sample:")
result_df.show(5)

# Afficher les statistiques finales
log_dataframe_info(result_df, "Final Results")