# Preprocessing

## Imports

### (ATTENTION: To use the standard pandas not the GPU version, uncomment the pandas import and comment the cudf comment)

In [None]:
import cudf as pd
# import pandas as pd
from cudf import DataFrame
import os
import time
import numpy as np

## Global variables

In [None]:
start_time = time.time()

path_to_data = "data/"

## Actual Preprocessing

### Preprocessing healthy: dataset | Drop operation

In [None]:
raw_healthy_csv_path = os.path.join(path_to_data, "raw_healthy_data.csv")

raw_healthy_df = pd.read_csv(raw_healthy_csv_path, index_col=0, header=0,  sep=',')

raw_healthy_df.drop('Description', axis=1, inplace=True)

if raw_healthy_df.dtypes.nunique() > 1:
    raise ValueError("DataFrame contains multiple data types, which is not supported.")

raw_healthy_df = raw_healthy_df.T

### Preprocessing unhealthy: dataset | Drop operation

In [None]:
raw_unhealthy_csv_path = os.path.join(path_to_data, "raw_unhealthy_data.csv")

raw_unhealthy_df = pd.read_csv(raw_unhealthy_csv_path, nrows=1, index_col=0, header=0, sep=',')

b_all_gene_markers = [
    'CD19', 'CD22', 'CD79A', 'CD10', 'BCR-ABL1', 'ETV6-RUNX1',
    'TCF3-PBX1', 'KMT2A-AF4', 'CRLF2', 'IKZF1', 'PAX5', 'JAK1', 'JAK2',
    'ABL1', 'ABL2', 'CSF1R', 'PDGFRB', 'NRAS', 'KRAS', 'PTPN11'
]

raw_unhealthy_df.head()

# all_genes_in_row  = raw_unhealthy_df.iloc[0]

# print(f"All genes in the first row: {all_genes_in_row.to_arrow().to_pylist()}")

# print(f"b_ALL gene markers: {b_all_gene_markers}")
# print(f"Genes in the first row: {all_genes_in_row.tolist()}")

# b_all_genes_mask = all_genes_in_row.isin(b_all_gene_markers)

# b_all_genes = all_genes_in_row[b_all_genes_mask]

# print(b_all_genes)

# genes_found = pd.Series(all_genes_in_row).isin(b_all_gene_markers)

# print(f"Genes found in the first row: {genes_found.tolist()}")

# raw_unhealthy_df.head()

# print(raw_unhealthy_df.iloc[1])

# print(f"Genes found in the DataFrame: {genes_found.tolist()}")

# has_b_all_marker = raw_unhealthy_df.iloc[0].isin(b_all_gene_markers)

# print(f"Does the DataFrame contain any B-ALL gene markers? {has_b_all_marker}")

# raw_unhealthy_df.drop('gene_name', inplace=True)
# raw_unhealthy_df.drop('gene_type', inplace=True)

# raw_unhealthy_df.head()

### Preprocessing healthy: gencode processor

In [None]:
temp_column_names = raw_healthy_df.columns.str.split('.').str[0]

columns_to_keep_mask = ~temp_column_names.duplicated(keep='first')

print("Total columns before removal:", len(raw_healthy_df.columns))

raw_healthy_df_cleaned = raw_healthy_df.loc[:, columns_to_keep_mask]
raw_healthy_df_cleaned.columns = temp_column_names[columns_to_keep_mask]

raw_healthy_df = raw_healthy_df_cleaned

del raw_healthy_df_cleaned  # Free memory

print("Total columns after removal:", len(raw_healthy_df.columns))

### Preprocessing unhealthy: gencode processor

In [None]:
temp_column_names = raw_unhealthy_df.columns.str.split('.').str[0]

columns_to_keep_mask = ~temp_column_names.duplicated(keep='first')

print("Total columns before removal:", len(raw_unhealthy_df.columns))

