In [45]:
import os
import pickle
import itertools
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import numpy as np
import pandas as pd
import igraph as ig
import networkx as nx

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from knots_tools import HumanChromosomeDtype, DatasetDtype, get_loop_pair_infos
import datasources
import modvis

# Lookup CCD stats

In [None]:
ccds = pd.read_csv('data/all_ccds.csv', dtype={'chromosome': HumanChromosomeDtype, 'dataset': DatasetDtype})
ccds = ccds.set_index(['dataset', 'chromosome', 'ccd_id'])
ccds

In [None]:
_cols = ['n_edges', 'n_nodes', 'length']
ccds.loc['GM12878', 'chr7', 100][_cols]

In [None]:
ccds.loc['GM12878', 'chr7'][_cols].sort_values(by=_cols).query('n_edges > 300 & n_edges < 400')

# Monte Carlo modeling

In [None]:
def read_batch_model_from_pickle(file):
    with open(file, 'rb') as f:
        pkl = pickle.load(f)
    args, result, exc, n_restarts = pkl
    i_rep, to_keep, restraint_set, path = args
    if exc is not None:
        bead_coords, used_restraints, raw_coords, init_str_points = None, None, None, None
    else:
        bead_coords, used_restraints, raw_coords, init_str_points = result
    return pd.Series({
        'used_restraints': used_restraints,
        'raw_coords': raw_coords,
        'kept_idx': to_keep,
        'exc': exc
    })

batch_models = datasources.DataSources(
    './batch', ['dataset', 'chromosome', 'ccd_id', 'group', 'n_restraints', 'run_idx']
).add(
    'model_file',
    'ccd_graph_(?P<dataset>\\w+)_(?P<chromosome>\\w+)_(?P<ccd_id>\\d+)_(?P<group>\\w+)_k(?P<n_restraints>\\d+)_(?P<run_idx>\\d+)\\.pkl$',
    parsers={'ccd_id': int, 'n_restraints': int, 'run_idx': int}
).get_paths_as_dataframe().reset_index()
batch_models['chromosome'] = batch_models['chromosome'].astype(HumanChromosomeDtype)
batch_models['dataset'] = batch_models['dataset'].astype(DatasetDtype)
batch_models  = pd.concat([
    batch_models, batch_models['model_file'].apply(read_batch_model_from_pickle)
], axis=1)
batch_models['has_exc'] = ~batch_models['exc'].isnull()
batch_models.info()
batch_models.head()

In [None]:
error_counts = batch_models.groupby(['dataset', 'chromosome', 'ccd_id', 'run_idx'], observed=True).has_exc.sum()
error_counts

In [None]:
errored_runs = error_counts[np.where(error_counts > 0)[0]].reset_index()
errored_runs

In [None]:
ok_models = batch_models[(batch_models.n_restraints >= 2) & ~batch_models.run_idx.isin(errored_runs.run_idx)]
ok_models = ok_models.drop(columns=['exc', 'model_file'])
ok_models

def _summarize_linking(idx):    
    s = ok_models.loc[idx]
    assert isinstance(s, pd.Series)
    infos = get_loop_pair_infos(s['raw_coords'], s['used_restraints'])    
    return {
        'idx': idx,
        'n_loop_pairs': len(infos),
        'n_disjoint_pairs': (infos['overlap'] == 0).sum(),
        'max_abs_linking': infos['abs_linking_number'].max(),
        'sum_abs_linking': infos['abs_linking_number'].sum(),
        'n_linked_pairs': (infos['abs_linking_number'] > 0).sum(),        
        'n_linked_disjoint_pairs': ((infos['abs_linking_number'] > 0) & (infos['overlap'] == 0)).sum()
    }

def calculate_linking_in_batch_data(n_workers=None):
    with ProcessPoolExecutor(max_workers=n_workers) as executor:
        futures = [
            executor.submit(_summarize_linking, idx)
            for idx in ok_models.index
        ]
        res = {}
        for r in tqdm(as_completed(futures), total=len(futures)):
            data = r.result()
            res[data['idx']] = data

    df = pd.DataFrame.from_dict(res, orient='index')
    return df.set_index('idx').sort_index()


ok_models = pd.concat([ok_models, calculate_linking_in_batch_data()], axis=1)
ok_models['is_linked'] = ok_models['max_abs_linking'] > 0
ok_models.info()
ok_models.head()

In [None]:
hue_order=['nolink', 'control', 'minor']
_pal = sns.color_palette('tab10', 3)
palette = [_pal[i] for i in [2, 0, 1]]

sns.countplot(
    data=ok_models[ok_models.is_linked], x='n_restraints', hue='group',
    hue_order=hue_order, palette=palette
);

In [None]:
ok_models.groupby(['n_restraints', 'group']).is_linked.value_counts().unstack(fill_value=0)

In [None]:
sns.boxplot(data=ok_models, x='n_restraints', y='n_linked_pairs', hue='group', hue_order=hue_order, palette=palette);

In [None]:
sns.boxplot(data=ok_models, x='n_restraints', y='sum_abs_linking', hue='group', hue_order=hue_order, palette=palette);

In [None]:
sns.boxplot(data=ok_models, x='n_restraints', y='max_abs_linking', hue='group', hue_order=hue_order, palette=palette);