In [2]:
import numpy as np
import pandas as pd
import logging
from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt

from lab_scripts.models.baselines import baseline_linear
from lab_scripts.models.baselines import baseline_mean

In [3]:
adata_gex = ad.read_h5ad("data/official/explore/multiome/multiome_gex_processed_training.h5ad")
adata_atac = ad.read_h5ad("data/official/explore/multiome/multiome_atac_processed_training.h5ad")

In [4]:
train_cells = adata_gex.obs_names[adata_gex.obs["batch"] != "s2d4"]
test_cells  = adata_gex.obs_names[adata_gex.obs["batch"] == "s2d4"]

In [5]:
# This will get passed to the method
input_train_mod1 = adata_atac[train_cells]
input_train_mod2 = adata_gex[train_cells]
input_test_mod1 =  adata_atac[test_cells]

# This will get passed to the metric
true_test_mod2 =  adata_gex[test_cells]

In [6]:
pred_test_mod2 = baseline_linear.fit_predict(input_train_mod1, input_train_mod2, input_test_mod1)

In [7]:
pred_test_mod2.X.toarray()

array([[0.01849598, 0.05648608, 0.12098093, ..., 2.9789195 , 0.16666089,
        0.01165176],
       [0.0088969 , 0.0592952 , 0.09541178, ..., 3.0214188 , 0.14971086,
        0.00796152],
       [0.01026978, 0.05236913, 0.09989016, ..., 3.1053221 , 0.13488059,
        0.00804165],
       ...,
       [0.01445955, 0.0459733 , 0.2051502 , ..., 3.4173088 , 0.1045588 ,
        0.00479589],
       [0.02190602, 0.05895532, 0.13886905, ..., 2.695336  , 0.14503315,
        0.00887601],
       [0.02549585, 0.0608732 , 0.13784692, ..., 2.7734642 , 0.1414175 ,
        0.00894089]], dtype=float32)

In [8]:
def calculate_rmse(true_test_mod2, pred_test_mod2):
    try:
        X = pred_test_mod2.X.toarray()
    except:
        X = pred_test_mod2.X
    if pred_test_mod2.var["feature_types"][0] == "GEX":
        return  mean_squared_error(true_test_mod2.layers["log_norm"].toarray(), X, squared=False)
    else:
        raise NotImplementedError("Only set up to calculate RMSE for GEX data")

In [9]:
for method in [baseline_linear.fit_predict, baseline_mean.fit_predict]:
    # Run prediction
    pred_test_mod2 = method(input_train_mod1, input_train_mod2, input_test_mod1)
    # Calculate RMSE
    rmse = calculate_rmse(true_test_mod2, pred_test_mod2)
    # Print results
    print(f' had a RMSE of {rmse:.4f}')

 had a RMSE of 0.3955
 had a RMSE of 0.2535