raw_unhealthy_df_cleaned = raw_unhealthy_df.loc[:, columns_to_keep_mask]
raw_unhealthy_df_cleaned.columns = temp_column_names[columns_to_keep_mask]

raw_unhealthy_df = raw_unhealthy_df_cleaned

del raw_unhealthy_df_cleaned  # Free memory

print("Total columns after removal:", len(raw_unhealthy_df.columns))

### Preprocessing healthy: Convert dtypes to int32 and drop NaN

In [None]:
raw_healthy_df = raw_healthy_df.astype(np.int32)
raw_healthy_df.dropna(axis=1, inplace=True)

### Preprocessing unhealthy: Convert dtypes to int32 and drop NaN

In [None]:
raw_unhealthy_df = raw_unhealthy_df.astype(np.int32)
raw_unhealthy_df.dropna(axis=1, inplace=True)

### Preprocessing: selecting only common genes

In [None]:
matching_genes = raw_healthy_df.columns.intersection(raw_unhealthy_df.columns)

unhealthy_df_rows_length = len(raw_unhealthy_df)

raw_healthy_df_filtered = raw_healthy_df[matching_genes]

raw_unhealthy_df_filtered = raw_unhealthy_df[matching_genes]

raw_healthy_df_filtered = raw_healthy_df_filtered.drop(raw_healthy_df_filtered.index[-(len(raw_healthy_df_filtered) - unhealthy_df_rows_length):])

if len(raw_healthy_df_filtered) != unhealthy_df_rows_length:
    raise ValueError("The number of rows in the healthy DataFrame does not match the unhealthy DataFrame after slicing.")

print(f"Healthy DataFrame rows after slicing: {len(raw_healthy_df_filtered)}")
print(f"Unhealthy DataFrame rows: {unhealthy_df_rows_length}")

print("Healthy DataFrame columns:", len(raw_healthy_df_filtered.columns))
print("Unhealthy DataFrame columns:", len(raw_unhealthy_df_filtered.columns))

raw_healthy_df = raw_healthy_df_filtered

raw_unhealthy_df = raw_unhealthy_df_filtered

del raw_healthy_df_filtered
del raw_unhealthy_df_filtered

### Preprocessing: add condition column

In [None]:
raw_unhealthy_df["condition"] = 0
raw_healthy_df["condition"] = 1

### Preprocessing: change index name

In [None]:
# raw_unhealthy_df_filtered.index.name = "sample_id"
# raw_healthy_df_filtered.index.name = "sample_id"

## Debug checks (can be commented out)

In [None]:
# raw_healthy_df_filtered.info()
# raw_healthy_df_filtered.head()

In [None]:
# raw_healthy_df_filtered.info()
# raw_unhealthy_df_filtered.head()

## Final checks before merge

In [None]:
are_column_names_same_regardless_order = set(raw_healthy_df.columns) == set(raw_unhealthy_df.columns)

if not are_column_names_same_regardless_order:
    raise ValueError("Column names in healthy and unhealthy DataFrames do not match.")

if raw_healthy_df.duplicated().any():
    raise ValueError("Healthy DataFrame contains duplicate rows.")

if raw_unhealthy_df.duplicated().any():
    raise ValueError("Unhealthy DataFrame contains duplicate rows.")

if raw_healthy_df.columns.duplicated().any():
    raise ValueError("Healthy DataFrame contains duplicate columns.")

if raw_unhealthy_df.columns.duplicated().any():
    raise ValueError("Unhealthy DataFrame contains duplicate columns.")

## Do merge

In [None]:
merged_df: DataFrame = pd.concat([raw_healthy_df, raw_unhealthy_df], axis=0)


merged_df.to_parquet(f"{path_to_data}merged_data.pq")

In [None]:
end_time = time.time()

print(f"Data processing completed in {end_time - start_time:.2f} seconds.")

# Preprocessing V2

## Imports

In [None]:
import cudf as pd
from cudf import DataFrame
import time
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from pandas import melt, merge
import numpy as np
import cupy as cp

