In [None]:
import os
import pandas as pd
import joblib
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from scipy.stats import spearmanr

# FILE PATH 
input_file = "Multi_features_Models_training_data.csv"  
df = pd.read_csv(input_file)
print("✅ File loaded successfully!")
print("Shape:", df.shape)
print("First 5 rows:")
display(df.head())

# FEATURE SLICING 
start_idx   = 2
kmer_count  = 84
end_idx     = start_idx + kmer_count  
X_kmers = df.iloc[:, start_idx:end_idx]

domain_start  = end_idx
domain_count  = 1288
domain_end    = domain_start + domain_count  
X_domains = df.iloc[:, domain_start:domain_end]

length_cols = ['Total_Chromosome_Length', 'Plasmid_Length']
X_lengths = df[length_cols]

y = df['Log1p_PIRACopyNumber']

print(f"\nShapes check:")
print(f"  K-mers:   {X_kmers.shape}")
print(f"  Domains:  {X_domains.shape}")
print(f"  Lengths:  {X_lengths.shape}")
print(f"  Target:   {y.shape}")

# PREPARE OUTPUT DIRECTORIES
out_dir = "Final_files"  # GitHub-friendly relative folder
models_dir = os.path.join(out_dir, "models")
os.makedirs(models_dir, exist_ok=True)

# DEFINE FEATURE SETS
feature_sets = {
    'plasmid_length':       X_lengths[['Plasmid_Length']],
    'domains':              X_domains,
    'kmers':                X_kmers,
    'domains+plasmid_len':  pd.concat([X_domains, X_lengths[['Plasmid_Length']]], axis=1),
    'domains+kmers':        pd.concat([X_domains, X_kmers], axis=1),
    'kmers+plasmid_len':    pd.concat([X_kmers, X_lengths[['Plasmid_Length']]], axis=1),
    'domains+chrom_len':    pd.concat([X_domains, X_lengths[['Total_Chromosome_Length']]], axis=1),
    'kmers+chrom_len':      pd.concat([X_kmers, X_lengths[['Total_Chromosome_Length']]], axis=1),
    'domains+kmers+plen':   pd.concat([X_domains, X_kmers, X_lengths[['Plasmid_Length']]], axis=1),
    'domains+plen+clen':    pd.concat([X_domains, X_lengths[['Plasmid_Length', 'Total_Chromosome_Length']]], axis=1),
    'all_features':         pd.concat([X_domains, X_kmers, X_lengths], axis=1),
}

# MODEL TRAINING AND EVALUATION 
results = []

for feat_name, X in feature_sets.items():
    for rep in range(1, 4):  # 3 replicates
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=rep*42
        )

        model = RandomForestRegressor(n_estimators=100, random_state=rep*42)
        model.fit(X_train, y_train)

        # Save model
        model_file = f"{feat_name.replace('+','_')}_rep{rep}.joblib"
        joblib.dump(model, os.path.join(models_dir, model_file))

        # Predict and evaluate
        y_pred = model.predict(X_test)
        r2 = r2_score(y_test, y_pred)
        mae = mean_absolute_error(y_test, y_pred)
        mse = mean_squared_error(y_test, y_pred)
        spear = spearmanr(y_test, y_pred).correlation

        results.append({
            'feature_set': feat_name,
            'replicate': rep,
            'r2': r2,
            'spearman': spear,
            'mae': mae,
            'mse': mse
        })

# SAVE EVALUATION METRICS 
eval_df = pd.DataFrame(results)
eval_file = os.path.join(out_dir, 'evaluation_metrics.csv')
os.makedirs(out_dir, exist_ok=True)
eval_df.to_csv(eval_file, index=False)

print(f"\n✅ Models saved in: {models_dir}")
print(f"✅ Evaluation metrics saved to: {eval_file}")


# Next, we will generate a plot comparing the performance of models trained on different feature sets

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

