In [1]:
from pathlib import Path
import pandas as pd
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import duckdb
import re
import matplotlib.pyplot as plt

In [3]:
DATA_ROOT = Path('../../igvf-pm/K562/leave-one-out')
PRED_CSV_DIR = DATA_ROOT / 'cross-validation/slurm-normalized-mse/outputs'
DB_BASE = Path('../../data')
DB = DB_BASE / "K562db"
CHROM = 'chr22'

In [5]:
preds = pd.read_csv(PRED_CSV_DIR / (CHROM + '.csv'), sep=',', index_col=False, header=0)

In [None]:
preds

In [None]:
fig, ax = plt.subplots()
fig.set_size_inches((10, 6))
ax.scatter(np.log2(preds['true']), np.log2(preds['predicted']),
           marker=".", alpha=0.4, linewidths=0)
ax.set_xlim(-1.5, 2.5)
ax.set_ylim(ax.get_xlim())
ax.set_xlabel("true")
ax.set_ylabel("predicted")
plt.show()

In [8]:
starrdb = duckdb.read_parquet(f"{DB}/**/*.parquet", hive_partitioning = True)

In [9]:
predsWithCoords = duckdb.sql("SELECT s.chrom, s.start, s.end, p.true, p.predicted "
                             "FROM preds AS p JOIN starrdb AS s on (p.index = s.index) "
                             "ORDER BY s.chrom, s.start").df()

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [None]:
predsWithCoords

In [11]:
def plot_tracks(predsTable,
                interval: (str, int, int),
                ylim=(-1.2, 1.5),
                convFunc=lambda x: np.log2(x),
                ylabel="log2FC",
                value_cols={'true': 'true', 'predicted': 'predicted'}):
    (chrom, start, end) = interval
    preds_interval = duckdb.sql("select * from predsTable " +
                                f"where chrom = '{chrom}' and " +
                                f"start >= {start} and \"end\" <= {end}").df()
    fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True)
    fig.subplots_adjust(hspace=0)
    axs[0].plot(preds_interval['start'], convFunc(preds_interval[value_cols['true']]))
    axs[0].set_ylim(ylim)
    axs[0].set_ylabel(value_cols['true'])
    axs[1].plot(preds_interval['start'], convFunc(preds_interval[value_cols['predicted']]))
    axs[1].set_ylabel(value_cols['predicted'])
    axs[1].set_xlabel("Chromosome position")
    fig.supylabel(ylabel)
    return fig, axs

In [None]:
interval = (CHROM, 35_300_000, 35_600_000)

plot_tracks(predsWithCoords, interval)
plt.show()