# EUGENe DeepSTARR model evaluation
Adam Klie (last updated: *09/20/2023*)
***
Notebook for evaluating a DeepSTARR model with EUGENe

# Set-up

In [None]:
import torch
from eugene import models
from eugene import settings
import seqdatasets
from eugene import preprocess as pp

# plot the predictions in a scatter plot
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import gaussian_kde
import numpy as np

# Add metrics to the plots
from sklearn.metrics import r2_score
from scipy.stats import pearsonr, spearmanr

settings.dataset_dir = "/cellar/users/aklie/data/eugene"

# Load the model

In [None]:
model = models.DeepSTARR.load_from_checkpoint("/cellar/users/aklie/projects/ML4GLand/models/DeepSTARR/eugene/DeepSTARR.ckpt")

# Load the test data

In [None]:
sdata_test = seqdatasets.deAlmeida22("test")
pp.ohe_seqs_sdata(sdata_test)

# Get predictions

In [None]:
preds = model.predict(sdata_test.ohe_seqs, batch_size=128).detach().numpy()
res_df = sdata_test.seqs_annot
res_df['Dev_log2_enrichment_scaled_pred'] = preds[:, 0]
res_df['Hk_log2_enrichment_scaled_pred'] = preds[:, 1]
res_df.head()

# Make plot

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

# Get point densities
x = res_df['Dev_log2_enrichment_scaled']
y = res_df['Dev_log2_enrichment_scaled_pred']
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)

# Sort the points by density, so that the densest points are plotted last
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]

ax[0].scatter(x, y, c=z)

# Get point densities
x = res_df['Hk_log2_enrichment_scaled']
y = res_df['Hk_log2_enrichment_scaled_pred']
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)

# Sort the points by density, so that the densest points are plotted last
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]

ax[1].scatter(x, y, c=z)

r2_dev = r2_score(res_df['Dev_log2_enrichment_scaled'], res_df['Dev_log2_enrichment_scaled_pred'])
r2_hk = r2_score(res_df['Hk_log2_enrichment_scaled'], res_df['Hk_log2_enrichment_scaled_pred'])

pearson_dev = pearsonr(res_df['Dev_log2_enrichment_scaled'], res_df['Dev_log2_enrichment_scaled_pred'])
pearson_hk = pearsonr(res_df['Hk_log2_enrichment_scaled'], res_df['Hk_log2_enrichment_scaled_pred'])

spearman_dev = spearmanr(res_df['Dev_log2_enrichment_scaled'], res_df['Dev_log2_enrichment_scaled_pred'])
spearman_hk = spearmanr(res_df['Hk_log2_enrichment_scaled'], res_df['Hk_log2_enrichment_scaled_pred'])

ax[0].set_title(f"Dev R2: {r2_dev:.2f}\nPearson: {pearson_dev[0]:.2f}\nSpearman: {spearman_dev[0]:.2f}")
ax[1].set_title(f"Hk R2: {r2_hk:.2f}\nPearson: {pearson_hk[0]:.2f}\nSpearman: {spearman_hk[0]:.2f}")

# Add a diagonal line to the plots
ax[0].plot(ax[0].get_xlim(), ax[0].get_ylim(), ls="--", c=".3")
ax[1].plot(ax[1].get_xlim(), ax[1].get_ylim(), ls="--", c=".3")

plt.tight_layout()
plt.show()

# DONE!

---