In [1]:
import scanpy as sc
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, root_mean_squared_error, r2_score


def train_and_eval(X_train, y_train, X_val, y_val, prefix):
    predictions = {}

    model = LinearRegression()
    model.fit(X_train, y_train)

    y_val_pred = model.predict(X_val)

    # Store predictions for bootstrapping
    predictions[f"{prefix}-true"] = y_val
    predictions[f"{prefix}-pred"] = y_val_pred

    mse = mean_squared_error(y_val, y_val_pred)
    rmse = root_mean_squared_error(y_val, y_val_pred)
    r2 = r2_score(y_val, y_val_pred)

    # Print results
    print(f"Validation {prefix} MSE: {mse:.4f}")
    print(f"Validation {prefix} RMSE: {rmse:.4f}")
    print(f"Validation {prefix} R2: {r2:.4f}\n")
    
    return predictions


target_variables_short = {
    "pearson_correlation_min_max": "pear_corr_min_max",
    "pearson_correlation_sigmoid": "pear_corr_sigmoid",
    "cosine_similarity_min_max": "cos_sim_min_max",
    "cosine_similarity_sigmoid": "cos_sim_sigmoid"
}

adata = sc.read("train_val_adata.h5ad")

adata_train, adata_val = adata[adata.obs["split"] == "train"], adata[adata.obs["split"] == "validation"]

predictions = {}

# Run on raw data (adata.X)
for target_var in target_variables_short.keys():
    target_var_short = target_variables_short[target_var]

    X_train, y_train = adata_train.X, adata_train.obs[target_var].values
    X_val, y_val = adata_val.X, adata_val.obs[target_var].values

    predictions.update(train_and_eval(X_train, y_train, X_val, y_val, f"raw-{target_var_short}"))

# Run on pca data (adata.obsm['X_pca'])
for target_var in target_variables_short.keys():
    target_var_short = target_variables_short[target_var]

    X_train, y_train = adata_train.obsm['X_pca'], adata_train.obs[target_var].values
    X_val, y_val = adata_val.obsm['X_pca'], adata_val.obs[target_var].values

    predictions.update(train_and_eval(X_train, y_train, X_val, y_val, f"pca-{target_var_short}"))

pd.DataFrame(predictions).to_csv("baseline_linreg.csv", index=False)

Validation raw-pear_corr_min_max MSE: 0.0015
Validation raw-pear_corr_min_max RMSE: 0.0392
Validation raw-pear_corr_min_max R2: 0.7243

Validation raw-pear_corr_sigmoid MSE: 0.0015
Validation raw-pear_corr_sigmoid RMSE: 0.0388
Validation raw-pear_corr_sigmoid R2: 0.7490

Validation raw-cos_sim_min_max MSE: 0.0028
Validation raw-cos_sim_min_max RMSE: 0.0525
Validation raw-cos_sim_min_max R2: 0.6915

Validation raw-cos_sim_sigmoid MSE: 0.0004
Validation raw-cos_sim_sigmoid RMSE: 0.0204
Validation raw-cos_sim_sigmoid R2: 0.9545

Validation pca-pear_corr_min_max MSE: 0.0023
Validation pca-pear_corr_min_max RMSE: 0.0474
Validation pca-pear_corr_min_max R2: 0.5953

Validation pca-pear_corr_sigmoid MSE: 0.0024
Validation pca-pear_corr_sigmoid RMSE: 0.0494
Validation pca-pear_corr_sigmoid R2: 0.5934

Validation pca-cos_sim_min_max MSE: 0.0038
Validation pca-cos_sim_min_max RMSE: 0.0619
Validation pca-cos_sim_min_max R2: 0.5708

Validation pca-cos_sim_sigmoid MSE: 0.0017
Validation pca-cos_sim_