# Process data for Transformer

In [None]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm

In [None]:
#read csv folder

"""
import zipfile
zip_path = '/content/drive/MyDrive/visit_1.zip'
save_path = '/content/visit_1'

if not os.path.exists(save_path):
  os.makedirs(save_path)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  zip_ref.extractall(save_path)
"""
extract_path = '../visit_1'
csv_files = [f for f in os.listdir(extract_path) if f.endswith('.csv')]

In [None]:
#read excel file
xlxs_path = '../shhs1_ahi_pruebas.xlsx'
xlxs_data = pd.read_excel(xlxs_path)

In [None]:
#process data
from joblib import Parallel, delayed

def process_csv(csv_file, extract_path):
    match = re.search(r'(\d+)_extraction', csv_file)
    if match:
        current_id = match.group(1)
        current_data = pd.read_csv(os.path.join(extract_path, csv_file))
        current_data['ID'] = current_id
        return current_data
    else:
        return pd.DataFrame()

In [None]:
# use parallel processors
num_processes = 6 #change any number of processes you want to use
all_data = pd.DataFrame()

processed_data = Parallel(n_jobs=num_processes)(
    delayed(process_csv)(csv_file, extract_path)
    for csv_file in tqdm(csv_files)
)

all_data = pd.concat([df for df in processed_data if not df.empty], ignore_index=True) # type: ignore

channels = ['H.R.', 'SaO2', 'ABDO RES', 'THOR RES', 'AIRFLOW','ID']
all_data = all_data[channels]

# Reset index after concatenation
all_data.reset_index(drop=True, inplace=True)

In [None]:
#create data that channel as feature, ahi as label
xlxs_data['ID'] = xlxs_data['ID'].astype(int)
all_data['ID'] = all_data['ID'].astype(int)

merged_data = pd.merge(all_data, xlxs_data, on='ID', how='left')


merged_data['nsrr_ahi_hp3r_aasm15'] = merged_data.groupby('ID')['nsrr_ahi_hp3r_aasm15'].transform('first')
merged_data['nsrr_ahi_hp4u_aasm15'] = merged_data.groupby('ID')['nsrr_ahi_hp4u_aasm15'].transform('first')
'''
merged_data = merged_data.drop(['nsrr_ahi_hp4u_aasm15 ', 'nsrr_ahi_hp3r_aasm15'], axis=1)

'''
merged_data = merged_data.fillna(0)

merged_data = merged_data.astype(float)

In [None]:
#Process features and labels
import numpy as np

X = merged_data[['H.R.', 'SaO2', 'ABDO RES', 'THOR RES', 'AIRFLOW']].values
Y = merged_data['nsrr_ahi_hp3r_aasm15'].values

def categorize_ahi(ahi):
  if ahi < 5:
    return 0 # no OSA
  elif ahi < 15:
    return 1 # mild
  elif ahi < 30:
    return 2 # moderate
  else:
    return 3 # severe

processed_data = Parallel(n_jobs=num_processes)(
    delayed(categorize_ahi)(ahi) for ahi in tqdm(Y)
)

Y = np.array(processed_data, dtype=np.int64)

In [None]:
directory = "../transformer_data"
os.makedirs(directory, exist_ok = True)
x_dir = os.path.join(directory,"x.npy")
y_dir = os.path.join(directory,"y.npy")
np.save(x_dir, X)
np.save(y_dir, Y)

# Process data for two-tower Transformer

In [None]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm
import numpy as np
from joblib import Parallel, delayed

In [None]:
# Read CSV
visit_path = '../visit_1'
visit_files = [f for f in os.listdir(visit_path) if f.endswith('.csv')]

aout_path = '../6000_filter'
aout_files = [f for f in os.listdir(aout_path) if f.endswith('.csv')]

# Read xlxs
xlxs_path = '../shhs1_ahi_pruebas.xlsx'
xlxs_data = pd.read_excel(xlxs_path)

In [None]:
# visit
def process_visit_csv(csv_file, visit_path):
    match = re.search(r'(\d+)_extraction_(\d+)', csv_file)
    if match:
        current_id = match.group(1)
        extraction_num = int(match.group(2))
        current_data = pd.read_csv(os.path.join(visit_path, csv_file))
        current_data['ID'] = current_id
        current_data['Extraction'] = extraction_num
        return current_data
    else:
        return pd.DataFrame()

In [None]:
num_processes = 30
visit_data = Parallel(n_jobs=num_processes)(
    delayed(process_visit_csv)(csv_file, visit_path) for csv_file in tqdm(visit_files)
)
visit_data = pd.concat([df for df in visit_data if not df.empty], ignore_index=True)

In [None]:
#Aout
def process_aout_csv(csv_file, aout_path):
    match = re.search(r'(\d+)_extraction_(\d+)', csv_file)
    if match:
        current_id = match.group(1)
        extraction_num = int(match.group(2))
        file_path = os.path.join(aout_path, csv_file)
        
        try:
            current_data = pd.read_csv(file_path, header=None)         

            current_data = current_data.to_numpy()
            id_column = np.full((current_data.shape[0], 1), current_id)
            extraction_column = np.full((current_data.shape[0], 1), extraction_num)
            combined_data = np.hstack((current_data, extraction_column, id_column))
            
            return combined_data
        except Exception as e:
            print(f"Error reading {csv_file}: {e}")
            return np.array([])  
    else:
        return np.array([]) 