# Configure matplotlib font and styling
plt.rcParams.update({
    'font.family': 'Arial',
    'font.size': 12,
    'axes.linewidth': 1.2,
    'xtick.major.size': 6,
    'xtick.major.width': 1.2,
    'ytick.major.size': 6,
    'ytick.major.width': 1.2,
    'axes.edgecolor': 'black',
    'figure.dpi': 300,
    'savefig.dpi': 300
})

# File paths to save the plots (GitHub-friendly)
save_path_pdf  = os.path.join(out_dir, "Figure_2B.pdf")

# Load the evaluation metrics CSV file
eval_df = pd.read_csv(os.path.join(out_dir, "evaluation_metrics.csv"))

# Clean feature_set names
eval_df['feature_set'] = (
    eval_df['feature_set']
    .str.replace(r'\+', '&', regex=True)
    .str.replace(r'\s*&\s*', ' & ', regex=True)
    .str.replace(r'\s+', ' ', regex=True)
    .str.strip()
)

def count_features(name: str) -> int:
    """Count the number of features in a feature set, capped at 4."""
    parts = [part.strip() for part in name.split('&')]
    return min(len(parts), 4)

# Compute feature complexity
eval_df['complexity'] = eval_df['feature_set'].apply(count_features)

# Compute mean and std of R² grouped by feature set
stats = (
    eval_df
    .groupby('feature_set')['r2']
    .agg(['mean', 'std'])
    .reset_index()
)

# Assign complexity
stats['complexity'] = stats['feature_set'].apply(count_features)
stats.loc[stats['feature_set'].str.lower() == 'all features', 'complexity'] = 4

# Sort by mean R² for plotting
stats = stats.sort_values('mean', ascending=True).reset_index(drop=True)

# Feature complexity colors
feature_colors = {1:'#4E79A7', 2:'#F28E2B', 3:'#E15759', 4:'#76B7B2'}
colors = stats['complexity'].map(feature_colors)

# Ensure "All Features" bar is teal
all_features_idx = stats[stats['feature_set'].str.lower() == 'all features'].index
for idx in all_features_idx:
    colors.iloc[idx] = feature_colors[4]

# Plot
plt.figure(figsize=(10, 6), dpi=1200)
y_pos = np.arange(len(stats))

# Horizontal bars
bars = plt.barh(
    y_pos,
    stats['mean'],
    color=colors,
    edgecolor='black',
    height=0.7
)

# Error bars
plt.errorbar(
    stats['mean'],
    y_pos,
    xerr=stats['std'],
    fmt='none',
    ecolor='black',
    elinewidth=1.5,
    capsize=5
)

# Individual points
for i, feature_set in enumerate(stats['feature_set']):
    r2_vals = eval_df[eval_df['feature_set'] == feature_set]['r2'].values
    plt.scatter(
        r2_vals,
        np.full_like(r2_vals, y_pos[i]),
        color='black',
        s=30,
        alpha=0.5,
        edgecolor='none',
        linewidth=0
    )

# Axis labels
plt.yticks(y_pos, stats['feature_set'], fontsize=12)
plt.xlabel(r'$R^2$', fontsize=12)
plt.xlim(0.5, 0.8)
plt.xticks([0.5, 0.6, 0.7, 0.8], fontsize=12)
plt.tick_params(axis='both', which='major', labelsize=12)
plt.grid(False)

# Legend
legend_handles = [
    Patch(color=color, label=f'{count} feature{"s" if count>1 else ""}')
    for count, color in feature_colors.items()
]
plt.legend(
    handles=legend_handles,
    title='Feature complexity',
    fontsize=12,
    title_fontsize=12,
    loc='center left',
    bbox_to_anchor=(1, 0.5),
    frameon=False
)

plt.tight_layout(rect=[0,0,0.85,1])

# Save plots
plt.savefig(save_path_pdf,  dpi=600, bbox_inches='tight')

plt.show()
print(f"✅ Plot saved to {save_path_tiff}, {save_path_svg}, and {save_path_pdf}")
