# Notebook Initialization

In [13]:
# %load nb_init.py

from pathlib import Path
import pandas as pd

base_dir = Path.cwd().parent
config_dir = base_dir / "config"
data_dir = base_dir / "data"
input_dir = data_dir / "COVID19"
preprocessed_dir = data_dir / "preprocessed"
output_dir = data_dir / "output"

metadata_file = input_dir / "metadata.csv"
labels_file = input_dir / "unzip_filenames.csv"
preprocessed_labels_file = preprocessed_dir / "labels.parquet"

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()

config_file = config_dir / "tfg.conf"

from pyhocon import ConfigFactory
config = ConfigFactory.parse_file(config_file).tfg


In [14]:
config

ConfigTree([('seed', 42),
            ('eda',
             ConfigTree([('csv_options',
                          ConfigTree([('header', 'true'),
                                      ('sep', ','),
                                      ('inferSchema', 'true')]))]))])

In [15]:
spark

# Load labels

In [16]:
labels = spark.read.parquet(str(preprocessed_labels_file))

labels.show(5)

+----------+-------+-------+-----+---------+
|patient_id|scan_id|n_slice|label|num_clips|
+----------+-------+-------+-----+---------+
|         0|   3131|    285|   CP|        5|
|         0|   3132|     42|   CP|        1|
|         0|   3133|    290|   CP|        5|
|         0|   3134|     37|   CP|        1|
|         0|   3135|    269|   CP|        4|
+----------+-------+-------+-----+---------+
only showing top 5 rows



In [20]:
TODO:
- Get unique patient ids into pandas
- Split the ids into train/test
- Divide labels into train_labels and test_labels

from sklearn.model_selection import train_test_split

train_test_split(labels, test_size=.1, random_state=config.get_int("seed"))

TypeError: Expected sequence or array-like, got <class 'pyspark.sql.dataframe.DataFrame'>

# Demographics
## Metadata file

In [56]:
!head -n2 {metadata_file}

patient_id,scan_id,Age,Sex(Male1/Female2),Critical_illness,Liver_function,Lung_function,Progression (Days)
1399,127,57,1,1,5,2,0.08


In [100]:
metadata = spark.read\
    .options(**csv_options)\
    .csv(str(metadata_file))

metadata.show(5)

+----------+-------+---+------------------+----------------+--------------+-------------+------------------+
|patient_id|scan_id|Age|Sex(Male1/Female2)|Critical_illness|Liver_function|Lung_function|Progression (Days)|
+----------+-------+---+------------------+----------------+--------------+-------------+------------------+
|      1399|    127| 57|                 1|               1|             5|            2|              0.08|
|      1297|     82| 55|                 1|               1|             3|            2|              0.88|
|      2255|    549|  3|                 1|               1|          null|         null|              0.02|
|      1184|     26|  5|                 2|               1|             0|            2|              0.02|
|      1186|     27|  2|                 2|               1|             2|            2|              0.02|
+----------+-------+---+------------------+----------------+--------------+-------------+------------------+
only showing top 5 

In [101]:
metadata.printSchema()

