In [None]:
%run ./Global_Configurations

In [None]:
%run ./Layer_Utilities_NB

In [None]:
class DataTransformer:
    def __init__(self):
        # Initialize any instance variables if needed
        self.structure = {
            'treatment': {
                'id': 'treatment_id',
                'start_date': 'treatment_start_date',
                'completion_date': 'treatment_completion_date',
                'outcome_status': 'treatment_outcome_status',
                'outcome_date': 'treatment_outcome_date',
                'duration_in_days': 'treatment_duration_in_days',
                'cost': 'treatment_cost',
                'type': 'treatment_type'
            },
            'provider': {
                'id': 'provider_id',
                'full_name': 'provider_full_name',
                'speciality_id': 'provider_speciality_id',
                'speciality_name': 'provider_speciality_name',
                'affiliated_hospital': 'provider_affiliated_hospital'
            },
            'location': {
                'id': 'location_id',
                'country': 'location_country',
                'state': 'location_state',
                'city': 'location_city'
            },
            'patient': {
                'id': 'patient_id',
                'full_name': 'patient_full_name',
                'gender': 'patient_gender',
                'age': 'patient_age'
            },
            'disease': {
                'id': 'disease_id',
                'speciality_id': 'disease_speciality_id',
                'name': 'disease_name',
                'type': 'disease_type',
                'severity': 'disease_severity',
                'transmission_mode': 'disease_transmission_mode',
                'mortality_rate': 'disease_mortality_rate'
            },
            'metadata': {
                'added_at': 'metadata_added_at',
                'modified_at': 'metadata_modified_at'
            }
        }

    def _remove_dynamodb_types(self, dynamodb_dict: dict) -> dict:
        """
        Recursively removes DynamoDB type indicators from a dictionary.
        
        Args:
            dynamodb_dict (dict): Dictionary with DynamoDB type indicators
            
        Returns:
            dict: Dictionary without DynamoDB type indicators
        """
        regular_dict = {}
        for key, value in dynamodb_dict.items():
            if isinstance(value, dict):
                if 'S' in value:
                    regular_dict[key] = value['S']
                elif 'N' in value:
                    regular_dict[key] = value['N']
                else:
                    regular_dict[key] = self._remove_dynamodb_types(value)
            else:
                regular_dict[key] = value
        return regular_dict

    def convert_columns(self, df: DataFrame) -> DataFrame:
        """
        Converts all DynamoDB JSON columns in a DataFrame to regular JSON format.
        
        Args:
            df (DataFrame): DataFrame with DynamoDB JSON formatted columns
            
        Returns:
            DataFrame: DataFrame with regular JSON formatted columns
        """
        if 'treatment_id_partition_key' in df.columns:
            df = df.drop('treatment_id_partition_key')
        
        convert_udf = udf(
            lambda x: json.dumps(
                self._remove_dynamodb_types(
                    json.loads(x) if isinstance(x, str) else x
                )
            ),
            StringType()
        )
        
        for column in df.columns:
            df = df.withColumn(f"{column}_temp", to_json(col(column)))
            df = df.withColumn(
                column,
                from_json(
                    convert_udf(col(f"{column}_temp")),
                    MapType(StringType(), StringType())
                )
            )
            df = df.drop(f"{column}_temp")
        
        return df

    def _generate_flat_columns(self, prefix: str, struct_dict: dict) -> list:
        """
        Recursively generates flattened column names and their corresponding expressions.
        
        Args:
            prefix (str): Prefix for the column name
            struct_dict (dict): Dictionary representing the structure
            
        Returns:
            list: List of tuples containing (flat_column_name, column_expression)
        """
        flat_columns = []
        for key, value in struct_dict.items():
            new_prefix = f"{prefix}_{key}" if prefix else key
            if isinstance(value, dict):
                flat_columns.extend(
                    self._generate_flat_columns(new_prefix, value)
                )
            else:
                flat_columns.append((new_prefix, f"{prefix}.{key}"))
        return flat_columns

    def flatten_dataframe(self, df: DataFrame) -> DataFrame:
        """
        Flattens the converted DataFrame into a structure with simple column names.
        
        Args:
            df (DataFrame): DataFrame with converted DynamoDB JSON columns
            
        Returns:
            DataFrame: Flattened DataFrame with simple column names
        """
        flat_columns = []
        for struct_name, struct_fields in self.structure.items():
            flat_columns.extend(
                self._generate_flat_columns(struct_name, struct_fields)
            )

        select_expr = [
            col(expr_path).alias(flat_name)
            for flat_name, expr_path in flat_columns
        ]

        flattened_df = df.select(*select_expr)
        return flattened_df
    
    def cast_column_types_and_format(self, df: DataFrame) -> DataFrame:
        """
        Casts DataFrame columns to specified types.
        
        Args:
            df (DataFrame): The Spark DataFrame to cast.
        
        Returns:
            DataFrame: The DataFrame with columns cast to the appropriate types.
        """

        # Define the mapping of column names to data types
        type_mapping = {
            'treatment_id': IntegerType(),
            'treatment_start_date': TimestampType(),
            'treatment_completion_date': TimestampType(),
            'treatment_outcome_status': StringType(),
            'treatment_outcome_date': TimestampType(),
            'treatment_duration_in_days': IntegerType(),
            'treatment_cost': DecimalType(10,2),
            'treatment_type': StringType(),
            'treatment_type_and_outcome_status_id': IntegerType(),
            'provider_id': IntegerType(),
            'provider_full_name': StringType(),
            'provider_speciality_id': IntegerType(),
            'provider_speciality_name': StringType(),
            'provider_affiliated_hospital': StringType(),
            'location_id': IntegerType(),
            'location_country': StringType(),
            'location_state': StringType(),
            'location_city': StringType(),
            'patient_id': IntegerType(),
            'patient_full_name': StringType(),
            'patient_gender': StringType(),
            'patient_age': IntegerType(),
            'disease_id': IntegerType(),
            'disease_speciality_id': IntegerType(),
            'disease_name': StringType(),
            'disease_type': StringType(),
            'disease_severity': StringType(),
            'disease_transmission_mode': StringType(),
            'disease_mortality_rate': DecimalType(5,2),  # Updated to decimal(5,2)
            'metadata_added_at': TimestampType(),
            'metadata_modified_at': TimestampType(),
        }

        # Cast the columns according to the type mapping
        for column, dtype in type_mapping.items():
            if column in df.columns:
                # Apply to_timestamp for added_at and modified_at columns
                if column in ['metadata_added_at', 'metadata_modified_at']:
                    # df = df.withColumn(column, F.to_timestamp(df[column]))
                    df = df.withColumn(column, to_timestamp(df[column]))
                else:
                    df = df.withColumn(column, df[column].cast(dtype))      
        
        return df

    def add_treatment_type_and_outcome_status_id(self, df: DataFrame) -> DataFrame:
        """Adds a unique ID for each combination of treatment_type and treatment_outcome_status."""
        
        # Create a new column representing the combination of treatment_type and treatment_outcome_status
        df = df.withColumn("treatment_combination", concat_ws("_", col("treatment_type"), col("treatment_outcome_status")))
        
        # Define a window specification to generate unique IDs for each combination
        window_spec = Window.orderBy("treatment_combination")
        
        # Assign a unique ID to each combination using dense_rank()
        df = df.withColumn("treatment_type_and_outcome_status_id", dense_rank().over(window_spec))
        
        # Drop the temporary combination column
        df = df.drop("treatment_combination")
        
        return df

    def process_dataframe(self, df: DataFrame) -> DataFrame:
        """
        Complete processing pipeline: converts DynamoDB format and flattens the DataFrame.
        
        Args:
            df (DataFrame): Input DataFrame with DynamoDB JSON columns
            
        Returns:
            DataFrame: Processed and flattened DataFrame
        """
        converted_df = self.convert_columns(df)
        flattened_df = self.flatten_dataframe(converted_df)
        result_df = self.add_treatment_type_and_outcome_status_id(flattened_df)
        return self.cast_column_types_and_format(result_df)