## Global Variables

In [None]:
start_time = time.time()

path_to_data = "data/"

## Preprocessing: T-Cell extraction

In [None]:
df_mixed_all = pd.read_parquet(f"{path_to_data}ALL.pq")
df_b_all = pd.read_parquet(f"{path_to_data}B_ALL.pq")

t_cell_genes = [
    'CD2', 'CD3D', 'CD3E', 'CD3G', 'CD4', 'CD6', 'CD7', 'TCF7', 'GATA3',
    'HOXA1', 'HOXA2', 'HOXA3', 'HOXA5', 'HOXA6', 'HOXA9', 'HOXA11', 'HOXA13',
    'LMO3', 'LAG3', 'FOXP3', 'ITGAL', 'TNFRSF9', 'TCL1A', 'NOTCH3'
]

df_mixed_all.columns = df_mixed_all.columns.str.strip()

df_b_all.columns = df_b_all.columns.str.strip()

In [None]:
t_all_samples = df_mixed_all[df_mixed_all['gene_name'].isin(t_cell_genes)]

In [None]:
t_all_samples_dropped_cols = t_all_samples.drop(columns=['gene_name', 'gene_type'])

df_b_all_dropped_cols = df_b_all.drop(columns=['gene_name', 'gene_type'])

t_all_transposed = t_all_samples_dropped_cols.T

b_all_transposed = df_b_all_dropped_cols.T

In [None]:
b_all_transposed.head()

In [None]:
t_all_transposed.head()

In [None]:
combined_df: DataFrame = pd.concat([b_all_transposed, t_all_transposed])

In [None]:
# Verify the lengths of the DataFrames
all_length = len(t_all_transposed)
b_all_length = len(b_all_transposed)
sum_length = all_length + b_all_length
combined_length = len(combined_df)

print(f"Length of ALL samples: {all_length}")
print(f"Length of B-ALL samples: {b_all_length}")
print(f"Sum of lengths: {sum_length}")
print(f"Length of combined DataFrame: {combined_length}")

if sum_length != combined_length:
    raise ValueError("The combined DataFrame does not match the sum of the individual DataFrames' lengths.")

In [None]:
combined_df.reset_index(inplace=True)

combined_df.rename(columns={'index': 'sample_id'}, inplace=True)

combined_df.head()

In [None]:
numeric_df = combined_df.to_pandas().select_dtypes(include="number")

cupy_array = cp.asarray(numeric_df.values)

x_log2 = cp.log2(cupy_array + 1)

normalized_df = pd.DataFrame(cp.asnumpy(x_log2), columns=numeric_df.columns, index=numeric_df.index)
normalized_df = pd.concat([combined_df[['sample_id']], normalized_df], axis=1)

normalized_df.head()

In [None]:
t_all_labels = pd.Series(['T-ALL'] * t_all_transposed.shape[0], index=t_all_transposed.index)

final_combined_labels = pd.concat([
    pd.Series(['B-ALL'] * b_all_transposed.shape[0], index=b_all_transposed.index),
    t_all_labels
])

final_combined_labels = final_combined_labels.to_pandas()

In [None]:
melted_combined_df = melt(normalized_df.to_pandas(), id_vars=['sample_id'], var_name='Gene', value_name='Expression')

melted_combined_df.head()

In [None]:
merged_combined_df = merge(melted_combined_df, final_combined_labels.rename("Diagnosis"), left_on='sample_id', right_index=True)

In [None]:
merged_combined_df.head()

In [None]:
genes_to_plot = [
    'ENSG00000000003.15',
    "ENSG00000005073.6"
]

plot_data_filtered_genes = merged_combined_df[merged_combined_df['Gene'].isin(genes_to_plot)].copy()

