In [1]:
from pyspark.sql import SparkSession
 
spark = SparkSession.builder \
        .appName("comm") \
        .getOrCreate()

from pyspark.sql.types import StructType, StructField, DoubleType, StringType
from pyspark.ml.feature import StringIndexer
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, floor, concat, lit


schema = StructType([
    StructField(str(i), DoubleType(), True) for i in range(2700)
] + [
    StructField('2700', StringType(), True),
    StructField('2701', StringType(), True),
    StructField('2702', StringType(), True)
])

path = "./data/data_9s.csv"
df = spark.read.schema(schema).option("header", True).csv(path)

indexer = StringIndexer(inputCol='2702', outputCol='2704') 
indexed = indexer.fit(df).transform(df)
df = indexed.withColumn("subid", floor(col("2704") % 5)).withColumn("subgroup", concat(col("2701"), lit("_subroup"), col("subid")))
# show the last 10 columns of a piece of dataframe we are going to process
df.select(df.columns[-10:]).show(5)
df_org = df

+------+-----+-----+-----+----+-----+------+------+-----+--------------+
|  2696| 2697| 2698| 2699|2700| 2701|  2702|  2704|subid|      subgroup|
+------+-----+-----+-----+----+-----+------+------+-----+--------------+
| -50.0|-49.0|-48.0|-45.0|   N|train|A00001|1247.0|    2|train_subroup2|
|  10.0|  9.0|  9.0|  8.0|   N|train|A00001|1247.0|    2|train_subroup2|
|  46.0| 42.0| 36.0| 29.0|   N|train|A00001|1247.0|    2|train_subroup2|
|   1.0|  4.0|  7.0| 10.0|   N| test|A00002|1248.0|    3| test_subroup3|
|-112.0|-88.0|-66.0|-31.0|   N| test|A00002|1248.0|    3| test_subroup3|
+------+-----+-----+-----+----+-----+------+------+-----+--------------+
only showing top 5 rows



In [2]:
# %pip install pytictoc
# using test as a speed comparison
from pytictoc import TicToc
import numpy as np
import torch
from pyspark.ml.feature import OneHotEncoder, StringIndexer
from pyspark.ml import Pipeline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import torch
import os

# loop through train / valid / test
g = 'test'
if not os.path.exists(g): 
    os.makedirs(g)
df = df_org.filter((df_org['2701'] == g))


t = TicToc() #create instance of class
t.tic() #Start timer
# engineer and save patient id information
pid = df.select(df.columns[2704])
pid.show(5)
# pid.write.mode('overwrite').parquet("./"+str(g)+"/pid.parquet")

# engineer and save onehot outcome 
y = df.select(df.columns[2700])
# StringIndexer step is necessary if the original column is categorical (non-numeric)
string_indexer = StringIndexer(inputCol=y.columns[0], outputCol="indexed")
encoder = OneHotEncoder(inputCol="indexed", outputCol="encoded")
# Define a pipeline with the required stages
pipeline = Pipeline(stages=[string_indexer, encoder])
# Fit the pipeline to the data and transform
model = pipeline.fit(y)
y = model.transform(y)
# Show the result
y.show(5)
# y.write.mode('overwrite').parquet("./"+str(g)+"/y.parquet")


# engineer and convert 1d signal of each row of x into a 3d RGB spectrogram and flatten it in a list
x = df.select(df.columns[:2700])
# write a function to apply to each row of x
def x2spec(row):
    signal_data = list(row)

    # Create the spectrogram
    fig, ax = plt.subplots(figsize=(5, 3))
    ax.specgram(signal_data, Fs=300)
    plt.axis("off")
    plt.tight_layout()
    plt.close()
    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    x_3d_spec_r = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)

    # Convert to tensor, normalize and flatten
    x_3d_spec_r = torch.Tensor(x_3d_spec_r).permute(2, 0, 1).unsqueeze(0)
    x_3d_spec_r = x_3d_spec_r / 255.0
    x_1d_spec_r = x_3d_spec_r.flatten()

    # Convert the tensor to a list for PySpark compatibility
    return [x_1d_spec_r.tolist()]

# Apply the function to the RDD and convert to DataFrame
x_rdd = x.rdd.map(x2spec)
x_df = x_rdd.toDF(["spec_flat"])
x_df.show(2)
# x_df.write.mode('overwrite').parquet("./"+str(g)+"/x_spec_flat.parquet")

t.toc() #Time elapsed since t.tic() # Elapsed time is 15.892209 seconds.

+-----+
|subid|
+-----+
|    3|
|    3|
|    3|
|    3|
|    3|
+-----+
only showing top 5 rows

+----+-------+-------------+
|2700|indexed|      encoded|
+----+-------+-------------+
|   N|    0.0|(3,[0],[1.0])|
|   N|    0.0|(3,[0],[1.0])|
|   N|    0.0|(3,[0],[1.0])|
|   N|    0.0|(3,[0],[1.0])|
|   N|    0.0|(3,[0],[1.0])|
+----+-------+-------------+
only showing top 5 rows

+--------------------+
|           spec_flat|
+--------------------+
|[1.0, 1.0, 1.0, 1...|
|[1.0, 1.0, 1.0, 1...|
+--------------------+
only showing top 2 rows

Elapsed time is 15.892209 seconds.


In [15]:
# 20 cores
