In [None]:
# ----------------------------
# Script: elastic_net_feature_selection.py
# Purpose: Perform Elastic Net feature selection on gene expression data
# Author: [Fahiz Mohammed PP]
# ----------------------------

import pandas as pd
import numpy as np
from sklearn.linear_model import ElasticNetCV
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import SelectFromModel
import matplotlib.pyplot as plt

# ----------------------------
# Step 1: Load Data
# ----------------------------

# Load expression matrix (rows = genes/features, columns = samples)
# Transpose if needed so that rows = samples, columns = genes
data = pd.read_csv("data/batch_corrected_lncRNA_expression.csv", index_col=0)

# Load labels (tumor = 1, normal = 0), ensure same sample order
labels = pd.read_csv("data/sample_labels.csv", index_col=0)
y = labels['label'].values  # Binary classification: 0 or 1

# Transpose expression data if rows are genes
if data.shape[0] != len(y):
    data = data.T

X = data.values
feature_names = data.columns

# ----------------------------
# Step 2: Preprocessing
# ----------------------------

# Scale features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# ----------------------------
# Step 3: Elastic Net CV
# ----------------------------

# ElasticNetCV with L1_ratio near 1 for sparse selection
elastic_net = ElasticNetCV(l1_ratio=0.9, cv=5, random_state=42, max_iter=10000)
elastic_net.fit(X_scaled, y)

# ----------------------------
# Step 4: Select Features
# ----------------------------

# Get non-zero coefficients
model = SelectFromModel(elastic_net, prefit=True)
selected_idx = model.get_support(indices=True)
selected_features = feature_names[selected_idx]

# Save selected features
pd.Series(selected_features).to_csv("results/selected_lncRNA_features.csv", index=False)

# ----------------------------
# Step 5: Plot Coefficients (Optional)
# ----------------------------

coefs = elastic_net.coef_
non_zero_idx = np.where(coefs != 0)[0]

plt.figure(figsize=(10, 6))
plt.bar(range(len(non_zero_idx)), coefs[non_zero_idx])
plt.xticks(range(len(non_zero_idx)), feature_names[non_zero_idx], rotation=90)
plt.title("Elastic Net Selected lncRNAs (Non-zero Coefficients)")
plt.ylabel("Coefficient Value")
plt.tight_layout()
plt.savefig("results/elastic_net_coefficients_plot.png")
plt.close()

print(f"Selected {len(selected_features)} features using Elastic Net.")