In [None]:
processed_data = Parallel(n_jobs=num_processes)(
    delayed(process_aout_csv)(csv_file, aout_path)
    for csv_file in aout_files
)

processed_data = [data for data in processed_data if data.size > 0]

if processed_data:
    all_data = np.vstack(processed_data)
    columns = [f'Feature_{i}' for i in range(all_data.shape[1] - 2)] + ['Extraction', 'ID']
    aout_data = pd.DataFrame(all_data, columns=columns)
    if 'Feature_1872' in aout_data.columns:
        aout_data.drop(columns=['Feature_1872'], inplace=True)
    print(aout_data.head())
else:
    aout_data = pd.DataFrame()
    print("No data processed.")

In [None]:
common_ids = set(aout_data['ID'].astype(int)).intersection(set(visit_data['ID'].astype(int)))
visit_data = visit_data[visit_data['ID'].astype(int).isin(common_ids)]
aout_data = aout_data[aout_data['ID'].astype(int).isin(common_ids)]

In [None]:
visit_data = visit_data.fillna(0)
aout_data = aout_data.fillna(0)
print(visit_data.shape)
print(aout_data.shape)
print(aout_data['ID'])

directory = "../transformer_data"

In [None]:
"""
    From original 12 channels to 5 channels. If you directly calculate the Aout matrix (LRIA analysis) 
    using 5 channels, then ignore this part.

"""
# Aout
aout_data = aout_data.sort_values(by=['ID', 'Extraction'])
print(aout_data['ID'])
print(aout_data['Extraction'])

# Delete column 'ID' and 'Extraction'
aout_data = aout_data.drop(columns=['ID', 'Extraction'])

# Delete first 1728 columns
aout_data = aout_data.iloc[:, 1728:]

selected_columns = [
    0, 1, 8, 9, 10, 12, 13, 20, 21, 22, 
    96, 97, 104, 105, 106, 108, 109, 116, 117, 
    118, 120, 121, 128, 129, 130
]
aout_data = aout_data.iloc[:, selected_columns]

aout_data_array = aout_data.to_numpy(dtype=float)

swap_pairs = [
    (0, 1), (2, 3), (5, 6), (7, 8),
    (10, 11), (12, 13), (15, 16),
    (17, 18), (20, 21), (22, 23)
]

for i, j in swap_pairs:
    aout_data_array[:, [i, j]] = aout_data_array[:, [j, i]]

# Visit array
aout_data_3d = np.zeros((5556, 5, 5))

for i in range(5556):
    for j in range(5):
        aout_data_3d[i, j, :] = aout_data_array[i, j*5:(j+1)*5]

np.save(os.path.join(directory, "aout_data_3d.npy"), aout_data_3d)

In [None]:
# label
unique_ids = sorted(common_ids)
num_ids = len(unique_ids)
num_extractions = 3 

xlxs_data['ID'] = xlxs_data['ID'].astype(int)
labels = xlxs_data[xlxs_data['ID'].isin(unique_ids)]
labels = labels.set_index('ID').loc[unique_ids, 'nsrr_ahi_hp3r_aasm15'].values

def categorize_ahi(ahi):
    if ahi < 5:
        return 0
    elif ahi < 15:
        return 1
    elif ahi < 30:
        return 2
    else:
        return 3

expanded_labels = np.array([categorize_ahi(ahi) for ahi in labels for _ in range(num_extractions)])
np.save(os.path.join(directory, "labels.npy"), expanded_labels)

In [None]:
# Visit
channels = ['H.R.', 'SaO2', 'ABDO RES', 'THOR RES', 'AIRFLOW']
visit_data = visit_data[visit_data['Extraction'].isin([1, 2, 3])]


unique_ids = visit_data['ID'].unique()
num_ids = len(unique_ids)
num_extractions = 3  
num_samples = 36000
num_channels = len(channels)


visit_data_array = np.zeros((num_ids * num_extractions, num_samples, num_channels))


def process_id_data(current_id):
    id_data = visit_data[visit_data['ID'] == current_id]
    id_data = id_data.sort_values(by=['Extraction', 'ID'])
    result = []
    for extraction_num in range(1, num_extractions + 1):
        extraction_data = id_data[id_data['Extraction'] == extraction_num]
        temp_array = np.zeros((num_samples, num_channels))
        for j, channel in enumerate(channels):
            temp_array[:len(extraction_data), j] = extraction_data[channel].values
        result.append(temp_array)
    return result

processed_results = Parallel(n_jobs=num_processes)(
    delayed(process_id_data)(current_id) for current_id in unique_ids
)

for i, id_result in enumerate(processed_results):
    for extraction_num, temp_array in enumerate(id_result):
        index = i * num_extractions + extraction_num
        visit_data_array[index] = temp_array

print(visit_data_array.shape)
np.save(os.path.join(directory, "visit_data.npy"), visit_data_array)