In [0]:
from pyspark.sql.functions import col, lit, round
from pyspark.sql.types import IntegerType
from pyspark.sql import DataFrame


In [0]:
class Transform:

    def __init__(self, input_df):
        self.input_df = input_df

    # What is the treatment success rate by treatment type?
    def treatment_type_success_rate(self)-> DataFrame:
        total_patients_df = self.input_df.groupBy("treatment_type").count().withColumnRenamed("count", "total_patients")

        treatment_data = self.input_df\
                        .groupBy("treatment_type", "treatment_outcome_status")\
                        .count()\
                        .withColumnRenamed("count", "count_of_patients")

        joined_df = treatment_data.alias("t") \
                                 .join(total_patients_df.alias("tp"), col("t.treatment_type") == col("tp.treatment_type"), "inner") \
                                 .select(
                                            col("t.treatment_type"),
                                            col("t.treatment_outcome_status"),
                                            col("t.count_of_patients"),
                                            col("tp.total_patients")
                                        )

        success_df = joined_df.withColumn("success_rate", round((col("count_of_patients")/col("total_patients"))*100, 2))\
                            .filter(col("t.treatment_outcome_status") == "successful")
        
        return success_df
    
    # What is the treatment duration by disease type?
    def treatment_duration_treatment_type(self)-> DataFrame:
        treatment_df = self.input_df.withColumn("treatment_duration", col("treatment_duration").cast(IntegerType()))\
                          .groupBy("treatment_type").sum("treatment_duration").alias("total_duration")

        return treatment_df
    
    #What is the average treatment cost by health condition (disease name)?
    def treatment_cost_by_disease_name(self)-> DataFrame:

        treatment_disease_cost_df = self.input_df.withColumn("treatment_cost", col("treatment_cost").cast(IntegerType()))\
                                        .groupBy("disease_name")\
                                        .sum("treatment_cost").alias("total_treatment_cost")
    
        return treatment_disease_cost_df
    
    # What is the success rate of treatments by age group?
    def success_rate_by_age(self)-> DataFrame:

        success_rate_df = self.input_df.groupBy("disease_name", "age", "treatment_outcome_status")\
                                      .count()
        
        success_rate_age = success_rate_df.filter(col("treatment_outcome_status") == "successful").orderBy("age")
        
        return success_rate_age
    # What is the total count of the deceased outcome by disease type?

    def disease_deceased_count(self)-> DataFrame:

        deceased_count_df = self.input_df.groupBy("disease_name", "treatment_outcome_status")\
                                        .count()
        
        deceased_count = deceased_count_df.filter(col("treatment_outcome_status") == "deceased")

        return deceased_count




