In [2]:
import numpy as np
import pandas as pd
import time
import re
import matplotlib.pyplot as plt
import seaborn as sns
import sys, os
sys.path.append("../..")
from ecit import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def parse_relation_from_txt(file_path, label="CI"):
    rules = []

    def normalize(var):
        return var.strip()
    

    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if ';' not in line:
                continue

            line = re.sub(r'^\d+\.\s*', '', line)

            Z_vars = []
            Z_raw = re.findall(r'\((.*?)\)', line)
            if Z_raw:
                Z_vars = [normalize(z) for z in Z_raw[0].split(',')]
                line = re.sub(r'\(.*?\)', '', line)

            parts = [normalize(p) for p in line.split(';') if p.strip()]
            if len(parts) < 3 and Z_vars:
                X, Y = parts[0], parts[1]
            elif len(parts) == 3:
                X, Y = parts[0], parts[1]
                Z_vars = [normalize(parts[2])]
            else:
                continue

            rules.append({
                "X": X,
                "Y": Y,
                "Z": Z_vars,
                "label": label
            })

    return rules

ci_r = parse_relation_from_txt("CI.txt", label="CI")
ni_r = parse_relation_from_txt("NI.txt", label="NI")
all_r = ci_r + ni_r
all_r

[{'X': 'P38', 'Y': 'p44/42', 'Z': ['PKA', 'PKC'], 'label': 'CI'},
 {'X': 'P38', 'Y': 'pakts473', 'Z': ['PKA', 'PKC'], 'label': 'CI'},
 {'X': 'P38', 'Y': 'PIP2', 'Z': ['PKA', 'PKC'], 'label': 'CI'},
 {'X': 'P38', 'Y': 'PIP3', 'Z': ['PKA', 'PKC'], 'label': 'CI'},
 {'X': 'P38', 'Y': 'pjnk', 'Z': ['PKA', 'PKC'], 'label': 'CI'},
 {'X': 'P38', 'Y': 'plcg', 'Z': ['PKA', 'PKC'], 'label': 'CI'},
 {'X': 'P38', 'Y': 'pmek', 'Z': ['PKA', 'PKC'], 'label': 'CI'},
 {'X': 'P38', 'Y': 'praf', 'Z': ['PKA', 'PKC'], 'label': 'CI'},
 {'X': 'p44/42', 'Y': 'P38', 'Z': ['PKA', 'pmek'], 'label': 'CI'},
 {'X': 'p44/42', 'Y': 'PIP3', 'Z': ['PKA', 'pmek'], 'label': 'CI'},
 {'X': 'p44/42', 'Y': 'pjnk', 'Z': ['PKA', 'pmek'], 'label': 'CI'},
 {'X': 'p44/42', 'Y': 'PKC', 'Z': ['PKA', 'pmek'], 'label': 'CI'},
 {'X': 'p44/42', 'Y': 'plcg', 'Z': ['PKA', 'pmek'], 'label': 'CI'},
 {'X': 'p44/42', 'Y': 'praf', 'Z': ['PKA', 'pmek'], 'label': 'CI'},
 {'X': 'pakts473', 'Y': 'P38', 'Z': ['PKA', 'PIP3'], 'label': 'CI'},
 {'X': 

In [4]:
df1 = pd.read_excel("Data Files/1. cd3cd28.xls")
df2 = pd.read_excel("Data Files/2. cd3cd28icam2.xls")
df = pd.concat([df1, df2], axis=0, ignore_index=True)
df

Unnamed: 0,praf,pmek,plcg,PIP2,PIP3,p44/42,pakts473,PKA,PKC,P38,pjnk
0,26.4,13.2,8.82,18.30,58.80,6.61,17.0,414.0,17.00,44.9,40.0
1,35.9,16.5,12.30,16.80,8.13,18.60,32.5,352.0,3.37,16.5,61.5
2,59.4,44.1,14.60,10.20,13.00,14.90,32.5,403.0,11.40,31.9,19.5
3,73.0,82.8,23.10,13.50,1.29,5.83,11.8,528.0,13.70,28.6,23.1
4,33.7,19.8,5.19,9.73,24.80,21.10,46.1,305.0,4.66,25.7,81.3
...,...,...,...,...,...,...,...,...,...,...,...
1750,79.9,53.8,63.20,403.00,53.80,8.98,15.3,1027.0,26.40,57.3,92.2
1751,44.1,30.8,23.50,22.50,1.10,101.00,164.0,1459.0,11.80,25.7,32.2
1752,50.9,78.4,78.40,279.00,47.00,20.90,27.9,470.0,10.90,20.2,7.1
1753,126.0,83.5,21.30,24.10,10.60,75.00,165.0,3619.0,14.50,38.5,7.3


In [5]:
from tqdm import tqdm

def simu(cit, p_ensemble_list, k, alpha=0.05, show=True):
    TP = np.array([0]*len(p_ensemble_list))
    TN = np.array([0]*len(p_ensemble_list))
    FP = np.array([0]*len(p_ensemble_list))
    FN = np.array([0]*len(p_ensemble_list))

    for rule in tqdm(all_r, disable=not show):
        label = rule["label"]
        X = df[[rule["X"]]].to_numpy()
        Y = df[[rule["Y"]]].to_numpy()
        Z = df[rule["Z"]].to_numpy()
        dz = Z.shape[1]
        obj_ECIT = ECIT(np.hstack((X,Y,Z)), cit, p_ensemble_list, k)
        try:
            ps = obj_ECIT([0], [1], list(range(2, dz + 2)))
        except Exception as e:
            if show:
                print(f"First attempt failed on rule {rule}, retrying... Error: {e}")
            try:
                ps = obj_ECIT([0], [1], list(range(2, dz + 2)))
            except Exception as e:
                if show:
                    print(f"Second attempt failed, skipping rule. Error: {e}")
                continue
        ps = np.array(ps)
        if label == "CI":
            TP += ps > alpha
            FP += ps <= alpha
        else:
            TN += ps <= alpha
            FN += ps > alpha
    pre = TP / (TP + FP )
    rec = TP / (TP + FN )
    f1 = 2 * pre * rec / (pre + rec)

    results = np.array([TP, TN, FP, FN, pre, rec, f1])
    results = results.T
    if show:
        print(results)
    return results

In [6]:
def run_simu(cit_list, ens_list, t=10, alpha=0.05):
    results = {}
    for cit in cit_list:
        table = []
        for k, p_ensemble in ens_list:
            ens = np.zeros((len(p_ensemble), 7))
            ti = 1 if k==1 and cit.__name__!='lpcit' else t
            if cit.__name__ == "rcit": ti = 100
            for _ in tqdm(range(ti), desc=cit.__name__+str(k)):
                ens += simu(cit, p_ensemble, k, alpha, show=False)
            ens = ens / ti
            for en in ens:
                table.append(list(en))
        results[cit.__name__] = table
    return results

In [7]:
import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.simplefilter("ignore", category=ConvergenceWarning)

np.random.seed(1)

cit_list = [rcit, kcit, lpcit, cmiknn, fisherz]
ens_list = [(1, [p_alpha2]), (5, [p_alpha175, p_alpha2])]
results = run_simu(cit_list, ens_list)
results

rcit1: 100%|██████████| 100/100 [02:43<00:00,  1.64s/it]
rcit5: 100%|██████████| 100/100 [14:48<00:00,  8.88s/it]
kcit1: 100%|██████████| 1/1 [05:01<00:00, 301.81s/it]
kcit5: 100%|██████████| 10/10 [07:19<00:00, 43.92s/it]
lpcit1: 100%|██████████| 10/10 [25:06<00:00, 150.68s/it]
lpcit5: 100%|██████████| 10/10 [1:06:49<00:00, 400.92s/it]
cmiknn1: 100%|██████████| 1/1 [35:28<00:00, 2128.39s/it]
cmiknn5: 100%|██████████| 10/10 [4:57:58<00:00, 1787.87s/it] 
fisherz1: 100%|██████████| 1/1 [00:00<00:00,  5.77it/s]
fisherz5: 100%|██████████| 10/10 [00:34<00:00,  3.42s/it]


{'rcit': [[34.21,
   31.33,
   15.79,
   18.67,
   0.6842,
   0.6469800237334105,
   0.664571678079773],
  [35.74,
   31.74,
   14.26,
   18.26,
   0.7147999999999999,
   0.6619087321663716,
   0.6867132917032802],
  [35.74,
   31.74,
   14.26,
   18.26,
   0.7147999999999999,
   0.6619087321663716,
   0.6867132917032802]],
 'kcit': [[31.0,
   34.0,
   19.0,
   16.0,
   0.62,
   0.6595744680851063,
   0.6391752577319586],
  [34.9,
   32.9,
   15.1,
   17.1,
   0.6980000000000001,
   0.6711767358634144,
   0.683394114236602],
  [34.9,
   32.9,
   15.1,
   17.1,
   0.6980000000000001,
   0.6711767358634144,
   0.683394114236602]],
 'lpcit': [[37.0,
   30.0,
   13.0,
   20.0,
   0.74,
   0.6487761971395491,
   0.6908585695495552],
  [41.9,
   28.8,
   8.1,
   21.2,
   0.8380000000000001,
   0.6639659318185448,
   0.7407517093061128],
  [41.9,
   28.8,
   8.1,
   21.2,
   0.8380000000000001,
   0.6639659318185448,
   0.7407517093061128]],
 'cmiknn': [[44.0,
   31.0,
   6.0,
   19.0,
   0.8