In [None]:
# --- Create the plot ---
plt.figure(figsize=(15, 6 * len(genes_to_plot))) # Adjust figure size dynamically based on number of genes
# Using a FacetGrid to create a separate plot for each gene
g = sns.catplot(
    data=plot_data_filtered_genes,
    x='Diagnosis',      # Categorical variable on the x-axis
    y='Expression',     # Numerical variable on the y-axis
    col='Gene',         # Create separate columns of plots for each gene
    col_wrap=len(genes_to_plot), # Wrap columns if you have many genes, adjust as needed
    kind='box',         # Use 'box' for box plot, or 'violin' for violin plot
    height=5,           # Height of each facet
    aspect=1.2,         # Aspect ratio of each facet
    palette='viridis',  # Color palette for diagnoses
    sharey=False,       # Allow each gene plot to have its own y-axis scale
)

# --- Customize the plots ---
g.set_axis_labels("Diagnosis", "Gene Expression (log2CPM or similar)")
g.set_titles(col_template="{col_name}") # Set title for each column (gene name)
g.set_xticklabels(rotation=45, ha='right') # Rotate x-axis labels for better readability if diagnosis names are long

# Add a main title for the entire figure
plt.suptitle('Gene Expression Across Diagnosis Groups', y=1.02, fontsize=16) # Adjust y for title position

plt.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout to prevent title overlap
plt.show()

In [None]:
end_time = time.time()

print(f"Data processing for T-cell genes completed in {end_time - start_time:.2f} seconds.")

# Model training V2

## TODO:
1. Change condition from "subtype" to something else

## Imports

In [None]:
import cudf as pd
from cudf import DataFrame
import time
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from pandas import melt, merge
import numpy as np
import cupy as cp
import cuml as cm
from cuml.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from xgboost import XGBClassifier
from sklearn.utils import class_weight
from mpl_toolkits.mplot3d import Axes3D
from sklearn.impute import SimpleImputer

## Global Variables

In [None]:
start_time = time.time()

path_to_data = "data/"

T_ALL_LABEL = "T-ALL"
B_ALL_LABEL = "B-ALL"
ETP_LABEL = "ETP-ALL"
MPAL_LABEL = "MPAL"
AML_LABEL = "AML"

## Training

### Preperation

In [None]:
df_mixed_all = pd.read_parquet(f"{path_to_data}ALL.pq")
df_b_all = pd.read_parquet(f"{path_to_data}B_ALL.pq")
# df_aml = pd.read_parquet(f"{path_to_data}AML.pq")

t_cell_genes = [
    'CD2', 'CD3D', 'CD3E', 'CD3G', 'CD4', 'CD6', 'CD7', 'TCF7', 'GATA3',
    'HOXA1', 'HOXA2', 'HOXA3', 'HOXA5', 'HOXA6', 'HOXA9', 'HOXA11', 'HOXA13',
    'LMO3', 'LAG3', 'FOXP3', 'ITGAL', 'TNFRSF9', 'TCL1A', 'NOTCH3'
]

etp_markers = [
    "CD7", "CD34", "KIT", "CD117", "CD33", "CD13", "HLA-DR", "WT1", "LYL1",
    "MEF2C", "BAALC", "HHEX", "FLT3", "DNMT3A", "EZH2", "PHF6"
]

mpal_markers = [
    "CD3", "CD19", "CD33", "CD13", "MPO", "CD22", "CD79A", "CD10",
    "KMT2A", "RUNX1", "NPM1", "IKZF1", "CRLF2", "CSF1R", "PDGFRB"
]

t_all_samples = df_mixed_all[df_mixed_all['gene_name'].isin(t_cell_genes)]

if t_all_samples.empty:
    raise ValueError("No T-cell genes found in the mixed ALL DataFrame.")

etp_samples = df_mixed_all[df_mixed_all['gene_name'].isin(etp_markers)]

if etp_samples.empty:
    raise ValueError("No ETP-MPAL markers found in the mixed ALL DataFrame.")

mpal_samples = df_mixed_all[df_mixed_all['gene_name'].isin(mpal_markers)]

if mpal_samples.empty:
    raise ValueError("No MPAL markers found in the mixed ALL DataFrame.")

