Import and create the sparksession

In [26]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, when, datediff, avg, count, round
from pyspark.sql.types import DoubleType
spark = SparkSession.builder.appName("Lung Cancer Data Analysis").getOrCreate()

Loading the csv file from Dataset folder

In [5]:
file_path = "./datasets/lung_cancer_dataset.csv"
df = spark.read.csv(file_path, header=True, inferSchema=True)
df.printSchema()
# df.show(5)

[Stage 7:====>                                                    (1 + 11) / 12]

root
 |-- id: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- gender: string (nullable = true)
 |-- country: string (nullable = true)
 |-- diagnosis_date: date (nullable = true)
 |-- cancer_stage: string (nullable = true)
 |-- family_history: string (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- bmi: double (nullable = true)
 |-- cholesterol_level: integer (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- asthma: integer (nullable = true)
 |-- cirrhosis: integer (nullable = true)
 |-- other_cancer: integer (nullable = true)
 |-- treatment_type: string (nullable = true)
 |-- end_treatment_date: date (nullable = true)
 |-- survived: integer (nullable = true)



                                                                                

#### Task 1
1. Remove duplicate rows
2. Convert numeric columns to correct types
3. Parse date columns
4. Convert all yes/no fields into binary 1/0

In [17]:

def clean_lung_cancer_data(df):
    df = df.dropDuplicates()
    
    date_columns = [col_name for col_name in df.columns if 'date' in col_name.lower()]
    for date_col in date_columns:
        df = df.withColumn(date_col, to_date(col(date_col), 'yyyy-MM-dd'))
    
    numeric_columns = ['age', 'BMI']
    for num_col in numeric_columns:
        df = df.withColumn(num_col, col(num_col).cast(DoubleType()))

    yes_no_cols = []
    for c in df.columns:
        unique_vals = df.select(c).distinct().rdd.flatMap(lambda x: x).collect()
        unique_vals = [str(v).lower() for v in unique_vals if v is not None]
        if set(unique_vals) <= {'yes', 'no'}:
            yes_no_cols.append(c)

    for col_name in yes_no_cols:
        df = df.withColumn(col_name, when(col(col_name) == 'yes', 1).otherwise(0))
    
    return df

df = clean_lung_cancer_data(df)
df.show(5)



+---+----+------+----------+--------------+------------+--------------+--------------+----+-----------------+------------+------+---------+------------+--------------+------------------+--------+
| id| age|gender|   country|diagnosis_date|cancer_stage|family_history|smoking_status| BMI|cholesterol_level|hypertension|asthma|cirrhosis|other_cancer|treatment_type|end_treatment_date|survived|
+---+----+------+----------+--------------+------------+--------------+--------------+----+-----------------+------------+------+---------+------------+--------------+------------------+--------+
|195|49.0|Female|    Latvia|    2018-01-25|   Stage III|             0| Former Smoker|38.8|              261|           1|     1|        0|           0|  Chemotherapy|        2019-12-31|       0|
|255|50.0|Female|Luxembourg|    2024-01-08|    Stage II|             0|Current Smoker|36.7|              254|           1|     0|        0|           0|     Radiation|        2025-05-30|       0|
|344|52.0|  Male|   

                                                                                

#### Task 2
1. Add a new column treatment_duration_days:
2. This should be the number of days between the diagnosis date and the treatment end date.
3. Return the average treatment duration for each treatment type.

In [23]:
df = df.withColumn(
        "treatment_duration_days",
        datediff(col("end_treatment_date"), col("diagnosis_date"))
    )
def analyze_treatment_duration(df):
    result = df.groupBy("treatment_type").agg(
        avg("treatment_duration_days").alias("average_treatment_duration")
    )
    
    return result

analyze_treatment_duration(df).show(5)
df.show(5)

                                                                                

+--------------+--------------------------+
|treatment_type|average_treatment_duration|
+--------------+--------------------------+
|     Radiation|        458.40320462900917|
|  Chemotherapy|        458.39540091909953|
|      Combined|         457.8152186120058|
|       Surgery|        457.73744630723684|
+--------------+--------------------------+





+---+----+------+----------+--------------+------------+--------------+--------------+----+-----------------+------------+------+---------+------------+--------------+------------------+--------+-----------------------+
| id| age|gender|   country|diagnosis_date|cancer_stage|family_history|smoking_status| BMI|cholesterol_level|hypertension|asthma|cirrhosis|other_cancer|treatment_type|end_treatment_date|survived|treatment_duration_days|
+---+----+------+----------+--------------+------------+--------------+--------------+----+-----------------+------------+------+---------+------------+--------------+------------------+--------+-----------------------+
|195|49.0|Female|    Latvia|    2018-01-25|   Stage III|             0| Former Smoker|38.8|              261|           1|     1|        0|           0|  Chemotherapy|        2019-12-31|       0|                    705|
|255|50.0|Female|Luxembourg|    2024-01-08|    Stage II|             0|Current Smoker|36.7|              254|           

                                                                                

#### Task 3
1. Groups patients by smoking_status
2. Calculates the average survival rate for each group
3. Returns the group with the highest survival rate

In [22]:
def get_highest_survival_by_smoking(df):
    result = df.groupBy("smoking_status").agg(
        avg("survived").alias("survival_rate")
    )
    top_group = result.orderBy(col("survival_rate").desc()).limit(1)
    
    return top_group

get_highest_survival_by_smoking(df).show(5)



+--------------+-------------------+
|smoking_status|      survival_rate|
+--------------+-------------------+
|  Never Smoked|0.22091034383684025|
+--------------+-------------------+



                                                                                

#### Task 4


In [27]:

def top_3_stage_iv_countries(df):
    total = df.groupBy("country").agg(count("*").alias("total_patients"))
    
    stage_iv = df.filter(col("cancer_stage") == "Stage IV") \
                 .groupBy("country") \
                 .agg(count("*").alias("stage_iv_count"))
    
    joined = total.join(stage_iv, on="country", how="left") \
                  .fillna(0, subset=["stage_iv_count"]) \
                  .withColumn("stage_iv_percentage", round((col("stage_iv_count") / col("total_patients")) * 100, 2))
    
    return joined.orderBy(col("stage_iv_percentage").desc()).limit(3)

top_3_stage_iv_countries(df).show(5)



                                                                                

+--------------+--------------+--------------+-------------------+
|       country|total_patients|stage_iv_count|stage_iv_percentage|
+--------------+--------------+--------------+-------------------+
|        Greece|         33052|          8429|               25.5|
|       Croatia|         33138|          8426|              25.43|
|Czech Republic|         32885|          8317|              25.29|
+--------------+--------------+--------------+-------------------+



#### Task 5


In [None]:
def analyze_high_risk_survivors(df):
  filtered = df.filter(
      (col("gender") == "male") &
      (col("cancer_stage").isin("Stage III", "Stage IV")) &
      (col("family_history") == 1) &
      (col("smoking_status") == "current") &
      (col("bmi") > 30) &
      (col("survived") == 1)
  )
  
  avg_age = filtered.agg(avg("age").alias("average_age")).collect()[0]["average_age"]
  
  hypertension_count = filtered.filter(col("hypertension") == 1).count()
  total_count = filtered.count()
  
  hypertension_percentage = (hypertension_count / total_count) * 100 if total_count > 0 else 0
  
  return avg_age, hypertension_percentage

                                                                                

(None, 0)


                                                                                