root
 |-- patient_id: integer (nullable = true)
 |-- scan_id: integer (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Sex(Male1/Female2): integer (nullable = true)
 |-- Critical_illness: integer (nullable = true)
 |-- Liver_function: integer (nullable = true)
 |-- Lung_function: integer (nullable = true)
 |-- Progression (Days): double (nullable = true)



In [102]:
id_expr = "CONCAT(patient_id, '::', scan_id) AS id"
metadata_exprs = [
    id_expr,
    "patient_id AS patient_id",
    "scan_id AS scan_id",
    "Age AS age",
    "`Sex(Male1/Female2)` AS gender",
    "Critical_illness AS critical_illness",
    "Liver_function AS liver_function",
    "Lung_function AS lung_function",
    "`Progression (Days)` AS progression_days",
]

metadata = metadata.selectExpr(*metadata_exprs)

metadata.show(5)

+---------+----------+-------+---+------+----------------+--------------+-------------+----------------+
|       id|patient_id|scan_id|age|gender|critical_illness|liver_function|lung_function|progression_days|
+---------+----------+-------+---+------+----------------+--------------+-------------+----------------+
|1399::127|      1399|    127| 57|     1|               1|             5|            2|            0.08|
| 1297::82|      1297|     82| 55|     1|               1|             3|            2|            0.88|
|2255::549|      2255|    549|  3|     1|               1|          null|         null|            0.02|
| 1184::26|      1184|     26|  5|     2|               1|             0|            2|            0.02|
| 1186::27|      1186|     27|  2|     2|               1|             2|            2|            0.02|
+---------+----------+-------+---+------+----------------+--------------+-------------+----------------+
only showing top 5 rows



In [81]:
metadata.count()

408

## Labels

In [103]:
labels = spark.read.options(**csv_options).csv(str(labels_file))

labels.show(5)

+--------+-----+----------+-------+-------+
|zip_file|label|patient_id|scan_id|n_slice|
+--------+-----+----------+-------+-------+
|CP-1.zip|   CP|         0|   3131|    285|
|CP-1.zip|   CP|         0|   3132|     42|
|CP-1.zip|   CP|         0|   3133|    290|
|CP-1.zip|   CP|         0|   3134|     37|
|CP-1.zip|   CP|         0|   3135|    269|
+--------+-----+----------+-------+-------+
only showing top 5 rows



In [104]:
labels.printSchema()

root
 |-- zip_file: string (nullable = true)
 |-- label: string (nullable = true)
 |-- patient_id: integer (nullable = true)
 |-- scan_id: integer (nullable = true)
 |-- n_slice: integer (nullable = true)



In [105]:
labels_expr = [
    id_expr,
    "patient_id AS patient_id",
    "scan_id AS scan_id",
    "n_slice AS n_slice",
    "label",
]

labels = labels.selectExpr(*labels_expr)

labels.show(5)

+-------+----------+-------+-------+-----+
|     id|patient_id|scan_id|n_slice|label|
+-------+----------+-------+-------+-----+
|0::3131|         0|   3131|    285|   CP|
|0::3132|         0|   3132|     42|   CP|
|0::3133|         0|   3133|    290|   CP|
|0::3134|         0|   3134|     37|   CP|
|0::3135|         0|   3135|    269|   CP|
+-------+----------+-------+-------+-----+
only showing top 5 rows



In [80]:
labels.count()

4178

## Check overlap between labels and metadata

Do we have demographics for patients for which we have data?

In [145]:
total_labels = labels.count()
total_labels_with_demo = labels.join(metadata, ["patient_id"], "left_semi").count()

print(f"We have demographics for {total_labels_with_demo} / {total_labels} observations ({100 * total_labels_with_demo / total_labels:.2f}%)")

We have demographics for 378 / 4178 observations (9.05%)


In [146]:
labels_with_metadata = labels.join(metadata, ["patient_id"], "left_semi")

labels_with_metadata.count()

378

In [147]:
labels_with_metadata.groupBy("label").count().show()

+------+-----+
| label|count|
+------+-----+
|    CP|  170|
|   NCP|   13|
|Normal|  195|
+------+-----+



Is there also overlap on patient_id AND scan_id level?

In [142]:
labels.join(metadata, ["patient_id", "scan_id"], "left_semi").groupBy("label").count().show()

+-----+-----+
|label|count|
+-----+-----+
+-----+-----+



In [107]:
labels_pd = labels.toPandas()

labels_pd

Unnamed: 0,id,patient_id,scan_id,n_slice,label
0,0::3131,0,3131,285,CP
1,0::3132,0,3132,42,CP
2,0::3133,0,3133,290,CP
3,0::3134,0,3134,37,CP
4,0::3135,0,3135,269,CP
...,...,...,...,...,...
4173,1919::374,1919,374,99,Normal
4174,1920::375,1920,375,100,Normal
4175,1921::376,1921,376,80,Normal
4176,1922::377,1922,377,87,Normal


In [117]:
labels.select("patient_id").count(), labels.select("patient_id").distinct().count()

(4178, 2742)

## Do any patient_ids have more than 1 label?

In [118]:
from pyspark.sql import functions as F

In [120]:
labels\
    .groupBy("patient_id")\
    .agg(F.countDistinct("label").alias("num_labels"))\
    .filter("num_labels > 1")\
    .count()

0

## Check number labels with / without unique patient ids

In [138]:
total_labels = labels.count()
total_slices = labels.selectExpr("sum(n_slice) AS total").first().total

total_labels, total_slices

(4178, 411529)

In [139]:
labels\
    .groupBy("label")\
    .agg(
        F.count("*").alias("count"),
        F.sum("n_slice").alias("n_slice")
    )\
    .withColumn("count_pct", F.expr(f"ROUND(count / {total_labels}, 4)"))\
    .withColumn("slice_pct", F.expr(f"ROUND(n_slice / {total_slices}, 4)"))\
    .orderBy("label")\
    .show()

+------+-----+-------+---------+---------+
| label|count|n_slice|count_pct|slice_pct|
+------+-----+-------+---------+---------+
|    CP| 1556| 159702|   0.3724|   0.3881|
|   NCP| 1544| 156071|   0.3696|   0.3792|
|Normal| 1078|  95756|    0.258|   0.2327|
+------+-----+-------+---------+---------+



In [149]:
unique_patients = labels.select("patient_id").distinct().count()
labels\
    .dropDuplicates(["patient_id"])\
    .groupBy("label")\
    .count()\
    .withColumn("pct", F.expr(f"ROUND(count / {unique_patients}, 4)"))\
    .orderBy("label")\
    .show()

+------+-----+------+
| label|count|   pct|
+------+-----+------+
|    CP|  964|0.3516|
|   NCP|  929|0.3388|
|Normal|  849|0.3096|
+------+-----+------+



## Check labels with metadata only

In [151]:
total_labels_with_metadata = labels_with_metadata.count()
total_slices_with_metadata = labels_with_metadata.selectExpr("sum(n_slice) AS total").first().total

print(f"Total labels with metadata: {total_labels_with_metadata}")
print(f"Total slices with metadata: {total_slices_with_metadata}")

labels_with_metadata\
    .groupBy("label")\
    .agg(
        F.count("*").alias("count"),
        F.sum("n_slice").alias("n_slice")
    )\
    .withColumn("count_pct", F.expr(f"ROUND(count / {total_labels_with_metadata}, 4)"))\
    .withColumn("slice_pct", F.expr(f"ROUND(n_slice / {total_slices_with_metadata}, 4)"))\
    .orderBy("label")\
    .show()

Total labels with metadata: 378
Total slices with metadata: 31616
+------+-----+-------+---------+---------+
| label|count|n_slice|count_pct|slice_pct|
+------+-----+-------+---------+---------+
|    CP|  170|  16084|   0.4497|   0.5087|
|   NCP|   13|    661|   0.0344|   0.0209|
|Normal|  195|  14871|   0.5159|   0.4704|
+------+-----+-------+---------+---------+



In [152]:
unique_patients_with_metadata = labels_with_metadata.select("patient_id").distinct().count()

print(f"There are {unique_patients_with_metadata} unique patients with metadata")
labels_with_metadata\
    .dropDuplicates(["patient_id"])\
    .groupBy("label")\
    .count()\
    .withColumn("pct", F.expr(f"ROUND(count / {unique_patients_with_metadata}, 4)"))\
    .orderBy("label")\
    .show()

There are 276 unique patients with metadata
+------+-----+------+
| label|count|   pct|
+------+-----+------+
|    CP|   99|0.3587|
|   NCP|   13|0.0471|
|Normal|  164|0.5942|
+------+-----+------+



## Conclusion

There is almost no metadata for patients with NCP (there's metadata only for 13). It could be usable if we only want to consider e.g. CP VS Normal, but won't be useful for NCP.

In [29]:
import sweetviz as sv

In [108]:
report = sv.analyze(labels_pd)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), la…




HBox(children=(HTML(value=''), FloatProgress(value=0.0, layout=Layout(flex='2'), max=5.0), HTML(value='')), la…




HBox(children=(HTML(value=''), FloatProgress(value=0.0, layout=Layout(flex='2'), max=1.0), HTML(value='')), la…




In [109]:
report.show_html()

Report SWEETVIZ_REPORT.html was generated! NOTEBOOK/COLAB USERS: the web browser MAY not pop up, regardless, the report IS saved in your notebook/colab files.