t_all_samples_dropped_cols = t_all_samples.drop(columns=['gene_name', 'gene_type'])
etp_samples = etp_samples.drop(columns=['gene_name', 'gene_type'])
mpal_samples = mpal_samples.drop(columns=['gene_name', 'gene_type'])

df_b_all_dropped_cols = df_b_all.drop(columns=['gene_name', 'gene_type'])
# df_aml_dropped_cols = df_aml.drop(columns=['gene_name', 'gene_type'])

t_all_transposed = t_all_samples_dropped_cols.T
etp_transposed = etp_samples.T
mpal_transposed = mpal_samples.T
b_all_transposed = df_b_all_dropped_cols.T
# aml_transposed = df_aml_dropped_cols.T

t_all_transposed["subtype"] = T_ALL_LABEL
etp_transposed["subtype"] = ETP_LABEL
mpal_transposed["subtype"] = MPAL_LABEL
b_all_transposed["subtype"] = B_ALL_LABEL
# aml_transposed["subtype"] = AML_LABEL

t_all_transposed.reset_index(inplace=True)
etp_transposed.reset_index(inplace=True)
mpal_transposed.reset_index(inplace=True)
b_all_transposed.reset_index(inplace=True)
# aml_transposed.reset_index(inplace=True)

t_all_transposed.rename(columns={'index': 'sample_id'}, inplace=True)
etp_transposed.rename(columns={'index': 'sample_id'}, inplace=True)
mpal_transposed.rename(columns={'index': 'sample_id'}, inplace=True)
b_all_transposed.rename(columns={'index': 'sample_id'}, inplace=True)
# aml_transposed.rename(columns={'index': 'sample_id'}, inplace=True)

In [None]:
combined_df: DataFrame = pd.concat([b_all_transposed, t_all_transposed, etp_transposed, mpal_transposed])

In [None]:
combined_df.fillna(0, inplace=True)

In [None]:
sample_ids = combined_df["sample_id"]

y = combined_df["subtype"]
x = combined_df.drop(columns=["sample_id", "subtype"])

le = cm.preprocessing.LabelEncoder()
y_encoded = le.fit_transform(y)

x_train, x_test, y_train, y_test = cm.model_selection.train_test_split(
    x, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
)

### Variance filter

In [None]:
N = 1000

var = x_train.var()

top_n_cols = var.nlargest(N).index.to_arrow().to_pylist()

x_train = x_train[top_n_cols]
x_test = x_test[top_n_cols]

### Scaler

In [None]:
scaler = cm.preprocessing.StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_test_scaled = scaler.transform(x_test)

### Add Noise

In [None]:
x_train_noisy = x_train_scaled.to_numpy()

noise = np.random.normal(loc=0, scale=0.1, size=x_train_noisy.shape)

x_train_noisy += noise

### XGBoost

In [None]:
xgb = XGBClassifier(
    use_label_encoder=False,
    eval_metric="mlogloss",
    scale_pos_weight=None,
    tree_method="gpu_hist",   # Use "hist" for CPU
    n_estimators=200,
    learning_rate=0.1
)

weight_map = {
    0: 1.0,
    1: 2.08,
    2: 2.08,
    3: 2.08
}

# Convert labels to per-sample weights
sample_weight = np.array([weight_map[label] for label in y_train.to_arrow().to_pylist()])

xgb.fit(x_train_noisy, y_train, sample_weight=sample_weight)

### Prediction

In [None]:
y_pred = xgb.predict(x_test_scaled)

### Evaluation

In [None]:
y_test_np = y_test.to_numpy()

# Accuracy
print(f"Accuracy: {accuracy_score(y_test_np, y_pred):.4f}")

# Confusion Matrix
print("Confusion Matrix:")
print(confusion_matrix(y_test_np, y_pred))

# Per-class metrics (precision, recall, F1)
print("\nClassification Report:")
print(classification_report(y_test_np, y_pred))

