# Train a ridge regression model to convert Decima outputs to scooby

In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge
import os
import torch
import pickle

In [2]:
targets_path = 'count_target_test_no_neighbor.pq'
gene_names_path = "gene_names.pq"
decima_path = '/gstore/data/resbioai/grelu/decima/20240823/data.h5ad'

## Load target counts

In [None]:
gene_names = pd.read_parquet(gene_names_path)[0].tolist()
len(gene_names)

In [None]:
targets = pd.DataFrame(targets, index=gene_names)

## Load decima predictions

In [None]:
ad = sc.read(decima_path)

## Subset to common genes

In [None]:
common_genes = list(set(ad.var_names).intersection(gene_names))
len(common_genes)

In [None]:
ad = ad[:, common_genes]
print(ad.shape)

## Split into train / val /test sets

In [None]:
ad_train = ad[:, ad.var.dataset=='train']
ad_val = ad[:, ad.var.dataset=='val']
ad_test = ad[:, ad.var.dataset=='test']
print(ad_train.shape, ad_val.shape, ad_test.shape)

In [None]:
targets_train = targets.loc[ad_train.var_names].values
targets_val = targets.loc[ad_val.var_names].values
targets_test = targets.loc[ad_test.var_names].values
print(targets_train.shape, targets_val.shape, targets_test.shape)

## Create matrices for ridge regression

In [None]:
X_train = ad_train.layers['preds'].T
X_val = ad_val.layers['preds'].T
X_test = ad_test.layers['preds'].T
print(X_train.shape, X_val.shape, X_test.shape)

In [None]:
Y_train = np.log(targets_train + 1)
Y_val = np.log(targets_val + 1)
Y_test = np.log(targets_test + 1)
print(Y_train.shape, Y_val.shape, Y_test.shape)

In [None]:
np.save('decima_test_genes.npy', ad_test.var_names.values)

## Train Ridge regression

In [None]:
model = Ridge(alpha=250)
model.fit(X_train, Y_train)
yhat = model.predict(X_val)
assert yhat.shape == (1912, 21)

per_pb_corrs = [np.corrcoef(yhat[:, i], Y_val[:, i])[0, 1] for i in range(21)]
print(np.round(np.mean(per_pb_corrs), 4))

## Generate predicted log counts

In [None]:
yhat = model.predict(X_test)
assert yhat.shape == (1628, 21)
print(np.mean([np.corrcoef(yhat[:, i], Y_test[:, i])[0, 1] for i in range(21)]))
print(np.nanmean([np.corrcoef(yhat[i, :], Y_test[i, :])[0, 1] for i in range(1598)]))
print("")

## Save

In [None]:
filename = 'regression_model.pkl'
with open(filename, 'wb') as file:
    pickle.dump(model, file)

np.save('decima_test_preds.npy', yhat)