## Statistics and Plots

### PCA

In [None]:
x_full = combined_df.drop(columns=["sample_id", "subtype"])
y = combined_df["subtype"].to_arrow().to_pylist()

# Drop low-variance genes (e.g., keep top 2000 most variable)
gene_variances = x_full.var(axis=0)
top_genes = gene_variances.sort_values(ascending=False).head(2000).index

# Filter to only those genes
X_reduced = x_full[top_genes.to_arrow().to_pylist()]

# Impute missing gene values with column mean
imputer = SimpleImputer(strategy="mean")
x_imputed = imputer.fit_transform(X_reduced.to_pandas().values)

# Scale
scaler = cm.preprocessing.StandardScaler()
X_scaled = scaler.fit_transform(x_imputed)  # or X_filled

# PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

# Wrap and plot
pca_df = pd.DataFrame(X_pca, columns=["PC1", "PC2"])
pca_df["subtype"] = y

In [None]:
plt.figure(figsize=(8, 6))
sns.scatterplot(data=pca_df.to_pandas(), x="PC1", y="PC2", hue="subtype", style="subtype", s=60)
plt.title("PCA (Missing Genes Imputed for Plotting)")
plt.grid(True)
plt.tight_layout()
plt.show()

explained = pca.explained_variance_ratio_
print(f"Explained variance: PC1 = {explained[0]*100:.2f}%, PC2 = {explained[1]*100:.2f}%")

# Statistics and Plots

## Imports

In [None]:
import cudf as pd
import cuml as sklearn
import matplotlib.pyplot as plt
import seaborn as sns

## Global Variables

In [None]:
path_to_data = "data/"

## Plots

### PCA Scatterplot

In [None]:
df = pd.read_parquet(f"{path_to_data}merged_data.pq")

x = df.drop(columns=["condition"])

y = df["condition"]

gene_columns = x.columns
mean_healthy = x[y == 1][gene_columns].mean()
mean_unhealthy = x[y == 0][gene_columns].mean()
mean_diff = (mean_healthy - mean_unhealthy).abs()

print("\nHead of Mean Differences (for top 5 genes):")
print(mean_diff.head())

k_genes = 10000

top_k_genes = mean_diff.nlargest(k_genes).index.to_pandas()

x_selected = x[top_k_genes]

print(f"\nOriginal number of genes: {x.shape[1]}")

print(f"Number of genes after aggressive selection (top {k_genes} by mean difference): {x_selected.shape[1]}")

scaler = sklearn.preprocessing.StandardScaler()

x_scaled = scaler.fit_transform(x_selected)

print("Shape of x_scaled:", x_scaled.shape)

pca = sklearn.decomposition.PCA(n_components=2)
pca_df = pca.fit_transform(x_scaled)

print("Shape of principal components:", pca_df.shape)

pca_df.columns = ["PC1", "PC2"]
pca_df.index = x_selected.index
pca_df["condition"] = y
pca_df = pca_df.to_pandas()

print(f"Shape of PCA DataFrame: {pca_df.shape}")
print("\nExplained Variance Ratio:")
print(f"PC1: {pca.explained_variance_ratio_[0]:.4f}")
print(f"PC2: {pca.explained_variance_ratio_[1]:.4f}")
print(f"Total Explained Variance (PC1 + PC2): {pca.explained_variance_ratio_.sum():.4f}")

print("Generating PCA plot...")
plt.figure(figsize=(10, 8))
sns.scatterplot(data=pca_df, x="PC1", y="PC2", hue="condition", palette='viridis', alpha=0.7, s=50)

plt.title('PCA of Gene Expression Data (Healthy vs. Unhealthy Samples)')
plt.xlabel(f'Principal Component 1 ({pca.explained_variance_ratio_[0]*100:.2f}% Variance Explained)')
plt.ylabel(f'Principal Component 2 ({pca.explained_variance_ratio_[1]*100:.2f}% Variance Explained)')
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend(title='Condition')
plt.show()

# Model training

## Imports

In [None]:
import cudf as pd
import cuml as sklearn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.feature_selection import VarianceThreshold
import matplotlib.pyplot as plt
import seaborn as sns

## Global Variables

In [None]:
path_to_data = "data/"

## KNN

### Preparation

In [None]:
print("Loading merged data from Parquet file...")

df = pd.read_parquet(f"{path_to_data}merged_data.pq")

x = df.drop(columns=["condition"])

gene_columns = x.columns

y = df["condition"]

print("Data loaded successfully.")

### Variance thresholding

In [None]:
print("Applying variance thresholding...")

print(f"Original number of genes: {len(gene_columns)}")

x_np = x.to_numpy()

selector = VarianceThreshold(threshold=0.1)
selector.fit(x_np)

selector_gene_mask = selector.get_support()

gene_columns_temp = gene_columns[selector_gene_mask]
x_filtered_variance = x[gene_columns_temp]

print(f"Number of genes after variance thresholding: {len(gene_columns_temp)}")

### Expression filtering

In [None]:
print("Low expression filtering...")

print(f"Number of genes before low expression filtering: {len(x_filtered_variance.columns)}")

gene_means = x_filtered_variance.mean()

low_expression_threshold = 0.1

gene_columns = gene_means[gene_means > low_expression_threshold].index.to_pandas()

x_filtered_low_expression = x_filtered_variance[gene_columns]

print(f"Number of genes after low expression filtering: {len(gene_columns)}")

### Train test split

In [None]:
print("Splitting data into training and testing sets...")

x_train, x_test, y_train, y_test = sklearn.model_selection.train_test_split(
    x_filtered_low_expression, y, test_size=0.2, random_state=42, stratify=y
)

print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print("x_test shape:", x_test.shape)
print("y_test shape:", y_test.shape)

### Agressive feature selection

In [None]:
print("Aggressive gene selection based on mean differences...")

print("Number of genes before selection:", x_train.shape[1])

mean_healthy = x_train[y_train == 1][gene_columns].mean()

mean_unhealthy = x_train[y_train == 0][gene_columns].mean()

mean_diff = (mean_healthy - mean_unhealthy).abs()

k_genes = 2000

print(f"Selecting top {k_genes} genes based on mean differences...")

top_k_genes = mean_diff.nlargest(k_genes).index.to_pandas()

x_train_selected = x_train[top_k_genes]
x_test_selected = x_test[top_k_genes]

print(f"Number of genes after aggressive selection: {x_train_selected.shape[1]}")

### Scaling gene expression data

In [None]:
print("Scaling selected gene expression data...")

scaler = sklearn.preprocessing.StandardScaler()
x_train_scaled = scaler.fit_transform(x_train_selected)
x_test_scaled = scaler.transform(x_test_selected)

print("Shape of x_train_scaled:", x_train_scaled.shape)
print("Shape of x_test_scaled:", x_test_scaled.shape)

## Logistical regression

### Logistic regression (to prevent perfect seperators)

In [None]:
print("Training logistical regression model...")

logreg_model = sklearn.linear_model.LogisticRegression(
    penalty='l2',
    C=0.0001, 
    solver='qn',
    max_iter=1000, 
)

logreg_model.fit(x_train_scaled, y_train)

print("Model training completed.")

### Logistic regression prediction and evaluation

In [None]:
y_pred_logreg = logreg_model.predict(x_test_scaled)

y_test_np = y_test.to_numpy()
y_pred_logreg_np = y_pred_logreg.to_numpy()

accuracy = accuracy_score(y_test_np, y_pred_logreg_np)
precision = precision_score(y_test_np, y_pred_logreg_np, average='binary')
recall = recall_score(y_test_np, y_pred_logreg_np, average='binary')
f1 = f1_score(y_test_np, y_pred_logreg_np, average='binary')

print(f"Logistic Regression Model Performance:")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")